returnn.torch.util.distributed¶
Differentiable distributed collectives for the Torch backend.
See also returnn.torch.distributed for the DDP context setup;
this module holds low-level autograd-aware collective ops.
- returnn.torch.util.distributed.all_reduce_sum(x: Tensor, *, group=None) Tensor[source]¶
Differentiable all-reduce (sum) across the distributed worker group.
Unlike a plain
torch.distributed.all_reduce, this propagates gradients correctly (the backward all-reduce-sums the gradient), so it can be used inside the model forward, e.g. for SyncBatchNorm-style statistics. We avoidtorch.distributed.nn.functional.all_reducebecause it is deprecated, and its backward is only correct for sum anyway.- Parameters:
x – local tensor, same shape on every worker
group – process group, or None for the default group
- Returns:
the sum of
xacross all workers, same shape, differentiable