returnn.torch.frontend._backend

Backend for exposing PyTorch-specific functionality.

class returnn.torch.frontend._backend.TorchBackend[source]

PyTorch backend

name: str | None = 'torch'[source]
RawTensorType[source]

alias of Tensor

static executing_eagerly() bool[source]
Returns:

whether we are executing eagerly

static set_random_seed(seed: int)[source]
Parameters:

seed

static get_random_state() Dict[str, bytes][source]
Returns:

random state

static set_random_state(state: Dict[str, bytes])[source]
Parameters:

state – as returned by get_random_state(). This might not always be successful (e.g. different hardware, different backend version), so the calling code should always have called set_random_seed before to have the random generators in a reasonable fallback state.

static get_dtype_name_raw(raw_tensor: Tensor) str[source]
Returns:

dtype of raw tensor, as string

static as_dtype_raw(dtype_name: str) dtype[source]
Parameters:

dtype_name – e.g. “float32”

Returns:

dtype object

static get_ndim_raw(raw_tensor: Tensor) int[source]
Returns:

ndim of raw tensor

static get_shape_raw(raw_tensor: Tensor) Tuple[int, ...][source]

shape

static get_shape_tuple_raw(raw_tensor: Tensor) Tuple[int, ...][source]
Returns:

shape of raw tensor

static get_known_shape_raw(raw_tensor: Tensor) Tuple[int | None, ...][source]
Returns:

shape of raw tensor; here for PyTorch the full shape is always known

static get_new_dim_raw(raw_tensor: Tensor, axis: int, *, name: str) Dim[source]
Parameters:
  • raw_tensor

  • axis

  • name

Returns:

new Dim object

static get_device(x: Tensor[Tensor]) str | None[source]

device

static copy_to_device(x: Tensor, device: str | None) Tensor[source]
Parameters:
  • x

  • device

static expand_dims_raw(raw_tensor: Tensor, axis: int) Tensor[source]
Parameters:
  • raw_tensor

  • axis – e.g. 1

Returns:

raw tensor with new axis

static expand_raw(raw_tensor: Tensor, axis: int, dim: int | Tensor) Tensor[source]
Parameters:
  • raw_tensor

  • axis – shape[axis] must be 1

  • dim – the new dim for shape[axis]

Returns:

shape[axis] expands to dim. in PyTorch or other frameworks which support custom strides, this is an efficient view and not a copy.

static copy(tensor: Tensor[Tensor]) Tensor[Tensor][source]
static cast_raw(raw_tensor: Tensor, dtype: str) Tensor[source]

cast

static set_requires_gradient(tensor: Tensor[Tensor])[source]

set requires grad

static gradient(y: Tensor, x: Tensor) Tensor[source]
static stop_gradient(tensor: Tensor) Tensor[source]

stop grad

static scaled_gradient(tensor: Tensor, scale: float | Tensor) Tensor[source]

scaled gradient

static scaled_gradient_ext(x: Tensor, *, scale: float | Tensor = 1.0, shift: float | Tensor | None = None, scale_shift_by_sum_over_axis: Dim | None = None)[source]

scaled gradient ext

static gradient_checkpoint_scope()[source]

gradient checkpoint scope

static merge_dims(source: Tensor, *, dims: Sequence[Dim], out_dim: Dim | None = None) Tuple[Tensor, Dim][source]

Merges a list of axes into a single one. (Flatten the dims.) E.g. input is (batch, width, height, dim) and dims=(width,height), then we get (batch, width*height, dim). Or input is (batch, time, height, dim) and axes=(height,dim), then we get (batch, time, height*dim).

Parameters:
  • source

  • dims

  • out_dim

Returns:

tensor, out_dim

static split_dims(source: Tensor, *, axis: Dim, dims: Sequence[Dim], pad_to_multiples: bool | None = None, pad_value: None | int | float = None) Tensor[source]

split dims

