returnn.frontend.attention

Attention

returnn.frontend.attention.dot_attention(query: Tensor, keys: Tensor, values: Tensor, *, key_dim: Dim, axis: Dim, att_dropout: float = 0.0, att_dropout_broadcast: bool | None = None) Tensor[source]

Calculates attention over the given axis, for given key dim. Any other unrelated axes do not matter here. This can be used for multi-head or single head. The query can have other dimensions or not.

Parameters:
  • query – {…, key_dim}. For self-attention, do not use the axis as in keys and values, but rather replace it by another new dim via replace_dim().

  • keys – {…, axis, key_dim}

  • values – {…, axis}

  • key_dim – dim in keys and query, to be reduced to calculate the attention energies.

  • axis – in keys and values, to apply attention on. softmax will be over this axis, and then it will be reduced

  • att_dropout – dropout for attention weights

  • att_dropout_broadcast – whether to broadcast over all but axis. normally not wanted. disabled by default since behavior version 19.

Returns:

like values but with axis removed, and maybe any additional axes from query

class returnn.frontend.attention.SelfAttentionBase(in_dim: Dim, proj_dim: Dim | None, *, key_dim_total: Dim, value_dim_total: Dim, num_heads: int | Dim, with_bias: bool = True, att_dropout: float = 0.1, att_dropout_broadcast: bool | None = None)[source]

Shared base class for (non-causal) self attention (SelfAttention) and causal self attention (CausalSelfAttention).

It uses dot_attention() for multi-headed dot-attention.

Parameters:
  • in_dim – input dim

  • proj_dim – if given, will add a final linear projection to this dim. otherwise no projection after the attention

  • key_dim_total – total key dim. should be a multiple of num_heads

  • value_dim_total – total value dim. should be a multiple of num_heads

  • num_heads – number of heads

  • with_bias – whether to add bias to qkv and proj linear projections. Was False in original Transformer, but many recent implementations use True by default. Also see: https://github.com/rwth-i6/returnn_common/issues/234.

  • att_dropout – dropout for attention weights

  • att_dropout_broadcast – whether to broadcast over all but axis. normally not wanted. disabled by default since behavior version 19.

forward_qkv(source: Tensor) Tuple[Tensor, Tensor, Tensor][source]
Returns:

q,k,v

attention(q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) Tensor[source]

apply attention

class returnn.frontend.attention.SelfAttention(in_dim: Dim, proj_dim: Dim | None, *, key_dim_total: Dim, value_dim_total: Dim, num_heads: int | Dim, with_bias: bool = True, att_dropout: float = 0.1, att_dropout_broadcast: bool | None = None)[source]

Classic self attention on sequence level

Parameters:
  • in_dim – input dim

  • proj_dim – if given, will add a final linear projection to this dim. otherwise no projection after the attention

  • key_dim_total – total key dim. should be a multiple of num_heads

  • value_dim_total – total value dim. should be a multiple of num_heads

  • num_heads – number of heads

  • with_bias – whether to add bias to qkv and proj linear projections. Was False in original Transformer, but many recent implementations use True by default. Also see: https://github.com/rwth-i6/returnn_common/issues/234.

  • att_dropout – dropout for attention weights

  • att_dropout_broadcast – whether to broadcast over all but axis. normally not wanted. disabled by default since behavior version 19.

class returnn.frontend.attention.CausalSelfAttention(in_dim: Dim, proj_dim: Dim | None, *, key_dim_total: Dim, value_dim_total: Dim, num_heads: int | Dim, with_bias: bool = True, att_dropout: float = 0.1, att_dropout_broadcast: bool | None = None)[source]

Classic causal self attention

Parameters:
  • in_dim – input dim

  • proj_dim – if given, will add a final linear projection to this dim. otherwise no projection after the attention

  • key_dim_total – total key dim. should be a multiple of num_heads

  • value_dim_total – total value dim. should be a multiple of num_heads

  • num_heads – number of heads

  • with_bias – whether to add bias to qkv and proj linear projections. Was False in original Transformer, but many recent implementations use True by default. Also see: https://github.com/rwth-i6/returnn_common/issues/234.

  • att_dropout – dropout for attention weights

  • att_dropout_broadcast – whether to broadcast over all but axis. normally not wanted. disabled by default since behavior version 19.

default_initial_state(*, batch_dims: Sequence[Dim]) CausalSelfAttentionState[source]

For causal attention.

class returnn.frontend.attention.CausalSelfAttentionState(*_args, k_accum: Tensor | None = None, v_accum: Tensor | None = None, accum_axis: Dim | None = None)[source]

