returnn.torch.util.debug_inf_nan
¶
Helpers to debug nan/inf values in tensors. E.g., you get nan/inf values in the loss, and you want to know where it comes from. There could be multiple potential sources:
- The parameters are already broken (nan/inf).
Then some prev step caused this. For this, we might want to add another option which performs a check before we update params, so that updating params will never break them unnoticed.
- The gradients are broken (nan/inf).
There are some PyTorch utilities to check this. This is currently not the focus here.
- Some part of the (forward) computation results in nan/inf.
Currently, this is the focus here. We want to know where this happens.
We could run the forward pass again in different modes:
- Python tracing, and inspecting all local variables which are tensors.
(Probably slow).
- PyTorch JIT tracing to compute the loss. This will give us the computation graph.
We can run this computation graph again and inspect all the intermediate values, and then see where the nan/inf values come from.
PyTorch profiling.
Note, one problem is non-determinism in the computation via e.g. dropout. So the method might not be totally reliable. Also, there might be inf/nan values which are ok, expected, and not a problem (e.g. masking the logits for attention). So we don’t stop on the first occurrence but just report all of them.
- returnn.torch.util.debug_inf_nan.debug_inf_nan(func: Callable[[], Tensor | None], *, with_grad: bool = False, report_every_op_call: bool = True, stop_reporting_after_first_inf_nan: bool = True, file: TextIO | TextIOBase | None = None)[source]¶
Debug the function.
- Parameters:
func – will be called like func(). if with_grad, we expect some loss tensor as return, and we will call loss = func(); loss.backward().
with_grad – whether to compute and debug gradients for inf/nan.
report_every_op_call – whether to report every op call.
stop_reporting_after_first_inf_nan – whether to stop reporting after the first inf/nan.
file – where to write the output to. Default is stdout.