from __future__ import division, print_function
import math
from . import decoder as _dec
from . import fstext as _fst
from .fstext import utils as _fst_utils
from . import matrix as _mat
from .matrix import common as _mat_comm
from . import nnet3 as _nnet3
from .util import io as _util_io
__all__ = ['Segmenter', 'NnetSAD', 'SegmentationProcessor']
[docs]class Segmenter(object):
"""Base class for speech segmenters.
Args:
graph (StdVectorFst): Segmentation graph.
beam (float): Logarithmic decoding beam.
max_active (int): Maximum number of active states in decoding.
acoustic_scale (float): Acoustic score scale.
"""
def __init__(self, graph, beam=8, max_active=1000, acoustic_scale=0.1):
decoder_opts = _dec.FasterDecoderOptions()
decoder_opts.beam = beam
decoder_opts.max_active = max_active
self.decoder = _dec.FasterDecoder(graph, decoder_opts)
self.acoustic_scale = acoustic_scale
def _make_decodable(self, loglikes):
"""Constructs a new decodable object from input log-likelihoods.
Args:
loglikes (object): Input log-likelihoods.
Returns:
DecodableMatrixScaled: A decodable object for computing scaled
log-likelihoods.
"""
if loglikes.num_rows == 0:
raise ValueError("Empty loglikes matrix.")
return _dec.DecodableMatrixScaled(loglikes, self.acoustic_scale)
[docs] def segment(self, input):
"""Segments input.
Output is a dictionary with the following `(key, value)` pairs:
============== ============================== ==========================
key value value type
============== ============================== ==========================
"alignment" Frame-level segmentation `List[int]`
"best_path" Best lattice path `CompactLattice`
"likelihood" Log-likelihood of best path `float`
"weight" Cost of best path `LatticeWeight`
============== ============================== ==========================
The "weight" output is a lattice weight consisting of (graph-score,
acoustic-score).
Args:
input (object): Input to segment.
Returns:
A dictionary representing segmentation output.
Raises:
RuntimeError: If segmentation fails.
"""
self.decoder.decode(self._make_decodable(input))
if not self.decoder.reached_final():
raise RuntimeError("No final state was active on the last frame.")
try:
best_path = self.decoder.get_best_path()
except RuntimeError:
raise RuntimeError("Empty segmentation output.")
ali, _, weight = _fst_utils.get_linear_symbol_sequence(best_path)
likelihood = - (weight.value1 + weight.value2)
if self.acoustic_scale != 0.0:
scale = _fst_utils.acoustic_lattice_scale(1.0 / self.acoustic_scale)
_fst_utils.scale_lattice(scale, best_path)
best_path = _fst_utils.convert_lattice_to_compact_lattice(best_path)
return {
"alignment": ali,
"best_path": best_path,
"likelihood": likelihood,
"weight": weight,
}
[docs]class NnetSAD(Segmenter):
"""Neural network based speech activity detection (SAD).
Args:
model (Nnet): SAD model. Model output should be log-posteriors for
[silence, speech, garbage] labels.
transform (Matrix): Transformation applied to SAD label posteriors. It
should be a 3x2 matrix mapping [silence, speech, garbage] posteriors
to [silence, speech] pseudo-likelihoods.
graph (StdVectorFst): SAD graph. Silence and speech arcs should be
labeled respectively with 1 and 2.
beam (float): Logarithmic decoding beam.
max_active (int): Maximum number of active states in decoding.
decodable_opts (NnetSimpleComputationOptions): Configuration options for
the SAD model.
"""
def __init__(self, model, transform, graph, beam=8, max_active=1000,
decodable_opts=None):
if not isinstance(model, _nnet3.Nnet):
raise TypeError("model argument should be a Nnet object")
self.model = model
self.priors = _mat.Vector()
self.transform = transform
_nnet3.set_batchnorm_test_mode(True, model)
_nnet3.set_dropout_test_mode(True, model)
_nnet3.collapse_model(_nnet3.CollapseModelConfig(), model)
if decodable_opts:
if not isinstance(decodable_opts,
_nnet3.NnetSimpleComputationOptions):
raise TypeError("decodable_opts should be either None or a "
"NnetSimpleComputationOptions object")
self.decodable_opts = decodable_opts
else:
self.decodable_opts = _nnet3.NnetSimpleComputationOptions()
self.compiler = _nnet3.CachingOptimizingCompiler.new_with_optimize_opts(
model, self.decodable_opts.optimize_config)
super(NnetSAD, self).__init__(graph, beam, max_active,
self.decodable_opts.acoustic_scale)
def _make_decodable(self, features):
"""Constructs a new decodable object from input features.
Args:
features (Matrix): Input features.
Returns:
DecodableMatrixScaled: A decodable object for computing scaled
log-likelihoods.
"""
if features.num_rows == 0:
raise ValueError("Empty feature matrix.")
nnet_computer = _nnet3.DecodableNnetSimple(
self.decodable_opts, self.model, self.priors,
features, self.compiler, None, None, 0)
post = _mat.Matrix(nnet_computer.num_frames(),
nnet_computer.output_dim())
for t in range(nnet_computer.num_frames()):
nnet_computer.get_output_for_frame(t, post[t])
post.apply_exp_()
# FIXME: Need to keep a reference to log_likes to keep it in scope
self._log_likes = _mat.Matrix(post.num_rows, self.transform.num_rows)
self._log_likes.add_mat_mat_(post, self.transform,
_mat_comm.MatrixTransposeType.NO_TRANS,
_mat_comm.MatrixTransposeType.TRANS,
1.0, 0.0)
self._log_likes.apply_log_()
return _dec.DecodableMatrixScaled(self._log_likes, self.acoustic_scale)
[docs] @staticmethod
def read_model(model_rxfilename):
"""Reads SAD model from an extended filename."""
with _util_io.xopen(model_rxfilename) as ki:
return _nnet3.Nnet().read(ki.stream(), ki.binary)
[docs] @staticmethod
def read_average_posteriors(post_rxfilename):
"""Reads average SAD label posteriors from an extended filename."""
with _util_io.xopen(post_rxfilename) as ki:
return _mat.Vector().read_(ki.stream(), ki.binary)
[docs] @staticmethod
def make_sad_graph(transition_scale=1.0, self_loop_scale=0.1,
min_silence_duration=0.03, min_speech_duration=0.3,
max_speech_duration=10.0, frame_shift=0.01,
edge_silence_probability=0.5, transition_probability=0.1):
"""Makes a decoding graph with a simple HMM topology suitable for SAD.
Output graph uses label 1 for 'silence' and label 2 for 'speech'.
Args:
transition_scale (float): Scale on transition log-probabilities
relative to LM weights.
self_loop_scale (float): Scale on self-loop log-probabilities
relative to LM weights.
min_silence_duration (float): Minimum duration for silence.
min_speech_duration (float): Minimum duration for speech.
max_speech_duration (float): Maximum duration for speech.
frame_shift (float): Frame shift in seconds.
edge_silence_probability (float): Probability of silence at the
edges.
transition_probability (float): Transition probability for silence
to speech or vice-versa.
Returns:
StdVectorFst: A simple decoding graph suitable for SAD.
"""
min_states_silence = int(min_silence_duration / frame_shift + 0.5)
min_states_speech = int(min_speech_duration / frame_shift + 0.5)
max_states_speech = int(max_speech_duration / frame_shift + 0.5)
symbols = _fst.SymbolTable()
symbols.add_symbol("<eps>")
symbols.add_symbol("silence")
symbols.add_symbol("speech")
compiler = _fst.StdFstCompiler(symbols, symbols)
# Initial transition to silence
print("0 1 silence silence {cost}".format(
cost=-math.log(edge_silence_probability)),
file=compiler)
silence_start_state = 1
# Silence min duration transitions
# 1->2, 2->3 and so on until
# (1 + min_states_silence - 2) -> (1 + min_states_silence - 1) ...
for state in range(silence_start_state,
silence_start_state + min_states_silence - 1):
print ("{state} {next_state} silence silence {cost}".format(
state=state, next_state=state + 1, cost=0.0),
file=compiler)
silence_last_state = silence_start_state + min_states_silence - 1
# Silence self-loop
print ("{state} {state} silence silence {cost}".format(
state=silence_last_state, cost=0.0),
file=compiler)
speech_start_state = silence_last_state + 1
# Initial transition to speech
print ("0 {state} speech speech {cost}".format(
state=speech_start_state,
cost=-math.log(1.0 - edge_silence_probability)),
file=compiler)
# Silence to speech transition
print ("{sil_state} {speech_state} speech speech {cost}".format(
sil_state=silence_last_state,
speech_state=speech_start_state,
cost=-math.log(transition_probability)),
file=compiler)
# Speech min duration
for state in range(speech_start_state,
speech_start_state + min_states_speech - 1):
print ("{state} {next_state} speech speech {cost}".format(
state=state, next_state=state + 1, cost=0.0),
file=compiler)
# Speech max duration
for state in range(speech_start_state + min_states_speech - 1,
speech_start_state + max_states_speech - 1):
print ("{state} {next_state} speech speech {cost}".format(
state=state, next_state=state + 1, cost=0.0),
file=compiler)
print ("{state} {sil_state} silence silence {cost}".format(
state=state, sil_state=silence_start_state,
cost=-math.log(transition_probability)),
file=compiler)
speech_last_state = speech_start_state + max_states_speech - 1
# Transition to silence after max duration of speech
print ("{state} {sil_state} silence silence {cost}".format(
state=speech_last_state, sil_state=silence_start_state,
cost=0.0),
file=compiler)
for state in range(1, speech_start_state):
print ("{state} {cost}".format(
state=state, cost=-math.log(edge_silence_probability)),
file=compiler)
for state in range(speech_start_state, speech_last_state + 1):
print ("{state} {cost}".format(
state=state,
cost=-math.log(1.0 - edge_silence_probability)),
file=compiler)
return compiler.compile()
[docs]class SegmentationProcessor(object):
"""Segmentation post-processor.
This class is used for converting segmentation labels to a list of segments.
Output includes only those segments labeled with the target labels.
Post-processing operations include::
* filtering out short segments
* padding segments
* merging consecutive segments
Args:
target_labels (List[int]): Target labels. Typically the speech labels.
frame_shift (float): Frame shift in seconds.
segment_padding (float): Additional padding on target segments.
Padding does not go beyond the adjacent segment. This is typically
used for padding speech segments with silence. Must be an integral
multiple of frame shift.
min_segment_dur (float): Minimum duration (in seconds) required for a
segment to be included. This is before any padding. Segments shorter
than this duration will be removed.
max_merged_segment_dur (float): Merge consecutive segments as long as
the merged segment is no longer than this many seconds. The segments
are only merged if their boundaries are touching. This is after
padding by --segment-padding seconds. 0 means do not merge. Use
'inf' to not limit the duration.
Attributes:
stats (SegmentationProcessor.Stats): Global segmentation post-processing
stats.
"""
def __init__(self, target_labels, frame_shift=0.01, segment_padding=0.2,
min_segment_dur=0, max_merged_segment_dur=0):
if not float(segment_padding / frame_shift).is_integer():
raise ValueError("segment_padding = {} is not an integral "
"multiple of frame_shift = {}"
.format(segment_padding,frame_shift))
self.target_labels = target_labels
self.frame_shift = frame_shift
self.segment_padding = int(segment_padding / frame_shift)
self.min_segment_dur = int(math.ceil(min_segment_dur / frame_shift))
self.max_merged_segment_dur = int(max_merged_segment_dur / frame_shift)
self.stats = self.Stats()
[docs] class Stats(object):
"""Stores segmentation post-processing stats."""
def __init__(self):
self.num_segments_initial = 0
self.num_short_segments_filtered = 0
self.num_merges = 0
self.num_segments_final = 0
self.initial_duration = 0.0
self.padding_duration = 0.0
self.filter_short_duration = 0.0
self.final_duration = 0.0
[docs] def add(self, other):
"""Adds stats from another"""
self.num_segments_initial += other.num_segments_initial
self.num_short_segments_filtered += other.num_short_segments_filtered
self.num_merges += other.num_merges
self.num_segments_final += other.num_segments_final
self.initial_duration += other.initial_duration
self.filter_short_duration += other.filter_short_duration
self.padding_duration += other.padding_duration
self.final_duration += other.final_duration
def __str__(self):
return ("num-segments-initial={num_segments_initial}, "
"num-short-segments-filtered={num_short_segments_filtered}, "
"num-merges={num_merges}, "
"num-segments-final={num_segments_final}, "
"initial-duration={initial_duration}, "
"filter-short-duration={filter_short_duration}, "
"padding-duration={padding_duration}, "
"final-duration={final_duration}".format(
num_segments_initial=self.num_segments_initial,
num_short_segments_filtered=self.num_short_segments_filtered,
num_merges=self.num_merges,
num_segments_final=self.num_segments_final,
initial_duration=self.initial_duration,
filter_short_duration=self.filter_short_duration,
padding_duration=self.padding_duration,
final_duration=self.final_duration))
[docs] def process(self, alignment):
"""Converts frame-level segmentation labels to a list of segments.
Args:
alignment (List[int]): Frame-level segmentation labels.
Returns:
Tuple[List[Tuple[int, int, int]], SegmentationProcessor.Stats]: List
of segments, where each entry is a (segment-beg, segment-end, label)
tuple, along with segmentation post-processing stats.
"""
stats = self.Stats()
segments = self.initialize_segments(alignment, stats)
segments = self.filter_short_segments(segments, stats)
segments = self.pad_segments(segments, stats, len(alignment))
segments = self.merge_consecutive_segments(segments, stats)
self.stats.add(stats)
return segments, stats
[docs] def initialize_segments(self, alignment, stats):
"""Initializes segments.
The alignment is frame-level segmentation labels. Output includes only
those segments labeled with the target labels.
"""
segments = []
if not alignment:
return segments
num_target_frames, seg_begin, seg_label = 0, 0, alignment[0]
for i, label in enumerate(alignment[1:], 1):
if label != seg_label:
if seg_label in self.target_labels:
segments.append((seg_begin, i, seg_label))
num_target_frames += i - seg_begin
seg_begin, seg_label = i, label
if seg_label in self.target_labels:
segments.append((seg_begin, len(alignment), seg_label))
num_target_frames += len(alignment) - seg_begin
stats.num_segments_initial = len(segments)
stats.num_segments_final = len(segments)
stats.initial_duration = num_target_frames * self.frame_shift
stats.final_duration = stats.initial_duration
return segments
[docs] def filter_short_segments(self, segments, stats):
"""Filters out short segments."""
if self.min_segment_dur <= 0:
return segments
filtered_segments = []
for segment in segments:
dur = segment[1] - segment[0]
if dur < self.min_segment_dur:
stats.filter_short_duration += dur * self.frame_shift
stats.num_short_segments_filtered += 1
else:
filtered_segments.append(segment)
stats.num_segments_final = len(filtered_segments)
stats.final_duration -= stats.filter_short_duration
return filtered_segments
[docs] def pad_segments(self, segments, stats, num_utt_frames=None):
"""Pads segments on both sides.
Ensures that the segments do not go beyond the neighboring segments or
utterance boundaries.
"""
num_padded_frames, padded_segments = 0, []
for i, (seg_beg, seg_end, label) in enumerate(segments):
seg_beg -= self.segment_padding # try padding on the left side
num_padded_frames += self.segment_padding
if seg_beg < 0:
# Padded segment start is before the beginning of the utterance.
# Reduce padding.
num_padded_frames += seg_beg
seg_beg = 0
if i >= 1 and prev_seg_end > seg_beg:
# Padded segment start is before the end of previous segment.
# Reduce padding.
num_padded_frames -= prev_seg_end - seg_beg
seg_beg = prev_seg_end
seg_end += self.segment_padding
num_padded_frames += self.segment_padding
if num_utt_frames is not None and seg_end > num_utt_frames:
# Padded segment end is beyond the max duration.
# Reduce padding.
num_padded_frames -= seg_end - num_utt_frames
seg_end = num_utt_frames
if i + 1 < len(segments) and seg_end > segments[i + 1][0]:
# Padded segment end is beyond the start of the next segment.
# Reduce padding.
num_padded_frames -= seg_end - segments[i + 1][0]
seg_end = segments[i + 1][0]
padded_segments.append((seg_beg, seg_end, label))
prev_seg_end = seg_end
stats.padding_duration = num_padded_frames * self.frame_shift
stats.final_duration += stats.padding_duration
return padded_segments
[docs] def merge_consecutive_segments(self, segments, stats):
"""Merges consecutive segments.
Done after padding. Consecutive segments that share a boundary are
merged if they have the same label and the merged segment is no longer
than 'max_merged_segment_dur'.
"""
if self.max_merged_segment_dur <= 0 or not segments:
return segments
merged_segments = [segments[0]]
for seg_beg, seg_end, label in segments[1:]:
prev_seg_beg, prev_seg_end, prev_label = merged_segments[-1]
if (seg_beg == prev_seg_end and label == prev_label and
seg_end - prev_seg_beg <= self.max_merged_segment_dur):
merged_segments[-1] = (prev_seg_beg, seg_end, label)
stats.num_merges += 1
else:
merged_segments.append((seg_beg, seg_end, label))
stats.num_segments_final = len(merged_segments)
return merged_segments
[docs] def write(self, key, segments, file_handle):
"""Writes segments to file"""
for begin, end, label in segments:
id = "{key}-{label}-{begin:07d}-{end:07d}".format(
key=key, label=label, begin=begin, end=end)
print("{id} {key} {begin:.2f} {end:.2f}".format(
id=id, key=key, begin=begin * self.frame_shift,
end=end * self.frame_shift),
file=file_handle)