RoPE

April 21, 2025

Rotary Position Embedding for transformers. A method to encode relative positional information directly into attention computations.

Attention is defined as,

qm=Wqxmkn=Wkxnvn=WVxnαm,n=exp(qmkn)i=1exp(qmki)o=αm,nvnq_m = W^qx_m \\[3mm] k_n = W^kx_n \\[3mm] v_n = W^Vx_n \\[3mm] \alpha_{m, n} = \frac{\exp(q_mk_n^{\top})}{\sum_{i = 1}^\ell\exp(q_mk_i^{\top})} \\[3mm] o = \alpha_{m, n} v_n

Positional Embeddings

Typically, Positional Embeddings are computed as:

PE[pos,2i]=sin(pos100002idmodel)PE[pos,2i+1]=cos(pos100002idmodel)X=X+PE\text{PE}_{[pos, 2i]} = \sin\left(\frac{\text{pos}}{10000^{\frac{2i}{d_\text{model}}}}\right) \\[3mm] \text{PE}_{[pos, 2i + 1]} = \cos\left(\frac{\text{pos}}{10000^{\frac{2i}{d_\text{model}}}}\right) \\[3mm] X = X + \text{PE}

where:

  • \ell is the sequence length
  • ii is [0,dmodel]\in [0, d_{\text{model}}]
  • pos[0,]\text{pos} \in [0, \ell]

These positional embeddings apply fixed positional information to a sequence of length \ell.

class PositionalEmbedding(nn.Module): def __init__(self, seq_len, d_model, dropout_p:float = 0.1): super().__init__() self.dropout = nn.Dropout(p = dropout_p) pe = torch.zeros(size = (seq_len, d_model), dtype = torch.float16) position = torch.arange(start = 0, end = seq_len, dtype = torch.float16) div_term = 10000 ** (torch.arange(start = 0, end = d_model, step = 2, dtype = torch.float16) / d_model) pe[:, 0::2] = torch.sin(position.unsqueeze(1) / div_term, dtype = torch.float16) pe[:, 1::2] = torch.cos(position.unsqueeze(1) / div_term, dtype = torch.float16) pe = pe.unsqueeze(0) self.register_buffer("positional_embedding", pe) def forward(self, x): # x is (batch_size, seq_len, d_model) x_pe = x + self.positional_embedding[:, :x.size(1), :] x = self.dropout(x_pe) return x

Rotary Positional Embeddings

Rotary Positional Embeddings are computed as:

f(xm,m)=RΘ,mdmodel(W(q,k)xm)f(x_m, m) = R^{d_{\text{model}}}_{\Theta, m} (W^{(q, k)}x_m)

where RΘ,mdmodel=R^{d_{\text{model}}}_{\Theta, m} =

[cos(mθ0)sin(mθ0)0000sin(mθ0)cos(mθ0)000000cos(mθ1)sin(mθ1)0000sin(mθ1)cos(mθ1)000000cos(mθd21)sin(mθd21)0000sin(mθd21)cos(mθd21)]\begin{bmatrix} \cos(m \theta_0) & -\sin(m \theta_0) & 0 & 0 & \cdots & 0 & 0 \\ \sin(m \theta_0) & \cos(m \theta_0) & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos(m \theta_1) & -\sin(m \theta_1) & \cdots & 0 & 0 \\ 0 & 0 & \sin(m \theta_1) & \cos(m \theta_1) & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos(m \theta_{\frac{d}{2}-1}) & -\sin(m \theta_{\frac{d}{2}-1}) \\ 0 & 0 & 0 & 0 & \cdots & \sin(m \theta_{\frac{d}{2}-1}) & \cos(m \theta_{\frac{d}{2}-1}) \end{bmatrix}

d=dmodeld = d_{\text{model}}

Again, note that mm is the index for the current mmth token in the sequence of length \ell.

For each m[0,]m \in [0, \ell], a unique matrix, RΘ,mdmodelR^{d_{\text{model}}}_{\Theta,m } is constructed.

Here, θiΘ\theta_i \in \Theta is equivalent to the scaling factor in the the original formulation for the positional endcodings,

θi=1100002idmodel\theta_i = \frac{1}{10000^{\frac{2i}{d_{\text{model}}}}}

and given that we multiply by mm, the mmth position in the sequence of length \ell, then,

mθi=m100002idmodel=pos100002idmodelm\theta_i = \frac{m}{10000^{\frac{2i}{d_{\text{model}}}}} = \frac{\text{pos}}{10000^{\frac{2i}{d_{\text{model}}}}}

