Tensor programs: batch × heads × sequence × dim

Advanced Machine learning
Created by Best · 01.06.2026 at 06:20 UTC

Production attention is a stack of GEMMs (general matrix multiplies) on GPU. Typical shapes: batch $B$, heads $h$, sequence $n$, head dimension $d_h$. Queries might live in $\mathbb{R}^{B \times h \times n \times d_h}$ before scores are formed by contracting the last dimension of $Q$ with $K$ .

Broadcasting errors across the head dimension pair the wrong query head with the wrong key head, scrambling outputs silently. Always verify tensor strides and einsum indices when debugging .

FlashAttention fuses tiled matrix multiply with softmax inside GPU SRAM, reducing round-trips to high-bandwidth memory (HBM). It recomputes some intermediates in backward pass rather than storing the full $n \times n$ attention matrix, trading compute for memory . Tiling keeps blocks of queries and keys resident in fast memory while accumulating softmax statistics blockwise .

Profile attention on real hardware: naive $O(n^2)$ materialization may be memory-bound before compute-bound. FlashAttention-style kernels change the bottleneck; low SM utilization often means fused attention is the fix rather than a wider model .

When porting to another framework, re-derive einsum indices on paper before trusting an auto-translated graph. Head and batch axes are the usual swap that passes shape checks but fails semantics .

University approvals: 0
Related cards
Builds on Differentiable key-value lookup view · Machine learning
Next Masking bugs and leakage · Machine learning
Video Content
Tasks
Question 1

FlashAttention primarily optimizes:

Hint

Skim the paragraphs on FlashAttention primarily optimizes in Tensor programs before choosing. Eliminate options that contradict a definition stated in the card.

Question 2

Tiling in a fused attention kernel keeps:

Hint

Skim the paragraphs on Tiling fused attention kernel keeps in Tensor programs before choosing. Eliminate options that contradict a definition stated in the card.

Question 3

A 'fused' GPU kernel combines:

Hint

Skim the paragraphs on fused kernel combines in Tensor programs before choosing. Eliminate options that contradict a definition stated in the card.

Question 4

Why does incorrect broadcasting across the head dimension scramble attention outputs?

Hint

Skim the paragraphs on incorrect broadcasting across the head dimension scramble attention in Tensor programs before choosing. Eliminate options that contradict a definition stated in the card.

Card Info
  • Topic: Machine learning
  • Difficulty: Advanced
  • Completed: 0 users
Creator
Best
Best
BestBuddy