Source code for kaldi.decoder._decoder

from .. import fstext as _fst
from .. import lat as _lat

from ._faster_decoder import *
from ._biglm_faster_decoder import *
from ._lattice_faster_decoder import *
from ._lattice_faster_decoder_ext import *
from ._lattice_biglm_faster_decoder import *
from ._lattice_faster_online_decoder import *
from ._lattice_faster_online_decoder_ext import *


class _DecoderBase(object):
    """Base class defining the Python API for decoders."""

    def get_best_path(self, use_final_probs=True):
        """Gets best path as a lattice.

        Args:
            use_final_probs (bool): If ``True`` and a final state of the graph
                is reached, then the output will include final probabilities
                given by the graph. Otherwise all final probabilities are
                treated as one.

        Returns:
            LatticeVectorFst: The best path.

        Raises:
            RuntimeError: In the unusual circumstances where no tokens survive.
        """
        ofst = _fst.LatticeVectorFst()
        success = self._get_best_path(ofst, use_final_probs)
        if not success:
            raise RuntimeError("Decoding failed. No tokens survived.")
        return ofst


class _LatticeDecoderBase(_DecoderBase):
    """Base class defining the Python API for lattice generating decoders."""

    def get_raw_lattice(self, use_final_probs=True):
        """Gets raw state-level lattice.

        The output raw lattice will be topologically sorted.

        Args:
            use_final_probs (bool): If ``True`` and a final state of the graph
                is reached, then the output will include final probabilities
                given by the graph. Otherwise all final probabilities are
                treated as one.

        Returns:
            LatticeVectorFst: The state-level lattice.

        Raises:
            RuntimeError: In the unusual circumstances where no tokens survive.
        """
        ofst = _fst.LatticeVectorFst()
        success = self._get_raw_lattice(ofst, use_final_probs)
        if not success:
            raise RuntimeError("Decoding failed. No tokens survived.")
        return ofst

    def get_lattice(self, use_final_probs=True):
        """Gets the lattice-determinized compact lattice.

        The output is a deterministic compact lattice with a unique path for
        each word sequence.

        Args:
            use_final_probs (bool): If ``True`` and a final state of the graph
                is reached, then the output will include final probabilities
                given by the graph. Otherwise all final probabilities are
                treated as one.

        Returns:
            CompactLatticeVectorFst: The lattice-determinized compact lattice.

        Raises:
            RuntimeError: In the unusual circumstances where no tokens survive.
        """
        ofst = _fst.CompactLatticeVectorFst()
        success = self._get_lattice(ofst, use_final_probs)
        if not success:
            raise RuntimeError("Decoding failed. No tokens survived.")
        return ofst


class _LatticeOnlineDecoderBase(_LatticeDecoderBase):
    """Base class defining the Python API for lattice generating online decoders."""

    def get_raw_lattice_pruned(self, beam, use_final_probs=True):
        """Prunes and returns raw state-level lattice.

        Behaves like :meth:`get_raw_lattice` but only processes tokens whose
        extra-cost is smaller than the best-cost plus the specified beam. It is
        worthwhile to call this function only if :attr:`beam` is less than the
        lattice-beam specified in the decoder options. Otherwise, it returns
        essentially the same thing as :meth:`get_raw_lattice`, but more slowly.

        The output raw lattice will be topologically sorted.

        Args:
            beam (float): Pruning beam.
            use_final_probs (bool): If ``True`` and a final state of the graph
                is reached, then the output will include final probabilities
                given by the graph. Otherwise all final probabilities are
                treated as one.

        Returns:
            LatticeVectorFst: The state-level lattice.

        Raises:
            RuntimeError: In the unusual circumstances where no tokens survive.
        """
        ofst = _fst.LatticeVectorFst()
        success = self._get_raw_lattice_pruned(ofst, use_final_probs, beam)
        if not success:
            raise RuntimeError("Decoding failed. No tokens survived.")
        return ofst


