returnn.frontend.nested
¶
Some utility functions on nested structures.
- returnn.frontend.nested.mask_nested(s: T, *, mask: Tensor, mask_cpu: Tensor | None = None, mask_value: T | Tensor | float | None, dim_map: Dict[Dim, Dim] | None = None, allow_dim_extension: bool = True) T [source]¶
Applies where(mask, s, mask_value) for nested structures.
- Parameters:
s
mask
mask_cpu – mask tensor for CPU. this is used e.g. for dyn dim sizes
mask_value
dim_map
allow_dim_extension
- Returns:
s with masked values
- returnn.frontend.nested.gather_nested(s: T, *, indices: Tensor, dim_map: Dict[Dim, Dim] | None = None) T [source]¶
This is like
gather()
, but for nested structures.- Parameters:
s – nested structure
indices – indices tensor. see
gather()
dim_map – if given, this will be updated with the new dim map
- Returns:
s with gathered tensors
- returnn.frontend.nested.masked_select_nested(s: T, *, mask: Tensor, mask_cpu: Tensor | None = None, dims: Sequence[Dim], out_dim: Dim | None = None, dim_map: Dict[Dim, Dim] | None = None) Tuple[T, Dim, Dict[Dim, Dim]] [source]¶
This is like
masked_select()
, but for nested structures.- Parameters:
s – nested structure
mask – mask tensor. see
masked_select()
mask_cpu – mask tensor for CPU. this is used e.g. for dyn dim sizes
dims – dims to mask. see
masked_select()
out_dim – the packed out dim. see
masked_select()
. if not given, a new one will be created.dim_map – if given, this will be updated with the new dim map
- Returns:
s with masked dims, out_dim, and a newly created dim map
- returnn.frontend.nested.masked_scatter_nested(s: T, backup: T, *, mask: Tensor, mask_cpu: Tensor, dims: Sequence[Dim], in_dim: Dim, masked_select_dim_map: Dict[Dim, Dim], masked_scatter_dim_map: Dict[Dim, Dim] | None = None) T [source]¶
Reverse of
masked_select_nested()
.- Parameters:
s – nested structure, where dims are packed, i.e. (in_dim,…)
backup – nested structure, where we scatter into. tensors like (dims…,…)
mask – mask tensor. see
masked_scatter()
/masked_select()
mask_cpu – mask tensor for CPU. this is used e.g. for dyn dim sizes. see
masked_scatter()
dims – dims to mask. see
masked_scatter()
/masked_select()
in_dim – the packed in dim. see
masked_scatter()
masked_select_dim_map – the dim map from
masked_select_nested()
. This describes how to map dims from s to backup.masked_scatter_dim_map – for any new dims created by this function, this will be updated
- Returns:
backup with s scattered in