returnn.tf.frontend_layers.cond

Conditional logic

https://github.com/rwth-i6/returnn_common/issues/24

class returnn.tf.frontend_layers.cond.Cond(condition: Tensor, *, name: str = 'cond')[source]

Conditional branching. Basically behaves like if ... else .... Only one branch will be executed, and the condition needs to be a bool scalar. This wraps to CondLayer in RETURNN and to tf.cond in TensorFlow.

Example:

with Cond(cond) as cond_obj:
  cond_obj.true = mod_true_case(x)
  cond_obj.false = mod_false_case(x)
  y = cond_obj.result

Corresponds to:

if cond:
  y = mod_true_case(x)
else:
  y = mod_false_case(x)

The context scope has two states corresponding to the True and False computation branch. The initial state is the True branch. Assigning cond_obj.true has the side effect of switching the computation to the False branch.

property true: T[source]

The getter usually would not be used.

property false: T[source]

The getter usually would not be used.

property result: T[source]
Returns:

the result, after you assigned true() and false().

add_op_to_current_branch(op: Tensor)[source]
Parameters:

op – like an assign_op. the value of the tensor is irrelevant, the underlying op is relevant

add_other_branch_prehook(callback: Callable[[], Any])[source]

add prehook to the other branch

add_other_branch_posthook(callback: Callable[[], Any])[source]

add posthook to the other branch

class returnn.tf.frontend_layers.cond.CondModule(cond: Cond)[source]

This module is used internally by Cond to create the RETURNN CondLayer for the conditional code. This module would not be directly used by the user.

By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to __call__().