(Q?)KV Cache

Juan Vera

April 2025

Abstract

I got tired of intuitively knowing how KV-Cache works without seeing it from first principles for myself, so here you go.

Preliminaries

We can define the Attention Mechanism as:

(1)Attention(q,K,V)=t=1l(exp(qktdmodel)j=1lexp(qkjdmodel))vt=vt^Rdmodel(1) \hspace{5mm}\text{Attention}(q, K, V) = \sum_{t=1}^l \left(\frac{\exp({\frac{qk_t^{\top}}{\sqrt{d_{\text{model}}}}})}{\sum_{j=1}^{l}\exp(\frac{qk_j^\top}{\sqrt{d_{\text{model}}}})}\right) \odot v_t \\[3mm] = \hat{v_t} \in \mathbb{R}^{d_{\text{model}}}

where ll is the sequence length, tt is the current token, and jj is index for the token up until tt.

Assuming q,vt,ktRdmodelq, v_t, k_t \in \mathbb{R}^{d_{\text{model}}} where dmodeld_{\text{model}} is the dimensionality of the attention-space.

The dot product, qktqk_t^{\top} is a similarity score as qkt=qkcos(θ)qk_t^{\top} = ||q||||k||\cos(\theta), and therefore the more similar (in direction) qq and kk are, the larger the value of cosθ\cos{\theta} will be.

The normalization by dmodel\sqrt{d_{\text{model}}} is to avoid saturation in the softmax()\text{softmax}(\cdot), as the higher the difference in magnitude the attention scores, αt=exp(qkt)j=1lexp(qkj)\alpha_t = \frac{\exp(qk_t^{\top})}{\sum_{j=1}^{l}\exp(qk_j^\top)} without normalization by dmodel\sqrt{d_{\text{model}}} will lead to some vector, vtv_t to have an extremely higher magnitude (if some αt\alpha_t is extremely large relative to other αt\alpha_t) relative to other vtv_t and respectively, the model will attend to the vtv_t with higher magnitude, extremely more than those with low magnitude, if unnormalized.

While this is normal behavior, desired to some degree, normalization allows for a much more evenly distributed attention-score matrix, where surrounding tokens play a larger role into the next-token prediction unlike what would've been the case without normalization.

After computing a total of ll αt\alpha_t via softmax\text{softmax}, we compute the element-wise product with all vtv_t and then sum, to get v^t\hat{v}_t.

We can define this same operation as a matrix multiplication:

(2)Attention(q,K,V)=softmax(qKd)V(2) \hspace{5mm} \text{Attention}(q, K, V) = \text{softmax}(\frac{qK^\top}{\sqrt{d}})V

where

  • K,VRl×dmodelK, V \in \mathbb{R}^{l \times d_{\text{model}}}, where ll is the sequence length and dmodeld_{\text{model}} is the dimensionality of the attention space.
  • α=softmax(qK)Rl\vec{\alpha} = \text{softmax}(qK^\top) \in \mathbb{R}^{l}
  • qRdmodelq \in \mathbb{R}^{d_\text{model}}

I won't waste my time trying to right mathematical notation for this, but essentially, you can define it as:

import torch
import torch.nn.functional as F

seq_len = 10
d_model = 256 # a common embedding size

q = torch.randn(size = (d_model,))
K = torch.randn( size = (seq_len, d_model) )
V = torch.randn(size = (seq_len, d_model))

attn_probs = F.softmax(torch.matmul(q, K.transpose(0, 1)))
print(attn_probs.shape)

We can simply matmul α\vec{\alpha} and VV as the operation is essentially the equivalent of the summation of all ll α,vt\alpha, v_t multiplications.

If VRl×dmodelV \in \mathbb{R}^{l\times d_{\text{model}}} then,

(2.5)[α1α2αl][v1,1v1,2v1,dmodelv2,1v2,2v2,dmodelvl,1vl,2vl,dmodel]Rdmodel(2.5) \hspace{5mm}\begin{bmatrix} \alpha_1 & \alpha_2 & \cdots & \alpha_l \end{bmatrix} \begin{bmatrix} v_{1,1} & v_{1,2} & \cdots & v_{1,d_{\text{model}}} \\ v_{2,1} & v_{2,2} & \cdots & v_{2,d_{\text{model}}} \\ \vdots & \vdots & \ddots & \vdots \\ v_{l,1} & v_{l,2} & \cdots & v_{l,d_{\text{model}}} \end{bmatrix} \in \mathbb{R}^{d_{\text{model}}}

where when we multiply the jjth column in VV with α\vec{\alpha}, we're equivalently computing the summation \rightarrow multiplication in (1)(1), to get the output, v^t\hat{v}_t, the vector which extracts how much "attention" the model should pay to the ttth token.

If you ran:

out = torch.matmul(attn_probs, V)
print(out.shape) # d_model

you'd get

out
as a vector in Rdmodel\mathbb{R}^{d_{\text{model}}}.

Of course, to compute all attention scores and correspondingly the full result, V^Rl×dmodel\hat{V} \in \mathbb{R}^{l \times d_{\text{model}}}, we can define qq as a matrix as well:

(3)Attention(Q,K,V)=softmax(QKd)V=V^Rl×dmodel(3) \hspace{5mm} \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^\top}{\sqrt{d}})V = \hat{V} \rightarrow \mathbb{R}^{l \times d_{\text{model}}}

where QRl×dmodelQ \in \mathbb{R}^{l \times d_{\text{model}}}, such that:

