LLM Architecture
Optimizing Contextual Processing with Multi-Head and Grouped-Query Attention
Deep dive into the math of self-attention and explore how GQA reduces memory overhead in long-context modern LLMs.
In this article
The Information Bottleneck: Self-Attention and Scaling
Modern large language models rely on the Transformer architecture to process text sequences in parallel rather than sequentially. This shift away from recurrent neural networks allowed for massive scaling but introduced a new bottleneck in how information is weighted across long contexts. Self-attention is the engine of this process, enabling every token in a sequence to look at every other token to build a contextual representation.
The core of self-attention involves three learned linear transformations that produce Query, Key, and Value matrices. You can think of the Query as what a token is looking for, the Key as what a token contains, and the Value as the information it provides when a match is found. By calculating the dot product between queries and keys, the model determines a similarity score that dictates how much information from each value should be aggregated.
A critical but often overlooked detail is the scaling factor of the square root of the key dimension used in the attention formula. Without this scaling, the dot products would grow extremely large as the dimensionality of the model increases, leading to very small gradients during the softmax operation. This mathematical stabilization ensures that the model can learn effectively across hundreds of layers and thousands of dimensions.
The Mechanics of Vector Similarity
In a high-dimensional space, the dot product acts as a measure of alignment between two vectors. When a Query vector aligns closely with a Key vector, their dot product is high, indicating that the information held by that Key is relevant to the current processing step. The resulting attention weight then acts as a gate, allowing a specific percentage of the corresponding Value vector to pass into the next layer of the network.
This process happens across multiple heads in parallel, which is known as Multi-Head Attention. Each head is initialized with different weights, allowing the model to simultaneously focus on different aspects of the text, such as grammar, factual entities, and long-range semantic dependencies. This diversity of perspective is what gives Transformers their nuanced understanding of complex natural language.
The Memory Crisis: KV Caching and Bandwidth
While training a model is compute-intensive, running inference on a trained model introduces a different set of engineering challenges. In the decoding phase, where the model generates text one token at a time, it must reference the entire preceding context for every new word. Recomputing the Key and Value vectors for all previous tokens at every step would be computationally disastrous and redundant.
To solve this, developers use a technique called KV Caching, which stores the previously computed Key and Value vectors in GPU memory. While this saves massive amounts of computation, it creates a significant memory bottleneck because the cache grows linearly with the length of the sequence. For a model with billions of parameters and a long context window, the KV cache can quickly consume the majority of available VRAM, limiting the number of concurrent users a single GPU can serve.
- Batch size: Increasing the number of parallel requests multiplies the cache size proportionally.
- Sequence length: Long-context applications like document analysis require vast amounts of memory to store history.
- Model depth: Every layer in the transformer maintains its own KV cache, compounding the memory requirements.
- Precision: Storing values in FP16 or BF16 consumes two bytes per element, making quantization a popular optimization.
The primary performance constraint during this phase is memory bandwidth rather than raw FLOPs. The GPU spends more time moving KV cache data from high-bandwidth memory to the processing cores than it does actually performing the attention calculations. Reducing the size of these cached vectors is therefore the most effective way to speed up inference and support longer conversations.
Prefill versus Decoding Phases
Inference is split into two distinct stages: prefill and decoding. During prefill, the model processes the entire user prompt at once, which is a compute-bound task where the GPU cores are fully utilized. The resulting Key and Value vectors are then saved into the cache for the next stage.
In the decoding stage, the model generates tokens one by one in an autoregressive fashion. Because each step only produces a single new vector while reading thousands of cached ones, the GPU remains underutilized while waiting for memory transfers. This asymmetry explains why the first token of a response often feels slower to generate than subsequent ones.
Modern Optimizations: Grouped Query Attention
To address the memory overhead of the KV cache, researchers initially proposed Multi-Query Attention, where all query heads share a single Key and Value head. While this drastically reduces the memory footprint, it often results in a significant drop in model quality because the model's ability to attend to different types of information is constrained. Multi-head and multi-query represent two extremes of a trade-off between performance and accuracy.
Grouped Query Attention, or GQA, was introduced as the ideal middle ground and is now the standard in models like Llama 3 and Mistral. In GQA, query heads are divided into groups, and each group shares a single set of Key and Value heads. This allows for a massive reduction in KV cache size with almost no loss in the model's reasoning capabilities or linguistic accuracy.
1import torch
2import torch.nn as nn
3
4class GroupedQueryAttention(nn.Module):
5 def __init__(self, d_model, n_heads, n_kv_heads):
6 super().__init__()
7 self.n_heads = n_heads
8 self.n_kv_heads = n_kv_heads
9 self.head_dim = d_model // n_heads
10 self.group_size = n_heads // n_kv_heads
11
12 # Standard linear projections
13 self.q_proj = nn.Linear(d_model, n_heads * self.head_dim)
14 self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim)
15 self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim)
16
17 def forward(self, x):
18 batch, seq_len, _ = x.shape
19 q = self.q_proj(x).view(batch, seq_len, self.n_heads, self.head_dim)
20 k = self.k_proj(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
21 v = self.v_proj(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
22
23 # Repeat K and V across the groups to match Q dimensions
24 # This effectively implements the head sharing logic
25 k = k.repeat_interleave(self.group_size, dim=2)
26 v = v.repeat_interleave(self.group_size, dim=2)
27
28 # Proceed with standard scaled dot-product attention
29 return q, k, vBy using a ratio such as eight query heads for every one KV head, a model can reduce the memory traffic required for the KV cache by a factor of eight. This optimization allows developers to run much larger models on consumer hardware and enables the processing of massive documents that would otherwise exceed the VRAM limits of professional data center GPUs.
The Benefits of Uptraining
One of the unique advantages of Grouped Query Attention is that existing Multi-Head Attention models can be converted to GQA through a process called uptraining. This involves initializing the grouped Key and Value heads by averaging the weights of the original heads and then performing a small amount of fine-tuning. This allows researchers to upgrade the efficiency of existing models without the astronomical cost of training from scratch.
This architectural flexibility has led to a rapid adoption of GQA across the open-source community. Developers can now take a high-performance model and optimize it for low-latency inference environments by simply adjusting the head grouping configuration. This democratizes access to advanced LLMs by lowering the hardware barrier for deployment.
Flexible Positioning: The Magic of RoPE
Traditional Transformers used absolute positional encodings, where a unique vector was added to each token to indicate its place in the sequence. This approach worked for fixed context lengths but failed when models encountered sequences longer than their training data. Rotary Positional Embeddings, or RoPE, solved this by treating the embedding space as a series of two-dimensional planes where position is encoded as a rotation.
Instead of adding a static vector, RoPE applies a rotation matrix to the Query and Key vectors. The angle of rotation depends on the token's position in the sequence. Because the dot product between two rotated vectors only depends on the difference between their angles, the model naturally learns the relative distance between tokens rather than their absolute coordinates.
RoPE is a game-changer for context length because it allows for length extrapolation. If a model is trained on 4,000 tokens, RoPE provides a mathematical path to extend that context to 128,000 tokens by simply adjusting the base frequency of the rotations.
1def apply_rotary_emb(x, cos, sin):
2 # Split the last dimension into pairs for 2D rotation
3 x1 = x[..., 0::2]
4 x2 = x[..., 1::2]
5
6 # Apply the rotation formula: [x1*cos - x2*sin, x1*sin + x2*cos]
7 # This rotates the vector in the complex plane
8 rotated_x = torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
9
10 # Flatten the pairs back into the original shape
11 return rotated_x.flatten(-2)
12
13# Note: cos and sin are precomputed based on sequence indicesThe rotation mechanism ensures that the magnitude of the vectors remains unchanged, which preserves the stability of the attention scores. Furthermore, the decay property of these rotations means that as tokens get further apart, their attention weights naturally diminish. This mirrors how humans process language, where recent context is usually more relevant than something mentioned many pages ago.
Scaling Beyond the Training Limit
The mathematical elegance of RoPE allows for a technique called NTK-aware scaling, which adjusts the rotation frequencies to fit more tokens into the same representational space. By slightly shrinking the angles of rotation, developers can extend the effective context window of a model without requiring any additional training. This has enabled the jump from the 2,048 token limits of early models to the million-token windows seen in state-of-the-art systems.
However, scaling context length is not just about the math of positioning; it also requires the efficient memory management provided by GQA. Without GQA, a million-token context would require hundreds of gigabytes of VRAM just for the KV cache. The combination of RoPE for positioning and GQA for memory efficiency is the fundamental architectural foundation of the modern generative AI era.