static reshape(source: Tensor, in_dims: Sequence[Dim], out_dims: Sequence[Dim]) Tensor[source]
static split(source: Tensor, *, axis: Dim, out_dims: Sequence[Dim]) Tuple[Tensor, ...][source]
static expand_dim(source: Tensor, dim: Dim) Tensor[source]

expand dim

static squeeze(source: Tensor, axis: Dim) Tensor[source]
static concat(*sources: Tuple[Tensor, Dim], allow_broadcast: bool = False, out_dim: Dim) Tensor[source]
static pad(source: Tensor, *, axes: Sequence[Dim], padding: Sequence[Tuple[Dim | int, Dim | int]], out_dims: Sequence[Dim], handle_dynamic_dims: bool, mode: str = 'constant', value: int | float | complex | number | ndarray | bool | str | Tensor | None = None) Tensor[source]
static cum_concat_step(source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Dim) Tensor[source]

cum concat step

static stack(sources: Sequence[Tensor], *, out_dim: Dim) Tensor[source]
static activation_raw(raw_tensor: Tensor, func: str) Tensor[source]
Parameters:
  • raw_tensor

  • func – e.g. “tanh”

Returns:

raw tensor after activation

static softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) Tensor[source]
Parameters:
  • tensor

  • axis

  • use_mask

Returns:

softmax over axis

static log_softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) Tensor[source]
Parameters:
  • tensor

  • axis

  • use_mask

Returns:

log_softmax over axis

static softmax_cross_entropy_with_logits(*, logits: Tensor, targets: Tensor, axis: Dim)[source]

Efficient cross entropy. For PyTorch this is actually the default cross entropy function. (torch.nn.functional.cross_entropy)

Parameters:
  • logits – target estimates given as inputs to softmax (i.e. unnormalized)

  • targets – probabilities, i.e. normalized, can also be sparse

  • axis – class labels dim over which softmax is computed

Returns:

cross entropy (same Dims as ‘logits’ but without ‘axis’)

static ctc_loss(*, logits: Tensor, logits_normalized: bool = False, targets: Tensor, input_spatial_dim: Dim, targets_spatial_dim: Dim, blank_index: int, max_approx: bool = False) Tensor[source]

CTC

static create_parameter_raw(tensor: Parameter, *, device: str | None = None) Parameter[source]
Returns:

parameter

static set_parameter_initial_value(param: Parameter, value: None | Tensor | int | float | complex | number | ndarray | bool | str) None[source]
Parameters:
  • param – parameter

  • value – initial value

static set_parameter_trainable(param: Parameter, trainable: bool) None[source]

set trainable

static parameter_assign(param: Parameter, value: Tensor, *, op: str = 'assign') None[source]

param assign

static parameter_assign_key(param: Parameter, key: int | float | complex | number | ndarray | bool | str | Tensor | slice | Sequence[int | float | complex | number | ndarray | bool | str | Tensor | slice], value: Tensor, *, op: str = 'assign', axis: Dim | Sequence[Dim] | None = None, key_dim: Dim | Sequence[Dim] | None = None) None[source]

param assign

static parameter_move_to(param: Parameter, *, device: str | None = None, dtype: str | None = None)[source]

to

static compare_raw(a: Tensor, kind: str, b: Tensor) Tensor[source]
Parameters:
  • a

  • kind – “equal”, “less”, “less_equal”, “greater”, “greater_equal”, “not_equal”

  • b

Returns:

a kind b

static combine_raw(a: Tensor, kind: str, b: Tensor) Tensor[source]
Parameters:
  • a

  • kind – “add”, “sub”, “mul”, “truediv”, “floordiv”, “mod”, “pow”, “maximum”, “minimum”, “logical_and”, “logical_or”, “squared_difference”

  • b

Returns:

a kind b

static reshape_raw(raw_tensor: Tensor, shape: Sequence[int | Tensor] | Tensor) Tensor[source]
Parameters:
  • raw_tensor

  • shape

Returns:

reshaped raw tensor; wraps torch.reshape

classmethod squeeze_raw(raw_tensor: Tensor, axes: Sequence[int]) Tensor[source]

