returnn.frontend.gradient¶
Utilities which affect the gradient
- returnn.frontend.gradient.set_requires_gradient(source: Tensor)[source]¶
- Parameters:
source
- Returns:
nothing, modifies source in-place
- returnn.frontend.gradient.gradient(y: Tensor, x: Tensor) Tensor[source]¶
- Parameters:
y – some scalar
x – some tensor
- Returns:
gradient of y w.r.t. x
- returnn.frontend.gradient.stop_gradient(source: Tensor) Tensor[source]¶
wraps tf.stop_gradient or torch detach
- returnn.frontend.gradient.stop_gradient_scope()[source]¶
Create a stop gradient scope. All tensors created within this scope will have their gradient stopped.
Example:
a = ... b = ... with stop_gradient_scope(): x = a + b y = x * c
In this example, the tensor
xwill have its gradient stopped, i.e. the gradient ofxw.r.t.aandbwill be zero.- Returns:
context manager which enables stopping the gradient. It supports __enter__ and __exit__, and the intended usage is with the with statement.
- returnn.frontend.gradient.scaled_gradient(source: Tensor, scale: float | Tensor) Tensor[source]¶
- Parameters:
source
scale – if constant 0., will use
stop_gradient(). Can be used as gradient reversal layer (with negative factor).
- Returns:
source with scaled gradient
- returnn.frontend.gradient.scaled_gradient_ext(source: Tensor, *, scale: float | Tensor, shift: float | Tensor | None = None, scale_shift_by_sum_over_axis: Dim | None = None) Tensor[source]¶
Just identity in the forward pass. Scales the gradient by some factor in backprop. Can be used as gradient reversal layer (with negative factor). For TF, uses
returnn.tf.util.basic.scaled_gradient(), ortf.stop_gradient()- Parameters:
source
scale – if constant 0. and no shift, will use
stop_gradient()shift
scale_shift_by_sum_over_axis – if given, calculates the sum over this axis (absolute values) and multiplies the shift value by this sum.
- Returns:
source with transformed gradient
- returnn.frontend.gradient.gradient_checkpoint_scope()[source]¶
Create a gradient checkpoint scope. All tensors created within this scope will not be stored for backpropagation, but will be recomputed on the fly during backpropagation.
Example:
a = ... b = ... c = ... with gradient_checkpoint_scope(): x = a + b y = x * c
In this example, the tensor
xwill not be stored for backpropagation, i.e. the computationx = a + bwill be recomputed during backpropagation.See
returnn.torch.util.gradient_checkpoint.gradient_checkpoint_scopefor more documentation for the PyTorch specific implementation.- Returns:
context manager which enables gradient checkpointing. It supports __enter__ and __exit__, and the intended usage is with the with statement.