Swarm of Attention Variants

April 22, 2025

Comprehensive overview of attention mechanism variants. Exploring multi-head attention, sparse attention, and other innovative approaches to the attention paradigm.

We can define Multi-Head Attention, for a single token tt, as:

qt=WqxtRdmodelK=WkXR×dmodelV=WvXR×dmodelqt.reshape(nheads,1,dhead)K.reshape(nheads,,dhead)V.reshape(nheads,,dhead)Attention(qt,K,V)=softmax(qtKdmodel)Vq_t = W^qx_t \in \mathbb{R}^{d_{\text{model}}} \\[3mm] K = W^kX \in \mathbb{R}^{\ell \times d_{\text{model}}}\\[3mm] V = W^vX \in \mathbb{R}^{\ell \times d_{\text{model}}}\\[5mm] q_t \text{.reshape} (\text{n}_{\text{heads}}, 1, \text{d}_{\text{head}}) \\[3mm] K \text{.reshape} (\text{n}_{\text{heads}}, \ell, \text{d}_{\text{head}}) \\[3mm] V \text{.reshape} (\text{n}_{\text{heads}}, \ell, \text{d}_{\text{head}}) \\[5mm] \text{Attention}(q_t, K, V) = \text{softmax}\left(\frac{q_tK^{\top}}{\sqrt{d_{\text{model}}}}\right)V

or parallelized, for all tokens tt, as:

Attention(Q,K,V)=softmax(QKd)V\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^{\top}}{\sqrt{d}})V

where QR×dmodelQ.reshape(nheads,,dhead)Q \in \mathbb{R}^{\ell \times d_{\text{model}}} \rightarrow Q\text{.reshape}(\text{n}_{\text{heads}}, \ell, \text{d}_{\text{head}}).

Multi-Query Attention

Multi-Query Attention, has the same definition as Multi-Head attention with the caveat that the keys and values are shared across all attention heads.

For each head, we have have a set of different queries, QR1××dheadQ \in \mathbb{R}^{1 \times \ell \times \text{d}_{\text{head}}}, where if we account for all heads, we'd have QRnheads××dheadQ \in \mathbb{R}^{\text{n}_{\text{heads}} \times \ell \times \text{d}_{\text{head}}}.

But for each head, we share the same KK and VV, both Rnheads××dhead\in \mathbb{R}^{\text{n}_{\text{heads}} \times \ell \times \text{d}_{\text{head}}}.

Given that we share the same KK and VV for all heads, it doesn't make sense for the initial projection from XX to KK and VV to be to R×dmodel\mathbb{R}^{\ell \times d_{\text{model}}}, as the reshaping to Rnheads×dhead\mathbb{R}^{n_{\text{heads}}\ell \times d_{\text{head}}} would introduce nheadn_{\text{head}} unqiue KK and VV matrices, such that we can instead directly project KK and VV as:

K,V=W{K,V}X,W{K,V}Rdmodel×dheadK, V = W^{\{K, V\}}X, \hspace{2mm} W^{\{K, V\}} \in \mathbb{R}^{d_{\text{model}} \times d_{\text{head}}}
class MultiQueryAttention(nn.Module): def __init__( self, n_heads, d_model, context_length, ntk_rope_scaling, dyn_scaling ): super().__init__() self.n_heads = n_heads self.d_model = d_model self.d_head = d_model // n_heads self.context_length = context_length self.ntk_rope_scaling = ntk_rope_scaling self.dyn_scaling = dyn_scaling self.rope = RotaryPositionalEmbedding( d_head = self.d_head, context_length = self.context_length, ntk_rope_scaling = self.ntk_rope_scaling, dyn_scaling = self.dyn_scaling ) def forward(self, q, k, v): # assumes q is (batch_size, context_length = seq_len, d_model) # assumes k is (batch_size, context_length = seq_len, d_head) # assumes v is (batch_size, context_length = seq_len, d_head) b, l, d_model = q.shape assert d_model == self.d_model, f"Expected d_model to be {self.d_model}, but got {d_model}" q = q.view(b, self.n_heads, l, self.d_head) k = k.unsqueeze(1) v = v.unsqueeze(1) assert q.shape[2:] == k.shape[2:], f"Expected q and k to have the same shape for dimensions 2 (seq_len) and 3 (d_head), but got {q.shape}, {k.shape}" assert k.shape == v.shape, f"Expected k and v to have the same shape, but got {k.shape}, {v.shape}" q = self.rope(q) k = self.rope(k) attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.d_head ** 0.5) attn_scores = F.softmax(attn_logits, dim = -1) attn_output = torch.matmul(attn_scores, v).view(b, l, d_model) return attn_output

