
RoPE
Juan Vera
April 2025
Attention is defined as,
Positional Embeddings
Typically, Positional Embeddings are computed as:
where:
- is the sequence length
- is
These positional embeddings apply fixed positional information to a sequence of length .
class PositionalEmbedding(nn.Module):
def __init__(self, seq_len, d_model, dropout_p:float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p = dropout_p)
pe = torch.zeros(size = (seq_len, d_model), dtype = torch.float16)
position = torch.arange(start = 0, end = seq_len, dtype = torch.float16)
div_term = 10000 ** (torch.arange(start = 0, end = d_model, step = 2, dtype = torch.float16) / d_model)
pe[:, 0::2] = torch.sin(position.unsqueeze(1) / div_term, dtype = torch.float16)
pe[:, 1::2] = torch.cos(position.unsqueeze(1) / div_term, dtype = torch.float16)
pe = pe.unsqueeze(0)
self.register_buffer("positional_embedding", pe)
def forward(self, x):
# x is (batch_size, seq_len, d_model)
x_pe = x + self.positional_embedding[:, :x.size(1), :]
x = self.dropout(x_pe)
return x
Rotary Positional Embeddings
Rotary Positional Embeddings are computed as:
where
Again, note that is the index for the current th token in the sequence of length .
For each , a unique matrix, is constructed.
Here, is equivalent to the scaling factor in the the original formulation for the positional endcodings,
and given that we multiply by , the th position in the sequence of length , then,
equivalent to the scaling factor in the original formulation for positional embeddings, the diffenrce being that this is applied directly onto the keys and queries via a matrix multiplication such that they directly influence the attention scores, .
A more computationally efficient way of computing the matrix multiplication can be done as:
where we reduce the complexity of the orginal multiplication from to per token$.
Where you have multiple tokens in your sequence, you can compute this out as
where
During inference, if we need to finetune on longer context lengths, what can be done is introduce (dynamic)-ntk aware scaling, where is scaled by a factor of or dynamically by
class RotaryPositionalEmbedding(nn.Module):
def __init__(
self,
d_head,
context_len,
ntk_rope_scaling: Union[dict, bool] = False,
dyn_scaling: Union[bool, float] = None
):
super().__init__()
position = torch.arange(start=0, end=context_len, dtype=torch.float16).unsqueeze(1)
if ntk_rope_scaling:
assert isinstance(ntk_rope_scaling, dict), "ntk_rope_scaling should be a dictionary"
assert 'pretrained_context_window' in ntk_rope_scaling, "pretrained_context_window should be in ntk_rope_scaling"
assert 'new_context_window' in ntk_rope_scaling, "new_context_window should be in ntk_rope_scaling"
if dyn_scaling:
assert isinstance(dyn_scaling, float), "dyn_scaling should be a float"
assert 0 < dyn_scaling <= 1, "dyn_scaling should be between 0 and 1"
scale = (dyn_scaling * (ntk_rope_scaling['new_context_window'] / ntk_rope_scaling['pretrained_context_window'])) + (1 - dyn_scaling)
else:
scale = ntk_rope_scaling['new_context_window'] / ntk_rope_scaling['pretrained_context_window']
position /= scale
div_term = 10000 ** (torch.arange(start=0, end=d_head, step=2, dtype=torch.float16) / d_head)
div_term = torch.repeat_interleave(div_term, repeats=2, dim=-1)
rope_cos = torch.cos(position / div_term)
rope_sin = torch.sin(position / div_term)
self.register_buffer("rope_cos", rope_cos)
self.register_buffer("rope_sin", rope_sin)
def forward(self, x, _inference=False):
if not _inference:
cos = self.rope_cos[:, 0::2]
sin = self.rope_sin[:, 0::2]
else:
assert x.shape[1] == 1, f"Expected seq_len=1 during inference, got {x.shape[1]}"
cos = self.rope_cos[:, self.t - 1:self.t]
sin = self.rope_sin[:, self.t - 1:self.t]
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
rotated_x_even = x_even * cos - x_odd * sin
rotated_x_odd = x_even * sin + x_odd * cos
rotated_x = torch.stack([rotated_x_even, rotated_x_odd], dim=-1).view_as(x)
if _inference:
if not hasattr(self, 't'):
self.t = 1
elif self.t >= self.context_len:
self.t = self.context_len
else:
self.t += 1
return rotated_x