returnn.frontend.run_ctx

Run context

We can either be in param-init stage, or in the main training loop, or forwarding loop.

class returnn.frontend.run_ctx.RunCtx(*, stage: str, train_flag: bool | Tensor = False, step: int | Tensor = 0, epoch: int | Tensor = 1, expected_outputs: TensorDict | None = None)[source]

We can either be in param-init stage, or in the main training (or eval) loop, or forwarding loop (doing recog, beam search, dumping whatever, …).

In training/eval, we expect that some loss is being defined via mark_as_loss(). In forwarding, we expect that some output is being defined via mark_as_output().

Parameters:

stage

  • “init”

  • ”train_step”, also for eval, for mark_as_loss and get_total_loss

  • ”forward_step”, for mark_as_output

property stage: str[source]
Returns:

“init”, “train_step”, “forward_step”

property train_flag: bool | Tensor[source]
Returns:

whether we are in training mode, i.e. the model is updated, and we are supposed to use dropout and similar mechanisms. In a graph-based backend, this can be dynamic.

train_flag_ctx(train_flag: bool | Tensor)[source]

Context manager to temporarily set the train_flag.

Usage example e.g. to disable dropout for some code:

with rf.get_run_ctx().train_flag_ctx(False):
    ...
Parameters:

train_flag – whether we are in training mode

property step: int | Tensor[source]
Returns:

global train step, starting with 0, not reset after an epoch, i.e. ignoring the epochs. In a graph-based backend, this can be dynamic.

get_step_tensor() Tensor[source]
Returns:

step as tensor

property epoch: int | Tensor[source]
Returns:

epoch

get_epoch_tensor() Tensor[source]
Returns:

epoch as tensor

mark_as_loss(loss: Tensor | Any, name: str, *, dims: Sequence[Dim] | None = None, scale: float = 1.0, as_error: bool = False, use_normalized_loss: bool = False, use_flatten_frames: bool = True, custom_inv_norm_factor: Tensor | None = None) None[source]

Mark the given loss tensor as a loss. This has the effect that it is specially handled by RETURNN. Specifically, the optimizer can use it in training, and it is used for reporting per batch or per epoch, and for learning rate scheduling.

This currently uses AsIsLoss in RETURNN but this is an implementation detail and might change.

Parameters:
  • loss – E.g. shape [B,T] or [B]. A Tensor is usually expected, but a raw tensor is also possible. You should not reduce the axes where RETURNN should collect epoch-wise statistics, such that RETURNN can properly accumulate it over batches. You should reduce_sum over axes where you do not want to have normalization. E.g. if you calculate framewise CE getting shape [B,T], and you want it to be sequence-level CE, calculate reduce_sum(loss, axis=T) to get [B] and pass only those sequence-level CE losses here.

  • name – name of the loss. this name is used for reporting by RETURNN, and also for LR scheduling.

  • dims – in case loss is not a Tensor, but a raw tensor

  • scale – scale the loss by this factor for the training optimizer (but not for any reporting). setting to 0.0 has the effect that this loss is not used by the optimizer.

  • as_error – if True, this loss is reported as an error instead of a loss, and not used by the training optimizer. This is by convention sth like the frame-error or edit-distance, and usually not differentiable anyway.

  • use_normalized_loss – the loss used in optimization will be normalized via reduce_mean instead of reduce_sum. E.g. if the overall normalization is sum(loss)/sum(num_frames), this is also what the optimizer will use, otherwise the optimizer will just use sum(loss).

  • use_flatten_frames – If True, will use returnn.tf.util.basic.flatten_with_seq_len_mask(), i.e. a “packed” sequence with the padded frames removed, and accumulates over that. This can be more efficient, also because it will further optimize incoming computations and e.g. skip softmax computations right before on the padded frames. This can also avoid issues with inf/nan in some cases. If False, it will mask the loss to 0 in the padded frames and accumulate over that. Typically, setting this to True (default) is both more efficient and better.

  • custom_inv_norm_factor – The standard inv norm factor is sum(target_seq_len) if the target has a time-axis, or sum(output_seq_len) if there is no target and the output has a time-axis, or 1 otherwise. (See Loss.init() for details.) This is used for proper normalization of accumulated loss/error per epoch and also proper normalization per batch for reporting, no matter if use_normalized_loss is True or False. If you want to change this norm factor, you can set this. Basically, for all reporting, it uses sum(loss) / sum(custom_inv_norm_factor).