State for StepwiseCausalSelfAttention.

Parameters:
  • k_accum – accumulated keys

  • v_accum – accumulated values

  • accum_axis

class returnn.frontend.attention.RelPosSelfAttention(in_dim: Dim, proj_dim: Dim | None, *, key_dim_total: Dim, value_dim_total: Dim, num_heads: int | Dim, with_bias: bool = True, with_linear_pos: bool = True, with_pos_bias: bool = True, learnable_pos_emb: bool = False, learnable_pos_emb_clipping: int = 16, separate_pos_emb_per_head: bool = True, pos_emb_dropout: float = 0.0, att_dropout: float = 0.1)[source]

Self-attention with relative positional encoding. This covers both Shawn et al. self-att rel pos 2018 (https://arxiv.org/abs/1803.02155), and Dai et al. Transformer-XL style 2019 (https://arxiv.org/abs/1901.02860).

It uses relative_positional_encoding() or LearnedRelativePositionalEncoding.

To get Shawn et al. self-att rel pos 2018 / RETURNN SelfAttentionLayer + RelativePositionalEncodingLayer: - with_bias = False (at least that was the RETURNN behavior) - with_linear_pos = False - with_pos_bias = False - learnable_pos_emb = True - separate_pos_emb_per_head = False (at least that was the RETURNN default)

To get Dai et al. Transformer-XL style 2019: - with_bias = False would be like the paper, however, in most implementations it is True (default) - with_linear_pos = True (default) - with_pos_bias = True (default) - learnable_pos_emb = True (default) - separate_pos_emb_per_head = True (default)

Further details: https://github.com/rwth-i6/returnn_common/wiki/Relative-positional-encoding

Code references, partly adapted from there: https://github.com/espnet/espnet/blob/4138010fb66ad27a43e8bee48a4932829a0847ae/espnet/nets/pytorch_backend/transformer/embedding.py#L260 https://github.com/kimiyoung/transformer-xl/blob/44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L4

Parameters:
  • in_dim – input dim

  • proj_dim – if given, will add a final linear projection to this dim. otherwise no projection after the attention

  • key_dim_total – total key dim. should be a multiple of num_heads

  • value_dim_total – total value dim. should be a multiple of num_heads

  • num_heads – number of heads

  • with_bias – whether to add bias to qkv and proj linear projections. Was False in original Transformer, but many recent implementations use True by default. Also see: https://github.com/rwth-i6/returnn_common/issues/234.

  • att_dropout – dropout for attention weights

  • att_dropout_broadcast – whether to broadcast over all but axis. normally not wanted. disabled by default since behavior version 19.

class returnn.frontend.attention.RelPosCausalSelfAttention(in_dim: Dim, proj_dim: Dim | None, *, key_dim_total: Dim, value_dim_total: Dim, num_heads: int | Dim, with_bias: bool = True, with_linear_pos: bool = True, with_pos_bias: bool = True, learnable_pos_emb: bool = False, learnable_pos_emb_clipping: int = 16, separate_pos_emb_per_head: bool = True, pos_emb_dropout: float = 0.0, att_dropout: float = 0.1)[source]

Self-attention with relative positional encoding. This covers both Shawn et al. self-att rel pos 2018 (https://arxiv.org/abs/1803.02155), and Dai et al. Transformer-XL style 2019 (https://arxiv.org/abs/1901.02860).

It uses relative_positional_encoding() or LearnedRelativePositionalEncoding.

Same defaults as RelPosSelfAttention, which is mostly Transformer-XL style.

Further details: https://github.com/rwth-i6/returnn_common/wiki/Relative-positional-encoding

Code references, partly adapted from there: https://github.com/espnet/espnet/blob/4138010fb66ad27a43e8bee48a4932829a0847ae/espnet/nets/pytorch_backend/transformer/embedding.py#L260 https://github.com/kimiyoung/transformer-xl/blob/44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L4

Parameters:
  • in_dim – input dim

  • proj_dim – if given, will add a final linear projection to this dim. otherwise no projection after the attention

  • key_dim_total – total key dim. should be a multiple of num_heads

  • value_dim_total – total value dim. should be a multiple of num_heads

  • num_heads – number of heads

  • with_bias – whether to add bias to qkv and proj linear projections. Was False in original Transformer, but many recent implementations use True by default. Also see: https://github.com/rwth-i6/returnn_common/issues/234.

  • att_dropout – dropout for attention weights

  • att_dropout_broadcast – whether to broadcast over all but axis. normally not wanted. disabled by default since behavior version 19.

class returnn.frontend.attention.CrossAttention(encoder_dim: Dim, query_in_dim: Dim, proj_dim: Dim | None, *, key_dim_total: Dim, value_dim_total: Dim, num_heads: int | Dim, with_bias: bool = True, att_dropout: float = 0.1, att_dropout_broadcast: bool | None = None)[source]

Cross attention

It uses dot_attention() for multi-headed dot-attention.

Parameters:
  • encoder_dim – encoder output dim = input dim for key-value

  • query_in_dim – input dim for query

  • proj_dim – if given, will add a final linear projection to this dim. otherwise no projection after the attention

  • key_dim_total – total key dim. should be a multiple of num_heads

  • value_dim_total – total value dim. should be a multiple of num_heads

  • num_heads – number of heads

  • with_bias – whether to add bias to qkv and proj linear projections. Was False in original Transformer, but many recent implementations use True by default. Also see: https://github.com/rwth-i6/returnn_common/issues/234.

  • att_dropout – dropout for attention weights

  • att_dropout_broadcast – whether to broadcast over all but axis. normally not wanted. disabled by default since behavior version 19.

transform_encoder(encoder: Tensor, *, axis: Dim) State[source]

Transformer encoder output. This is intended as an initial API suggestion.

forward_kv(source: Tensor) Tuple[Tensor, Tensor][source]

This would be calculated once for the whole sequence (batch) and then always reused for attention().

Returns:

k,v

forward_query(source: Tensor) Tensor[source]

This is calculated for every different query.

Returns:

q

attention(q: Tensor, k: Tensor, v: Tensor, *, kv_axis: Dim) Tensor[source]

apply attention

class returnn.frontend.attention.LearnedRelativePositionalEncoding(feat_dim: Dim, *, clipping: int = 16, dtype: str | None = None, causal: bool = False)[source]

Learnable relative positional encoding.

E.g. as used in Shawn et al, 2018 (https://arxiv.org/abs/1803.02155).

https://github.com/rwth-i6/returnn_common/wiki/Relative-positional-encoding

Parameters:
  • feat_dim – feature dim, for the emb matrix and output

  • clipping – max distance to consider. emb matrix shape is [2 * clipping + 1, feat_dim] if not causal, else [clipping + 1, feat]. The first and last frame will be the clipping frames.

  • dtype – for the emb matrix and output

full_matrix(*, query_spatial_dim: Dim, key_value_spatial_dim: Dim, query_offset: int | Tensor | None = None) Tensor[source]
Returns:

as full matrix [query_spatial_dim,key_value_spatial_dim,feat_dim]. however, note that __call__ is usually to be preferred, as this gives a more efficient format.

returnn.frontend.attention.relative_positional_encoding(*, query_spatial_dim: Dim, key_value_spatial_dim: Dim, feat_dim: Dim, query_offset: int = 0, dtype: str | None = None) Tuple[Tensor, Dim][source]

Implements relative positional encoding, Transformer-XL style (https://arxiv.org/abs/1901.02860), as used for example by RelPosSelfAttention.

Code references, partly adapted from there: https://github.com/espnet/espnet/blob/4138010fb66ad27a43e8bee48a4932829a0847ae/espnet/nets/pytorch_backend/transformer/embedding.py#L260 https://github.com/kimiyoung/transformer-xl/blob/44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L4

Note that this encoding is stored in a cache so that it is only calculated once. and then reused.

Note that we could extend the implementation later to also buffer it even across mini-batches, like the ESPnet implementation does, e.g. by storing it in an auxiliary variable and increasing its size when needed. But this is not done yet, to keep the code simple.

Returns:

tensor of shape [spatial_dim * 2 - 1, feat_dim], and the out spatial dim (spatial_dim * 2 - 1). In the center is the rel pos i-j=0. All to the right are for i-j>0, all to the left for i-j<0.

returnn.frontend.attention.sinusoidal_positional_encoding(*, spatial_dim: Dim, feat_dim: Dim, offset: int | Tensor | None = None, dtype: str | None = None, device: str | None = None) Tensor[source]

Implements absolute sinusoidal positional encoding.

Code adopted from relative_positional_encoding() and our TF util get_positional_encoding().

Note that this encoding is stored in a cache so that it is only calculated once. and then reused.

Note that we could extend the implementation later to also buffer it even across mini-batches, like the ESPnet implementation does, e.g. by storing it in an auxiliary variable and increasing its size when needed. But this is not done yet, to keep the code simple.

Returns:

tensor of shape [spatial_dim, feat_dim] if spatial_dim != single_step_dim else [feat_dim]