returnn.frontend.dropout

Dropout

returnn.frontend.dropout.dropout(source: Tensor, drop_prob: float | Tensor, *, axis: Dim | Sequence[Dim] | bool | None = None, on_forward: bool = False) Tensor[source]

Applies dropout.

Dropout will only be applied during training (unless you set on_forward=True).

When dropout is applied, the output will be scaled by 1/dropout.

Parameters:
  • source

  • drop_prob – 0.0 means to apply no dropout. 100% would mask everything. For every value in the tensor, the probability of it being dropped is drawn independently given this probability. The broadcasted axes are those not specified in axis.

  • axis – axis to apply dropout on. multiple axes can be specified. This defines the set of axes where the dropout mask is not broadcasted to. If None (default), it will not broadcast on any axis. False is the same as None, and allows to write axis=use_dropout_broadcast and ...feature_dim. (RETURNN also has the noise_shape option but the axis option provides the same functionality.)

  • on_forward – apply dropout during inference and training (so just always). otherwise only during training.

returnn.frontend.dropout.dropout_broadcast_default() bool[source]

Check the global RETURNN config whether we should broadcast on non-related dropout dimensions.

Historically in RETURNN, when we did dropout in the feature dimension, we broadcasted the dropout mask over the other dimensions (e.g. time and batch).

This function provides an easy global config controllable way to control this, via the option rf_dropout_broadcast.

The default for now: keep same as historical RETURNN, unless we find that this is really not a good idea. Then we might change the default via a new behavior version.

Also see the option rf_att_dropout_broadcast, which does the same for attention dropout. Although the default for attention dropout broadcasting was already changed with behavior version 19.

Returns:

whether broadcasting should be used for dropout. Note that this does not actually effect dropout(). Any user of dropout() should check this explicitly.