[docs]class FasterDecoder(_DecoderBase, FasterDecoder): """Faster decoder. Args: fst (StdFst): Decoding graph `HCLG`. opts (FasterDecoderOptions): Decoder options. """ def __init__(self, fst, opts): super(FasterDecoder, self).__init__(fst, opts) self._fst = fst # keep a reference to FST to keep it in scope
[docs]class BiglmFasterDecoder(_DecoderBase, BiglmFasterDecoder): """Faster decoder for decoding with big language models. This is as :class:`LatticeFasterDecoder`, but does online composition between decoding graph :attr:`fst` and the difference language model :attr:`lm_diff_fst`. Args: fst (StdFst): Decoding graph. opts (BiglmFasterDecoderOptions): Decoder options. lm_diff_fst (StdDeterministicOnDemandFst): The deterministic on-demand FST representing the difference in scores between the LM to decode with and the LM the decoding graph :attr:`fst` was compiled with. """ def __init__(self, fst, opts, lm_diff_fst): super(BiglmFasterDecoder, self).__init__(fst, opts, lm_diff_fst) self._fst = fst # keep references to FSTs self._lm_diff_fst = lm_diff_fst # to keep them in scope
[docs]class LatticeFasterDecoder(_LatticeDecoderBase, LatticeFasterDecoder): """Lattice generating faster decoder. Args: fst (StdFst): Decoding graph `HCLG`. opts (LatticeFasterDecoderOptions): Decoder options. """ def __init__(self, fst, opts): super(LatticeFasterDecoder, self).__init__(fst, opts) self._fst = fst # keep a reference to FST to keep it in scope
[docs]class LatticeFasterGrammarDecoder(_LatticeDecoderBase, LatticeFasterGrammarDecoder): """Lattice generating faster grammar decoder. Args: fst (GrammarFst): Decoding graph `HCLG`. opts (LatticeFasterDecoderOptions): Decoder options. """ def __init__(self, fst, opts): super(LatticeFasterGrammarDecoder, self).__init__(fst, opts) self._fst = fst # keep a reference to FST to keep it in scope
[docs]class LatticeBiglmFasterDecoder(_LatticeDecoderBase, LatticeBiglmFasterDecoder): """Lattice generating faster decoder for decoding with big language models. This is as :class:`LatticeFasterDecoder`, but does online composition between decoding graph :attr:`fst` and the difference language model :attr:`lm_diff_fst`. Args: fst (StdFst): Decoding graph `HCLG`. opts (LatticeFasterDecoderOptions): Decoder options. lm_diff_fst (StdDeterministicOnDemandFst): The deterministic on-demand FST representing the difference in scores between the LM to decode with and the LM the decoding graph :attr:`fst` was compiled with. """ def __init__(self, fst, opts, lm_diff_fst): super(LatticeBiglmFasterDecoder, self).__init__(fst, opts, lm_diff_fst) self._fst = fst # keep references to FSTs self._lm_diff_fst = lm_diff_fst # to keep them in scope
[docs]class LatticeFasterOnlineDecoder(_LatticeOnlineDecoderBase, LatticeFasterOnlineDecoder): """Lattice generating faster online decoder. Similar to :class:`LatticeFasterDecoder` but computes the best path without generating the entire raw lattice and finding the best path through it. Instead, it traces back through the lattice. Args: fst (StdFst): Decoding graph `HCLG`. opts (LatticeFasterDecoderOptions): Decoder options. """ def __init__(self, fst, opts): super(LatticeFasterOnlineDecoder, self).__init__(fst, opts) self._fst = fst # keep a reference to FST to keep it in scope # This method is missing from the C++ class so we implement it here. def _get_lattice(self, use_final_probs=True): raw_fst = self.get_raw_lattice(use_final_probs).invert().arcsort() lat_opts = _lat.DeterminizeLatticePrunedOptions() config = self.get_options() lat_opts.max_mem = config.det_opts.max_mem ofst = _fst.CompactLatticeVectorFst() _lat.determinize_lattice_pruned(raw_fst, config.lattice_beam, ofst, lat_opts) ofst.connect() if ofst.num_states() == 0: raise RuntimeError("Decoding failed. No tokens survived.") return ofst
[docs]class LatticeFasterOnlineGrammarDecoder(_LatticeOnlineDecoderBase, LatticeFasterOnlineGrammarDecoder): """Lattice generating faster online grammar decoder. Similar to :class:`LatticeFasterGrammarDecoder` but computes the best path without generating the entire raw lattice and finding the best path through it. Instead, it traces back through the lattice. Args: fst (GrammarFst): Decoding graph `HCLG`. opts (LatticeFasterDecoderOptions): Decoder options. """ def __init__(self, fst, opts): super(LatticeFasterOnlineGrammarDecoder, self).__init__(fst, opts) self._fst = fst # keep a reference to FST to keep it in scope # This method is missing from the C++ class so we implement it here. def _get_lattice(self, use_final_probs=True): raw_fst = self.get_raw_lattice(use_final_probs).invert().arcsort() lat_opts = _lat.DeterminizeLatticePrunedOptions() config = self.get_options() lat_opts.max_mem = config.det_opts.max_mem ofst = _fst.CompactLatticeVectorFst() _lat.determinize_lattice_pruned(raw_fst, config.lattice_beam, ofst, lat_opts) ofst.connect() if ofst.num_states() == 0: raise RuntimeError("Decoding failed. No tokens survived.") return ofst
__all__ = [name for name in dir() if name[0] != '_' and not name.endswith('Base')]