returnn.frontend.normalization¶
Normalization functions such as batch norm
- returnn.frontend.normalization.moments(x: Tensor, axis: Dim | Sequence[Dim], *, use_mask: bool = True, correction: int | float | Tensor = 0, distributed: bool = False) Tuple[Tensor, Tensor][source]¶
- Parameters:
x – input
axis – the axis (or axes) to be reduced, to calculate statistics over
use_mask – whether to use a mask for dynamic spatial dims in the reduction
correction –
The variance will be estimated by
sum((x - mean)**2) / (n-correction)wherenis the number of elements in the axis (or the axes) (withuse_mask=True, taking masking into account, usingnum_elements_of_shape()). The defaultcorrection=0will return the biased variance estimation.correction=1is the Bessel correction and will return the unbiased variance estimation. In PyTorch, there was an argumentunbiasedfor this, but this changed recently tocorrection(PyTorch issue #61492,- `Python Array API Standard
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html>`__).
In PyTorch, the default is
correction=1, which is the unbiased variance estimation, while in most other frameworks, the default iscorrection=0, which is the biased variance estimation.distributed – If True and a Torch DDP process group exists (world size > 1), compute the statistics over the global batch across all workers, by all-reducing the per-worker sum / sum-of-squares (differentiable) and count. This matches torch.nn.SyncBatchNorm. Default False keeps the per-worker (local) statistics.
- Returns:
tuple (mean, variance). it has the same shape as the input with the axis removed
- class returnn.frontend.normalization.LayerNorm(in_dim: Dim | Sequence[Dim], *, eps: float = 1e-06, with_bias: bool = True)[source]¶
-
Note that we just normalize over the feature-dim axis here. This is consistent to the default behavior of
tf.keras.layers.LayerNormalizationand also how it is commonly used in many models, including Transformer.However, there are cases where it would be common to normalize over all axes except batch-dim, or all axes except batch and time. For a more generic variant, see
norm().By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to
__call__().
- class returnn.frontend.normalization.RMSNorm(in_dim: Dim | Sequence[Dim], *, eps: float = 1e-06, with_bias: bool = False)[source]¶
Root Mean Square Layer Normalization (RMSNorm).
Alternative to
LayerNormthat uses the root-mean-square of the input as the normalization factor. I.e. the main difference to layer norm is: No subtraction of mean.Note, the bias here is optional, and disabled by default (in line with most implementations of RMSNorm), unlike our
LayerNorm, where the bias is enabled by default.By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to
__call__().
- class returnn.frontend.normalization.GroupNorm(in_dim: Dim | Sequence[Dim], *, num_groups: int | Dim, eps: float = 1e-06)[source]¶
-
Note: this is non-standard. It reduces the statistics only over the in-group channels, independently per spatial/time position, i.e. it does not pool over the spatial dims. This differs from
torch.nn.GroupNorm/ the GroupNorm paper. For the standard, spatially-pooled (batch-independent) variant, useGroupNormSpatial.By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to
__call__().
- class returnn.frontend.normalization.GroupNormSpatial(in_dim: Dim | Sequence[Dim], *, num_groups: int | Dim, eps: float = 1e-06)[source]¶
Standard (spatially-pooled) group normalization, equivalent to
torch.nn.GroupNorm: the mean/variance are computed over the in-group channels AND the given spatial dim(s), per batch element and per group (so batch-independent).Unlike
GroupNorm(which reduces only over the in-group channels, i.e. per spatial position), the spatial dim(s) must be passed explicitly to__call__– they cannot be reliably inferred from the input tensor.spatial_dimmay be a singleDimor a sequence of dims. Withnum_groups=1this normalizes over all channels and spatial positions per sequence (what torchaudio’s Conformer uses); it is not the same asLayerNorm.By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to
__call__().
- class returnn.frontend.normalization.BatchNorm(in_dim: Dim, *, affine: bool = True, momentum: float = 0.1, eps: float = 0.001, track_running_stats: bool = True, use_mask: bool | None = None, distributed: bool | None = None)[source]¶
Batch normalization. https://arxiv.org/abs/1502.03167
Note that the default arguments differ from corresponding batch norm in RETURNN. See here for discussion on defaults: https://github.com/rwth-i6/returnn/issues/522
We calculate statistics over all axes except the given in_dim. I.e. all other axes are reduced for the statistics.
To compensate the normalization, there are learnable parameters gamma and beta (optional, used when option affine is True).
The usual behavior depends on whether this is used in training or evaluation, although this often configurable in other frameworks. The usual behavior, in training:
# Using statistics from current batch. mean_cur_batch, variance_cur_batch = moments(source, reduce_dims) y = (x - mean_cur_batch) / sqrt(variance_cur_batch + epsilon) y = gamma * y + beta # Updating running statistics for later use. mean = (1 - momentum) * mean + momentum * mean_cur_batch variance = (1 - momentum) * variance + momentum * variance_cur_batch
The usual behavior, not in training (i.e. in evaluation):
# Using collected statistics. Not using statistics from current batch. y = (x - mean) / sqrt(variance + epsilon) y = gamma * y + beta
- Parameters:
in_dim – the feature dimension of the input
affine – whether to use learnable parameters gamma and beta
momentum – momentum for the running mean and variance
eps – epsilon for the variance
track_running_stats – If True, uses statistics of the current batch for normalization during training, and the tracked statistics (running mean and variance) during evaluation. If False, uses statistics of the current batch for normalization during both training and evaluation.
use_mask –
whether to use a mask for dynamic spatial dims. This must be specified if the input has dynamic spatial dims. True would use the correct masking then. However, that is inconsistent to all other frameworks
which ignore the masking, and also slower, and the fused op would not be used.
- False would be consistent to all other frameworks,
and potentially allows for the use of an efficient fused op internally.
distributed – compute batch statistics over the global batch across all DDP workers (SyncBatchNorm-style) instead of per-worker. None (default) reads the global config option
rf_batch_norm_distributed(default False). Only meaningful under Torch DDP grad-sync.
- returnn.frontend.normalization.batch_norm_distributed_default() bool[source]¶
Global-config default for
BatchNormdistributed(SyncBatchNorm-style global stats). Controlled via the optionrf_batch_norm_distributed, mirroringrf_dropout_broadcast. Default False (per-worker stats). Only has an effect under Torch DDP; enable it explicitly there.
- returnn.frontend.normalization.normalize(a: Tensor, *, axis: Dim | Sequence[Dim], epsilon: float = 1e-06) Tensor[source]¶
Mean- and variance-normalize some input in the given input dimension(s), such that the resulting tensor has mean 0 and variance 1.
If you want that this can be shifted and scaled again, you need additional parameters, cf.
Normalize.- Parameters:
a – input
axis – axis over which the mean and variance are computed
epsilon – epsilon for numerical stability
- Returns:
(a - mean) / sqrt(variance + epsilon)
- class returnn.frontend.normalization.Normalize(*, param_dims: Dim | Sequence[Dim], epsilon: float = 1e-06, scale: bool = True, bias: bool = True)[source]¶
normalize()with additional scale and bias- Parameters:
param_dims – shape of the scale and bias parameters
epsilon – epsilon for numerical stability
scale – whether to include a trainable scale
bias – whether to include a trainable bias