The benefit of Multi-Query attention is that is saves memory in the KVKV-cache during inference at the loss of some expressivity in the attention mechanism.

Grouped-Query Attention

Grouped Query Attention can be seen as the "balance" between the extremes of Multi-Head Attention and Multi-Query Attention, where rather than each KK and VV being shared amongst all heads for all QQQ \in \mathcal{Q}, there's a unique set of K,VK,VK, V \in \mathcal{K}, \mathcal{V} for gg groups, where the size of the group is Hg\frac{H}{g}, where HH is the total number of heads.

The cardinality of K\mathcal{K} and V\mathcal{V} is then Hg\frac{H}{g}.

Similar to Multi-Query Attention, we must initially project XK,VX \rightarrow K, V to a dimensionality other than dmodeld_{\text{model}}, but like Multi-Head Attention KK and VV are still reshaped to match the dimension of Q.reshape(nheads,,dhead)Q\text{.reshape}(n_{\text{heads}}, \ell, d_{\text{head}}).

Q=WqXR×dmodelK=WkXR×dheadgV=WvXR×dheadgQ=Q.reshape(nheads,,dhead)K=K.reshape(g,,dhead)V=V.reshape(g,,dhead)K=K.interleave(dheadg)V=V.interleave(dheadg)Attention(Q,K,V)=softmax(QKd)VQ = W^qX \in \mathbb{R}^{\ell \times d_{\text{model}}} \\[3mm] K = W^kX \in \mathbb{R}^{\ell \times d_{\text{head}} * g} \\[3mm] V = W^vX \in \mathbb{R}^{\ell \times d_{\text{head}} * g} \\[3mm] Q = Q\text{.reshape}(\text{n}_{\text{heads}}, \ell, d_{\text{head}})\\[3mm] K = K\text{.reshape}(g, \ell, d_{\text{head}}) \\[3mm] V = V\text{.reshape}(g, \ell, d_{\text{head}}) \\[3mm] K = K\text{.interleave}(d_{\text{head}} * g) \\[3mm] V = V\text{.interleave}(d_{\text{head}} * g)\\[5mm] \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^{\top}}{\sqrt{d}})V

or coded out,

class GroupedQueryAttention(nn.Module): def __init__( self, n_heads:int, n_groups:int, d_model, context_length, ntk_rope_scaling, dyn_scaling ): super().__init__() self.n_heads = n_heads self.n_groups = n_groups self.d_model = d_model self.d_head = d_model // n_heads self.context_length = context_length self.ntk_rope_scaling = ntk_rope_scaling self.dyn_scaling = dyn_scaling self.rope = RotaryPositionalEmbedding( d_head = self.d_head, context_length = self.context_length, ntk_rope_scaling = self.ntk_rope_scaling, dyn_scaling = self.dyn_scaling ) def forward(self, q, k, v): # assumes q is (batch_size, context_length = seq_len, d_model) # k will be reshaped to be (batch_size, n_groups, context_length, d_head) # v will be reshaped to be (batch_size, n_groups, context_length, d_head) # # assuming the dimension of k, v is divisible by n_groups, for even k, v reshaping # such that we end up with the same d_head for q, k, v b, l, d_model = q.shape assert self.n_heads % self.n_groups == 0, f"Expected n_heads to be divisible by n_groups, but got n_heads: \ {self.n_heads}, n_groups: {self.n_groups} with remainder of {self.n_heads % self.n_groups}" assert d_model == self.d_model, f"Expected d_model to be {self.d_model}, but got {d_model}" repeats = int(self.n_heads / self.n_groups) q = q.view(b, self.n_heads, l, self.d_head) k = k.view(b, self.n_groups, l, self.d_head).repeat_interleave(repeats = repeats, dim = 1) v = v.view(b, self.n_groups, l, self.d_head).repeat_interleave(repeats = repeats, dim = 1) attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.d_head ** 0.5) attn_scores = F.softmax(attn_logits, dim = -1) attn_output = torch.matmul(attn_scores, v).view(b, l, d_model) return attn_output

The benefit of Grouped-Query Attention is that we still save some memory in the KVKV-cache during inference, with some loss of expressivity in the attention mechanism, but it's all dependent on the number of groups, gg, assigned to the mechanism.