
Swarm of Attention Variants
Juan Vera
April 2025
We can define Multi-Head Attention, for a single token , as:
or parallelized, for all tokens , as:
where .
Multi-Query Attention
Multi-Query Attention, has the same definition as Multi-Head attention with the caveat that the keys and values are shared across all attention heads.
For each head, we have have a set of different queries, , where if we account for all heads, we'd have .
But for each head, we share the same and , both .
Given that we share the same and for all heads, it doesn't make sense for the initial projection from to and to be to , as the reshaping to would introduce unqiue and matrices, such that we can instead directly project and as:
class MultiQueryAttention(nn.Module):
def __init__(
self,
n_heads,
d_model,
context_length,
ntk_rope_scaling,
dyn_scaling
):
super().__init__()
self.n_heads = n_heads
self.d_model = d_model
self.d_head = d_model // n_heads
self.context_length = context_length
self.ntk_rope_scaling = ntk_rope_scaling
self.dyn_scaling = dyn_scaling
self.rope = RotaryPositionalEmbedding(
d_head = self.d_head,
context_length = self.context_length,
ntk_rope_scaling = self.ntk_rope_scaling,
dyn_scaling = self.dyn_scaling
)
def forward(self, q, k, v):
# assumes q is (batch_size, context_length = seq_len, d_model)
# assumes k is (batch_size, context_length = seq_len, d_head)
# assumes v is (batch_size, context_length = seq_len, d_head)
b, l, d_model = q.shape
assert d_model == self.d_model, f"Expected d_model to be {self.d_model}, but got {d_model}"
q = q.view(b, self.n_heads, l, self.d_head)
k = k.unsqueeze(1)
v = v.unsqueeze(1)
assert q.shape[2:] == k.shape[2:], f"Expected q and k to have the same shape for dimensions 2 (seq_len) and 3 (d_head), but got {q.shape}, {k.shape}"
assert k.shape == v.shape, f"Expected k and v to have the same shape, but got {k.shape}, {v.shape}"
q = self.rope(q)
k = self.rope(k)
attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.d_head ** 0.5)
attn_scores = F.softmax(attn_logits, dim = -1)
attn_output = torch.matmul(attn_scores, v).view(b, l, d_model)
return attn_output
The benefit of Multi-Query attention is that is saves memory in the -cache during inference at the loss of some expressivity in the attention mechanism.
Grouped-Query Attention
Grouped Query Attention can be seen as the "balance" between the extremes of Multi-Head Attention and Multi-Query Attention, where rather than each and being shared amongst all heads for all , there's a unique set of for groups, where the size of the group is , where is the total number of heads.
The cardinality of and is then .
Similar to Multi-Query Attention, we must initially project to a dimensionality other than , but like Multi-Head Attention and are still reshaped to match the dimension of .
or coded out,
class GroupedQueryAttention(nn.Module):
def __init__(
self,
n_heads:int,
n_groups:int,
d_model,
context_length,
ntk_rope_scaling,
dyn_scaling
):
super().__init__()
self.n_heads = n_heads
self.n_groups = n_groups
self.d_model = d_model
self.d_head = d_model // n_heads
self.context_length = context_length
self.ntk_rope_scaling = ntk_rope_scaling
self.dyn_scaling = dyn_scaling
self.rope = RotaryPositionalEmbedding(
d_head = self.d_head,
context_length = self.context_length,
ntk_rope_scaling = self.ntk_rope_scaling,
dyn_scaling = self.dyn_scaling
)
def forward(self, q, k, v):
# assumes q is (batch_size, context_length = seq_len, d_model)
# k will be reshaped to be (batch_size, n_groups, context_length, d_head)
# v will be reshaped to be (batch_size, n_groups, context_length, d_head)
#
# assuming the dimension of k, v is divisible by n_groups, for even k, v reshaping
# such that we end up with the same d_head for q, k, v
b, l, d_model = q.shape
assert self.n_heads % self.n_groups == 0, f"Expected n_heads to be divisible by n_groups, but got n_heads: \
{self.n_heads}, n_groups: {self.n_groups} with remainder of {self.n_heads % self.n_groups}"
assert d_model == self.d_model, f"Expected d_model to be {self.d_model}, but got {d_model}"
repeats = int(self.n_heads / self.n_groups)
q = q.view(b, self.n_heads, l, self.d_head)
k = k.view(b, self.n_groups, l, self.d_head).repeat_interleave(repeats = repeats, dim = 1)
v = v.view(b, self.n_groups, l, self.d_head).repeat_interleave(repeats = repeats, dim = 1)
attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.d_head ** 0.5)
attn_scores = F.softmax(attn_logits, dim = -1)
attn_output = torch.matmul(attn_scores, v).view(b, l, d_model)
return attn_output
The benefit of Grouped-Query Attention is that we still save some memory in the -cache during inference, with some loss of expressivity in the attention mechanism, but it's all dependent on the number of groups, , assigned to the mechanism.