Banner

Swarm of Attention Variants

Juan Vera

April 2025

We can define Multi-Head Attention, for a single token tt, as:

qt=WqxtRdmodelK=WkXR×dmodelV=WvXR×dmodelqt.reshape(nheads,1,dhead)K.reshape(nheads,,dhead)V.reshape(nheads,,dhead)Attention(qt,K,V)=softmax(qtKdmodel)Vq_t = W^qx_t \in \mathbb{R}^{d_{\text{model}}} \\[3mm] K = W^kX \in \mathbb{R}^{\ell \times d_{\text{model}}}\\[3mm] V = W^vX \in \mathbb{R}^{\ell \times d_{\text{model}}}\\[5mm] q_t \text{.reshape} (\text{n}_{\text{heads}}, 1, \text{d}_{\text{head}}) \\[3mm] K \text{.reshape} (\text{n}_{\text{heads}}, \ell, \text{d}_{\text{head}}) \\[3mm] V \text{.reshape} (\text{n}_{\text{heads}}, \ell, \text{d}_{\text{head}}) \\[5mm] \text{Attention}(q_t, K, V) = \text{softmax}\left(\frac{q_tK^{\top}}{\sqrt{d_{\text{model}}}}\right)V

or parallelized, for all tokens tt, as:

Attention(Q,K,V)=softmax(QKd)V\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^{\top}}{\sqrt{d}})V

where QR×dmodelQ.reshape(nheads,,dhead)Q \in \mathbb{R}^{\ell \times d_{\text{model}}} \rightarrow Q\text{.reshape}(\text{n}_{\text{heads}}, \ell, \text{d}_{\text{head}}).

Multi-Query Attention

Multi-Query Attention, has the same definition as Multi-Head attention with the caveat that the keys and values are shared across all attention heads.

For each head, we have have a set of different queries, QR1××dheadQ \in \mathbb{R}^{1 \times \ell \times \text{d}_{\text{head}}}, where if we account for all heads, we'd have QRnheads××dheadQ \in \mathbb{R}^{\text{n}_{\text{heads}} \times \ell \times \text{d}_{\text{head}}}.

But for each head, we share the same KK and VV, both Rnheads××dhead\in \mathbb{R}^{\text{n}_{\text{heads}} \times \ell \times \text{d}_{\text{head}}}.

Given that we share the same KK and VV for all heads, it doesn't make sense for the initial projection from XX to KK and VV to be to R×dmodel\mathbb{R}^{\ell \times d_{\text{model}}}, as the reshaping to Rnheads×dhead\mathbb{R}^{n_{\text{heads}}\ell \times d_{\text{head}}} would introduce nheadn_{\text{head}} unqiue KK and VV matrices, such that we can instead directly project KK and VV as:

K,V=W{K,V}X,W{K,V}Rdmodel×dheadK, V = W^{\{K, V\}}X, \hspace{2mm} W^{\{K, V\}} \in \mathbb{R}^{d_{\text{model}} \times d_{\text{head}}}
class MultiQueryAttention(nn.Module):
    
    def __init__(
        self, 
        n_heads, 
        d_model, 
        context_length, 
        ntk_rope_scaling, 
        dyn_scaling
        ):
        
        super().__init__()
   
        self.n_heads = n_heads 
        self.d_model = d_model
        self.d_head = d_model // n_heads 
        self.context_length = context_length
        self.ntk_rope_scaling = ntk_rope_scaling
        self.dyn_scaling = dyn_scaling
  
        self.rope = RotaryPositionalEmbedding(
            d_head = self.d_head,
            context_length = self.context_length,
            ntk_rope_scaling = self.ntk_rope_scaling,
            dyn_scaling = self.dyn_scaling
        )
   
    def forward(self, q, k, v):
        
        # assumes q is (batch_size, context_length = seq_len, d_model)
        # assumes k is (batch_size, context_length = seq_len, d_head)
        # assumes v is (batch_size, context_length = seq_len, d_head) 
        
        b, l, d_model = q.shape
        
        assert d_model == self.d_model, f"Expected d_model to be {self.d_model}, but got {d_model}" 
        
        q = q.view(b, self.n_heads, l, self.d_head)
        k = k.unsqueeze(1)
        v = v.unsqueeze(1) 
        
        assert q.shape[2:] == k.shape[2:], f"Expected q and k to have the same shape for dimensions 2 (seq_len) and 3 (d_head), but got {q.shape}, {k.shape}" 
        assert k.shape == v.shape, f"Expected k and v to have the same shape, but got {k.shape}, {v.shape}"
        
        q = self.rope(q)
        k = self.rope(k) 
        attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.d_head ** 0.5) 
        attn_scores = F.softmax(attn_logits, dim = -1)
        attn_output = torch.matmul(attn_scores, v).view(b, l, d_model)  
 
        return attn_output

