returnn.torch.util.rope¶
PyTorch utility for Rotary Position Embedding (RoPE).
Provides a torch.compile-compiled core kernel for applying RoPE.
The kernel is numerically identical to the plain RF reference implementation
returnn.frontend.attention._apply_rope_real()
but fuses the four element-wise operations into a single kernel launch, avoiding intermediate allocations.
It also supports all PyTorch dtypes natively (including bfloat16 and float16) without any upcasting.
- returnn.torch.util.rope.apply_rope(x: Tensor, pos_enc: Tensor) Tensor[source]¶
RoPE kernel operating on raw PyTorch tensors.
This function is supposed to be compiled.
- Parameters:
x – input tensor
[..., D]with feat dim last; any float dtypepos_enc – positional encoding
[..., D]broadcast-compatible with x; firstD/2entries along last axis are sin, secondD/2are cos
- Returns:
rotated tensor with the same shape and dtype as x