Contents Menu Expand Light mode Dark mode Auto light/dark, in light mode Auto light/dark, in dark mode Skip to content
RETURNN documentation
RETURNN documentation

Getting Started

  • Installation
  • Basic Usage
  • RETURNN frontend
  • Technological Overview
  • Tensor and Dim
  • Recurrency
  • Training
  • Native operations
  • RETURNN as Framework
  • Recent development of RETURNN
  • TensorFlow LSTM Benchmark
  • Frequently Asked Questions

User Guide

  • Network Structure
  • Dataset Input/Output
  • Recurrent Sub-Networks
  • Extending RETURNN

Reference

  • Configuration Parameters
    • behavior_version
    • General Settings
    • Training
    • Optimizer Settings
    • Pretraining
    • Model Loading
    • Generation and Search
    • Debugging
  • Datasets
    • Generic Datasets
    • Text Datasets
    • Audio Datasets
    • Dataset Combination
  • Layers / Modules
    • Basic Layers
    • Shape and Type Modification
    • Recurrent Layers
    • Recurrent Units
    • Attention Layers
    • Normalization Layers
    • Regularization Layers
    • Custom Layers
    • Utility Layers
    • Loss Functions
    • Softmax Layers
    • Generative Layers
    • Alignment Layers
  • Optimizer

Advanced Topics

  • Pre-Training / Dynamic Networks
  • Multi GPU training with PyTorch
  • Multi GPU training with TensorFlow
  • Debugging
  • Profiling (Python generic)
  • Profiling (PyTorch)
  • Profiling (TensorFlow)
  • Deterministic training
  • Test suite

Applications

  • Speech Recognition
  • Language Modelling
  • Machine Translation

