Source code for kaldi.decoder

from ._grammar_fst import *
from ._decodable_matrix import *
from ._decodable_mapped import *
from ._decodable_sum import *
from ._faster_decoder import *
from ._biglm_faster_decoder import *
from ._lattice_faster_decoder import *
from ._lattice_faster_decoder_ext import *
from ._lattice_biglm_faster_decoder import *
from ._lattice_faster_online_decoder import *
from ._lattice_faster_online_decoder_ext import *
from ._training_graph_compiler import *
from ._training_graph_compiler_ext import *
from .. import fstext as _fst
from .. import lat as _lat


class _DecoderBase(object):
    """Base class defining the Python API for decoders."""

    def get_best_path(self, use_final_probs=True):
        """Gets best path as a lattice.

        Args:
            use_final_probs (bool): If ``True`` and a final state of the graph
                is reached, then the output will include final probabilities
                given by the graph. Otherwise all final probabilities are
                treated as one.

        Returns:
            LatticeVectorFst: The best path.

        Raises:
            RuntimeError: In the unusual circumstances where no tokens survive.
        """
        ofst = _fst.LatticeVectorFst()
        success = self._get_best_path(ofst, use_final_probs)
        if not success:
            raise RuntimeError("Decoding failed. No tokens survived.")
        return ofst


class _LatticeDecoderBase(_DecoderBase):
    """Base class defining the Python API for lattice generating decoders."""

    def get_raw_lattice(self, use_final_probs=True):
        """Gets raw state-level lattice.

        The output raw lattice will be topologically sorted.

        Args:
            use_final_probs (bool): If ``True`` and a final state of the graph
                is reached, then the output will include final probabilities
                given by the graph. Otherwise all final probabilities are
                treated as one.

        Returns:
            LatticeVectorFst: The state-level lattice.

        Raises:
            RuntimeError: In the unusual circumstances where no tokens survive.
        """
        ofst = _fst.LatticeVectorFst()
        success = self._get_raw_lattice(ofst, use_final_probs)
        if not success:
            raise RuntimeError("Decoding failed. No tokens survived.")
        return ofst

    def get_lattice(self, use_final_probs=True):
        """Gets the lattice-determinized compact lattice.

        The output is a deterministic compact lattice with a unique path for
        each word sequence.

        Args:
            use_final_probs (bool): If ``True`` and a final state of the graph
                is reached, then the output will include final probabilities
                given by the graph. Otherwise all final probabilities are
                treated as one.

        Returns:
            CompactLatticeVectorFst: The lattice-determinized compact lattice.

        Raises:
            RuntimeError: In the unusual circumstances where no tokens survive.
        """
        ofst = _fst.CompactLatticeVectorFst()
        success = self._get_lattice(ofst, use_final_probs)
        if not success:
            raise RuntimeError("Decoding failed. No tokens survived.")
        return ofst


class _LatticeOnlineDecoderBase(_LatticeDecoderBase):
    """Base class defining the Python API for lattice generating online decoders."""

    def get_raw_lattice_pruned(self, beam, use_final_probs=True):
        """Prunes and returns raw state-level lattice.

        Behaves like :meth:`get_raw_lattice` but only processes tokens whose
        extra-cost is smaller than the best-cost plus the specified beam. It is
        worthwhile to call this function only if :attr:`beam` is less than the
        lattice-beam specified in the decoder options. Otherwise, it returns
        essentially the same thing as :meth:`get_raw_lattice`, but more slowly.

        The output raw lattice will be topologically sorted.

        Args:
            beam (float): Pruning beam.
            use_final_probs (bool): If ``True`` and a final state of the graph
                is reached, then the output will include final probabilities
                given by the graph. Otherwise all final probabilities are
                treated as one.

        Returns:
            LatticeVectorFst: The state-level lattice.

        Raises:
            RuntimeError: In the unusual circumstances where no tokens survive.
        """
        ofst = _fst.LatticeVectorFst()
        success = self._get_raw_lattice_pruned(ofst, use_final_probs, beam)
        if not success:
            raise RuntimeError("Decoding failed. No tokens survived.")
        return ofst


