15 (Bonus). State Space Models#

⚡Compute Note: You can run this notebook on CPU.

Large Language Models (LLMs) keep getting longer context windows — 8k, 32k, 128k, and beyond. But underneath this progress lies a fundamental limitation: Transformers scale quadratically with sequence length. That means doubling the context quadruples the compute and memory footprint.

This makes extremely long-context modeling expensive, slow, and sometimes impossible to train on standard hardware. As we continue pushing towards models that can reason over entire books, multi-hour conversations, or days of sensor data, we need architectures that can scale.

That brings us to State Space Models (SSMs) — a class of sequence models that can process arbitrarily long inputs with:

  • O(N) compute

  • O(1) memory

  • Effectively infinite context length

Unlike Transformers, which store all past tokens in a growing KV cache, SSMs maintain a compressed hidden state that evolves over time and summarizes the entire history. For more information, check this out.

In this tutorial, we will build intuition about:

  • Why Transformers struggle with long context

  • How classical State Space Models work

  • How modern variants like Mamba add selectivity and nonlinearity

  • Why SSMs are becoming central to long-context LLM research (Mamba, Jamba, Hyena, RWKV, RetNet)

We will start from the classical state-space equations, discretize them for deep learning, and gradually build toward a minimal working implementation of a Selective SSM. We will compare this with MLPs, RNNs, and a tiny Transformer on a long-range synthetic task in which Transformers struggle but SSMs succeed.

Let’s begin by understanding why Transformers struggle as context lengths grow.

import math
import random
from dataclasses import dataclass
from typing import Optional, Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


