Recurrent state

class returnn.frontend.state.State(*args, **kwargs)[source]

Covers all the state of a recurrent module, i.e. exactly what needs to be stored and passed into the module next time you call it as initial state.

This behaves somewhat like a namedtuple, although we derive from dict.

When you derive further from this class, make sure that it works correctly with tree, which creates new instances of the same class by calling type(instance)(keys_and_values) with keys_and_values = ((key, result[key]) for key in instance). See LstmState for an example:

class LstmState(rf.State):
    def __init__(self, *_args, h: Tensor = None, c: Tensor = None):
        if not _args:
            self.h = h
            self.c = c

Also see:

flatten_tensors() List[Tensor][source]

See cls_deep_tensors().

classmethod cls_flatten_tensors(obj: State | dict | Any) List[Tensor][source]

Iterates through obj and all its sub-objects, yielding all tensors.