returnn.torch.util.gradient_checkpoint
¶
Gradient checkpointing.
Following a lot of the code of the official
torch.utils.checkpoint,
using torch.autograd.graph.saved_tensors_hooks
and TorchDispatchMode
but also handling the RNG fork and reset in a similar way.
See also returnn.tf.util.gradient_checkpoint
:
same API and logic in TF, although it heavily makes use
of the TF computation graph, i.e. graph mode,
which makes this particular feature much easier to implement.
See also: https://github.com/rwth-i6/returnn/issues/1552 https://discuss.pytorch.org/t/gradient-checkpointing/205416 https://gist.github.com/soulitzer/ec1049a947be046de7fbc2af61a4ee8c
- class returnn.torch.util.gradient_checkpoint.gradient_checkpoint_scope[source]¶
Create a gradient checkpoint scope. All tensors created within this scope will not be stored for backpropagation, but will be recomputed on the fly during backpropagation.
Example:
a = ... b = ... c = ... with gradient_checkpoint_scope(): x = a + b y = x * c
In this example, the tensor
x
will not be stored for backpropagation, i.e. the computationx = a + b
will be recomputed during backpropagation.Internally, this uses the PyTorch
torch.autograd.graph.saved_tensors_hooks
mechanism to override what we store for backpropagation, and how to recompute it. And we use the PyTorchTorchDispatchMode
to intercept all operations within the scope. Note that the usage oftorch.autograd.graph.saved_tensors_hooks
is tricky here as we need it beyond the scope of thegradient_checkpoint_scope
, specifically for all future usages of the tensorx
in the example. See the code documentation for more details on this.Note, PyTorch itself also provides a gradient checkpointing API, namely torch.utils.checkpoint. This API is different: You cannot easily specify what not to store / what to recompute. You rather specify a start/end point what to store for backpropagation, and then PyTorch will recompute everything in between. For the example above, you define that
y
is the end point and will be stored. It looks like this:a = ... b = ... c = ... y = torch.utils.checkpoint.checkpoint(lambda: (a + b) * c)
PyTorch will not recompute
... * c
here, but it will recomputea + b
. We find this API more cumbersome to use and less flexible, because in many case, you know what you want to recompute, i.e. what you don’t want to store. The PyTorch API is more about what you want to store, and then recompute everything else between.See also: https://github.com/rwth-i6/returnn/issues/1552 https://discuss.pytorch.org/t/gradient-checkpointing/205416