returnn.tf.frontend_layers.cond
¶
Conditional logic
https://github.com/rwth-i6/returnn_common/issues/24
- class returnn.tf.frontend_layers.cond.Cond(condition: Tensor, *, name: str = 'cond')[source]¶
Conditional branching. Basically behaves like
if ... else ...
. Only one branch will be executed, and the condition needs to be a bool scalar. This wraps toCondLayer
in RETURNN and totf.cond
in TensorFlow.Example:
with Cond(cond) as cond_obj: cond_obj.true = mod_true_case(x) cond_obj.false = mod_false_case(x) y = cond_obj.result
Corresponds to:
if cond: y = mod_true_case(x) else: y = mod_false_case(x)
The context scope has two states corresponding to the True and False computation branch. The initial state is the True branch. Assigning
cond_obj.true
has the side effect of switching the computation to the False branch.
- class returnn.tf.frontend_layers.cond.CondModule(cond: Cond)[source]¶
This module is used internally by
Cond
to create the RETURNNCondLayer
for the conditional code. This module would not be directly used by the user.By convention, any options to the module are passed to __init__, and potential changing inputs (other tensors) are passed to
__call__()
.