mark_as_output(tensor: Tensor | Any, name: str, *, dims: Sequence[Dim] | None = None) None[source]

Mark this as an output. This has the effect that RETURNN will in any case construct the corresponding layer. Also see mark_as_default_output().

This is intended mostly for forwarding, or exporting the model (TF graph, TFLite, ONNX, etc). You must specify a shape to have the output shape (order of dims) well-defined (if not specified, we check if some defaults are possible, like BTF, or BF).

Parameters:
  • tensor

  • name

  • dims – this specifies the order of the dims of the output, such that it is well-defined for some external application. If not specified, we try to infer BTF or BF as default, if that works, otherwise it will be an error.

mark_as_default_output(tensor: Tensor | Any, *, shape: Sequence[Dim] | None = None) None[source]

Calls mark_as_output(tensor, “output”, shape=shape).

Mark this as the default output. See Frontend.mark_as_default_output() for more details.

Parameters:
  • tensor

  • shape

check_outputs_complete()[source]

If expected outputs are given, check that all expected outputs are present.

total_loss() Tensor | float[source]
Returns:

total loss, as it is used for backpropagation

class returnn.frontend.run_ctx.Loss(loss: Tensor, name: str, scale: float = 1.0, as_error: bool = False, use_normalized_loss: bool = False, use_flatten_frames: bool = True, custom_inv_norm_factor: Tensor | None = None, _summed_loss_cached: Tensor | None = None, _mean_loss_cached: Tensor | None = None)[source]

Loss via RunCtx.mark_as_loss().

We collect all relevant information here.

loss: Tensor[source]
name: str[source]
scale: float = 1.0[source]
as_error: bool = False[source]
use_normalized_loss: bool = False[source]
use_flatten_frames: bool = True[source]
custom_inv_norm_factor: Tensor | None = None[source]
get_summed_loss() Tensor[source]
Returns:

sum of loss (scalar)

get_mean_loss() Tensor[source]
Returns:

sum of loss (scalar)

get_inv_norm_factor() int | Tensor[source]
Returns:

inverse norm factor (scalar)

get_scaled_reduced_loss() Tensor[source]
Returns:

scaled reduced loss (scalar), as it is supposed to be used for calculating the train gradient

returnn.frontend.run_ctx.get_run_ctx() RunCtx[source]
Returns:

current run context, see RunCtx

returnn.frontend.run_ctx.get_run_ctx_step() Tensor[source]
Returns:

shortcut for get_run_ctx().get_step_tensor()

returnn.frontend.run_ctx.init_train_step_run_ctx(*, train_flag: bool | Tensor = True, step: int | Tensor = 0, epoch: int | Tensor = 1)[source]

Call this before the train_step function is called, when you write your own training loop.

Also see init_forward_step_run_ctx().

Parameters:
  • train_flag – whether we intend to do actual training. you might want to use dropout only in this case. (In case of PyTorch, we would also call module.train() first, which will also store this flag internally.) If False, we would call the same train_step function, but we intend to do evaluation with the same loss.

  • step – you might want to schedule dropout or other things depending on the step

  • epoch – you might want to schedule dropout or other things depending on the epoch

returnn.frontend.run_ctx.init_forward_step_run_ctx(*, expected_outputs: TensorDict | None = None, step: int | Tensor = 0, epoch: int | Tensor = 1)[source]

Call this before the forward_step function is called, when you write your own forward loop.

Also see init_train_step_run_ctx().