def set_seed(seed: int = 42) -> None:
    """Fix random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed(42)
Using device: cuda

🧪 A Toy Problem That Requires Understanding Long-Range Structure#

To understand why State Space Models matter, we need a task that highlights how different architectures handle repeating patterns with long periods — something that requires tracking structure far back in the sequence, not just recent context.

Instead of a brutal copy-from-200-steps-ago task, we use a smoother but still challenging version of long-range dependence: 🔁 Periodic Sequence Prediction Task.

We create sequences with a hidden periodic structure. For each training example:

  1. We randomly choose a period P between 10 and 100

  2. We generate a random base pattern: \([b_0, b_1, \ldots, b_{P-1}]\)

  3. We repeat this pattern until we reach a full sequence of length: \(T = 300\). The resulting sequence looks like:

\[x = (b_0, b_1, \ldots, b_{P-1}, b_0, b_1, \ldots, b_{P-1}, \ldots)\]

Our task is next-token prediction:

\[y_t = x_{(t+1) \bmod T}\]

This forces the model to infer — and internally represent — the periodicity of the sequence to make correct predictions.

We use: • Sequence length: \(T = 300\) • Period range: \(P \in [10, 100]\) • Vocabulary: integers \(\{0, 1, \ldots, 19\}\)

❗ Why This Task Still Requires Long-Range Understanding

The task demands models to capture global structure, not just local patterns:

  • MLP fails because it lacks any notion of order or repetition. It cannot infer periodicity.

  • Tiny Transformer has positional information and attention, but with only a few layers and small embedding size, learning long periods (e.g., 80–100) is difficult. Attention tends to focus locally rather than on long-range repeats.

  • GRU can, in principle, encode “where we are” in the cycle using its hidden state. Performs better than MLP/Transformer at moderate periods, but still struggles as periods grow large.

  • State Space Models (SSMs), can track a smoothly evolving internal state that naturally captures long-range periodic structure. These models are explicitly built to maintain information over long sequences without quadratic compute.

SEQ_LEN = 300          # full sequence length
VOCAB_SIZE = 20        # integers from 0 to 19
TRAIN_SAMPLES = 5000
TEST_SAMPLES = 5000

MIN_PERIOD = 10
MAX_PERIOD = 100


def make_periodic_sample(
    seq_len: int = SEQ_LEN,
    vocab_size: int = VOCAB_SIZE,
    min_period: int = MIN_PERIOD,
    max_period: int = MAX_PERIOD,
):
    """
    Generate one periodic sequence and its next-token targets.
    """
    # 1) Sample period
    period = random.randint(min_period, max_period)

    # 2) Sample base pattern
    pattern = torch.randint(0, vocab_size, (period,))

    # 3) Tile the pattern to length seq_len
    n_repeats = (seq_len + period - 1) // period
    x_full = pattern.repeat(n_repeats)[:seq_len]  # (seq_len,)

    # 4) Next-token prediction target (circular)
    # y[t] = x_full[(t + 1) % seq_len]
    y_full = torch.roll(x_full, shifts=-1)

    return x_full, y_full


def build_periodic_dataset(n_samples: int):
    X, Y = [], []
    for _ in range(n_samples):
        x, y = make_periodic_sample()
        X.append(x)
        Y.append(y)
    return torch.stack(X), torch.stack(Y)


X_train, Y_train = build_periodic_dataset(TRAIN_SAMPLES)
X_test, Y_test = build_periodic_dataset(TEST_SAMPLES)

print("Train shape:", X_train.shape)
print("Test shape:", X_test.shape)
Train shape: torch.Size([5000, 300])
Test shape: torch.Size([5000, 300])

🔹 Baseline 1: A Simple MLP#

To establish a baseline, we start with the simplest model possible: a Multilayer Perceptron (MLP).

This MLP:

  • Embeds each integer into a vector

  • Flattens the entire sequence into one large vector

  • Passes it through a few feedforward layers

  • Predicts all outputs at once

But there’s a fundamental issue. An MLP has no idea what sequence order means. It cannot learn temporal dependencies, positional structure, or long-range relationships.

This makes the MLP a perfect “dumb baseline” to highlight why architectures with recurrence or attention are needed for sequence modeling.

Below we train the MLP and observe its performance.

class MLPBaseline(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(VOCAB_SIZE, 32)
        self.net = nn.Sequential(
            nn.Linear(SEQ_LEN * 32, 512),
            nn.ReLU(),
            nn.Linear(512, SEQ_LEN * VOCAB_SIZE),
        )

    def forward(self, x):
        # x: (B, T)
        emb = self.embed(x)                 # (B, T, 32)
        flat = emb.reshape(x.size(0), -1)   # (B, T*32)
        out = self.net(flat)                # (B, T * VOCAB_SIZE)
        return out.reshape(x.size(0), SEQ_LEN, VOCAB_SIZE)  # (B, T, V)


mlp = MLPBaseline().to(device)
criterion = nn.CrossEntropyLoss()
mlp_optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)

def train_mlp(model, n_steps: int = 300, batch_size: int = 64):
    model.train()
    for step in range(n_steps):
        idx = torch.randint(0, TRAIN_SAMPLES, (batch_size,))
        batch_x = X_train[idx].to(device)
        batch_y = Y_train[idx].to(device)

        logits = model(batch_x)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), batch_y.reshape(-1))

        mlp_optimizer.zero_grad()
        loss.backward()
        mlp_optimizer.step()

        if step % 50 == 0:
            print(f"[MLP] step {step} | loss = {loss.item():.4f}")

train_mlp(mlp)
[MLP] step 0 | loss = 3.0252
[MLP] step 50 | loss = 2.9010
[MLP] step 100 | loss = 2.4713
[MLP] step 150 | loss = 2.0472
[MLP] step 200 | loss = 1.8850
[MLP] step 250 | loss = 1.5553

🔹 Baseline 2: A Tiny Transformer#

The next baseline is a small Transformer encoder. Transformers are extremely powerful sequence models, but their ability to learn long-range dependencies depends heavily on:

  1. Attention capacity

  2. Depth and width of the model

  3. Proper positional information

Since attention alone does not encode order, we add sinusoidal positional encodings, exactly like in the original Attention is All You Need paper. This allows the model to distinguish whether a token occurred early or late in the sequence.

What we expect:

  • The Transformer should outperform the MLP.

  • The model must infer the period length and where it currently is within the repeating cycle.

This makes the tiny Transformer a great middle baseline, it has some long-range capability, but not enough to handle variable-length periodic structure robustly.

The code below defines and trains this tiny Transformer baseline.

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=SEQ_LEN):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # shape (1, T, d_model)

    def forward(self, x):
        # x: (B, T, d_model)
        return x + self.pe[:, :x.size(1), :]


class TinyTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        d_model = 64
        n_heads = 4
        n_layers = 2

        self.embed = nn.Embedding(VOCAB_SIZE, d_model)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=n_heads, 
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=n_layers
        )

        self.out = nn.Linear(d_model, VOCAB_SIZE)

    def forward(self, x):
        emb = self.embed(x)                      # (B, T, d_model)
        emb = self.pos_encoding(emb)             # add positional info
        h = self.transformer(emb)                # (B, T, d_model)
        return self.out(h)                       # (B, T, vocab)


transformer = TinyTransformer().to(device)
opt_t = torch.optim.Adam(transformer.parameters(), lr=2e-4)


def train_transformer(model, n_steps=2000):
    model.train()
    for step in range(n_steps):
        idx = torch.randint(0, TRAIN_SAMPLES, (32,))
        batch_x = X_train[idx].to(device)
        batch_y = Y_train[idx].to(device)

        logits = model(batch_x)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), batch_y.reshape(-1))

        opt_t.zero_grad()
        loss.backward()
        opt_t.step()

        if step % 500 == 0:
            print(f"[Transformer] step {step} | loss = {loss.item():.4f}")


train_transformer(transformer)
[Transformer] step 0 | loss = 3.1882
[Transformer] step 500 | loss = 2.8013
[Transformer] step 1000 | loss = 2.7920
[Transformer] step 1500 | loss = 2.7718

📊 Evaluating the Baseline Models#

Now that both the MLP and the tiny Transformer have been trained, we evaluate how well they perform on the Repeat-After-K long-range dependency task.

The evaluation is simple:

  • Run the model on the entire test set

  • Compute argmax predictions at each timestep

  • Compare them with the ground truth targets

  • Compute the mean accuracy over all tokens in all sequences

This gives us a clear picture of how well each architecture handles a dependency that reaches 200 steps into the past.

@torch.no_grad()
def evaluate(model):
    model.eval()
    x = X_test.to(device)
    y = Y_test.to(device)

    logits = model(x)
    preds = logits.argmax(dim=-1)

    correct = (preds == y).float().mean().item()
    return correct


mlp_acc = evaluate(mlp)
transformer_acc = evaluate(transformer)

print(f"MLP Test Accuracy:          {mlp_acc:.4f}")
print(f"Transformer Test Accuracy:  {transformer_acc:.4f}")
MLP Test Accuracy:          0.2388
Transformer Test Accuracy:  0.2720

🔹 Baseline 3: GRU#

Before we shift to State Space Models, it is worth examining how a recurrent neural network performs on this long-range dependency task. A GRU (Gated Recurrent Unit) processes the sequence one token at a time and maintains a hidden state that evolves with each step:

\[h_t = \mathrm{GRU}(x_t, h_{t-1})\]

This hidden state acts as a compressed memory of everything that has happened so far. Unlike the MLP and the tiny Transformer, the GRU naturally carries information forward in time.

✔ Why GRUs should perform better#

  1. They have true temporal recurrence. The hidden state is explicitly designed to carry information across time. If the period is, say, 60, the GRU can learn a 60-step rhythm and follow it.

  2. They require only O(1) memory with respect to sequence length. There is no attention matrix, no expanding KV cache, and no need to store all past tokens. A GRU processes a sequence of length 300 with the same memory cost as a sequence of length 30.

  3. Their updates are local and stable. Each step only depends on the previous state. This often helps for tasks with simple long-range rules.

✘ But GRUs still struggle with very long dependencies#

Even with gating, GRUs must push information through hundreds of recurrent updates. When the period becomes large (e.g., 80–100), the network may:

  • Fade or distort the periodic pattern

  • Lose track of the exact rhythm

  • Fail to generalize across different period lengths

This makes GRUs better than MLPs and small Transformers, but still not ideal for reliably modeling very long-range periodic structure.

Below, we implement and train a simple GRU model and evaluate its performance.

class GRUBaseline(nn.Module):
    def __init__(self, d_model: int = 256, num_layers: int = 2):
        super().__init__()
        self.embed = nn.Embedding(VOCAB_SIZE, d_model)
        self.gru = nn.GRU(
            input_size=d_model,
            hidden_size=d_model,
            num_layers=num_layers,
            batch_first=True,
        )
        self.out = nn.Linear(d_model, VOCAB_SIZE)

    def forward(self, x):
        # x: (B, T)
        emb = self.embed(x)              # (B, T, d_model)
        h, _ = self.gru(emb)             # (B, T, d_model)
        logits = self.out(h)             # (B, T, vocab)
        return logits


gru = GRUBaseline().to(device)
opt_gru = torch.optim.Adam(gru.parameters(), lr=1e-3)

def train_gru(model, n_steps: int = 3000, batch_size: int = 64):
    model.train()
    for step in range(n_steps):
        idx = torch.randint(0, TRAIN_SAMPLES, (batch_size,))
        batch_x = X_train[idx].to(device)
        batch_y = Y_train[idx].to(device)

        logits = model(batch_x)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), batch_y.reshape(-1))
        
        opt_gru.zero_grad()
        loss.backward()
        opt_gru.step()

        if step % 500 == 0:
            print(f"[GRU] step {step} | loss = {loss.item():.4f}")


train_gru(gru, n_steps=3000, batch_size=64)

gru_acc = evaluate(gru)
print(f"GRU Test Accuracy:          {gru_acc:.4f}")
[GRU] step 0 | loss = 3.0009
[GRU] step 500 | loss = 2.5025
[GRU] step 1000 | loss = 2.2018
[GRU] step 1500 | loss = 1.7013
[GRU] step 2000 | loss = 1.5784
[GRU] step 2500 | loss = 1.6664
GRU Test Accuracy:          0.4003

🧩 State Space Models with Selective Memory#

So far we’ve seen that Transformers are excellent at capturing rich dependencies, but their quadratic cost in sequence length makes them impractical for truly long contexts. Recurrent models (like GRUs) scale linearly, but struggle to remember precise long-term patterns due to vanishing gradients.

State Space Models (SSMs), and specifically Mamba, sit beautifully between these extremes — they provide continuous-time recurrence with parallelizable convolution, enabling efficient long-range reasoning without quadratic attention cost.

I recommend reading this blog post.

🧠 The Core Idea: How Mamba Thinks About Sequences#

RNNs think step-by-step. Transformers look at all tokens at once with attention.

State Space Models (SSMs), like Mamba, use a third view:

“Keep a continuous memory state that evolves smoothly as we read the sequence.”

You can think of this memory as a vector \(x_t\) that gets updated as tokens come in:

  • \(x_t\) = what the model remembers at time t

  • It is updated using:

    • Old memory (what we remembered so far)

    • New input (current token embedding)

In continuous form, this is written as:

\[ \frac{dx(t)}{dt} = A\,x(t) + B\,u(t) \]

You don’t need to worry about the calculus; intuitively:

  • \(A\) says how memory decays or mixes over time

  • \(B\) says how strongly each input channel affects memory

  • \(u(t)\) is the current input (e.g., token embedding)

When we turn this into a discrete update over tokens, it becomes:

\[ x_{t+1} \approx \text{(some decay)} \cdot x_t + \text{(some weight)} \cdot u_t \]

So each step:

“Forget a bit of the old state, add a bit of the new input.”

We then read out a prediction from the state:

\[ y_t = C\,x_t \]

This “evolving memory + readout” is the backbone of long-sequence models like S4, Mamba, and S5.

⚙️ From S4 to Mamba: Making the Memory Selective#

Vanilla SSMs use a fixed step size (how fast time moves).
Mamba makes this input-adaptive and adds a gate:

  • It learns a per-step “time step”:

    \[ \Delta_t = \text{softplus}(W_\Delta x_t) \]

    → The model can speed up or slow down how fast memory changes.

  • It also learns a gate that decides how much of the SSM output to use:

    \[ z_t = \sigma(W_g x_t) \odot y_t \]

    → The gate can boost important tokens and down-weight irrelevant ones.

Together, this means:

Mamba doesn’t just remember — it chooses when and what to remember.

🧩 Our Simplified Implementation: Diagonal SSM + Learned Gates#

For this tutorial, we implement a simplified Mamba that is still faithful to the core idea, but much easier to code and understand.

We make two simplifications:

  1. Diagonal A
    Each hidden channel has its own decay rate (A_i). No big matrices.

  2. Explicit causal convolution
    We compute how each past token influences the current output using a simple exponential kernel.

The discrete update per channel (i) looks like:

\[ x_{t+1,i} = e^{-\Delta_{t,i} A_i}\,x_{t,i} + (\Delta_{t,i} B_i)\,u_{t,i} \]
  • First term: keep a decayed version of old memory

  • Second term: add a scaled version of the new input

Unrolling this over time gives a causal convolution:

\[ y_t = \sum_{k=0}^{t} \underbrace{\left(\prod_{j=k+1}^{t} e^{-\Delta_{j}A}\right)}_{\text{how much memory survives}} \underbrace{(\Delta_k B)\,u_k}_{\text{contribution from step }k} \]

We then apply a gate and project back to the model dimension:

\[ z_t = \sigma(W_g x_t) \odot y_t,\quad o_t = W_o z_t \]

🧮 Putting It All Together#

The code below implements this pipeline step by step:

  1. Input projection – embed tokens into a “state space” vector.

  2. Dynamic decay – learn a step size (\Delta_t) for each time step and channel.

  3. Exponential decays and drives – compute how fast each channel forgets and how strongly inputs enter the state.

  4. Causal convolution – combine all past inputs with these decays to get the current output (y_t).

  5. Selective gating – decide how much of (y_t) to use at each position.

  6. Projection to vocab space – turn the resulting features into logits for next-token prediction.

In short:

Mamba = RNN-like recurrence + SSM math + gating, all wrapped into a layer that scales linearly with sequence length and can handle very long contexts efficiently.

Now let’s see how this simplified Mamba layer looks in code 👇

class SimpleMambaLayer(nn.Module):
    """
    Minimal Mamba layer:
      - diagonal state matrix A (no NxN explosion)
      - learned B, C
      - input-dependent Δ_t (step size)
      - SSM scan: x_{t+1} = exp(-Δ_t * A) * x_t + (Δ_t * B) * u_t
      - convolution kernel from closed-form SSM
      - gating
    """
    def __init__(self, d_model=64, state_dim=64):
        super().__init__()
        self.d_model = d_model
        self.N = state_dim

        # Input projections
        self.in_proj = nn.Linear(d_model, self.N)
        self.delta_proj = nn.Linear(d_model, self.N)
        self.gate_proj = nn.Linear(d_model, self.N)

        # SSM parameters (diagonal A, diagonal B)
        self.A_log = nn.Parameter(torch.randn(self.N))   # will become positive via softplus
        self.B = nn.Parameter(torch.randn(self.N) * 0.1)

        # Output projection
        self.out_proj = nn.Linear(self.N, d_model)

    def forward(self, x):
        B, T, D = x.shape

        u = self.in_proj(x)                            # (B,T,N)
        delta = F.softplus(self.delta_proj(x))         # (B,T,N)
        gate = torch.sigmoid(self.gate_proj(x))        # (B,T,N)

        A = F.softplus(self.A_log)                     # (N,)
        decay = torch.exp(-delta * A)                  # (B,T,N)
        drive = delta * self.B                         # (B,T,N)

        # Compute kernel weights
        log_decay = torch.log(decay + 1e-7)            # (B,T,N)
        prefix = torch.cumsum(log_decay, dim=1)        # (B,T,N)

        # Compute k_t = exp(prefix[t] - prefix[k]) * drive[k]
        y = torch.zeros_like(u)

        for t in range(T):
            w = torch.exp(prefix[:, t:t+1, :] - prefix[:, :t+1, :]) # (B,t+1,N)
            y[:, t] = torch.sum(w * (drive[:, :t+1, :] * u[:, :t+1, :]), dim=1)

        z = gate * y
        return self.out_proj(z)

class SimpleMambaModel(nn.Module):
    def __init__(self, d_model: int = 64, state_dim: int = 1024, num_layers: int = 1):
        super().__init__()
        self.embed = nn.Embedding(VOCAB_SIZE, d_model)

        layers = []
        for _ in range(num_layers):
            layers.append(SimpleMambaLayer(d_model, state_dim))
        self.layers = nn.ModuleList(layers)

        self.out = nn.Linear(d_model, VOCAB_SIZE)

    def forward(self, x):
        # x: (B, T)
        h = self.embed(x)  # (B, T, d_model)
        for layer in self.layers:
            h = layer(h)
        logits = self.out(h)  # (B, T, vocab)
        return logits

mamba = SimpleMambaModel(d_model=64, state_dim=1024, num_layers=2).to(device)
opt_mamba = torch.optim.Adam(mamba.parameters(), lr=1e-3)

def train_mamba(model, n_steps: int = 3000, batch_size: int = 64):
    model.train()
    for step in range(n_steps):
        idx = torch.randint(0, TRAIN_SAMPLES, (batch_size,))
        batch_x = X_train[idx].to(device)
        batch_y = Y_train[idx].to(device)

        logits = model(batch_x)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), batch_y.reshape(-1))

        opt_mamba.zero_grad()
        loss.backward()
        opt_mamba.step()

        if step % 500 == 0:
            print(f"[Mamba-like] step {step} | loss = {loss.item():.4f}")


train_mamba(mamba, n_steps=2000, batch_size=64)

mamba_acc = evaluate(mamba)
print(f"Simple Mamba-like Test Accuracy: {mamba_acc:.4f}")
[Mamba-like] step 0 | loss = 2.9986
[Mamba-like] step 500 | loss = 2.5404
[Mamba-like] step 1000 | loss = 2.4470
[Mamba-like] step 1500 | loss = 2.1321
Simple Mamba-like Test Accuracy: 0.4613

🐍 Mamba 2: Making SSMs Faster and More Flexible#

So far, we’ve seen a simplified Mamba-style layer (Mamba 1) and used it to build intuition:

  • There is a hidden state that evolves over time.

  • The state update is governed by a state matrix \(A\) and input matrix \(B\).

  • We get linear-time sequence processing and good long-range memory.

Mamba 2 takes these ideas and asks:

“How can we keep the nice SSM structure,
but make it even more parallel, more hardware-friendly,
and better at modeling different kinds of patterns?”

⚙️ What Changes in Mamba 2?#

Mamba 2 keeps the same spirit:

  • State evolution via SSM

  • Input-dependent step sizes and gating

  • Linear-time selective scan over the sequence

But improves how the computation is organized:

  1. State Space Duality (SSD)
    The SSM layer is rewritten in a form that looks more like a multi-head block:

    • The channels are grouped into “heads”.

    • Each head can learn different temporal patterns (slow, fast, oscillatory…).

    • The math is arranged so that it can reuse many of the same tricks and kernels that fast attention implementations use.

  2. Better parallelism on hardware
    Mamba 2 is designed from the ground up to:

    • Maximize GPU utilization

    • Reduce memory movement

    • Make the selective scan more fused and batched

  3. Richer effective dynamics
    Instead of a single “flat” SSM, Mamba 2 behaves more like:

    • A set of structured SSM heads

    • Each head has its own parameters and sees the sequence slightly differently

    • Their outputs are combined, similar to multi-head attention, but still in O(T) time.

🧩 High-Level Mamba 2 Equations#

At a very high level, for each head \(h\) and channel \(i\), you can think of Mamba 2 computing something like:

  1. Head-specific state update:

    \[ x_{t+1}^{(h)} = A^{(h)}_t x_t^{(h)} + B^{(h)}_t u_t \]
  2. Head-specific output:

    \[ y_t^{(h)} = C^{(h)} x_t^{(h)}\]
  3. Combine heads and gate them:

    \[y_t = \sum_h g_t^{(h)} \odot y_t^{(h)}\]

Here:

  • Each head \(h\) has its own SSM parameters and sees a projection of the input.

  • The gates \(g_t^{(h)}\) decide how much each head contributes at each position.

  • The implementation uses a careful SSD formulation so that all heads can be computed efficiently in parallel.

You don’t need to track every symbol — the key idea is:

Mamba 2 = Mamba-style SSM + multi-head structure + better GPU-friendly math.

🧠 Intuition: When Do We Need Mamba 2 Over Mamba 1?#

In our toy notebook, a single small Mamba-style layer (Mamba 1) is already enough to:

  • Beat the MLP

  • Compete with or beat a small Transformer and GRU on long-period patterns

So why care about Mamba 2?

  • Scaling up
    When you scale to many layers, large d_model, and very long sequences (e.g. tens of thousands of tokens), you want:

    • Stronger modeling capacity per layer

    • Excellent hardware efficiency

    • Ability to use head-wise structure, like attention

  • Better use of modern GPUs
    Mamba 2 reorganizes the computation so it can reuse high-performance kernels, making training and inference more efficient at scale.

  • More flexible temporal patterns
    With grouped channels / heads, Mamba 2 can let different parts of the model specialize:

    • Some heads track very long, slow signals.

    • Others focus on short, local patterns.

    • All of this stays within an SSM framework.

In short:

Mamba 1 is great for intuition and small-scale demos.
Mamba 2 is the engine you plug into real models when you want long context, high throughput, and modern GPU efficiency.

import torch.nn as nn
from mamba_ssm import Mamba2

class Mamba2Model(nn.Module):
    def __init__(self, d_model: int = 64, state_dim: int = 1024, num_layers: int = 1):
        super().__init__()
        self.embed = nn.Embedding(VOCAB_SIZE, d_model)

        self.mamba2 = Mamba2(
            d_model=d_model,
            d_state=state_dim,
            d_conv=4,
            expand=2,
            headdim=64,
        )
        self.mamba2.use_mem_eff_path = False

        self.out = nn.Linear(d_model, VOCAB_SIZE)

    def forward(self, x):
        # x: (B, T)
        h = self.embed(x)     # (B, T, d_model)
        h = self.mamba2(h)    # (B, T, d_model) – real or identity
        logits = self.out(h)  # (B, T, vocab_size)
        return logits

mamba2 = Mamba2Model(d_model=64, state_dim=1024, num_layers=2).to(device)
opt_mamba2 = torch.optim.Adam(mamba2.parameters(), lr=1e-3)

def train_mamba(model, n_steps: int = 3000, batch_size: int = 64):
    model.train()
    for step in range(n_steps):
        idx = torch.randint(0, TRAIN_SAMPLES, (batch_size,))
        batch_x = X_train[idx].to(device)
        batch_y = Y_train[idx].to(device)

        logits = model(batch_x)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), batch_y.reshape(-1))

        opt_mamba2.zero_grad()
        loss.backward()
        opt_mamba2.step()

        if step % 500 == 0:
            print(f"[Mamba2] step {step} | loss = {loss.item():.4f}")

train_mamba(mamba2, n_steps=2000, batch_size=64)

mamba2_acc = evaluate(mamba2)
print(f"Mamba2 Test Accuracy: {mamba2_acc:.4f}")
[Mamba2] step 0 | loss = 3.0486
[Mamba2] step 500 | loss = 1.2650
[Mamba2] step 1000 | loss = 0.7374
[Mamba2] step 1500 | loss = 0.6665
Mamba2 Test Accuracy: 0.5248

We see that the accuracy goes from nearly 20% for MLP and Transformer to 40% for GRU then 52% for advanced state space models such as Mamba2.

✅ Conclusion: Why SSMs Matter for Long Context#

In this notebook we walked through a full mini-battle of sequence models on a synthetic long-range pattern task (periodic integer sequences):

  1. MLP baseline

    • Sees the whole sequence at once, but has no notion of time.

    • Treats the sequence like a bag of positions → struggles to lock onto long periodic structure.

  2. Tiny Transformer

    • Adds positional encodings and self-attention.

    • In theory, can model long-range dependencies, but with small depth/width it:

      • Spreads attention too diffusely,

      • Has to pay a quadratic cost in sequence length,

      • And still fails to capture long periodic patterns reliably.

  3. GRU (recurrent baseline)

    • Processes tokens one by one, with a hidden state as memory.

    • Naturally suited to temporal structure and often beats MLP / tiny Transformer.

    • But gradients still have to flow through hundreds of steps → long-range recall remains fragile.

  4. Mamba-style SSM (our “Mamba 1”)

    • Reframes the problem as a continuous-time state space system: \( x_{t+1} \approx A_t x_t + B_t u_t\).

    • Uses selective, input-dependent updates and exponential decays to keep or forget information over very long horizons.

    • Already shows how linear-time SSMs can compete with or outperform classical sequence models.

  5. Mamba2 (official SSM block)

    • Builds on the same SSM foundation, but adds:

      • A multi-head / grouped channel structure (State Space Duality),

      • Better GPU-friendly kernels,

      • More flexible temporal dynamics per head.

    • In our experiments, Mamba2 achieved the highest test accuracy on the periodic sequence task, while still scaling linearly in sequence length.

🧠 Key Takeaways#

  • Long context is not free.
    Simply increasing Transformer context length grows compute and memory quadratically, and small models quickly hit their limits.

  • Recurrence helps, but isn’t enough.
    GRUs/LSTMs can remember further back than MLPs, but still struggle on very long dependencies and don’t scale as nicely to very long sequences.

  • State Space Models provide a sweet spot.
    They offer:

    • Linear-time sequence processing,

    • Explicit control over how memory decays over time,

    • And with selectivity/gating (Mamba, Mamba2), input-adaptive memory that can stretch across thousands of tokens.

  • Mamba2 is a practical, scalable building block.
    It keeps the SSM core but packages it into a drop-in, GPU-efficient layer you can use much like a Transformer block.

🚀 Where to Go Next#

From here, you can:

  • Stack multiple Mamba / Mamba2 layers and compare to deeper Transformers.

  • Try real data (text, time series, code) with long-range dependencies.

  • Mix architectures:

    • Mamba layers for long context,

    • Attention layers for fine-grained local reasoning.

In short, this tutorial shows that if you care about long context + efficiency,
you don’t have to stay in the Transformer world — State Space Models, and especially Mamba2, are powerful alternatives for building the next generation of long-context LLMs.