returnn.torch.util.scaled_gradient
¶
Scaled gradients for backward pass. This also covers gradient reversal, which is simply the case with scale=-1. We actually extend the simple scaling by some further optional transformations like shifting.
The code is adapted from our TF implementation, see returnn.tf.util.basic.scaled_gradient()
.
For some discussion on the specific implementation, see: https://discuss.pytorch.org/t/gradient-scaling-reversal/186392
Also see other reference implementations: https://github.com/facebookresearch/fairseq/blob/100cd91db19bb/fairseq/modules/grad_multiply.py https://github.com/janfreyberg/pytorch-revgrad/blob/449fa763a76d/src/pytorch_revgrad/functional.py https://github.com/tadeephuy/GradientReversal/blob/5d9857d63/gradient_reversal/functional.py
- returnn.torch.util.scaled_gradient.scaled_gradient(x: Tensor, scale: float) Tensor [source]¶
- Parameters:
x
scale
- Returns:
just x, however, in backward pass, the gradient is scaled by the given factor
- returnn.torch.util.scaled_gradient.scaled_gradient_ext(x: Tensor, *, scale: float | Tensor = 1.0, shift: float | Tensor | None = None, scale_shift_by_sum_over_axis: int | None = None)[source]¶
- Parameters:
x
scale – will scale gradient by this value
shift – will shift gradient by this value
scale_shift_by_sum_over_axis – if given, will scale and shift by the sum over the given axis
- Returns:
just x, but gradient in backward pass will be transformed accordingly