equivalent to the scaling factor in the original formulation for positional embeddings, the diffenrce being that this is applied directly onto the keys and queries via a matrix multiplication such that they directly influence the attention scores, αm,n\alpha_{m, n}.

A more computationally efficient way of computing the matrix multiplication can be done as:

RΘ,mdxm=[x1x2x3x4xd1xd][cosmθ0cosmθ0cosmθ1cosmθ1cosmθd/2cosmθd/2]+[x2x1x4x3xdxd1][sinmθ0sinmθ0sinmθ1sinmθ1sinmθd/2sinmθd/2]R_{\Theta, m}^d x_m = \begin{bmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{bmatrix} \odot \begin{bmatrix} \cos m \theta_0 \\ \cos m \theta_0 \\ \cos m \theta_1 \\ \cos m \theta_1 \\ \vdots \\ \cos m \theta_{d/2} \\ \cos m \theta_{d/2} \end{bmatrix} + \begin{bmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_d \\ x_{d-1} \end{bmatrix} \odot \begin{bmatrix} \sin m \theta_0 \\ \sin m \theta_0 \\ \sin m \theta_1 \\ \sin m \theta_1 \\ \vdots \\ \sin m \theta_{d/2} \\ \sin m \theta_{d/2} \end{bmatrix}

where we reduce the complexity of the orginal multiplication from O(dmodel2)\mathcal{O}(d_{\text{model}}^2) to O(dmodel)\mathcal{O}(d_{\text{model}}) per token$.

Where you have multiple tokens in your sequence, you can compute this out as

RΘ,mdX=XRcos+XinvRsin R^d_{\Theta, m}X = XR_{\cos} + X_{\text{inv}}R_{\sin}

where Rcos,Rsin,XRdmodel×lR_{\cos}, R_{\sin}, X \in \mathbb{R}^{d_{\text{model}}\times l}

During inference, if we need to finetune on longer context lengths, what can be done is introduce (dynamic)-ntk aware scaling, where mm is scaled by a factor of LL\frac{L'}{L} or dynamically by (αLL)+(1α)(\alpha * \frac{L'}{L}) + (1 - \alpha)

class RotaryPositionalEmbedding(nn.Module): def __init__( self, d_head, context_len, ntk_rope_scaling: Union[dict, bool] = False, dyn_scaling: Union[bool, float] = None ): super().__init__() position = torch.arange(start=0, end=context_len, dtype=torch.float16).unsqueeze(1) if ntk_rope_scaling: assert isinstance(ntk_rope_scaling, dict), "ntk_rope_scaling should be a dictionary" assert 'pretrained_context_window' in ntk_rope_scaling, "pretrained_context_window should be in ntk_rope_scaling" assert 'new_context_window' in ntk_rope_scaling, "new_context_window should be in ntk_rope_scaling" if dyn_scaling: assert isinstance(dyn_scaling, float), "dyn_scaling should be a float" assert 0 < dyn_scaling <= 1, "dyn_scaling should be between 0 and 1" scale = (dyn_scaling * (ntk_rope_scaling['new_context_window'] / ntk_rope_scaling['pretrained_context_window'])) + (1 - dyn_scaling) else: scale = ntk_rope_scaling['new_context_window'] / ntk_rope_scaling['pretrained_context_window'] position /= scale div_term = 10000 ** (torch.arange(start=0, end=d_head, step=2, dtype=torch.float16) / d_head) div_term = torch.repeat_interleave(div_term, repeats=2, dim=-1) rope_cos = torch.cos(position / div_term) rope_sin = torch.sin(position / div_term) self.register_buffer("rope_cos", rope_cos) self.register_buffer("rope_sin", rope_sin) def forward(self, x, _inference=False): if not _inference: cos = self.rope_cos[:, 0::2] sin = self.rope_sin[:, 0::2] else: assert x.shape[1] == 1, f"Expected seq_len=1 during inference, got {x.shape[1]}" cos = self.rope_cos[:, self.t - 1:self.t] sin = self.rope_sin[:, self.t - 1:self.t] x_even = x[..., 0::2] x_odd = x[..., 1::2] rotated_x_even = x_even * cos - x_odd * sin rotated_x_odd = x_even * sin + x_odd * cos rotated_x = torch.stack([rotated_x_even, rotated_x_odd], dim=-1).view_as(x) if _inference: if not hasattr(self, 't'): self.t = 1 elif self.t >= self.context_len: self.t = self.context_len else: self.t += 1 return rotated_x