returnn.torch.frontend.bridge

Direct bridge between pure PyTorch code and modules and RF code and modules.

https://github.com/rwth-i6/returnn/issues/1287

returnn.torch.frontend.bridge.pt_module_to_rf_module(pt_module: Module) Module[source]
Parameters:

pt_module – torch module

Returns:

RF module

returnn.torch.frontend.bridge.wrapped_pt_module_to_rf_module(pt_module: Module) Module | None[source]
Parameters:

pt_module – torch module

Returns:

RF module if the torch module is a wrapped RF module, or None otherwise

returnn.torch.frontend.bridge.rf_module_to_pt_module(rf_module: Module, *, aux_params_as_buffers: bool = True) Module[source]
Parameters:
  • rf_module – RF module

  • aux_params_as_buffers – whether to map RF auxiliary parameters to PyTorch buffers, otherwise to normal parameters, i.e. they occur in model.named_parameters(). Note that even when they are part of model.named_parameters(), aux params usually don’t have a gradient, and then they are not updated by the optimizer. Historically, this was False. Now, this is True by default, as this is more reasonable. Note that the optimizer state dict will change if you change this, however, we will automatically convert such optimizer state dict.

Returns:

torch module

class returnn.torch.frontend.bridge.PTModuleAsRFModule(pt_module: Module)[source]

Wrapped module.

It is recommended to use pt_module_to_rf_module() instead of using this directly.

By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to __call__().

property pt_module: Module[source]

RF module

class returnn.torch.frontend.bridge.RFModuleAsPTModule(rf_module: Module, *, aux_params_as_buffers: bool = True)[source]

Wrapped module.

It is recommended to use rf_module_to_pt_module() instead of using this directly.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

property rf_module: Module[source]

RF module

forward(*args, **kwargs)[source]
register_parameter(name: str, param: Parameter | None) None[source]

(re)register parameter

register_buffer(name: str, tensor: Tensor | None, persistent: bool = True) None[source]

(re)register buffer