returnn.torch.util.gradient_checkpoint

Gradient checkpointing.

Following a lot of the code of the official torch.utils.checkpoint, using torch.autograd.graph.saved_tensors_hooks and TorchDispatchMode but also handling the RNG fork and reset in a similar way.

See also returnn.tf.util.gradient_checkpoint: same API and logic in TF, although it heavily makes use of the TF computation graph, i.e. graph mode, which makes this particular feature much easier to implement.

See also: https://github.com/rwth-i6/returnn/issues/1552 https://discuss.pytorch.org/t/gradient-checkpointing/205416 https://gist.github.com/soulitzer/ec1049a947be046de7fbc2af61a4ee8c

class returnn.torch.util.gradient_checkpoint.gradient_checkpoint_scope[source]

Create a gradient checkpoint scope. All tensors created within this scope will not be stored for backpropagation, but will be recomputed on the fly during backpropagation.

Example:

a = ...
b = ...
c = ...
with gradient_checkpoint_scope():
    x = a + b
y = x * c

In this example, the tensor x will not be stored for backpropagation, i.e. the computation x = a + b will be recomputed during backpropagation.

Internally, this uses the PyTorch torch.autograd.graph.saved_tensors_hooks mechanism to override what we store for backpropagation, and how to recompute it. And we use the PyTorch TorchDispatchMode to intercept all operations within the scope. Note that the usage of torch.autograd.graph.saved_tensors_hooks is tricky here as we need it beyond the scope of the gradient_checkpoint_scope, specifically for all future usages of the tensor x in the example. See the code documentation for more details on this.

Note, PyTorch itself also provides a gradient checkpointing API, namely torch.utils.checkpoint. This API is different: You cannot easily specify what not to store / what to recompute. You rather specify a start/end point what to store for backpropagation, and then PyTorch will recompute everything in between. For the example above, you define that y is the end point and will be stored. It looks like this:

a = ...
b = ...
c = ...
y = torch.utils.checkpoint.checkpoint(lambda: (a + b) * c)

PyTorch will not recompute ... * c here, but it will recompute a + b. We find this API more cumbersome to use and less flexible, because in many case, you know what you want to recompute, i.e. what you don’t want to store. The PyTorch API is more about what you want to store, and then recompute everything else between.

See also: https://github.com/rwth-i6/returnn/issues/1552 https://discuss.pytorch.org/t/gradient-checkpointing/205416

exit_saved_tensors_hooks_scope()[source]

exit saved_tensors_hooks_scope if not yet done.