Rust + WASM · Interactive Demo

Attention Residuals

Replacing fixed residual connections with learned softmax attention over depth.
Each layer selectively routes information from all preceding blocks.

Paper: Attention Residuals (MoonshotAI / Kimi) · Implementation: attnres (burn framework)

Loading WASM engine…

The Problem with Standard Residuals

In standard Transformers, the residual connection is a simple addition:

hl+1 = hl + Fl(hl)

Every layer contributes equally to the output, regardless of whether its representation is useful at a given position. In deep networks (100+ layers), this creates two problems:

  1. Information dilution. Early layers’ features get progressively washed out as more layers add to the sum. The signal-to-noise ratio degrades with depth.
  2. No selective routing. A layer that computed a “useful” representation has the same weight as one that computed noise. There is no mechanism to choose which layers matter.

Key insight: What if each layer could attend to all prior layers and dynamically decide how much weight to assign each one?

Standard Residual
All layers contribute equally (weight = 1).
No selectivity over depth.

Attention Residuals: The Algorithm

Stack block representations

Collect all completed block sums b0, …, bn-1 plus the current partial block into a value matrix.

V = [b0; b1; …; bn(partial)]  ∈  ℝ(N+1) × D

Normalize keys with RMSNorm

Prevent large-magnitude blocks from dominating attention logits. Without this, deeper blocks (which accumulate more layer outputs) would receive disproportionate weight.

K = RMSNorm(V) = (V / √mean(V²)) · γ

Compute depth attention logits

A learned pseudo-query wl ∈ ℝD scores each block. Crucially, w is initialized to zero — ensuring the model starts as a standard residual and smoothly transitions.

logitsi = Ki · wl    ∀ i ∈ {0, …, N}

Softmax over depth

The softmax is taken over the block/depth dimension, not the sequence dimension. This is attention over layers, not over tokens.

αi = softmax(logits)i = exp(logitsi) / ∑j exp(logitsj)

Weighted combination

The output is a learned convex combination of all block representations. Each layer can choose exactly how much information to draw from each depth.

h = ∑i αi · Vi
Two AttnRes per transformer layer. Each transformer layer has two sublayers (self-attention + MLP). Both sublayers get their own AttnRes operation with its own learned pseudo-query wl. This means each sublayer independently decides which prior blocks are most relevant for its computation.

Interactive: Core AttnRes Operation

Adjust the pseudo-query weights and observe how the attention distribution over depth changes in real time. With zero weights (initialization), all sources receive equal attention. As weights evolve during training, selective patterns emerge.

Training: Watching Patterns Emerge

Observe how depth attention patterns evolve during training. The pseudo-query vectors wl start at zero (uniform attention) and gradually learn to selectively attend to the most useful blocks at each depth.

Step 0
Loss
Loss Curve
Initialize a model and start training to see the loss curve
Depth Attention Heatmap (evolving)
Depth attention patterns will appear here during training
Pseudo-Query Norms ||wl||
Pseudo-query norm evolution will appear here during training

Standard Residual vs. AttnRes

Standard Residual

h = hl + F(hl)
  • Fixed weight = 1 per layer
  • No selectivity over depth
  • Information dilution in deep models
  • Zero learnable parameters added

Attention Residual

h = ∑ αi · bi
  • Learned weights via softmax
  • Selective routing over depth
  • Preserves signal in deep networks
  • +D params per sublayer (negligible)
Block AttnRes reduces overhead. Instead of attending to every individual layer (quadratic), Block AttnRes groups layers into N blocks and attends over blocks. With typical N=4–8 blocks for a 100-layer model, the overhead is minimal.