Source code for kaldi.matrix.functions


from . import _compressed_matrix
from . import _kaldi_matrix
from . import _kaldi_matrix_ext
from . import _kaldi_vector
from . import _kaldi_vector_ext
import _matrix_common # FIXME: Relative/absolute import is buggy in Python 3.
from . import _sparse_matrix
from . import _sp_matrix
from . import _tp_matrix

from ._matrix_functions import *
from ._sp_matrix import SolverOptions
from ._sp_matrix import solve_quadratic_problem
from ._sp_matrix import solve_quadratic_matrix_problem
from ._sp_matrix import solve_double_quadratic_matrix_problem


[docs]def approx_equal(a, b, tol=0.01): """Checks if given vectors (or matrices) are approximately equal. Args: a (Vector or Matrix or SpMatrix or DoubleVector or DoubleMatrix or DoubleSpMatrix): The first object. b (Vector or Matrix or SpMatrix or DoubleVector or DoubleMatrix or DoubleSpMatrix): The second object. tol (float): The tolerance for the equality check. Defaults to ``0.01``. Returns: True if input objects have the same type and size, and :math:`\\Vert a-b \\Vert \\leq \\mathrm{tol} \\times \\Vert a \\Vert`. Raises: TypeError: If the first object is not a vector or matrix instance. """ if isinstance(a, (_kaldi_vector.VectorBase, _kaldi_vector.DoubleVectorBase, _kaldi_matrix.MatrixBase, _kaldi_matrix.DoubleMatrixBase, _sp_matrix.SpMatrix, _sp_matrix.DoubleSpMatrix)): return a.approx_equal(b, tol) raise TypeError("a is not a vector or matrix instance")
[docs]def assert_equal(a, b, tol=0.01): """Asserts given vectors (or matrices) are approximately equal. Args: a (Vector or Matrix or SpMatrix or DoubleVector or DoubleMatrix or DoubleSpMatrix): The first object. b (Vector or Matrix or SpMatrix or DoubleVector or DoubleMatrix or DoubleSpMatrix): The second object. tol (float): The tolerance for the equality check. Defaults to ``0.01``. Raises: TypeError: If the first object is not a vector or matrix instance. AssertionError: If input objects do not have the same type or size, or :math:`\\Vert a-b \\Vert > \\mathrm{tol} \\times \\Vert a \\Vert`. """ assert(approx_equal(a, b, tol))
[docs]def create_eigenvalue_matrix(real, imag, D=None): """Creates the eigenvalue matrix. Eigenvalue matrix :math:`D` is part of the decomposition used in eig. :math:`D` will be block-diagonal with blocks of size 1 (for real eigenvalues) or 2x2 for complex pairs. If a complex pair is :math:`\\lambda +- i\\mu`, :math:`D` will have a corresponding 2x2 block :math:`[\\lambda, \\mu; -\\mu, \\lambda]`. This function will throw if any complex eigenvalues are not in complex conjugate pairs (or the members of such pairs are not consecutively numbered). The D you supply must has correct dimensions. Args: real (Vector or DoubleVector): The real part of the eigenvalues. imag (Vector or DoubleVector): The imaginary part of the eigenvalues. D (Matrix or DoubleMatrix or None): The output matrix. If provided, the eigenvalue matrix is written into this matrix. If ``None``, the eigenvalue matrix is returned. Defaults to ``None``. Returns: Matrix or DoubleMatrix: The eigenvalue matrix if **D** is ``None``. Raises: RuntimeError: If `real.dim != imag.dim` TypeError: If input types are not supported. """ if (isinstance(real, _kaldi_vector.VectorBase) and isinstance(imag, _kaldi_vector.VectorBase)): if D is None: D = Matrix(real.dim, real.dim) _kaldi_matrix._create_eigenvalue_matrix(real, imag, D) return D else: _kaldi_matrix._create_eigenvalue_matrix(real, imag, D) if (isinstance(real, _kaldi_vector.DoubleVectorBase) and isinstance(imag, _kaldi_vector.DoubleVectorBase)): if D is None: D = DoubleMatrix(real.dim, real.dim) _kaldi_matrix._create_eigenvalue_double_matrix(real, imag, D) return D else: _kaldi_matrix._create_eigenvalue_double_matrix(real, imag, D) raise TypeError("real and imag should be vectors with the same data type.")
[docs]def sort_svd(s, U, Vt=None, sort_on_absolute_value=True): """Sorts singular-value decomposition in-place. SVD is :math:`U\\ diag(s)\\ V^T`. This function is as generic as possible, to be applicable to other types of problems. Requires `s.dim == U.num_cols`, and sorts from greatest to least absolute value, moving the columns of **U**, and the rows of **Vt**, if provided, around in the same way. Note: The ``absolute value'' part won't matter if this is an actual SVD, since singular values are non-negative. Args: s (Vector): The singular values. U (Matrix): The :math:`U` part of SVD. Vt (Matrix): The :math:`V^T` part of SVD. Defaults to ``None``. sort_on_absolute_value (bool): How to sort **s**. If True, sort from greatest to least absolute value. Otherwise, sort from greatest to least value. Defaults to ``True``. Raises: RuntimeError: If `s.dim != U.num_cols`. TypeError: If input types are not supported. """ if (isinstance(s, _kaldi_vector.VectorBase) and isinstance(U, _kaldi_matrix.MatrixBase)): _kaldi_matrix._sort_svd(s, U, Vt, sort_on_absolute_value) if (isinstance(s, _kaldi_vector.DoubleVectorBase) and isinstance(U, _kaldi_matrix.DoubleMatrixBase)): _kaldi_matrix._sort_double_svd(s, U, Vt, sort_on_absolute_value) raise TypeError("s and U should respectively be a vector and matrix with " "matching data types.")
[docs]def filter_matrix_rows(matrix, keep_rows): """Filters matrix rows. The output is a matrix containing only the rows `r` of **in** such that `keep_rows[r] == True`. Args: matrix (Matrix or SparseMatrix or CompressedMatrix or GeneralMatrix or DoubleMatrix or DoubleSparseMatrix): The input matrix. keep_rows (List[bool]): The list that determines which rows to keep. Returns: A new matrix constructed with the rows to keep. Raises: RuntimeError: If `matrix.num_rows != keep_rows.length`. TypeError: If input matrix type is not supported. """ if isinstance(matrix, _kaldi_matrix.Matrix): return _sparse_matrix._filter_matrix_rows(matrix, keep_rows) if isinstance(matrix, _sparse_matrix.SparseMatrix): return _sparse_matrix._filter_sparse_matrix_rows(matrix, keep_rows) if isinstance(matrix, _compressed_matrix.CompressedMatrix): return _sparse_matrix._filter_compressed_matrix_rows(matrix, keep_rows) if isinstance(matrix, _sparse_matrix.GeneralMatrix): return _sparse_matrix._filter_general_matrix_rows(matrix, keep_rows) if isinstance(matrix, _kaldi_matrix.DoubleMatrix): return _sparse_matrix._filter_matrix_rows_double(matrix, keep_rows) if isinstance(matrix, _sparse_matrix.DoubleSparseMatrix): return _sparse_matrix._filter_sparse_matrix_rows_double(matrix, keep_rows) raise TypeError("input matrix type is not supported.")
[docs]def vec_vec(v1, v2): """Returns the dot product of vectors. Args: v1 (Vector or DoubleVector): The first vector. v2 (Vector or DoubleVector or SparseVector or DoubleSparseVector): The second vector. Returns: The dot product of v1 and v2. Raises: RuntimeError: In case of size mismatch. TypeError: If input types are not supported. """ if isinstance(v1, _kaldi_vector.VectorBase): if isinstance(v2, _kaldi_vector.VectorBase): return _kaldi_vector._vec_vec(v1, v2) elif isinstance(v2, _sparse_matrix.SparseVector): return _sparse_matrix._vec_svec(v1, v2) elif isinstance(v1, _kaldi_vector.DoubleVectorBase): if isinstance(v2, _kaldi_vector.DoubleVectorBase): return _kaldi_vector._vec_vec_double(v1, v2) elif isinstance(v2, _sparse_matrix.DoubleSparseVector): return _sparse_matrix._vec_svec_double(v1, v2) raise TypeError("v1 and v2 should be vectors with the same data type.")
[docs]def vec_mat_vec(v1, M, v2): """Computes a vector-matrix-vector product. Performs the operation :math:`v_1\\ M\\ v_2`. Precision of input matrices should match. Args: v1 (Vector or DoubleVector): The first input vector. M (Matrix or DoubleMatrix or SpMatrix): The input matrix. v2 (Vector or DoubleVector): The second input vector. Returns: The vector-matrix-vector product. Raises: RuntimeError: In case of size mismatch. """ if (isinstance(v1, _kaldi_vector.VectorBase) and isinstance(v2, _kaldi_vector.VectorBase)): if isinstance(M, _kaldi_matrix.MatrixBase): return _kaldi_vector_ext._vec_mat_vec(v1, M, v2) if isinstance(M, _sp_matrix.SpMatrix): return _sp_matrix._vec_sp_vec(v1, M, v2) elif (isinstance(v1, _kaldi_vector.DoubleVectorBase) and isinstance(v2, _kaldi_vector.DoubleVectorBase)): if isinstance(M, _kaldi_matrix.DoubleMatrixBase): return _kaldi_vector_ext._vec_mat_vec_double(v1, M, v2) if isinstance(M, _sp_matrix.DoubleSpMatrix): return _sp_matrix._vec_sp_vec_double(v1, M, v2) raise TypeError("given combination of input types is not supported")
[docs]def trace_mat(A): """Returns the trace of :math:`A`. Args: A (Matrix or DoubleMatrix): The input matrix. """ if isinstance(A, _kaldi_matrix.MatrixBase): return _kaldi_matrix._trace_mat(A) if isinstance(A, _kaldi_matrix.DoubleMatrixBase): return _kaldi_matrix._trace_double_mat(A) raise TypeError("input matrix type is not supported")
[docs]def trace_mat_mat(A, B, transA=_matrix_common.MatrixTransposeType.NO_TRANS): """Returns the trace of :math:`A\\ B`. Precision of input matrices should match. Args: A (Matrix or DoubleMatrix or SpMatrix or DoubleSpMatrix or SparseMatrix or DoubleSparseMatrix): The first input matrix. B (Matrix or DoubleMatrix or SpMatrix or DoubleSpMatrix or SparseMatrix or DoubleSparseMatrix): The second input matrix. transA (_matrix_common.MatrixTransposeType): Whether to use **A** or its transpose. Defaults to ``MatrixTransposeType.NO_TRANS``. lower (bool): Whether to count lower-triangular elements only once. Active only if both inputs are symmetric matrices. Defaults to ``False``. """ if isinstance(A, _kaldi_matrix.MatrixBase): if isinstance(B, _kaldi_matrix.MatrixBase): return _kaldi_matrix._trace_mat_mat(A, B, transA) elif isinstance(B, _sparse_matrix.SparseMatrix): return _sparse_matrix._trace_mat_smat(A, B, transA) elif isinstance(A, _sp_matrix.SpMatrix): if isinstance(B, _kaldi_matrix.MatrixBase): return _sp_matrix._trace_sp_mat(A, B) elif isinstance(B, _sp_matrix.SpMatrix): if lower: return _sp_matrix._trace_sp_sp_lower(A, B) else: return _sp_matrix._trace_sp_sp(A, B) elif isinstance(A, _kaldi_matrix.DoubleMatrixBase): if isinstance(B, _kaldi_matrix.DoubleMatrixBase): return _kaldi_matrix._trace_double_mat_mat(A, B, transA) elif isinstance(B, _sparse_matrix.DoubleSparseMatrix): return _sparse_matrix._trace_double_mat_smat(A, B, transA) elif isinstance(A, _sp_matrix.DoubleSpMatrix): if isinstance(B, _sp_matrix.DoubleSpMatrix): if lower: return _sp_matrix._trace_double_sp_sp_lower(A, B) else: return _sp_matrix._trace_double_sp_sp(A, B) raise TypeError("given combination of matrix types is not supported")
[docs]def trace_mat_mat_mat(A, B, C, transA=_matrix_common.MatrixTransposeType.NO_TRANS, transB=_matrix_common.MatrixTransposeType.NO_TRANS, transC=_matrix_common.MatrixTransposeType.NO_TRANS): """Returns the trace of :math:`A\\ B\\ C`. Precision of input matrices should match. Args: A (Matrix or DoubleMatrix): The first input matrix. B (Matrix or DoubleMatrix or SpMatrix or DoubleSpMatrix): The second input matrix. C (Matrix or DoubleMatrix): The third input matrix. transA (_matrix_common.MatrixTransposeType): Whether to use **A** or its transpose. Defaults to ``MatrixTransposeType.NO_TRANS``. transB (_matrix_common.MatrixTransposeType): Whether to use **B** or its transpose. Defaults to ``MatrixTransposeType.NO_TRANS``. transC (_matrix_common.MatrixTransposeType): Whether to use **C** or its transpose. Defaults to ``MatrixTransposeType.NO_TRANS``. """ if isinstance(A, _kaldi_matrix.MatrixBase): if (isinstance(B, _kaldi_matrix.MatrixBase) and isinstance(C, _kaldi_matrix.MatrixBase)): return _kaldi_matrix._trace_mat_mat_mat(A, transA, B, transB, C, transC) elif (isinstance(B, _sp_matrix.SpMatrix) and isinstance(C, _kaldi_matrix.MatrixBase)): return _sp_matrix._trace_mat_sp_mat(A, transA, B, C, transC) elif isinstance(A, _kaldi_matrix.DoubleMatrixBase): if (isinstance(B, _kaldi_matrix.DoubleMatrixBase) and isinstance(C, _kaldi_matrix.DoubleMatrixBase)): return _kaldi_matrix._trace_double_mat_mat_mat(A, transA, B, transB, C, transC) elif (isinstance(B, _sp_matrix.DoubleSpMatrix) and isinstance(C, _kaldi_matrix.DoubleMatrixBase)): return _sp_matrix._trace_double_mat_sp_mat(A, transA, B, C, transC) raise TypeError("given combination of matrix types is not supported")
[docs]def trace_mat_mat_mat_mat(A, B, C, D, transA=_matrix_common.MatrixTransposeType.NO_TRANS, transB=_matrix_common.MatrixTransposeType.NO_TRANS, transC=_matrix_common.MatrixTransposeType.NO_TRANS, transD=_matrix_common.MatrixTransposeType.NO_TRANS): """Returns the trace of :math:`A\\ B\\ C\\ D`. Precision of input matrices should match. Args: A (Matrix or DoubleMatrix): The first input matrix. B (Matrix or DoubleMatrix or SpMatrix or DoubleSpMatrix): The second input matrix. C (Matrix or DoubleMatrix): The third input matrix. D (Matrix or DoubleMatrix or SpMatrix or DoubleSpMatrix): The fourth input matrix. transA (_matrix_common.MatrixTransposeType): Whether to use **A** or its transpose. Defaults to ``MatrixTransposeType.NO_TRANS``. transB (_matrix_common.MatrixTransposeType): Whether to use **B** or its transpose. Defaults to ``MatrixTransposeType.NO_TRANS``. transC (_matrix_common.MatrixTransposeType): Whether to use **C** or its transpose. Defaults to ``MatrixTransposeType.NO_TRANS``. transD (_matrix_common.MatrixTransposeType): Whether to use **D** or its transpose. Defaults to ``MatrixTransposeType.NO_TRANS``. """ if isinstance(A, _kaldi_matrix.MatrixBase): if (isinstance(B, _kaldi_matrix.MatrixBase) and isinstance(C, _kaldi_matrix.MatrixBase) and isinstance(D, _kaldi_matrix.MatrixBase)): return _kaldi_matrix._trace_mat_mat_mat_mat(A, transA, B, transB, C, transC, D, transD) elif (isinstance(B, _sp_matrix.SpMatrix) and isinstance(C, _kaldi_matrix.MatrixBase) and isinstance(D, _sp_matrix.SpMatrix)): return _sp_matrix._trace_mat_sp_mat_sp(A, transA, B, C, transC, D) elif isinstance(A, _kaldi_matrix.DoubleMatrixBase): if (isinstance(B, _kaldi_matrix.DoubleMatrixBase) and isinstance(C, _kaldi_matrix.DoubleMatrixBase) and isinstance(D, _kaldi_matrix.DoubleMatrixBase)): return _kaldi_matrix._trace_double_mat_mat_mat_mat( A, transA, B, transB, C, transC, D, transD) elif (isinstance(B, _sp_matrix.DoubleSpMatrix) and isinstance(C, _kaldi_matrix.DoubleMatrixBase) and isinstance(D, _sp_matrix.DoubleSpMatrix)): return _sp_matrix._trace_double_mat_sp_mat_sp(A, transA, B, C, transC, D) raise TypeError("given combination of matrix types is not supported")
################################################################################ __all__ = [name for name in dir() if name[0] != '_']