returnn.torch.util.distributed

Differentiable distributed collectives for the Torch backend.

See also returnn.torch.distributed for the DDP context setup; this module holds low-level autograd-aware collective ops.

returnn.torch.util.distributed.all_reduce_sum(x: Tensor, *, group=None) Tensor[source]

Differentiable all-reduce (sum) across the distributed worker group.

Unlike a plain torch.distributed.all_reduce, this propagates gradients correctly (the backward all-reduce-sums the gradient), so it can be used inside the model forward, e.g. for SyncBatchNorm-style statistics. We avoid torch.distributed.nn.functional.all_reduce because it is deprecated, and its backward is only correct for sum anyway.

Parameters:
  • x – local tensor, same shape on every worker

  • group – process group, or None for the default group

Returns:

the sum of x across all workers, same shape, differentiable