[docs]class FasterDecoder(_DecoderBase, _faster_decoder.FasterDecoder): """Faster decoder. Args: fst (StdFst): Decoding graph `HCLG`. opts (FasterDecoderOptions): Decoder options. """ def __init__(self, fst, opts): super(FasterDecoder, self).__init__(fst, opts) self._fst = fst # keep a reference to FST to keep it in scope
[docs]class BiglmFasterDecoder(_DecoderBase, _biglm_faster_decoder.BiglmFasterDecoder): """Faster decoder for decoding with big language models. This is as :class:`LatticeFasterDecoder`, but does online composition between decoding graph :attr:`fst` and the difference language model :attr:`lm_diff_fst`. Args: fst (StdFst): Decoding graph. opts (BiglmFasterDecoderOptions): Decoder options. lm_diff_fst (StdDeterministicOnDemandFst): The deterministic on-demand FST representing the difference in scores between the LM to decode with and the LM the decoding graph :attr:`fst` was compiled with. """ def __init__(self, fst, opts, lm_diff_fst): super(BiglmFasterDecoder, self).__init__(fst, opts, lm_diff_fst) self._fst = fst # keep references to FSTs self._lm_diff_fst = lm_diff_fst # to keep them in scope
[docs]class LatticeFasterDecoder(_LatticeDecoderBase, _lattice_faster_decoder.LatticeFasterDecoder): """Lattice generating faster decoder. Args: fst (StdFst): Decoding graph `HCLG`. opts (LatticeFasterDecoderOptions): Decoder options. """ def __init__(self, fst, opts): super(LatticeFasterDecoder, self).__init__(fst, opts) self._fst = fst # keep a reference to FST to keep it in scope
[docs]class LatticeFasterGrammarDecoder( _LatticeDecoderBase, _lattice_faster_decoder_ext.LatticeFasterGrammarDecoder): """Lattice generating faster grammar decoder. Args: fst (GrammarFst): Decoding graph `HCLG`. opts (LatticeFasterDecoderOptions): Decoder options. """ def __init__(self, fst, opts): super(LatticeFasterGrammarDecoder, self).__init__(fst, opts) self._fst = fst # keep a reference to FST to keep it in scope
[docs]class LatticeBiglmFasterDecoder( _LatticeDecoderBase, _lattice_biglm_faster_decoder.LatticeBiglmFasterDecoder): """Lattice generating faster decoder for decoding with big language models. This is as :class:`LatticeFasterDecoder`, but does online composition between decoding graph :attr:`fst` and the difference language model :attr:`lm_diff_fst`. Args: fst (StdFst): Decoding graph `HCLG`. opts (LatticeFasterDecoderOptions): Decoder options. lm_diff_fst (StdDeterministicOnDemandFst): The deterministic on-demand FST representing the difference in scores between the LM to decode with and the LM the decoding graph :attr:`fst` was compiled with. """ def __init__(self, fst, opts, lm_diff_fst): super(LatticeBiglmFasterDecoder, self).__init__(fst, opts, lm_diff_fst) self._fst = fst # keep references to FSTs self._lm_diff_fst = lm_diff_fst # to keep them in scope
[docs]class LatticeFasterOnlineDecoder( _LatticeOnlineDecoderBase, _lattice_faster_online_decoder.LatticeFasterOnlineDecoder): """Lattice generating faster online decoder. Similar to :class:`LatticeFasterDecoder` but computes the best path without generating the entire raw lattice and finding the best path through it. Instead, it traces back through the lattice. Args: fst (StdFst): Decoding graph `HCLG`. opts (LatticeFasterDecoderOptions): Decoder options. """ def __init__(self, fst, opts): super(LatticeFasterOnlineDecoder, self).__init__(fst, opts) self._fst = fst # keep a reference to FST to keep it in scope # This method is missing from the C++ class so we implement it here. def _get_lattice(self, use_final_probs=True): raw_fst = self.get_raw_lattice(use_final_probs).invert().arcsort() lat_opts = _lat.DeterminizeLatticePrunedOptions() config = self.get_options() lat_opts.max_mem = config.det_opts.max_mem ofst = _fst.CompactLatticeVectorFst() _lat.determinize_lattice_pruned(raw_fst, config.lattice_beam, ofst, lat_opts) ofst.connect() if ofst.num_states() == 0: raise RuntimeError("Decoding failed. No tokens survived.") return ofst
[docs]class LatticeFasterOnlineGrammarDecoder( _LatticeOnlineDecoderBase, _lattice_faster_online_decoder_ext.LatticeFasterOnlineGrammarDecoder): """Lattice generating faster online grammar decoder. Similar to :class:`LatticeFasterGrammarDecoder` but computes the best path without generating the entire raw lattice and finding the best path through it. Instead, it traces back through the lattice. Args: fst (GrammarFst): Decoding graph `HCLG`. opts (LatticeFasterDecoderOptions): Decoder options. """ def __init__(self, fst, opts): super(LatticeFasterOnlineGrammarDecoder, self).__init__(fst, opts) self._fst = fst # keep a reference to FST to keep it in scope # This method is missing from the C++ class so we implement it here. def _get_lattice(self, use_final_probs=True): raw_fst = self.get_raw_lattice(use_final_probs).invert().arcsort() lat_opts = _lat.DeterminizeLatticePrunedOptions() config = self.get_options() lat_opts.max_mem = config.det_opts.max_mem ofst = _fst.CompactLatticeVectorFst() _lat.determinize_lattice_pruned(raw_fst, config.lattice_beam, ofst, lat_opts) ofst.connect() if ofst.num_states() == 0: raise RuntimeError("Decoding failed. No tokens survived.") return ofst
[docs]class TrainingGraphCompiler(_training_graph_compiler_ext.TrainingGraphCompiler): """Training graph compiler.""" def __init__(self, trans_model, ctx_dep, lex_fst, disambig_syms, opts): """ Args: trans_model (TransitionModel): Transition model `H`. ctx_dep (ContextDependency): Context dependency model `C`. lex_fst (StdVectorFst): Lexicon `L`. disambig_syms (List[int]): Disambiguation symbols. opts (TrainingGraphCompilerOptions): Compiler options. """ super(TrainingGraphCompiler, self).__init__( trans_model, ctx_dep, lex_fst, disambig_syms, opts) # keep references to these objects to keep them in scope self._trans_model = trans_model self._ctx_dep = ctx_dep self._lex_fst = lex_fst
[docs] def compile_graph(self, word_fst): """Compiles a single training graph from a weighted acceptor. Args: word_fst (StdVectorFst): Weighted acceptor `G` at the word level. Returns: StdVectorFst: The training graph `HCLG`. """ ofst = super(TrainingGraphCompiler, self).compile_graph(word_fst) return _fst.StdVectorFst(ofst)
[docs] def compile_graphs(self, word_fsts): """Compiles training graphs from weighted acceptors. Args: word_fsts (List[StdVectorFst]): Weighted acceptors at the word level. Returns: List[StdVectorFst]: The training graphs. """ ofsts = super(TrainingGraphCompiler, self).compile_graphs(word_fsts) for i, fst in enumerate(ofsts): ofsts[i] = _fst.StdVectorFst(fst) return ofsts
[docs] def compile_graph_from_text(self, transcript): """Compiles a single training graph from a transcript. Args: transcript (List[int]): The input transcript. Returns: StdVectorFst: The training graph `HCLG`. """ ofst = super(TrainingGraphCompiler, self).compile_graph_from_text(transcript) return _fst.StdVectorFst(ofst)
[docs] def compile_graphs_from_text(self, transcripts): """Compiles training graphs from transcripts. Args: transcripts (List[List[int]]): The input transcripts. Returns: List[StdVectorFst]: The training graphs. """ ofsts = super(TrainingGraphCompiler, self).compile_graphs_from_text(transcripts) for i, fst in enumerate(ofsts): ofsts[i] = _fst.StdVectorFst(fst) return ofsts
__all__ = [name for name in dir() if name[0] != '_' and not name.endswith('Base')]