returnn.frontend.nested

Some utility functions on nested structures.

returnn.frontend.nested.mask_nested(s: T, *, mask: Tensor, mask_cpu: Tensor | None = None, mask_value: T | Tensor | float | None, dim_map: Dict[Dim, Dim] | None = None, allow_dim_extension: bool = True) T[source]

Applies where(mask, s, mask_value) for nested structures.

Parameters:
  • s

  • mask

  • mask_cpu – mask tensor for CPU. this is used e.g. for dyn dim sizes

  • mask_value

  • dim_map

  • allow_dim_extension

Returns:

s with masked values

returnn.frontend.nested.gather_nested(s: T, *, indices: Tensor, dim_map: Dict[Dim, Dim] | None = None) T[source]

This is like gather(), but for nested structures.

Parameters:
  • s – nested structure

  • indices – indices tensor. see gather()

  • dim_map – if given, this will be updated with the new dim map

Returns:

s with gathered tensors

returnn.frontend.nested.masked_select_nested(s: T, *, mask: Tensor, mask_cpu: Tensor | None = None, dims: Sequence[Dim], out_dim: Dim | None = None, dim_map: Dict[Dim, Dim] | None = None) Tuple[T, Dim, Dict[Dim, Dim]][source]

This is like masked_select(), but for nested structures.

Parameters:
  • s – nested structure

  • mask – mask tensor. see masked_select()

  • mask_cpu – mask tensor for CPU. this is used e.g. for dyn dim sizes

  • dims – dims to mask. see masked_select()

  • out_dim – the packed out dim. see masked_select(). if not given, a new one will be created.

  • dim_map – if given, this will be updated with the new dim map

Returns:

s with masked dims, out_dim, and a newly created dim map

returnn.frontend.nested.masked_scatter_nested(s: T, backup: T, *, mask: Tensor, mask_cpu: Tensor, dims: Sequence[Dim], in_dim: Dim, masked_select_dim_map: Dict[Dim, Dim], masked_scatter_dim_map: Dict[Dim, Dim] | None = None) T[source]

Reverse of masked_select_nested().

Parameters:
  • s – nested structure, where dims are packed, i.e. (in_dim,…)

  • backup – nested structure, where we scatter into. tensors like (dims…,…)

  • mask – mask tensor. see masked_scatter()/masked_select()

  • mask_cpu – mask tensor for CPU. this is used e.g. for dyn dim sizes. see masked_scatter()

  • dims – dims to mask. see masked_scatter()/masked_select()

  • in_dim – the packed in dim. see masked_scatter()

  • masked_select_dim_map – the dim map from masked_select_nested(). This describes how to map dims from s to backup.

  • masked_scatter_dim_map – for any new dims created by this function, this will be updated

Returns:

backup with s scattered in