Alignment Layers

Forced Alignment Layer

class returnn.tf.layers.basic.ForcedAlignmentLayer(align_target, topology, input_type, **kwargs)[source]

Calculates a forced alignment, via Viterbi algorithm.

Parameters:
  • align_target (LayerBase) –
  • topology (str) – e.g. “ctc” or “rna” (RNA is CTC without label loop)
  • input_type (str) – “log_prob” or “prob”
layer_class = 'forced_align'[source]
classmethod get_sub_layer_out_data_from_opts(layer_name, parent_layer_kwargs)[source]
Parameters:
  • layer_name (str) – sub layer name
  • parent_layer_kwargs (dict[str]) –
Return type:

(Data, TFNetwork, type)|None

get_sub_layer(layer_name)[source]
Parameters:layer_name (str) –
Return type:LayerBase|None
get_dep_layers()[source]
Return type:list[LayerBase]
classmethod transform_config_dict(d, network, get_layer)[source]
Parameters:
classmethod get_out_data_from_opts(name, sources, **kwargs)[source]
Parameters:
  • name (str) –
  • sources (list[LayerBase]) –
Return type:

Data

Fast Baum-Welch Layer

class returnn.tf.layers.basic.FastBaumWelchLayer(align_target, align_target_key=None, ctc_opts=None, sprint_opts=None, input_type='log_prob', tdp_scale=1.0, am_scale=1.0, min_prob=0.0, staircase_seq_len_source=None, **kwargs)[source]

Calls fast_baum_welch() or fast_baum_welch_by_sprint_automata(). We expect that our input are +log scores, e.g. use log-softmax.

Parameters:
  • align_target (str) – e.g. “sprint” or “staircase”
  • align_target_key (str|None) – e.g. “classes”, used for e.g. align_target “ctc”
  • ctc_opts (dict[str]) – used for align_target “ctc”
  • sprint_opts (dict[str]) – used for Sprint (RASR) for align_target “sprint”
  • input_type (str) – “log_prob” or “prob”
  • tdp_scale (float) –
  • am_scale (float) –
  • min_prob (float) – clips the minimum prob (value in [0,1])
  • staircase_seq_len_source (LayerBase|None) –
layer_class = 'fast_bw'[source]
recurrent = True[source]
classmethod transform_config_dict(d, network, get_layer)[source]
Parameters:
classmethod get_out_data_from_opts(name, sources, **kwargs)[source]
Parameters:
  • name (str) –
  • sources (list[LayerBase]) –
Return type:

Data