[α1,1α1,2α1,lα2,1α2,2α2,lαl,1αl,2αl,l][v1,1v1,2v1,dmodelv2,1v2,2v2,dmodelvl,1vl,2vl,dmodel]Rl×dmodel\begin{bmatrix} \alpha_{1,1} & \alpha_{1,2} & \cdots & \alpha_{1,l} \\ \alpha_{2,1} & \alpha_{2,2} & \cdots & \alpha_{2,l} \\ \vdots & \vdots & \ddots & \vdots \\ \alpha_{l,1} & \alpha_{l,2} & \cdots & \alpha_{l,l} \end{bmatrix} \begin{bmatrix} v_{1,1} & v_{1,2} & \cdots & v_{1,d_{\text{model}}} \\ v_{2,1} & v_{2,2} & \cdots & v_{2,d_{\text{model}}} \\ \vdots & \vdots & \ddots & \vdots \\ v_{l,1} & v_{l,2} & \cdots & v_{l,d_{\text{model}}} \end{bmatrix} \rightarrow \mathbb{R}^{l \times d_{\text{model}}}

or coded out:

import torch
import torch.nn.functional as F

seq_len = 10
d_model = 256 # a common embedding size

Q = torch.randn(size = (seq_len, d_model))
K = torch.randn( size = (seq_len, d_model) )
V = torch.randn(size = (seq_len, d_model))

attn_probs = F.softmax(torch.matmul(Q, K.transpose(0, 1)))
print(attn_probs.shape) # (seq_len, seq_len)

allows us to then compute the matrix multiplication with VV to get V^Rl×dmodel\hat{V} \in \mathbb{R}^{l \times d_{\text{model}}}.

out = torch.matmul(attn_probs, V)
print(out.shape) # (seq_len, d_model)

Of course, you can further paralellize this process by computing this in batches.

KV Cache

To answer the question, "why the hell can we cache the K and V matrices?"

Given the definition for the attention mechanism, (3)(3):

(3)Attention(Q,K,V)=softmax(QKdmodel)V=V^Rl×dmodel(3) \hspace{5mm} \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^\top}{\sqrt{d_{\text{model}}}})V = \hat{V} \rightarrow \mathbb{R}^{l \times d_{\text{model}}}

where:

  • QRl×dmodelQ \in \mathbb{R}^{l \times d_{\text{model}}}
  • KRl×dmodelK \in \mathbb{R}^{l \times d_{\text{model}}}
  • VRl×dmodelV \in \mathbb{R}^{l \times d_{\text{model}}}
  • QKRl×lQK^{\top} \in \mathbb{R}^{l \times l}

caching becomes important in autoregressive settings -- where we always predict the next token, tt, given an input of len()t1\text{len}() \rightarrow t - 1.

Where we have Q,K,VQ, K, V and XRl×dmodelX \in \mathbb{R}^{l \times d_{\text{model}}}:

Qt=XtWQVt=XtWVKt=XtWKQ_t = X_tW_Q \\[3mm] V_t = X_tW_V \\[3mm] K_t = X_tW_K \\[3mm]

all in Rl×dmodel\mathbb{R}^{l \times d_{\text{model}}} or equivalently, during autoregressive inference where l=tl = t, Rt×dmodel\mathbb{R}^{t \times d_{\text{model}}}.

When predicting the next token, t+1t + 1, Q,K,VQ, K, V comes to be R(t+1)×dmodel\in \mathbb{R}^{(t+1) \times d_{\text{model}}}, such that:

Kt+1=[k1k2ktkt+1]=Xt+1WQVt+1=[v1v2vtvt+1]=Xt+1WVK_{t+1} = \begin{bmatrix} k_1 \\ k_2 \\ \vdots \\ k_t \\ k_{t+1}\end{bmatrix} = X_{t+1}W_Q \\[4mm] V_{t+1} = \begin{bmatrix} v_1 \\ v_2 \\ \vdots \\ v_t \\ v_{t+1}\end{bmatrix} = X_{t+1}W_V

Notice that:

Kt+1=[Ktkt+1]Vt+1=[Vtvt+1]K_{t+1} = \begin{bmatrix} K_t \\ k_{t+1} \end{bmatrix} \\[3mm] V_{t+1} = \begin{bmatrix} V_t \\ v_{t+1} \end{bmatrix}

such that you can cache any KtK_t and VtV_t, and reuse it, while only having to compute kt+1k_{t+1}.

Q Cache?

Looking back at equation (2)(2) and (2.5)(2.5):

(2)Attention(q,K,V)=softmax(qKd)V(2.5)[α1α2αl][v1,1v1,2v1,dmodelv2,1v2,2v2,dmodelvl,1vl,2vl,dmodel]Rdmodel(2) \hspace{5mm} \text{Attention}(q, K, V) = \text{softmax}(\frac{qK^\top}{\sqrt{d}})V \\[5mm] (2.5) \hspace{5mm}\begin{bmatrix} \alpha_1 & \alpha_2 & \cdots & \alpha_l \end{bmatrix} \begin{bmatrix} v_{1,1} & v_{1,2} & \cdots & v_{1,d_{\text{model}}} \\ v_{2,1} & v_{2,2} & \cdots & v_{2,d_{\text{model}}} \\ \vdots & \vdots & \ddots & \vdots \\ v_{l,1} & v_{l,2} & \cdots & v_{l,d_{\text{model}}} \end{bmatrix} \in \mathbb{R}^{d_{\text{model}}}

you can see that for a given qq, you get a vector of attention scores, α\vec{\alpha}, for that given qq with respect to all rows, kiKk_i \in K (or all columns in its transpose).

Given that during autoregressive generation, you only need to predict the next token, having a set of attention scores as the vector α\vec{\alpha} rather than the full matrix ARl×l\Alpha \in \mathbb{R}^{l \times l}, it's redundant to cache QQ at all, when all you really need is qq, to compute the scaled dot product.