returnn.frontend.gradient

Utilities which affect the gradient

returnn.frontend.gradient.set_requires_gradient(source: Tensor)[source]
Parameters:

source

Returns:

nothing, modifies source in-place

returnn.frontend.gradient.gradient(y: Tensor, x: Tensor) Tensor[source]
Parameters:
  • y – some scalar

  • x – some tensor

Returns:

gradient of y w.r.t. x

returnn.frontend.gradient.stop_gradient(source: Tensor) Tensor[source]

wraps tf.stop_gradient or torch detach

returnn.frontend.gradient.stop_gradient_scope()[source]

Create a stop gradient scope. All tensors created within this scope will have their gradient stopped.

Example:

a = ...
b = ...
with stop_gradient_scope():
    x = a + b
y = x * c

In this example, the tensor x will have its gradient stopped, i.e. the gradient of x w.r.t. a and b will be zero.

Returns:

context manager which enables stopping the gradient. It supports __enter__ and __exit__, and the intended usage is with the with statement.

returnn.frontend.gradient.scaled_gradient(source: Tensor, scale: float | Tensor) Tensor[source]
Parameters:
  • source

  • scale – if constant 0., will use stop_gradient(). Can be used as gradient reversal layer (with negative factor).

Returns:

source with scaled gradient

returnn.frontend.gradient.scaled_gradient_ext(source: Tensor, *, scale: float | Tensor, shift: float | Tensor | None = None, scale_shift_by_sum_over_axis: Dim | None = None) Tensor[source]

Just identity in the forward pass. Scales the gradient by some factor in backprop. Can be used as gradient reversal layer (with negative factor). For TF, uses returnn.tf.util.basic.scaled_gradient(), or tf.stop_gradient()

Parameters:
  • source

  • scale – if constant 0. and no shift, will use stop_gradient()

  • shift

  • scale_shift_by_sum_over_axis – if given, calculates the sum over this axis (absolute values) and multiplies the shift value by this sum.

Returns:

source with transformed gradient

returnn.frontend.gradient.gradient_checkpoint_scope()[source]

Create a gradient checkpoint scope. All tensors created within this scope will not be stored for backpropagation, but will be recomputed on the fly during backpropagation.

Example:

a = ...
b = ...
c = ...
with gradient_checkpoint_scope():
    x = a + b
y = x * c

In this example, the tensor x will not be stored for backpropagation, i.e. the computation x = a + b will be recomputed during backpropagation.

See returnn.torch.util.gradient_checkpoint.gradient_checkpoint_scope for more documentation for the PyTorch specific implementation.

Returns:

context manager which enables gradient checkpointing. It supports __enter__ and __exit__, and the intended usage is with the with statement.