Base module class, Module.

class returnn.frontend.module.Module[source]

This can represent a subnetwork in RETURNN.

You can write PyTorch-like code here, like:

class MyModule(rf.Module):

  def __init__(self, dim: Dim, activation=tanh):
    self.layer_norm = rf.LayerNorm(dim)
    self.linear = rf.Linear(dim, dim)
    self.activation = activation

  def __call__(self, x: Tensor) -> Tensor:
    x_ = x
    x = self.layer_norm(x)
    x = self.linear(x)
    x = self.activation(x)
    return x_ + x

A module (here, just like in PyTorch or Keras) has params, but getting some output for some input requires an additional forward or __call__ call, which can be called multiple times. Every such call would then share the same module parameters.

The __init__() would usually get module-level arguments which describe the parameters. As a module might be called multiple times, any input-specific arguments such as spatial dims are usually arguments of __call__(). Other arguments which might vary between calls would also be arguments of __call__() such as epsilon although there are no strict rules.

By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to __call__().

default_initial_state(*, batch_dims: Sequence[Dim]) State | None[source]

default initial state, to be used if the module has recurrent (hidden) state. When a module has recurrent state, the convention is to return a tuple with instance State as the last item, and to accept the state argument with a State with the same nested structure. This can be a nested structure and should match the structure of the state argument and returned value.

get_default_name() str[source]

Get a default layer name (used when we do not have a Module attribute pointing to this). This is used by NameCtx for the RETURNN layer naming (but only when the RETURNN layer name is not implied by other the module attribute hierarchy).

get_deep(target: str) Any[source]

Returns the deep attrib given by target if it exists, otherwise throws an error.

set_deep(target: str, value: Any) None[source]

Sets the deep attrib given by target to value.

children() Iterator[Module][source]

Get all immediate children modules, excluding self.

named_children() Iterator[Tuple[str, Module]][source]

Get all immediate children modules, excluding self.

modules(*, recurse: bool = True, include_self: bool = True) Iterator[Module][source]

Get all children modules, optionally recursively, maybe including self.

named_modules(*, recurse: bool = True, include_self: bool = True, memo: Set[RefIdEq[Module]] | None = None, prefix: str = '') Iterator[Tuple[str, Module]][source]

Get all children modules (including self iff include_self=True (default)), optionally recursively.

named_parameters(*, recurse: bool = True) Iterator[Tuple[str, Parameter]][source]

Get all children parameters, together with their names.

With recurse=True (default), this iterates over all children modules and iterates through their parameters as well.

parameters(*, recurse: bool = True) Iterator[Parameter][source]

Get all children parameters. Also see named_parameters() for some more documentation.

property has_parameters[source]

Whether this module has variables

apply(fn: Callable[[Module], None]) T[source]

Applies the function fn to all children modules and self.



to(*, device: str | None = None, dtype: str | None = None)[source]

Move all parameters to the specified device and/or dtype.

This is an inplace operation. Afterward, all parameters are on the new device/dtype. See

register_forward_hook(hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Any | None], *, prepend: bool = False) RemovableHandle[source]

Register forward hook. Very similar to PyTorch. (Code even partly copied from PyTorch.) This will be called after __call__, following this logic (where module=self):

result = module(*args, **kwargs) result = hook(args, kwargs, result)

  • hook

  • prepend – if there are multiple hooks, this will be registered in front of all, otherwise at the end

class returnn.frontend.module.Functional(func)[source]

Used for functions (pure functional, i.e. not methods of another module) and via ModuleList to wrap up any functions or lambdas as modules.

(This is often not necessary, but sometimes useful.)

By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to __call__().

get_default_name() str[source]

default name