returnn.tf.util.gradient_checkpoint

Gradient checkpointing.

returnn.tf.util.gradient_checkpoint.gradient_checkpoint_scope()[source]
Returns:

context manager, where all tensors created inside the scope will be recomputed at backprop time, based on existing tensors which have been created earlier outside the scope.

If prepare_gradient_checkpointing() is not called later, this does not have any effect. If no gradients are being calculated, this also does not have any effect.

returnn.tf.util.gradient_checkpoint.gradient_checkpoint_exclude_scope()[source]
Returns:

context manager, where all tensors created inside the scope will be excluded for recomputation at backprop time.

returnn.tf.util.gradient_checkpoint.prepare_gradient_checkpointing()[source]

Call this after the computation graph for calculating the model + loss has been created, before the gradients are calculated (before tf.gradients is called).

This will create a copy of all the ops from within the gradient_checkpoint_scope() scope.

This patches the op._gradient_function of all consuming ops to use the copied ops instead. So effectively, for backpropagation, it will recalculate all such tensors.