Internals

  • API
    • returnn
    • returnn.__main__
    • returnn.__old_mod_loader__
    • returnn.__setup__
    • returnn.config
    • returnn.datasets
    • returnn.datasets.audio
    • returnn.datasets.basic
    • returnn.datasets.bundle_file
    • returnn.datasets.cached
    • returnn.datasets.cached2
    • returnn.datasets.distrib_files
    • returnn.datasets.generating
    • returnn.datasets.hdf
    • returnn.datasets.huggingface
    • returnn.datasets.lm
    • returnn.datasets.map
    • returnn.datasets.meta
    • returnn.datasets.multi_proc
    • returnn.datasets.normalization_data
    • returnn.datasets.numpy_dump
    • returnn.datasets.postprocessing
    • returnn.datasets.raw_wav
    • returnn.datasets.sprint
    • returnn.datasets.stereo
    • returnn.datasets.text_dict
    • returnn.datasets.util
    • returnn.datasets.util.feature_extraction
    • returnn.datasets.util.strings
    • returnn.datasets.util.vocabulary
    • returnn.engine
    • returnn.engine.base
    • returnn.engine.batch
    • returnn.extern
    • returnn.extern.HawkAaronWarpTransducer
    • returnn.extern.WarpRna
    • returnn.extern.WarpRna.__main__
    • returnn.extern.graph_editor
    • returnn.extern.graph_editor.edit
    • returnn.extern.graph_editor.reroute
    • returnn.extern.graph_editor.select
    • returnn.extern.graph_editor.subgraph
    • returnn.extern.graph_editor.transform
    • returnn.extern.graph_editor.util
    • returnn.extern.official_tf_resnet
    • returnn.extern.official_tf_resnet.resnet_model
    • returnn.forward_iface
    • returnn.frontend
    • returnn.frontend._backend
    • returnn.frontend._cache
    • returnn.frontend._native
    • returnn.frontend._numpy_backend
    • returnn.frontend._random_journal
    • returnn.frontend._utils
    • returnn.frontend.array_
    • returnn.frontend.assert_
    • returnn.frontend.attention
    • returnn.frontend.audio
    • returnn.frontend.audio.mel
    • returnn.frontend.audio.specaugment
    • returnn.frontend.backend
    • returnn.frontend.build_from_dict
    • returnn.frontend.cond
    • returnn.frontend.const
    • returnn.frontend.container
    • returnn.frontend.control_flow_ctx
    • returnn.frontend.conv
    • returnn.frontend.conversions
    • returnn.frontend.conversions.espnet_e_branchformer
    • returnn.frontend.conversions.hf_llama
    • returnn.frontend.conversions.torch_nn
    • returnn.frontend.decoder
    • returnn.frontend.decoder.transformer
    • returnn.frontend.device
    • returnn.frontend.dims
    • returnn.frontend.dropout
    • returnn.frontend.dtype
    • returnn.frontend.encoder
    • returnn.frontend.encoder.base
    • returnn.frontend.encoder.conformer
    • returnn.frontend.encoder.conformer_v2
    • returnn.frontend.encoder.e_branchformer
    • returnn.frontend.encoder.transformer
    • returnn.frontend.gradient
    • returnn.frontend.graph
    • returnn.frontend.hooks
    • returnn.frontend.init
    • returnn.frontend.label_smoothing
    • returnn.frontend.linear
    • returnn.frontend.loop
    • returnn.frontend.loss
    • returnn.frontend.math_
    • returnn.frontend.matmul
    • returnn.frontend.module
    • returnn.frontend.nested
    • returnn.frontend.normalization
    • returnn.frontend.parameter
    • returnn.frontend.parametrizations
    • returnn.frontend.parametrize
    • returnn.frontend.piecewise_linear
    • returnn.frontend.rand
    • returnn.frontend.rec
    • returnn.frontend.reduce
    • returnn.frontend.run_ctx
    • returnn.frontend.signal
    • returnn.frontend.state
    • returnn.frontend.stepwise_scheduler
    • returnn.frontend.tensor_array
    • returnn.frontend.types
    • returnn.import_
    • returnn.import_.common
    • returnn.import_.git
    • returnn.import_.import_
    • returnn.learning_rate_control
    • returnn.log
    • returnn.native_op
    • returnn.pretrain
    • returnn.sprint
    • returnn.sprint.cache
    • returnn.sprint.control
    • returnn.sprint.error_signals
    • returnn.sprint.extern_interface
    • returnn.sprint.interface
    • returnn.tensor
    • returnn.tensor._dim_extra
    • returnn.tensor._tensor_extra
    • returnn.tensor._tensor_mixin_base
    • returnn.tensor._tensor_op_overloads
    • returnn.tensor.control_flow_ctx
    • returnn.tensor.dim
    • returnn.tensor.marked_dim
    • returnn.tensor.tensor
    • returnn.tensor.tensor_dict
    • returnn.tensor.utils
    • returnn.tf
    • returnn.tf.compat
    • returnn.tf.data_pipeline
    • returnn.tf.distributed
    • returnn.tf.engine
    • returnn.tf.frontend_layers
    • returnn.tf.frontend_layers._backend
    • returnn.tf.frontend_layers._utils
    • returnn.tf.frontend_layers.cond
    • returnn.tf.frontend_layers.config_entry_points
    • returnn.tf.frontend_layers.debug_eager_mode
    • returnn.tf.frontend_layers.dims
    • returnn.tf.frontend_layers.layer
    • returnn.tf.frontend_layers.loop
    • returnn.tf.frontend_layers.make_layer
    • returnn.tf.frontend_layers.masked_computation
    • returnn.tf.frontend_layers.parameter_assign
    • returnn.tf.frontend_layers.prev_tensor_ref
    • returnn.tf.frontend_low_level
    • returnn.tf.frontend_low_level._backend
    • returnn.tf.horovod
    • returnn.tf.hyper_param_tuning
    • returnn.tf.layers
    • returnn.tf.layers.base
    • returnn.tf.layers.basic
    • returnn.tf.layers.rec
    • returnn.tf.layers.segmental_model
    • returnn.tf.layers.signal_processing
    • returnn.tf.layers.variable
    • returnn.tf.native_op
    • returnn.tf.network
    • returnn.tf.sprint
    • returnn.tf.updater
    • returnn.tf.util
    • returnn.tf.util.basic
    • returnn.tf.util.data
    • returnn.tf.util.gradient_checkpoint
    • returnn.tf.util.ken_lm
    • returnn.tf.util.open_fst
    • returnn.torch
    • returnn.torch.data
    • returnn.torch.data.extern_data
    • returnn.torch.data.pipeline
    • returnn.torch.data.queued_data_iter
    • returnn.torch.data.returnn_dataset_wrapper
    • returnn.torch.data.tensor_utils
    • returnn.torch.distributed
    • returnn.torch.engine
    • returnn.torch.frontend
    • returnn.torch.frontend._backend
    • returnn.torch.frontend._rand
    • returnn.torch.frontend.bridge
    • returnn.torch.frontend.compile_helper
    • returnn.torch.frontend.raw_ops
    • returnn.torch.optim
    • returnn.torch.optim.lion
    • returnn.torch.updater
    • returnn.torch.util
    • returnn.torch.util.array_
    • returnn.torch.util.assert_
    • returnn.torch.util.debug_inf_nan
    • returnn.torch.util.diagnose_gpu
    • returnn.torch.util.distributed
    • returnn.torch.util.exception_helper
    • returnn.torch.util.gradient_checkpoint
    • returnn.torch.util.module
    • returnn.torch.util.native_op
    • returnn.torch.util.native_op_code_compiler
    • returnn.torch.util.rope
    • returnn.torch.util.scaled_gradient
    • returnn.util
    • returnn.util.basic
    • returnn.util.better_exchook
    • returnn.util.bpe
    • returnn.util.collect_outputs_dict
    • returnn.util.cuda_env
    • returnn.util.debug
    • returnn.util.debug_helpers
    • returnn.util.file_cache
    • returnn.util.fsa
    • returnn.util.hot_reload
    • returnn.util.literal_py_to_pickle
    • returnn.util.lru_cache
    • returnn.util.math
    • returnn.util.multi_proc_manager_with_watchdog
    • returnn.util.multi_proc_non_daemonic_spawn
    • returnn.util.native_code_compiler
    • returnn.util.pprint
    • returnn.util.py_ext_mod_compiler
    • returnn.util.result_with_reason
    • returnn.util.sig_proc
    • returnn.util.task_system
    • returnn.util.train_proc_manager
    • returnn.util.watch_memory
  • TensorFlow Beam Search
