returnn.torch.util.rope

PyTorch utility for Rotary Position Embedding (RoPE).

Provides a torch.compile-compiled core kernel for applying RoPE. The kernel is numerically identical to the plain RF reference implementation returnn.frontend.attention._apply_rope_real() but fuses the four element-wise operations into a single kernel launch, avoiding intermediate allocations. It also supports all PyTorch dtypes natively (including bfloat16 and float16) without any upcasting.

returnn.torch.util.rope.apply_rope(x: Tensor, pos_enc: Tensor) Tensor[source]

RoPE kernel operating on raw PyTorch tensors.

This function is supposed to be compiled.

Parameters:
  • x – input tensor [..., D] with feat dim last; any float dtype

  • pos_enc – positional encoding [..., D] broadcast-compatible with x; first D/2 entries along last axis are sin, second D/2 are cos

Returns:

rotated tensor with the same shape and dtype as x