returnn.torch.distributed

torch.distributed utils

class returnn.torch.distributed.DistributedContext(options: Dict[str, Any])[source]

This class setups some helper functions for torch distributed training

local_rank() int[source]

local rank

local_size() int[source]

local size

rank() int[source]

global rank

size() int[source]

global size

get_param_sync_step() int | None[source]

param sync step

maybe_make_distributed_module(module: Module) DistributedDataParallel | None[source]

Maybe make a wrapped distributed module.

Parameters:

module – original module

Returns:

potentially wrapped module

step_after_param_update(*, module: Module, epoch_step_idx: int)[source]

one train step

returnn.torch.distributed.get_ctx(config=None) DistributedContext | None[source]
Parameters:

config (Config|None)

Returns:

the global context if Torch distributed is enabled, or None otherwise. If we did not setup the context yet, it will automatically create it.