returnn.torch.util.scaled_gradient

Scaled gradients for backward pass. This also covers gradient reversal, which is simply the case with scale=-1. We actually extend the simple scaling by some further optional transformations like shifting.

The code is adapted from our TF implementation, see returnn.tf.util.basic.scaled_gradient().

For some discussion on the specific implementation, see: https://discuss.pytorch.org/t/gradient-scaling-reversal/186392

Also see other reference implementations: https://github.com/facebookresearch/fairseq/blob/100cd91db19bb/fairseq/modules/grad_multiply.py https://github.com/janfreyberg/pytorch-revgrad/blob/449fa763a76d/src/pytorch_revgrad/functional.py https://github.com/tadeephuy/GradientReversal/blob/5d9857d63/gradient_reversal/functional.py

returnn.torch.util.scaled_gradient.scaled_gradient(x: Tensor, scale: float) Tensor[source]
Parameters:
  • x

  • scale

Returns:

just x, however, in backward pass, the gradient is scaled by the given factor

returnn.torch.util.scaled_gradient.scaled_gradient_ext(x: Tensor, *, scale: float | Tensor = 1.0, shift: float | Tensor | None = None, scale_shift_by_sum_over_axis: int | None = None)[source]
Parameters:
  • x

  • scale – will scale gradient by this value

  • shift – will shift gradient by this value

  • scale_shift_by_sum_over_axis – if given, will scale and shift by the sum over the given axis

Returns:

just x, but gradient in backward pass will be transformed accordingly