Back to top

TensorFlow Beam Search¶

RETURNN can perform beam search on an arbitrary network architecture with an arbitrary number of outputs. Mostly due to this fact, there is no single place in the code where beam search happens. The following classes and functions are related to beam search:

  • returnn.tf.layers.rec.ChoiceLayer

  • returnn.tf.layers.rec.DecideLayer

  • returnn.tf.util.data.SearchBeam

  • returnn.tf.layers.base.SearchChoices

  • returnn.tf.layers.basic.SelectSearchSourcesLayer

  • returnn.tf.layers.rec._SubnetworkRecCell._opt_search_resolve()

  • returnn.tf.network.TFNetwork.get_search_choices()

For an example implementation of search, please have a look at Recurrent Nets with Independent Step Count.

Choice- and Decide Layer¶

ChoiceLayer and DecideLayer are the only layers that actively manipulate the beam. In the other layers the beam is hidden inside the batch dimension and the layers aren’t even aware of this.

During search, ChoiceLayer creates a beam by taking the top (i.e. largest) k elements of the input vector (which is typically the output of a softmax). If the input has a beam itself then the output will contain the top k elements found in any of the vectors in the incoming beam. DecideLayer simply outputs the first best entry in the beam. This is used after the actual search to output the first best hypothesis (although, if you skip it, you can also output the n-best list).

During training (by default) there is no beam, ChoiceLayer outputs the ground truth label and DecideLayer does nothing.

Beam Selection¶

In addition to those two layers, there is logic between the layers that manipulates the beam: ChoiceLayers can occur at any place in the network. Therefore, in general, all layers operate on a different beam. When a layer has several inputs, it must be ensured that the incoming beams correspond to each other. This means, that all the incoming beams have the same size and all the n-th entries in each of the beams derive from a common origin (= ChoiceLayer) somewhere along the network graph. For this, the graph is parsed for each layer, the most recent ChoiceLayer is determined and all other inputs are “translated” to this most recent beam. In more detail, this means that we trace back the dependencies of each entry in the most recent beam until we get to another ChoiceLayer. Here we collect the single corresponding entry we arrived at. Doing that for all entries in the most recent beam and for all ChoiceLayers in the network we create new “selected” beams for the other inputs, which are then used instead of the original incoming ones as sources for the current layer. In the code this happens in:

network._create_layer() ->
network._create_layer_desc() ->
SearchChoices.translate_to_common_search_beam() ->
SearchChoices.translate_to_this_search_beam() ->
SelectSearchSourcesLayer.select_if_needed() ->
SelectSearchSourcesLayer.__init__() ->
select_src_beams()

Backtracking¶

Finally, when there is a beam inside a recurrent layer (this is actually the most common place where it occurs), there is an additional step in which the beams of the outputs of the recurrent layer are resolved over time after all time steps are evaluated. This is exactly what is better known as backtracking, i.e. we create the full n-best sequences for all outputs, instead of outputting the contents of the beam at each step in time. This is implemented in:

_SubnetworkRecCell.get_output() ->
_SubnetworkRecCell._construct_output_layers_moved_out() ->
_SubnetworkRecCell.get_loop_acc_layer() ->
_SubnetworkRecCell._opt_search_resolve()

For all this, the utility function that parses the dependency graph for the most recent ChoiceLayers is returnn.tf.network.TFNetwork.get_search_choices().

Previous
returnn.util.watch_memory
Copyright © 2014–2023, RETURNN contributors
Made with Sphinx and @pradyunsg's Furo
On this page
  • TensorFlow Beam Search
    • Choice- and Decide Layer
    • Beam Selection
    • Backtracking