returnn.frontend.module¶
Base module class, Module.
- class returnn.frontend.module.Module[source]¶
This can represent a subnetwork in RETURNN.
You can write PyTorch-like code here, like:
class MyModule(rf.Module): def __init__(self, dim: Dim, activation=tanh): super().__init__() self.layer_norm = rf.LayerNorm(dim) self.linear = rf.Linear(dim, dim) self.activation = activation def __call__(self, x: Tensor) -> Tensor: x_ = x x = self.layer_norm(x) x = self.linear(x) x = self.activation(x) return x_ + x
A module (here, just like in PyTorch or Keras) has params, but getting some output for some input requires an additional forward or __call__ call, which can be called multiple times. Every such call would then share the same module parameters.
The
__init__()would usually get module-level arguments which describe the parameters. As a module might be called multiple times, any input-specific arguments such as spatial dims are usually arguments of__call__(). Other arguments which might vary between calls would also be arguments of__call__()such as epsilon although there are no strict rules.By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to
__call__().- default_initial_state(*, batch_dims: Sequence[Dim]) State | None[source]¶
- Returns:
default initial state, to be used if the module has recurrent (hidden) state. When a module has recurrent state, the convention is to return a tuple with instance
Stateas the last item, and to accept thestateargument with aStatewith the same nested structure. This can be a nested structure and should match the structure of thestateargument and returned value.
- get_default_name() str[source]¶
Get a default layer name (used when we do not have a Module attribute pointing to this). This is used by
NameCtxfor the RETURNN layer naming (but only when the RETURNN layer name is not implied by other the module attribute hierarchy).
- get_deep(target: str) Any[source]¶
Returns the deep attrib given by
targetif it exists, otherwise throws an error.
- named_children() Iterator[Tuple[str, Module]][source]¶
Get all immediate children modules, excluding self.
- modules(*, recurse: bool = True, include_self: bool = True) Iterator[Module][source]¶
Get all children modules, optionally recursively, maybe including self.
- named_modules(*, recurse: bool = True, include_self: bool = True, memo: Set[RefIdEq[Module]] | None = None, prefix: str = '') Iterator[Tuple[str, Module]][source]¶
Get all children modules (including self iff include_self=True (default)), optionally recursively.
- named_parameters(*, recurse: bool = True) Iterator[Tuple[str, Parameter]][source]¶
Get all children parameters, together with their names.
With recurse=True (default), this iterates over all children modules and iterates through their parameters as well.
- parameters(*, recurse: bool = True) Iterator[Parameter][source]¶
Get all children parameters. Also see
named_parameters()for some more documentation.
- apply(fn: Callable[[Module], None]) T[source]¶
Applies the function
fnto all children modules and self.- Returns:
self
- to(*, device: str | None = None, dtype: str | None = None)[source]¶
Move all parameters to the specified device and/or dtype.
This is an inplace operation. Afterward, all parameters are on the new device/dtype. See
rf.Parameter.to().
- register_forward_hook(hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Any | None], *, prepend: bool = False) RemovableHandle[source]¶
Register forward hook. Very similar to PyTorch. (Code even partly copied from PyTorch.) This will be called after __call__, following this logic (where module=self):
- Parameters:
hook
prepend – if there are multiple hooks, this will be registered in front of all, otherwise at the end
- class returnn.frontend.module.Functional(func, *, attribs: Dict[str, Any] | None = None)[source]¶
Used for functions (pure functional, i.e. not methods of another module) and via
ModuleListto wrap up any functions or lambdas as modules.(This is often not necessary, but sometimes useful.)
- Parameters:
func – callable. you might want to use functools.partial if you want to fix some arguments.
attribs – optional dict of attributes to set on this module. e.g.
out_dim.