The benefit of Multi-Query attention is that is saves memory in the KVKV-cache during inference at the loss of some expressivity in the attention mechanism.

Grouped-Query Attention

Grouped Query Attention can be seen as the "balance" between the extremes of Multi-Head Attention and Multi-Query Attention, where rather than each KK and VV being shared amongst all heads for all QQQ \in \mathcal{Q}, there's a unique set of K,VK,VK, V \in \mathcal{K}, \mathcal{V} for gg groups, where the size of the group is Hg\frac{H}{g}, where HH is the total number of heads.

The cardinality of K\mathcal{K} and V\mathcal{V} is then Hg\frac{H}{g}.

Similar to Multi-Query Attention, we must initially project XK,VX \rightarrow K, V to a dimensionality other than dmodeld_{\text{model}}, but like Multi-Head Attention KK and VV are still reshaped to match the dimension of Q.reshape(nheads,,dhead)Q\text{.reshape}(n_{\text{heads}}, \ell, d_{\text{head}}).

Q=WqXR×dmodelK=WkXR×dheadgV=WvXR×dheadgQ=Q.reshape(nheads,,dhead)K=K.reshape(g,,dhead)V=V.reshape(g,,dhead)K=K.interleave(dheadg)V=V.interleave(dheadg)Attention(Q,K,V)=softmax(QKd)VQ = W^qX \in \mathbb{R}^{\ell \times d_{\text{model}}} \\[3mm] K = W^kX \in \mathbb{R}^{\ell \times d_{\text{head}} * g} \\[3mm] V = W^vX \in \mathbb{R}^{\ell \times d_{\text{head}} * g} \\[3mm] Q = Q\text{.reshape}(\text{n}_{\text{heads}}, \ell, d_{\text{head}})\\[3mm] K = K\text{.reshape}(g, \ell, d_{\text{head}}) \\[3mm] V = V\text{.reshape}(g, \ell, d_{\text{head}}) \\[3mm] K = K\text{.interleave}(d_{\text{head}} * g) \\[3mm] V = V\text{.interleave}(d_{\text{head}} * g)\\[5mm] \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^{\top}}{\sqrt{d}})V

or coded out,

class GroupedQueryAttention(nn.Module):
    
    def __init__(
        self,
        n_heads:int,
        n_groups:int,
        d_model, 
        context_length, 
        ntk_rope_scaling, 
        dyn_scaling        
        ):
        
        super().__init__()
  
        self.n_heads = n_heads 
        self.n_groups = n_groups
        self.d_model = d_model
        self.d_head = d_model // n_heads 
        self.context_length = context_length
        self.ntk_rope_scaling = ntk_rope_scaling
        self.dyn_scaling = dyn_scaling
   
        self.rope = RotaryPositionalEmbedding(
            d_head = self.d_head,
            context_length = self.context_length,
            ntk_rope_scaling = self.ntk_rope_scaling,
            dyn_scaling = self.dyn_scaling
        )
    
    def forward(self, q, k, v):
        
        # assumes q is (batch_size, context_length = seq_len, d_model)
        # k will be reshaped to be (batch_size, n_groups, context_length, d_head)
        # v will be reshaped to be (batch_size, n_groups, context_length, d_head)       
        #
        # assuming the dimension of k, v is divisible by n_groups, for even k, v reshaping 
        # such that we end up with the same d_head for q, k, v
       
        b, l, d_model = q.shape
        
        assert self.n_heads % self.n_groups == 0, f"Expected n_heads to be divisible by n_groups, but got n_heads: \
            {self.n_heads}, n_groups: {self.n_groups} with remainder of {self.n_heads % self.n_groups}"
        assert d_model == self.d_model, f"Expected d_model to be {self.d_model}, but got {d_model}"  
        
        repeats = int(self.n_heads / self.n_groups)
        q = q.view(b, self.n_heads, l, self.d_head)
        k = k.view(b, self.n_groups, l, self.d_head).repeat_interleave(repeats = repeats, dim = 1)
        v = v.view(b, self.n_groups, l, self.d_head).repeat_interleave(repeats = repeats, dim = 1)
        attn_logits = torch.matmul(q, k.transpose(-2, -1)) / (self.d_head ** 0.5) 
        attn_scores = F.softmax(attn_logits, dim = -1)
        attn_output = torch.matmul(attn_scores, v).view(b, l, d_model)      
        
        return attn_output 

The benefit of Grouped-Query Attention is that we still save some memory in the KVKV-cache during inference, with some loss of expressivity in the attention mechanism, but it's all dependent on the number of groups, gg, assigned to the mechanism.