Banner

RoPE

Juan Vera

April 2025

Attention is defined as,

qm=Wqxmkn=Wkxnvn=WVxnαm,n=exp(qmkn)i=1exp(qmki)o=αm,nvnq_m = W^qx_m \\[3mm] k_n = W^kx_n \\[3mm] v_n = W^Vx_n \\[3mm] \alpha_{m, n} = \frac{\exp(q_mk_n^{\top})}{\sum_{i = 1}^\ell\exp(q_mk_i^{\top})} \\[3mm] o = \alpha_{m, n} v_n

Positional Embeddings

Typically, Positional Embeddings are computed as:

PE[pos,2i]=sin(pos100002idmodel)PE[pos,2i+1]=cos(pos100002idmodel)X=X+PE\text{PE}_{[pos, 2i]} = \sin\left(\frac{\text{pos}}{10000^{\frac{2i}{d_\text{model}}}}\right) \\[3mm] \text{PE}_{[pos, 2i + 1]} = \cos\left(\frac{\text{pos}}{10000^{\frac{2i}{d_\text{model}}}}\right) \\[3mm] X = X + \text{PE}

where:

  • \ell is the sequence length
  • ii is [0,dmodel]\in [0, d_{\text{model}}]
  • pos[0,]\text{pos} \in [0, \ell]

These positional embeddings apply fixed positional information to a sequence of length \ell.

class PositionalEmbedding(nn.Module):
    
    def __init__(self, seq_len, d_model, dropout_p:float = 0.1):
        super().__init__()

        self.dropout = nn.Dropout(p = dropout_p)  
        pe = torch.zeros(size = (seq_len, d_model), dtype = torch.float16)
        position = torch.arange(start = 0, end = seq_len, dtype = torch.float16)
        div_term = 10000 ** (torch.arange(start = 0, end = d_model, step = 2, dtype = torch.float16) / d_model)  
        pe[:, 0::2] = torch.sin(position.unsqueeze(1) / div_term, dtype = torch.float16)
        pe[:, 1::2] = torch.cos(position.unsqueeze(1) / div_term, dtype = torch.float16)
        pe = pe.unsqueeze(0)
        self.register_buffer("positional_embedding", pe)
        
       
    def forward(self, x):
        # x is (batch_size, seq_len, d_model)
        x_pe = x + self.positional_embedding[:, :x.size(1), :] 
        x = self.dropout(x_pe) 
        return x  

Rotary Positional Embeddings

Rotary Positional Embeddings are computed as:

f(xm,m)=RΘ,mdmodel(W(q,k)xm)f(x_m, m) = R^{d_{\text{model}}}_{\Theta, m} (W^{(q, k)}x_m)

where RΘ,mdmodel=R^{d_{\text{model}}}_{\Theta, m} =

[cos(mθ0)sin(mθ0)0000sin(mθ0)cos(mθ0)000000cos(mθ1)sin(mθ1)0000sin(mθ1)cos(mθ1)000000cos(mθd21)sin(mθd21)0000sin(mθd21)cos(mθd21)]\begin{bmatrix} \cos(m \theta_0) & -\sin(m \theta_0) & 0 & 0 & \cdots & 0 & 0 \\ \sin(m \theta_0) & \cos(m \theta_0) & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos(m \theta_1) & -\sin(m \theta_1) & \cdots & 0 & 0 \\ 0 & 0 & \sin(m \theta_1) & \cos(m \theta_1) & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos(m \theta_{\frac{d}{2}-1}) & -\sin(m \theta_{\frac{d}{2}-1}) \\ 0 & 0 & 0 & 0 & \cdots & \sin(m \theta_{\frac{d}{2}-1}) & \cos(m \theta_{\frac{d}{2}-1}) \end{bmatrix}

d=dmodeld = d_{\text{model}}

Again, note that mm is the index for the current mmth token in the sequence of length \ell.

For each m[0,]m \in [0, \ell], a unique matrix, RΘ,mdmodelR^{d_{\text{model}}}_{\Theta,m } is constructed.

Here, θiΘ\theta_i \in \Theta is equivalent to the scaling factor in the the original formulation for the positional endcodings,

θi=1100002idmodel\theta_i = \frac{1}{10000^{\frac{2i}{d_{\text{model}}}}}

and given that we multiply by mm, the mmth position in the sequence of length \ell, then,

mθi=m100002idmodel=pos100002idmodelm\theta_i = \frac{m}{10000^{\frac{2i}{d_{\text{model}}}}} = \frac{\text{pos}}{10000^{\frac{2i}{d_{\text{model}}}}}

equivalent to the scaling factor in the original formulation for positional embeddings, the diffenrce being that this is applied directly onto the keys and queries via a matrix multiplication such that they directly influence the attention scores, αm,n\alpha_{m, n}.