squeeze

static transpose_raw(raw_tensor: Tensor, perm: Sequence[int]) Tensor[source]
Parameters:
  • raw_tensor

  • perm – e.g. [0, 2, 1]

Returns:

permuted (transposed) raw tensor; wraps torch.permute

static convert_to_tensor(value: Tensor | Tensor | int | float | complex | number | ndarray | bool | str, *, dims: Sequence[Dim], dtype: str, sparse_dim: Dim | None = None, device: str | None = None, name: str | None = None) Tensor[Tensor][source]
Parameters:
  • value

  • dims

  • dtype

  • sparse_dim

  • device

  • name

Returns:

tensor

static full(dims: Sequence[Dim], fill_value: int | float | complex | number | ndarray | bool | str | Tensor, *, dtype: str, device: str | None = None, sparse_dim: Dim | None = None, feature_dim: Dim | None = None) Tensor[source]
static gather(source: Tensor, *, indices: Tensor | int, axis: Dim, clip_to_valid: bool = False) Tensor[source]

Gather.

There are a few options in PyTorch, all having somewhat different semantics and different advantages or disadvantages and different limitations.

  • torch.gather, most generic

  • torch.index_select, similar as tf.gather, but does not support batch axes

  • Tensor.__getitem__

  • torch.embedding

static scatter(source: Tensor, *, indices: Tensor, indices_dim: Dim | Sequence[Dim], mode: str, fill_value: int | float, out_dim: Dim | Sequence[Dim]) Tensor[source]

Scatters into new zero-tensor. If entries in indices are duplicated, the corresponding values in source will be added together (scatter_add in PyTorch). (TF segment_sum can be implemented via this.)

Parameters:
  • source – [batch_dims…, indices_dim(s)…, feature_dims…]

  • indices – [batch_dims…, indices_dim(s)…] -> out_dim

  • indices_dim

  • mode – “sum”, “max”, “min”

  • fill_value

  • out_dim

Returns:

[batch_dims…, out_dim, feature_dims…]

static slice(source: Tensor, *, axis: Dim, start: int | Tensor | None = None, end: int | Tensor | None = None, step: int | Tensor | None = None, size: int | Tensor | Dim | None = None, out_dim: Dim) Tensor[source]
static flip(source: Tensor, *, axis: Dim) Tensor[source]
static where(cond: Tensor, true_: Tensor | int | float | complex | number | ndarray | bool | str, false_: Tensor | int | float | complex | number | ndarray | bool | str, *, allow_broadcast_all_sources: bool = False) Tensor[source]
static search_sorted(sorted_seq: Tensor, values: Tensor, *, axis: Dim, side: str = 'left', out_dtype: str = 'int32') Tensor[source]

search sorted

static is_finite(x: Tensor) Tensor[source]

is finite

static is_infinite(x: Tensor) Tensor[source]

is positive or negative infinite

static is_neg_infinite(x: Tensor) Tensor[source]

is negative infinite

static clip_by_value(x: Tensor, clip_value_min: Tensor | int | float | complex | number | ndarray | bool | str, clip_value_max: Tensor | int | float | complex | number | ndarray | bool | str, *, allow_broadcast_all_sources: bool = False) Tensor[source]

clip by value

static lerp(start: Tensor, end: Tensor, weight: float | Tensor, *, allow_broadcast_all_sources: bool = False) Tensor[source]
static cumsum(source: Tensor, *, spatial_dim: Dim) Tensor[source]
static matmul(a: Tensor[Tensor], b: Tensor[Tensor], *, reduce: Dim | Sequence[Dim], use_mask: bool = True) Tensor[Tensor][source]

batched matmul of a and b, see base class doc string

static range_over_dim(dim: Dim, *, dtype: str | None = None, device: str | None = None) Tensor[Tensor][source]
Parameters:
  • dim

  • dtype

  • device

Returns:

tensor with shape [dim]

