jepa-rs Browser Demo

Self-supervised JEPA training & inference in your browser, using the deterministic CPU-backed WASM path today

Configuration

Training Progress

Step: 0 / 200 Elapsed: 0.0s 0 steps/s

Input

Upload Image

Draw

Results

Patch Norm Heatmap

I-JEPA Architecture

                    ┌────────────────┐
         x_context ─►  Context       │
                    │  Encoder (θ)   ├─► s_x ──┐
                    └────────────────┘         │
                                               ▼
                                         ┌──────────┐
                               z (opt.) ─►          │
                                         │ Predictor├─► ŝ_y ──┐
                      target_positions ─►│          │         │
                                         └──────────┘         │  ┌──────────┐
                                                              ├──► EnergyFn │─► loss
                    ┌────────────────┐                        │  └──────────┘
         x_target  ─►  Target        │                        │
                    │  Encoder (ξ)   ├─► s_y ─────────────────┘
                    └────────────────┘
                         ↑
                         │ EMA(θ → ξ)
                

Context Encoder

Vision Transformer (ViT) that encodes visible patches with gradient flow. Only sees unmasked context tokens — target patches are removed before self-attention.

crates/jepa-vision/src/vit.rs → VitEncoder

Target Encoder

Same ViT architecture as the context encoder, but weights are updated via Exponential Moving Average (EMA) — no gradients flow through this path.

crates/jepa-core/src/ema.rs → Ema

Predictor

Narrow transformer that predicts target representations from context embeddings using position-conditioned prediction tokens.

crates/jepa-vision/src/image.rs → TransformerPredictor

Block Masking

Generates contiguous rectangular blocks of masked target patches on the 2D patch grid, ensuring context and target tokens are disjoint.

crates/jepa-core/src/masking.rs → BlockMasking

Energy Function

Measures prediction quality in representation space. L2, cosine, and smooth L1 distances are supported.

crates/jepa-core/src/energy.rs → L2Energy

Collapse Regularizer

Prevents representation collapse via VICReg or Barlow Twins loss terms that encourage variance and decorrelation.

crates/jepa-core/src/collapse.rs → VICReg

Training Data Flow

  1. Masking: BlockMasking generates context/target index split on the patch grid
  2. Context encoding: Only visible patches enter the context encoder (strict pre-encoder masking)
  3. Target encoding: Full image is encoded by the EMA target encoder (detached from gradient graph)
  4. Prediction: Predictor uses context embeddings + target position tokens to predict target representations
  5. Loss: Energy function (L2) + collapse regularizer (VICReg) compute the total loss
  6. Update: AdamW optimizer steps the context encoder and predictor; EMA updates the target encoder