A more computationally efficient way of computing the matrix multiplication can be done as:

RΘ,mdxm=[x1x2x3x4xd1xd][cosmθ0cosmθ0cosmθ1cosmθ1cosmθd/2cosmθd/2]+[x2x1x4x3xdxd1][sinmθ0sinmθ0sinmθ1sinmθ1sinmθd/2sinmθd/2]R_{\Theta, m}^d x_m = \begin{bmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \\ \vdots \\ x_{d-1} \\ x_d \end{bmatrix} \odot \begin{bmatrix} \cos m \theta_0 \\ \cos m \theta_0 \\ \cos m \theta_1 \\ \cos m \theta_1 \\ \vdots \\ \cos m \theta_{d/2} \\ \cos m \theta_{d/2} \end{bmatrix} + \begin{bmatrix} -x_2 \\ x_1 \\ -x_4 \\ x_3 \\ \vdots \\ -x_d \\ x_{d-1} \end{bmatrix} \odot \begin{bmatrix} \sin m \theta_0 \\ \sin m \theta_0 \\ \sin m \theta_1 \\ \sin m \theta_1 \\ \vdots \\ \sin m \theta_{d/2} \\ \sin m \theta_{d/2} \end{bmatrix}

where we reduce the complexity of the orginal multiplication from O(dmodel2)\mathcal{O}(d_{\text{model}}^2) to O(dmodel)\mathcal{O}(d_{\text{model}}) per token$.

Where you have multiple tokens in your sequence, you can compute this out as

RΘ,mdX=XRcos+XinvRsin R^d_{\Theta, m}X = XR_{\cos} + X_{\text{inv}}R_{\sin}

where Rcos,Rsin,XRdmodel×lR_{\cos}, R_{\sin}, X \in \mathbb{R}^{d_{\text{model}}\times l}

During inference, if we need to finetune on longer context lengths, what can be done is introduce (dynamic)-ntk aware scaling, where mm is scaled by a factor of LL\frac{L'}{L} or dynamically by (αLL)+(1α)(\alpha * \frac{L'}{L}) + (1 - \alpha)

class RotaryPositionalEmbedding(nn.Module):

    def __init__(
        self, 
        d_head, 
        context_len,
        ntk_rope_scaling: Union[dict, bool] = False, 
        dyn_scaling: Union[bool, float] = None
        ):
        
        super().__init__()

        position = torch.arange(start=0, end=context_len, dtype=torch.float16).unsqueeze(1)
        
        if ntk_rope_scaling:
            assert isinstance(ntk_rope_scaling, dict), "ntk_rope_scaling should be a dictionary"
            assert 'pretrained_context_window' in ntk_rope_scaling, "pretrained_context_window should be in ntk_rope_scaling"
            assert 'new_context_window' in ntk_rope_scaling, "new_context_window should be in ntk_rope_scaling"
            
            if dyn_scaling:
                assert isinstance(dyn_scaling, float), "dyn_scaling should be a float"
                assert 0 < dyn_scaling <= 1, "dyn_scaling should be between 0 and 1"
                scale = (dyn_scaling * (ntk_rope_scaling['new_context_window'] / ntk_rope_scaling['pretrained_context_window'])) + (1 - dyn_scaling)
            else:
                scale = ntk_rope_scaling['new_context_window'] / ntk_rope_scaling['pretrained_context_window']
            
            position /= scale

        div_term = 10000 ** (torch.arange(start=0, end=d_head, step=2, dtype=torch.float16) / d_head)
        div_term = torch.repeat_interleave(div_term, repeats=2, dim=-1)
        rope_cos = torch.cos(position / div_term)
        rope_sin = torch.sin(position / div_term)
        
        self.register_buffer("rope_cos", rope_cos)
        self.register_buffer("rope_sin", rope_sin)

    def forward(self, x, _inference=False):
        if not _inference:
            cos = self.rope_cos[:, 0::2]
            sin = self.rope_sin[:, 0::2]
        else:
            assert x.shape[1] == 1, f"Expected seq_len=1 during inference, got {x.shape[1]}"
            cos = self.rope_cos[:, self.t - 1:self.t]
            sin = self.rope_sin[:, self.t - 1:self.t]
        
        x_even = x[..., 0::2]
        x_odd = x[..., 1::2]
        
        rotated_x_even = x_even * cos - x_odd * sin
        rotated_x_odd = x_even * sin + x_odd * cos
        
        rotated_x = torch.stack([rotated_x_even, rotated_x_odd], dim=-1).view_as(x)

        if _inference:
            if not hasattr(self, 't'):
                self.t = 1
            elif self.t >= self.context_len:
                self.t = self.context_len
            else:
                self.t += 1
        
        return rotated_x