returnn.tf.frontend_layers.masked_computation

Masked computation. Wrap MaskedComputationLayer in RETURNN.

https://github.com/rwth-i6/returnn_common/issues/23

class returnn.tf.frontend_layers.masked_computation.MaskedComputation(mask: Tensor, *, name: str = 'masked_computation')[source]

This is expected to be inside a Loop.

Usage example:

loop = nn.Loop(...)
loop.state.y = ...  # some initial output
loop.state.h = ...  # some initial state
with loop:

  mask = ...  # dtype bool, shape [batch] or whatever, for current (fast) frame
  with nn.MaskedComputation(mask=mask):
    loop.state.y, loop.state.h = slow_rnn(x, loop.state.h)
  y = loop.state.y  # access from outside

This is equivalent to:

loop = nn.Loop(...)
loop.state.y = ...  # some initial output
loop.state.h = ...  # some initial state
with loop:

  mask = ...  # dtype bool, shape [batch] or whatever, for current frame
  y_, h_ = slow_rnn(x, loop.state.h)
  loop.state.y = nest.map(lambda a, b: nn.where(cond=mask, x=a, y=b), y_, loop.state.y)
  loop.state.h = nest.map(lambda a, b: nn.where(cond=mask, x=a, y=b), h_, loop.state.h)
  y = loop.state.y

In pseudocode, non-batched (mask is just a scalar bool), it would look like:

y = ...  # some initial output
h = ...  # some initial state
while True:

  mask = ...  # bool
  if mask:
    y, h = slow_rnn(x, h)
Parameters:

mask (Tensor) – bool, shape [batch]

class returnn.tf.frontend_layers.masked_computation.MaskedComputationModule(masked_computation: MaskedComputation)[source]

This is for internal use by MaskedComputation.

By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to __call__().