Attention Quiz
Answer the following questions to check your understanding.
🎮 Interactive Quiz (Recommended)
Why divide by √d_k in Self-Attention?
What do Q, K, V represent?
What is the main advantage of Multi-Head Attention?
What does GQA (Grouped Query Attention) do?
What is the role of a causal mask?
What does repeat_kv do in GQA?
What is the advantage of Flash Attention over the standard implementation?
🎯 Comprehensive Questions
Q8: Practical scenario
If you find that attention weights are almost uniform (each position is ~1/seq_len), what might be wrong and how can you fix it?
Show reference answer
Possible causes:
Q/K initialization issues:
- projection matrices too small
- Q·K scores near 0
- softmax(0, 0, ..., 0) ≈ uniform
Scaling factor issues:
- divided by too large a value
- or forgot the square root (divide by d_k instead of √d_k)
Head dim misconfigured:
- overly large head_dim makes variance huge
- but this tends to cause extreme, not uniform, distributions
Model didn’t learn meaningful patterns:
- data issues
- insufficient capacity
Diagnostics:
# check Q·K scores before softmax
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(head_dim)
print(f"scores mean: {scores.mean()}, std: {scores.std()}")
# normal: mean ≈ 0, std ≈ 1
# problematic: std too small (near 0)Fixes:
Check initialization:
python# use Xavier or Kaiming initialization nn.init.xavier_uniform_(self.wq.weight)Verify scaling factor:
python# ensure sqrt(head_dim), not head_dim scale = math.sqrt(self.head_dim)Visualize attention:
pythonplt.imshow(attn_weights[0, 0].detach().numpy()) plt.title("Attention weights") plt.colorbar()
Q9: Conceptual understanding
Why do Q, K, V all come from the same input x, but still require three different projection matrices? Can we just use x directly?
Show reference answer
Problem with using x directly:
If Q = K = V = x, then:
scores = x @ x.T # self dot-productsThis is essentially cosine similarity in the embedding space.
Issues:
Symmetry: attention from token_i to token_j equals attention from token_j to token_i
- but language relations are often asymmetric
- e.g., “cat eats fish”: cat should attend to fish, fish doesn’t need to attend to cat
Limited expressiveness:
- only captures similarity
- cannot represent relations like subject/object or modifier
Why three projections matter:
Q = x @ W_Q # "as a query, what should I focus on?"
K = x @ W_K # "as a key, what should I reveal?"
V = x @ W_V # "what content should I pass?"Advantages:
- Asymmetry: Q and K differ, enabling directional relations
- Role separation: query view, key view, content view
- Expressiveness: can learn arbitrary relation patterns
Analogy:
- Library search:
- your question (Q): written in natural language
- index labels (K): keywords/tags
- content (V): the actual text
- different “languages,” matched to retrieve the right content
Q10: Code understanding
Explain why contiguous() is necessary here:
output = output.transpose(1, 2).contiguous().view(batch, seq_len, -1)Show reference answer
Background:
PyTorch tensors have two concepts:
- Storage: physical memory layout
- View: how you interpret that memory
What transpose does:
# suppose output is [batch, n_heads, seq, head_dim]
# memory is laid out in that order
output = output.transpose(1, 2)
# now shape is [batch, seq, n_heads, head_dim]
# but memory layout is unchangedProblem:
view() requires contiguous memory, but transpose makes it non-contiguous:
# original memory order (simplified):
# [head0_pos0, head0_pos1, head1_pos0, head1_pos1, ...]
# logical order after transpose:
# [head0_pos0, head1_pos0, head0_pos1, head1_pos1, ...]
# non-contiguous → view failsWhat contiguous() does:
output = output.transpose(1, 2).contiguous()
# 1. reallocate memory
# 2. reorder data to match new view
# 3. now view is safePerformance note:
contiguous()copies memory (costly)- but it’s required here
- alternative:
reshape()handles this implicitly but is less explicit
Best practice:
# when you know you need contiguous memory
output = output.transpose(1, 2).contiguous().view(...)
# or use reshape
output = output.transpose(1, 2).reshape(...)✅ Completion check
After finishing all questions, check whether you can:
- [ ] Get Q1–Q7 all correct: solid basics
- [ ] Provide 2+ diagnostics in Q8: debugging ability
- [ ] Explain projections in Q9: design principles understood
- [ ] Explain why contiguous() is needed in Q10: PyTorch memory model
If anything is unclear, return to teaching.md or rerun the experiments.
🎓 Advanced challenge
Want to go deeper? Try:
Modify experiment code:
- implement Attention without scaling and observe softmax
- implement MQA (Multi-Query Attention) and compare with GQA
- visualize attention patterns from different heads
Read papers:
- Attention Is All You Need - original Transformer paper
- GQA Paper - Grouped Query Attention
- Flash Attention - efficient attention implementation
Implement variants:
- implement Cross-Attention (Q and KV from different inputs)
- implement Sliding Window Attention
- implement Sparse Attention
Next: go to 04. FeedForward to learn feed-forward networks.