static reduce(source: Tensor[Tensor], *, mode: str, axis: Dim | Sequence[Dim], use_mask: bool = True) Tensor[Tensor][source]
static top_k(source: Tensor[Tensor], *, axis: Dim | Sequence[Dim], k: int | Tensor, k_dim: Dim | None = None, sorted: bool = True) Tuple[Tensor, Tensor | Sequence[Tensor], Dim][source]
static random_journal_record() Generator[_random_journal.RandomJournal][source]
Returns:

the journal

static random(*, dims: Sequence[Dim], dtype: str, device: str | None = None, sparse_dim: Dim | None = None, feature_dim: Dim | None = None, distribution: str, mean: int | float | Tensor | None = None, stddev: int | float | Tensor | None = None, bound: int | float | Tensor | None = None, minval: int | float | Tensor | None = None, maxval: int | float | Tensor | None = None, seed: int | Sequence[int] | ndarray | None = None, algorithm: str | None = None, explicit_state: Tensor | None = None, auto_update_state: bool | None = None, static: bool | None = None, out: Tensor[Tensor] | None = None) Tensor[source]

random. See rf.random for details.

static masked_select(tensor: Tensor, *, mask: Tensor, dims: Sequence[Dim], out_dim: Dim | None = None) Tuple[Tensor, Dim][source]
Parameters:
  • tensor

  • mask

  • dims – the order of the dims defines the format. those dims should be exactly the dims of the mask.

  • out_dim

Returns:

tensor where all dims in mask/dims are removed and replaced by a new dim. the new dim is also returned. if mask==True for all elements, the returned tensor would be simply the flattened input tensor.

static masked_scatter(source: Tensor, *, mask: Tensor, dims: Sequence[Dim], in_dim: Dim) Tensor[source]

masked scatter

static batch_norm(source: Tensor[Tensor], *, in_dim: Dim | Sequence[Dim], running_mean: Tensor | None, running_variance: Tensor | None, gamma: Tensor | None, beta: Tensor | None, epsilon: float, momentum: float, affine: bool, use_mask: bool) Tensor[source]

batch norm

static conv(source: Tensor, *, in_dim: Dim, out_dim: Dim, in_spatial_dims: Sequence[Dim], out_spatial_dims: Sequence[Dim] | None = None, filter: Tensor, filter_size: Sequence[Dim], padding: str, strides: int | Sequence[int] | None = None, dilation_rate: int | Sequence[int] | None = None, groups: int | None = None, bias: Tensor | None = None) Tuple[Tensor, Sequence[Dim]][source]
static pool(source: Tensor, *, mode: str, pool_size: Sequence[int], padding: str = 'valid', dilation_rate: Sequence[int] | int = 1, strides: Sequence[int], in_spatial_dims: Sequence[Dim], out_spatial_dims: Sequence[Dim] | None = None) Tuple[Tensor, Sequence[Dim]][source]
static stft(x: Tensor, *, in_spatial_dim: Dim, frame_step: int, frame_length: int, fft_length: int, window_use_frame_length: bool = True, align_window_left: bool = True, window_enforce_even: bool = True, out_spatial_dim: Dim, out_dim: Dim) Tensor[source]
static lstm(source: Tensor[Tensor], *, state_h: Tensor[Tensor], state_c: Tensor[Tensor], ff_weight: Tensor[Tensor], rec_weight: Tensor[Tensor], bias: Tensor[Tensor] | None, spatial_dim: Dim, in_dim: Dim, out_dim: Dim) Tuple[Tensor[Tensor], Tuple[Tensor[Tensor], Tensor[Tensor]]][source]

Wraps the functional LSTM from PyTorch.

Returns:

Tuple consisting of two elements: the result as a Tensor and the new state as a State (different from the previous one).

TensorArrayType[source]

alias of List[Tensor]

static tensor_array_unstack(tensor: Tensor, *, axis: Dim) TensorArrayType[source]

unstack

static tensor_array_stack(tensor_array: TensorArrayType, *, axis: Dim, tensor_template: Tensor) Tensor[source]

stack