PennyLane
Install
Install

Related materials

  • Related contentMagic state distillation
  • Related contentIntroduction to mid-circuit measurements
  • Related contentHow to quantum just-in-time (QJIT) compile Grover's algorithm with Catalyst

Contents

  1. Introduction
  2. What We’ll Build
  3. Understanding Quantum Error Correction
    1. Symplectic Representation
    2. Measurement and Syndrome
    3. Decoding: The Classical Half of QEC
    4. CSS Codes: Simplifying the Structure
  4. The Steane Code
  5. Lookup‑table (LUT) decoding
  6. Belief-Propagation (BP) Decoder
    1. The Sum-Product Algorithm
    2. Why This Works
    3. BP in JAX
  7. Catalyst hybrid kernel
  8. Simulating Errors and Full Correction
  9. Benchmarking logical vs. physical error rates
  10. Conclusion and Limitations
  11. References
  12. About the author

Downloads

  • Download Python script
  • Download Notebook
  • View on GitHub
  1. Demos/
  2. Quantum Computing/
  3. Decoding Quantum Errors on the Steane code with Belief Propagation & Catalyst

Decoding Quantum Errors on the Steane code with Belief Propagation & Catalyst

Tom Ginsberg

Tom Ginsberg

Published: August 25, 2025. Last updated: August 25, 2025.

Learn how to build, simulate, and decode the Steane code using JAX and Catalyst, blending quantum circuits with fast classical decoders in a seamless workflow.

Hero image

Introduction

This tutorial walks you through a simplified error correction cycle using the Steane $[[7,1,3]]$ code. You’ll encode a logical qubit, introduce noise, extract syndromes, and apply decoding using two different strategies: a simple lookup table and a belief propagation (BP) decoder. Both decoders are implemented in JAX and JIT-compiled by Catalyst, allowing everything to run inside a single @qml.qjit circuit.

Why is this exciting? Quantum error correction (QEC) is essential for building reliable quantum computers, but it requires more than just quantum operations. Fast classical feedback is needed as well. Catalyst addresses this by fusing the classical and quantum workflows, giving us one unified, hardware-agnostic kernel that runs on CPUs, GPUs, and beyond.

What We’ll Build

By the end of this tutorial, you’ll have:

  • Encoded a logical $|0⟩$ using the Steane code.

  • Simulated noise using configurable bit-flip and phase-flip channels.

  • Extracted syndromes via ancilla-assisted stabilizer measurements.

  • Decoded errors using:

    • a Lookup Table (LUT) decoder

    • a Belief Propagation (BP) decoder

  • Benchmarked performance across different physical error rates.

Understanding Quantum Error Correction

At its core, QEC protects quantum information through redundant encoding. Stabilizer codes are a foundational tool for this purpose. Each stabilizer is a multi-qubit Pauli operator (composed of I, X, Y, Z gates) that defines parity constraints the code space must satisfy.

Symplectic Representation

To formalize stabilizers and errors, we use symplectic vectors. For \(n\) physical qubits, any \(n\)-qubit Pauli operator can be represented as a binary vector of length \(2n\):

\[(v | u) = (v_1, v_2, ..., v_n | u_1, u_2, ..., u_n),\]

where:

  • \(v_i = 1\) if the Pauli has an X on qubit \(i\) (0 otherwise),

  • \(u_i = 1\) if the Pauli has a Z on qubit \(i\) (0 otherwise),

  • and \(v_i = u_i = 1\) if the Pauli has a Y on qubit \(i\), since \(X \cdot Z \propto Y\).

For example, the operator \(XZIY\) on 4 qubits corresponds to \((1,0,0, 1 | 0,1,0, 1)\). The stabilizer group is then generated by a set of such symplectic vectors, forming a stabilizer matrix of size \(m \times 2n\) (with \(m\) generators).

The commutation condition between two Pauli operators \((v|u)\) and \((v'|u')\) is captured by their symplectic inner product:

\[v \cdot u' + u \cdot v' \pmod{2}.\]

For a valid stabilizer code, all generators must commute, which implies that the symplectic inner product between any two rows of the stabilizer matrix must be zero.

Measurement and Syndrome

When an error \(e\) occurs, represented as a symplectic vector \((v_e | u_e)\), the syndrome is calculated by taking the symplectic inner product of \(e\) with each stabilizer generator. This produces a syndrome bit \(s_i\) for each generator:

\[s_i = v_g^{(i)} \cdot u_e + u_g^{(i)} \cdot v_e \pmod{2},\]

where \((v_g^{(i)} | u_g^{(i)})\) is the \(i\)-th generator.

Below, we show the basic quantum circuit for extracting a syndrome value. The ancilla qubit is initialized in the \(|+\rangle\) state, and controlled operations are applied based on the stabilizer generators. Finally, the ancilla is measured in the \(X\)-basis to obtain the syndrome value. Check out Arthur Pesah’s excellent blog post series 1 on Stabilizer codes for a deeper introduction.

from typing import Callable, Optional, Dict, Union, Sequence

import pennylane as qml

syndromes = ["XXZIZIX", "XXIIZZI"]
dev = qml.device("lightning.qubit", wires := max(map(len, syndromes)) + 1)


@qml.qnode(device=dev)
def ancilla_assisted_syndrome_extraction(syndromes: list[str]):
    ancilla = wires - 1
    for i, syndrome in enumerate(syndromes):
        qml.Hadamard(ancilla)
        for i, s in enumerate(syndrome):
            if s == "X":
                qml.CNOT(wires=[ancilla, i])
            elif s == "Z":
                qml.CZ(wires=[ancilla, i])
        qml.Hadamard(ancilla)
        qml.measure(ancilla)
        qml.Barrier()
        ancilla += 1


print(qml.draw(ancilla_assisted_syndrome_extraction, show_all_wires=True)(syndromes))
0: ────╭X──────────────────────||────╭X───────────────────||─┤
1: ────│──╭X───────────────────||────│──╭X────────────────||─┤
2: ────│──│──╭Z────────────────||────│──│─────────────────||─┤
3: ────│──│──│─────────────────||────│──│─────────────────||─┤
4: ────│──│──│──╭Z─────────────||────│──│──╭Z─────────────||─┤
5: ────│──│──│──│──────────────||────│──│──│──╭Z──────────||─┤
6: ────│──│──│──│──╭X──────────||────│──│──│──│───────────||─┤
7: ──H─╰●─╰●─╰●─╰●─╰●──H──┤↗├──||────│──│──│──│───────────||─┤
8: ────────────────────────────||──H─╰●─╰●─╰●─╰●──H──┤↗├──||─┤

Decoding: The Classical Half of QEC

Once you have syndrome bits from your stabilizer measurements, you need to figure out what error likely occurred–this is the job of the decoder. Formally, given the syndrome, you’re solving for the most probable error, usually called the maximum likelihood estimate (MLE) for the error.

However, exact MLE decoding depends on the precise information of your noise model and is generally computationally intractable (NP-Hard) because \(n\) one-bit syndrome measurements can take on \(2^n\) unique values. In practice, we rely on approximate methods tuned to assumptions about the noise model.

A complete quantum error correction (QEC) cycle. A logical state (:math: `|\psi\rangle_L`), experiences an error before stabilizers (:math: `\mathcal{S}`) are measured. The resulting syndrome is decoded classically to produce a correction (:math: `\mathcal{R}`), which is applied to restore the logical state

CSS Codes: Simplifying the Structure

CSS (Calderbank–Shor–Steane) codes are a special class of stabilizer code where the generators are split into X-type and Z-type operators. Their symplectic vectors look like this:

  • X-type generator: \((v | 0)\) (only Xs)

  • Z-type generator: \((0 | u)\) (only Zs)

This allows us to represent the stabilizers with two \(m \times n\) matrices:

  • \(H_X\) for X-type generators

  • \(H_Z\) for Z-type generators

The commutation condition to ensure that all generators are simultaneously observable is:

\[H_X H_Z^T = 0 \pmod{2},\]

which ensures that all X and Z stabilizers commute pairwise. When measuring syndromes:

  • X-type stabilizers detect Z errors via \(s_X = H_X e_Z^T \pmod{2}\).

  • Z-type stabilizers detect X errors via \(s_Z = H_Z e_X^T \pmod{2}\).

This separation makes decoding modular, allowing you to handle X and Z errors independently. When we introduce the Steane code later, you’ll see these matrices explicitly and how they simplify syndrome calculation and decoding. See a similar diagram below for the CSS code cycle structure.

The Error Correction Cycle on a CSS Code

The Steane Code

The Steane code is one of the simplest quantum error correcting codes, a CSS code built from two classical Hamming codes. It encodes one logical qubit into seven physical qubits and can correct any single-qubit error. Traditionally, the error correcting ability of a code is referred as the distance or \(d\) and the number of errors a code can correct is \(\lfloor (d-1)/2 \rfloor\). Since the Steane code can correct a single error, it is said to have distance \(3\). This code uses six stabilizer generators:

\[\begin{split}H_X = \begin{bmatrix} 0 & 0 & 0 & 1 & 1 & 1 & 1 \\ 0 & 1 & 1 & 0 & 0 & 1 & 1 \\ 1 & 0 & 1 & 0 & 1 & 0 & 1 \end{bmatrix}, \quad H_Z = H_X.\end{split}\]

We’ll start by implementing two decoding strategies:

  • Lookup Table (LUT): Pre-compute minimal corrections for every syndrome (possible for small codes like this one).

  • Belief Propagation (BP): An iterative message-passing algorithm that operates on the code’s Tanner graph (a bipartite graph representing the relationships between qubits and stabilizers). It approximates the marginal probabilities of errors on each qubit, offering greater scalability for larger, sparser codes.

We’ll begin with the LUT decoder due to its simplicity and then explore BP, which is more flexible for larger or sparser codes.

Lookup‑table (LUT) decoding

For the Steane code, with \(3\) \(X\) and \(3\) \(Z\) stabilizer generators, there are \(2^3=8\) possible syndromes for both \(X\) and \(Z\). We can create a small table that maps each three‑bit syndrome to a weight‑1 error.

import jax.numpy as jnp
from itertools import combinations
from jax.typing import ArrayLike
import jax
from tabulate import tabulate


def lookup_decoder(matrix: ArrayLike, max_weight: int = 1):
    m, n = matrix.shape
    lut = jnp.zeros((1 << m, n), dtype=jnp.int8)

    # fill table with the lowest‑weight correction for each syndrome
    # we do this by iterating over all possible weight one errors and computing their corresponding syndromes
    for w in range(1, max_weight + 1):
        # iterate over all possible weight-w errors
        for qs in combinations(range(n), w):
            err = jnp.zeros(n, dtype=jnp.int8).at[jnp.array(qs)].set(1)  # error mask
            syn = (matrix @ err) % 2  # syndrome for this error
            idx = jnp.dot(syn, 1 << jnp.arange(m, dtype=jnp.int8))  # syndrome bits to base 10 index
            lut = lut.at[idx].set(err)

    @jax.jit
    def _decode(syndrome: ArrayLike):
        # convert the syndrome to base 10 and look it up in the table
        idx = jnp.dot(syndrome, 1 << jnp.arange(m))
        return lut[idx]

    return _decode


H_steane= jnp.array(
    [[0, 0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 0, 1, 1], [1, 0, 1, 0, 1, 0, 1]], dtype=int
)
lut_steane= lookup_decoder(H_steane)

# we see that the steane code has a nice property where counting up in binary shifts the error to the right
table_data = []
for i in range(8):
    decoded = lut_steane(jnp.array([int(x) for x in f"{i:03b}"]))
    table_data.append([f"{i:03b}", "".join(map(str, decoded))])

print(tabulate(table_data, headers=["Syndrome", "LUT Error"]))
Syndrome    LUT Error
----------  -----------
       000      0000000
       001      1000000
       010      0100000
       011      0010000
       100      0001000
       101      0000100
       110      0000010
       111      0000001

While this approach is optimal for small codes, it rapidly becomes infeasible for larger examples. For instance, the distance-\(30\) rotated surface code, which encodes only \(1\) logical qubits, has \(450\) stabilizers for both \(X\) and \(Z\). A full lookup table decoder for just one check type would require approximately \(2.9\times 10^{35}\) entries.

Belief-Propagation (BP) Decoder

Belief propagation is an iterative message-passing algorithm used to decode errors by working on the Tanner graph 2 of the code. This graph has two types of nodes:

  • Variable nodes represent the physical qubits, which may or may not have experienced an error. These correspond to the bits of the error vector \(e = (e_1, e_2, \dots, e_n)\).

  • Check nodes represent stabilizers, which enforce parity constraints on subsets of qubits. Each check node corresponds to a row of the parity-check matrix \(H\).

There is an edge between a check node \(c\) and a variable node \(v\) if and only if \(H_{cv} = 1\), meaning that qubit \(v\) participates in stabilizer \(c\).

The goal is to estimate the probability that each qubit has been flipped (i.e., that \(e_v = 1\)), given the observed syndrome bits \(s_c\). BP updates there beliefs iteratively by exchanging messages between variable and check nodes.

The Sum-Product Algorithm

The BP decoder is based on the sum-product algorithm, which computes marginal probabilities over the graph. Here’s the procedure:

  1. Initialization

    Each variable node \(v\) sends an initial message to its neighboring checks that reflects the intrinsic belief about whether an error has occurred. This is the log-likelihood ratio (LLR) based on the physical error rate \(p\):

    \[L_0 = \log\frac{1 - p}{p}\]

    This expresses the prior belief: if \(p\) is small (e.g., 0.01), then \(L_0\) is positive, favoring no error; if \(p\) is close to 0.5, \(L_0\) is near zero (no strong prior). In general \(p\) is a parameter of the algorithm that can be tuned to your specific noise source.

  2. Variable-to-Check Update

    Each variable node updates its message to a neighboring check \(c\) by combining its intrinsic belief with the incoming messages from other connected checks:

    \[m_{v \to c} = L_0 + \sum_{c' \in N(v) \setminus c} m_{c' \to v}\]

    Here:

    • \(m_{v \to c}\) is the message from variable \(v\) to check \(c\).

    • \(N(v)\) is the set of checks connected to variable \(v\).

    • \(m_{c' \to v}\) are messages received from neighboring checks other than \(c\).

  3. Check-to-Variable Update

    Each check node updates its message to a neighboring variable \(v\) based on the syndrome bit \(s_c\) and the incoming messages from the other variables connected to it:

    \[m_{c \to v} = (-1)^{s_c} \; 2 \, \operatorname{arctanh} \biggl( \prod_{v' \in N(c) \setminus v} \tanh\frac{m_{v' \to c}}{2} \biggr)\]

    Here:

    • \(m_{c \to v}\) is the message from check \(c\) to variable \(v\).

    • \(s_c\) is the syndrome bit for check \(c\) (0 if the stabilizer is satisfied, 1 if violated).

    • \(N(c)\) is the set of variables connected to check \(c\).

    • The \(\tanh\) and \(\operatorname{arctanh}\) functions implement the sum-product rule for combining binary parity checks derived from classical probability theory.

    What’s going on? This formula indicates that if the product of incoming \(\tanh\) terms is close to +1 or -1, it means there is a strong belief about whether the parity is satisfied or violated. The \(\operatorname{arctanh}\) converts that back into an LLR-style message. The \((-1)^{s_c}\) factor flips the sign if the syndrome is 1, signaling that a parity error was detected.

  4. Iteration

    Steps 2 and 3 are repeated for a fixed number of iterations (e.g., 10–20) or until the messages converge (i.e., stop changing significantly). Traditional theory and heuristics in error correction say to repeat \(BP\) roughly on the order of \(O(n)\).

  5. Decision Rule

    After the iterations, each variable node computes its posterior LLR by summing its intrinsic belief and all incoming messages:

    \[L_v = L_0 + \sum_{c \in N(v)} m_{c \to v}\]

    The decoder then makes a hard decision:

    • If \(L_v < 0\), it guesses \(e_v = 1\) (error detected).

    • If \(L_v > 0\), it guesses \(e_v = 0\) (no error).

Why This Works

Belief propagation is exact on tree-like graphs, where no cycles exist. However, even on Tanner graphs, which are never tree-like, it provides a good approximation to the maximum-likelihood decoder by using only local, iterative computations. Nevertheless, its performance can degrade when the Tanner graph contains many short cycles—a common characteristic of many popular quantum codes, which can lead to poor convergence. In practice, further extensions like BP-OSD 3, BP-LSD 4 or Ambiguity Clustering 5 are used to fix these issues.

See the following summary article 6 as well as Chapter 5 in Bayesian Reasoning and Machine Learning 7 for a deeper dive into message passing algorithms on graphs.

BP in JAX

Below, we implement a BP decoder using JAX broken down into it’s core components.

Before we can pass messages, we need to establish the connectivity between nodes. The _build_graph function scans the parity‑check matrix once and records, for every variable node, which checks touch it and vice versa. We convert the neighbour lists to tuples so they become immutable, hashable static data. JAX can then embed their values as compile‑time constants in the XLA program and reliably reuse the compiled kernel multiple times. A cool thing about JAX/XLA is that when using simple static parameters like the ones below, the individual integers it contains are baked into the XLA program as compile‑time constants, so we can truly compile a high performance decoder for our specific parity check matrix.

def _build_graph(
    pcm: ArrayLike,
) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]:
    """
    Pre‑compute variable‑node and check‑node neighbors.

    Returns
    -------
    var_neighbors : tuple[tuple[int, ...], ...]  # length = n
    check_neighbors : tuple[tuple[int, ...], ...]  # length = m
    """
    m, n = pcm.shape
    vars_, checks_ = [[] for _ in range(n)], [[] for _ in range(m)]

    for c in range(m):
        for v in range(n):
            if pcm[c, v]:
                vars_[v].append(c)
                checks_[c].append(v)

    return tuple(map(tuple, vars_)), tuple(map(tuple, checks_))

A nice way to visulaize this Tanner graph is using the networkx package. Below is an example on the Steane code.

import matplotlib.pyplot as plt
import networkx as nx

vars, checks = _build_graph(H_steane)
G = nx.Graph()
num_vars = len(vars)
num_checks = len(checks)

# build the nx graph object from our vars and checks
for v in range(num_vars):
    G.add_node(f"v{v}", bipartite=0)
for c in range(num_checks):
    G.add_node(f"c{c}", bipartite=1)
for c in range(num_checks):
    for v in checks[c]:
        G.add_edge(f"c{c}", f"v{v}")

pos = nx.bipartite_layout(G, nodes=[f"v{i}" for i in range(num_vars)])

plt.figure(figsize=(10, 7))
nx.draw(G, pos, with_labels=True, node_color="skyblue", node_size=500, font_weight="bold")
plt.title("Bipartite Graph for H_steane", fontsize=16)
plt.show()
Bipartite Graph for H_steane

The _c2v_update helper function performs one full sweep of check‑to‑variable updates (step 3 of the sum‑product algorithm). It takes the previous messages, the syndrome, the neighbor tables, and two scalars (L_int for the intrinsic log‑likelihood ratio and eps for numerical safety). It loops only over existing edges, multiplies the relevant \(\operatorname{tanh}\) terms, clips the product, applies \(\operatorname{arctanh}\), and writes the new message into the next matrix.

def _c2v_update(
    m_c2v_prev: ArrayLike,
    syndrome: ArrayLike,
    var_nei: tuple[tuple[int, ...], ...],
    check_nei: tuple[tuple[int, ...], ...],
    L_int: float,
    eps: float,
) -> ArrayLike:
    """
    Compute the next round of check‑to‑variable messages.
    """
    m, n = m_c2v_prev.shape
    m_c2v_next = jnp.zeros_like(m_c2v_prev)

    # Loop over checks (outer) then their vars (inner)
    for c in range(m):
        Vc = check_nei[c]
        if len(Vc) < 2:
            continue  # degree‑1 checks carry no new info

        for v in Vc:
            prod = 1.0
            # product over all v' ≠ v in this check
            for v_p in Vc:
                if v_p == v:
                    continue
                incoming = L_int
                for c_p in var_nei[v_p]:
                    if c_p != c:
                        incoming += m_c2v_prev[c_p, v_p]
                prod *= jnp.tanh(incoming / 2.0)

            prod = jnp.clip(prod, -1.0 + eps, 1.0 - eps)
            msg = ((-1) ** syndrome[c]) * 2.0 * jnp.arctanh(prod)
            m_c2v_next = m_c2v_next.at[c, v].set(msg)

    return m_c2v_next

Once the main loop finishes, we still need a hard decision. The function _posterior_llrs folds every final check‑to‑variable message for bit v into its intrinsic LLR, yielding the posterior belief for that bit. A negative value means “error likely,” a positive value means “no error.”

def _posterior_llrs(
    m_c2v_final: ArrayLike, var_nei: tuple[tuple[int, ...], ...], L_int: float
) -> ArrayLike:
    """
    Combine intrinsic LLR with all incoming messages.
    """
    n = m_c2v_final.shape[1]
    llr = jnp.full(n, L_int)
    for v in range(n):
        for c in var_nei[v]:
            llr = llr.at[v].add(m_c2v_final[c, v])
    return llr

build_bp_decoder serves as the main entry point for compiling our decoder. It takes the parity‑check matrix and channel error rate, builds the parity graph, pre‑computes the intrinsic LLR, and returns a JIT‑compiled function _decode.

Inside _decode, the following steps are executed:

  1. All messages are zero-initialized.

  2. _c2v_update is called inside a jax.lax.fori_loop for max_iter rounds.

  3. Final messages are converted to posterior LLRs with _posterior_llrs.

  4. A binary error vector is output by thresholding the LLRs at zero.

Because the whole _decode body is wrapped in @jax.jit, the first call compiles everything into an XLA kernel; subsequent calls run at full device speed.

def build_bp_decoder(
    parity_check_matrix: ArrayLike,
    error_rate: float,
    max_iter: int = 10,
    epsilon: float = 1e-9,
) -> Callable[[ArrayLike], ArrayLike]:
    """
    Return a JIT‑compiled BP decoder for the given code and channel.

    Parameters
    ----------
    parity_check_matrix : array‑like (m, n)
    error_rate : float              # BSC crossover probability p
    max_iter : int
    epsilon : float                 # numerical safety margin
    """
    pcm = jnp.asarray(parity_check_matrix, dtype=jnp.int32)
    m, n = pcm.shape
    L_int = jnp.log((1.0 - error_rate) / error_rate)

    var_nei, check_nei = _build_graph(pcm)

    @jax.jit
    def _decode(syndrome: ArrayLike) -> ArrayLike:
        syndrome = jnp.asarray(syndrome, dtype=jnp.int32)

        # Initialise all messages to zero
        m_c2v = jnp.zeros((m, n), dtype=jnp.float32)

        # BP loop
        def body(_, msgs):
            return _c2v_update(msgs, syndrome, var_nei, check_nei, L_int, epsilon)

        m_c2v = jax.lax.fori_loop(0, max_iter, body, m_c2v)

        # Hard decision from posterior LLRs
        llr = _posterior_llrs(m_c2v, var_nei, L_int)
        return (llr < 0).astype(jnp.int32)

    # optionally we can force our decoder to compile right away by calling it on a test input
    _decode(jnp.zeros(m, dtype=jnp.int32))

    return _decode

Let’s test the performance of the BP decoder on the Steane code compared to the LUT decoder.

bp_steane = build_bp_decoder(H_steane, error_rate=0.05, max_iter=7)

n_bits = H_steane.shape[0]
correct = 0
total_syndromes = 2**n_bits

table_data = []
headers = ["Syndrome", "BP Estimated Error", "LUT Exact Error", "Match"]

for i in range(total_syndromes):
    syndrome_binary_string = f"{i:0{n_bits}b}"
    s_array = jnp.array([int(x) for x in syndrome_binary_string])

    # Get error patterns from BP decoder and LUT
    bp_pattern = bp_steane(s_array)
    lut_pattern = lut_steane(s_array)
    match = jnp.all(bp_pattern == lut_pattern)

    bp_pattern_str = "".join(map(str, bp_pattern.tolist()))
    lut_pattern_str = "".join(map(str, lut_pattern.tolist()))

    table_data.append([syndrome_binary_string, bp_pattern_str, lut_pattern_str, str(match)])

    # Increment correct count if patterns match
    if match:
        correct += 1

print(tabulate(table_data, headers=headers))

# Calculate and print the BP accuracy
accuracy = (correct / total_syndromes) * 100 if total_syndromes > 0 else 0
print(f"\nBP Accuracy: {accuracy:.2f}%")
Syndrome    BP Estimated Error    LUT Exact Error  Match
----------  --------------------  -----------------  -------
       000               0000000            0000000  True
       001               1000000            1000000  True
       010               0100000            0100000  True
       011               0010000            0010000  True
       100               0001000            0001000  True
       101               0000100            0000100  True
       110               0000010            0000010  True
       111               0000001            0000001  True

BP Accuracy: 100.00%

Before diving into the code, let’s test our belief‑propagation (BP) decoder on a bigger example: the n‑bit repetition code. This code stores each logical bit by repeating it \(n\) times (e.g. \(0 \mapsto 00\ldots0\) and \(1 \mapsto 11\ldots1\)). Its parity‑check matrix consists of \(n-1\) rows, each enforcing that two neighbouring bits are equal. Below, we measure how often the BP decoder corrects random errors on a 50‑bit repetition code and compare its success rate to an optimal maximum‑likelihood (ML) decoder, which simply picks the lower‑weight error pattern consistent with the observed syndrome.

def rep_code(n: int) -> ArrayLike:
    """
    Build the (n − 1) × n parity‑check matrix H for the [n, 1] repetition code.

    Each row enforces equality between two neighboring bits:
        H[i] has 1s in positions i and i+1, zeros elsewhere.
    """
    # First row: parity check on bits 0 and 1 → [1, 1, 0, 0, …, 0]
    first_row = jnp.zeros(n, dtype=jnp.int8).at[jnp.array([0, 1])].set(1)
    rows = [first_row]

    # Remaining rows: slide the two‑bit “window” to the right
    for _ in range(n - 2):
        rows.append(jnp.roll(rows[-1], 1))  # shift previous row by 1 position

    return jnp.stack(rows)  # shape = (n‑1, n)


@jax.jit
def ml_rep_decoder(syndrome: ArrayLike) -> ArrayLike:
    """
    Minimum‑weight decoder for the repetition code.

    Parameters
    ----------
    syndrome : ArrayLike, shape (n‑1,)
        The syndrome s = H e (mod 2).

    Returns
    -------
    error : ArrayLike, shape (n,)
        A lowest‑weight error vector consistent with `syndrome`.
    """
    # Candidate 1: assume e[0] = 0, then recover the rest via cumulative XOR.
    #   e[k+1] = e[k] ⊕ s[k]  ⇒  e = [0, cumsum(s) mod 2]
    e0 = jnp.concatenate((jnp.array([0], dtype=jnp.int32), jnp.mod(jnp.cumsum(syndrome), 2)))

    # Candidate 2: flip every bit (equivalent to choosing e[0] = 1).
    e1 = (e0 + 1) & 1  # fast “add‑one then mod 2”

    # Compare Hamming weights.
    w0, w1 = jnp.sum(e0), jnp.sum(e1)

    # Return the lighter candidate (ties resolved in favour of e0).
    return jax.lax.cond(w0 <= w1, lambda _: e0, lambda _: e1, operand=None)

We run a short experiment on a \(50\) bit repetition code. We sample 10,000 random syndromes vectors and compute the accuracy of our BP decoder compared to our baseline ml_rep_decoder

H_rep = rep_code(n := 50)
bp_rep = build_bp_decoder(parity_check_matrix=H_rep, error_rate=0.1, max_iter=n)

# sample random syndromes
N = 10_000
key = jax.random.PRNGKey(0)
syndromes = jax.random.randint(key, shape=(N, n - 1), minval=0, maxval=2)

# use jax to map the decoder over the syndromes
# since our decoders are jit compiled jax functions they can be used with jax.vmap
success_rate = jnp.mean(
    jnp.all(jax.vmap(ml_rep_decoder)(syndromes) == jax.vmap(bp_rep)(syndromes), axis=1)
)

print(f"Decoding success rate: {success_rate * 100:.2f}%")
Decoding success rate: 85.73%

Catalyst hybrid kernel

Now that we understand a good chunk of theory behind CSS codes, the Steane code and decoding algorithms, let’s put this into action with Catalyst!

Catalyst lets us build hybrid quantum-classical workflows, compiling both quantum operations and classical decoding logic into a single, efficient kernel. We’ll start with a quantum-classical circuit to prepare the logical zero state \(|0\rangle_L\) for our Steane code. This method is also general for initializing logical zero states for any CSS codes.

Start with a \(+1\) eigenstate (or stabilizer state) of all the \(Z\)-type stabilizers. The \(|0\ldots 0\rangle\) is always stabilized by any \(Z\)-type Pauli operator, making it a suitable choice.

Then, for each X-type generator:

  • Prepare an ancilla qubit in the \(|+\rangle\) state.

  • Measure X-type stabilizers using CNOT operations onto an ancilla.

  • Measure in the \(X\) basis.

Next:

  • Use measurement outcomes (syndromes) to determine necessary corrections using our decoder.

  • Apply Z-type corrections based on decoding results.

This procedure uses projective measurements to force the data qubits to be in the \(+1\) eigenstate of our \(X\)-type generators. Since the state was already a \(+1\) eigenstate of our \(Z\)-type generators, and by virtue of the CSS code all \(X\) and \(Z\) generators simultaneously commute, we are left with a state in the \(+1\) eigenspace of all the generators.

import pennylane as qml
from jax import random
import catalyst

r, n = H_steane.shape
n_wires = n + r

dev = qml.device("lightning.qubit", wires=n_wires)


def measure_x_stabilizers(H: ArrayLike):
    """
    Measure all X type stabilizers based on the parity check matrix X then apply Z type corrections from our decoder
    :param H: Parity check X matrix
    """
    r, n = H.shape

    # Encode logical |0>
    # (Hadamard on ancillas, controlled X stabilizers)
    for a in range(r):
        qml.H(wires=n + a)
    for a, row in enumerate(H):
        for q, x in enumerate(row):
            if x:
                qml.CNOT(wires=[n + a, q])
    for a in range(r):
        qml.H(wires=n + a)

    # Measure + reset ancillas (X stabilizers)
    sx = jnp.stack([catalyst.measure(n + a) for a in range(r)])
    for a, bit in enumerate(sx):
        if bit:
            qml.PauliX(wires=n + a)  # reset ancilla

    # Z‑correction
    # Since the BP and LUT decoder
    # we're both perfect on the Steane code
    # well use the LUT for simplicity
    rec_z = lut_steane(sx)
    for q, bit in enumerate(rec_z):
        if bit:
            qml.PauliZ(wires=q)


@qml.qjit(autograph=True)
@qml.qnode(dev)
def encode_zero_steane():
    measure_x_stabilizers(H_steane)
    return qml.state()

A simple utility function to display the state

from pprint import pprint

def state_vector_to_dict(
    sv: ArrayLike,
    wires: Optional[Sequence[int]],
    eps: float = 1e-8,
    probability: bool = False,
    display: bool = True,
) -> Dict[str, Union[float, complex]]:
    """
    Convert a state vector into {bitstring: amplitude | probability}.
    """
    n_qubits = int(jnp.log2(len(sv)))


    out: Dict[str, Union[float, complex]] = {}

    for idx, amp in enumerate(sv):
        mag = jnp.abs(amp) ** 2 if probability else jnp.abs(amp)
        if mag <= eps:
            continue

        bitstring = f"{idx:0{n_qubits}b}"
        key = "".join(b for i, b in enumerate(bitstring) if wires is None or i in wires)

        if probability:
            out[key] = out.get(key, 0.0) + float(mag)
        else:
            out[key] = amp.item()

    if display:
        pprint(out)

    return out

We run the encode_zero function and see that we recover the correct logical zero state for the Steane code:

\[\begin{split}\begin{aligned}|\overline{0}\rangle= & \frac{1}{\sqrt{8}}(|0000000\rangle+|1010101\rangle+|0110011\rangle+|1100110\rangle \\ & +|0001111\rangle+|1011010\rangle+|0111100\rangle+|1101001\rangle)\end{aligned}\end{split}\]
sv_clean = encode_zero_steane()
state_vector_to_dict(sv_clean, display=True, wires=range(n))
{'0000000': (0.35355339059327373+0j),
 '0001111': (0.35355339059327373+0j),
 '0110011': (0.35355339059327373+0j),
 '0111100': (0.35355339059327373+0j),
 '1010101': (0.35355339059327373-0j),
 '1011010': (0.35355339059327373-0j),
 '1100110': (0.35355339059327373-0j),
 '1101001': (0.35355339059327373-0j)}

{'0000000': (0.35355339059327373+0j), '0001111': (0.35355339059327373+0j), '0110011': (0.35355339059327373+0j), '0111100': (0.35355339059327373+0j), '1010101': (0.35355339059327373-0j), '1011010': (0.35355339059327373-0j), '1100110': (0.35355339059327373-0j), '1101001': (0.35355339059327373-0j)}

Simulating Errors and Full Correction

We’re now ready to wrap everything together:

  • Prepare the zero state.

  • Simulate noise using a depolarizing channel.

  • Perform one complete round of stabilizer measurements and corrections.

def noise_channel(n: int, p_err: float, key: random.PRNGKey):
    """
    Apply a single‑qubit Pauli noise channel independently to each of `n` qubits.

    For every qubit the channel does:
        0 → I       with probability 1 - p_err
        1 → X       with probability p_err / 3
        2 → Z       with probability p_err / 3
        3 → Y       with probability p_err / 3
    """
    probs = jnp.array([1.0 - p_err, p_err / 3, p_err / 3, p_err / 3])
    outcomes = random.choice(key, 4, shape=(n,), p=probs)

    for idx, outcome in enumerate(outcomes):
        if outcome == 1:
            qml.X(wires=idx)
        elif outcome == 2:
            qml.Z(wires=idx)
        elif outcome == 3:
            qml.Y(wires=idx)


# this is a helper function to get the specific error we used in a given round based on the key
def get_error(n: int, p_err: float, key: random.PRNGKey):
    err = []
    probs = jnp.array([1.0 - p_err, p_err / 3, p_err / 3, p_err / 3])
    outcomes = random.choice(key, 4, shape=(n,), p=probs)

    for idx, outcome in enumerate(outcomes):
        if outcome == 1:
            err.append(qml.X(wires=idx))
        elif outcome == 2:
            err.append(qml.Z(wires=idx))
        elif outcome == 3:
            err.append(qml.Y(wires=idx))
    return qml.ops.prod(*err)

Similar to measure_x_stabilizers, however, we now apply CNOT from data to an ancilla prepared in the \(|0\rangle\) state and perform a \(Z\)-basis measurement.

def measure_z_stabilizers(H):
    r, n = H.shape
    for a, row in enumerate(H):
        for q, x in enumerate(row):
            if x:
                qml.CNOT(wires=[q, n + a])

    sz = jnp.stack([catalyst.measure(n + a) for a in range(r)])
    for a, bit in enumerate(sz):
        if bit:
            qml.PauliX(wires=n + a)

    rec_x = lut_steane(sz)
    for q, bit in enumerate(rec_x):
        if bit:
            qml.PauliX(wires=q)

Now, let’s run the qec_round using state preparation, followed by one round of noise injection and one round of \(X\) and \(Z\) correction. We’ll print the error that occurred in our noisy channel and demonstrate that the output state closely resembles the noiseless state we observed previously.

@qml.qjit(autograph=True)
@qml.qnode(dev, interface="jax")
def qec_round(H: ArrayLike, p_err=1e-3, key=random.PRNGKey(0)):
    """One round of Steane code QEC with LUT decoding."""

    measure_x_stabilizers(H)  # prepare 0 state
    noise_channel(n, p_err, key)  # inject IID pauli noise
    measure_x_stabilizers(H)  # correct X errors
    measure_z_stabilizers(H)  # correct Z errors

    return qml.state()


p_err = 0.1
key = random.PRNGKey(10)
print(f"Running Steane Code QEC Round with error: {get_error(n, p_err=p_err, key=key)}")
state_vector_to_dict(qec_round(H_steane, p_err=p_err, key=key), display=True, wires=range(n))
Running Steane Code QEC Round with error: X(0) @ X(6)
{'0010110': (0.3535533905932738+0j),
 '0011001': (0.3535533905932738+0j),
 '0100101': (0.3535533905932738+0j),
 '0101010': (0.3535533905932738+0j),
 '1000011': (0.3535533905932738+0j),
 '1001100': (0.3535533905932738+0j),
 '1110000': (0.3535533905932738+0j),
 '1111111': (0.3535533905932738+0j)}

{'0010110': (0.3535533905932738+0j), '0011001': (0.3535533905932738+0j), '0100101': (0.3535533905932738+0j), '0101010': (0.3535533905932738+0j), '1000011': (0.3535533905932738+0j), '1001100': (0.3535533905932738+0j), '1110000': (0.3535533905932738+0j), '1111111': (0.3535533905932738+0j)}

If we increase the likelihood of errors, we are more likely to end up with an error pattern that can’t be corrected.

p_err = 0.3
key = random.PRNGKey(8)
print(f"Running Steane Code QEC Round with error: {get_error(n, p_err=p_err, key=key)}")
state_vector_to_dict(qec_round(H_steane, p_err=p_err, key=key), display=True, wires=range(n))
Running Steane Code QEC Round with error: X(3) @ X(6)
{'0010110': (0.3535533905932738+0j),
 '0011001': (0.3535533905932738+0j),
 '0100101': (0.3535533905932738+0j),
 '0101010': (0.3535533905932738+0j),
 '1000011': (0.3535533905932738+0j),
 '1001100': (0.3535533905932738+0j),
 '1110000': (0.3535533905932738+0j),
 '1111111': (0.3535533905932738+0j)}

{'0010110': (0.3535533905932738+0j), '0011001': (0.3535533905932738+0j), '0100101': (0.3535533905932738+0j), '0101010': (0.3535533905932738+0j), '1000011': (0.3535533905932738+0j), '1001100': (0.3535533905932738+0j), '1110000': (0.3535533905932738+0j), '1111111': (0.3535533905932738+0j)}

Benchmarking logical vs. physical error rates

In the final section of this demo, we will compute the average performance of our Steane code error correction circuit for a range of possible error rates. We’ll define logical error rates by comparing the state vector from our noisy simulation with the clean state vector sv_clean of the Steane code logical zero state.

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

clean_idx = jnp.where(sv_clean)[0]


def logical_error(sv):
    st = sv[clean_idx]
    return 1 - jnp.all(jnp.isclose(st / st[0], 1))

Simulate 1000 noisy shots for several noise levels. We’ll use jax.vmap to efficiently map our catalyst kernel over a set of random keys.

@jax.jit
def single_trial_error(key, p_err, H_steane):
    """Performs one QEC round and checks for a logical error."""
    round_output = qec_round(H_steane, p_err, key)
    err = logical_error(round_output)
    return err


batch_trial_errors = jax.vmap(single_trial_error, in_axes=(0, None, None))

N = 1000
p_rng = 2 ** jnp.arange(-5, -1.75, 0.25, dtype=jnp.float32)
res = []
master_key = random.PRNGKey(0)

for p in tqdm(p_rng):
    keys_for_batch, master_key = random.split(master_key)
    all_keys = random.split(keys_for_batch, N)

    errors_batch = batch_trial_errors(all_keys, p, H_steane)

    p_value = p.item()
    for idx, err in enumerate(errors_batch):
        res.append({"p": p_value, "seed": idx, "err": err.item()})

df = pd.DataFrame(res)
0%|          | 0/13 [00:00<?, ?it/s]
  8%|▊         | 1/13 [00:02<00:24,  2.02s/it]
 15%|█▌        | 2/13 [00:03<00:15,  1.45s/it]
 23%|██▎       | 3/13 [00:04<00:13,  1.37s/it]
 31%|███       | 4/13 [00:05<00:11,  1.23s/it]
 38%|███▊      | 5/13 [00:06<00:10,  1.27s/it]
 46%|████▌     | 6/13 [00:07<00:08,  1.20s/it]
 54%|█████▍    | 7/13 [00:08<00:06,  1.14s/it]
 62%|██████▏   | 8/13 [00:10<00:06,  1.23s/it]
 69%|██████▉   | 9/13 [00:11<00:04,  1.17s/it]
 77%|███████▋  | 10/13 [00:12<00:03,  1.13s/it]
 85%|████████▍ | 11/13 [00:13<00:02,  1.24s/it]
 92%|█████████▏| 12/13 [00:14<00:01,  1.18s/it]
100%|██████████| 13/13 [00:15<00:00,  1.14s/it]
100%|██████████| 13/13 [00:15<00:00,  1.22s/it]

Plot the results using seaborn

p_rng_min = p_rng[0]
p_rng_max = p_rng[-1]

sns.set_theme(style="whitegrid", context="talk")

plt.figure(figsize=(10, 7))
sns.lineplot(
    data=df,
    x="p",
    y="err",
    marker="o",
    markersize=8,
    linewidth=2.5,
    label="Simulated Logical Error Rate",
)

plt.plot(
    [p_rng_min, p_rng_max],
    [p_rng_min, p_rng_max],
    linestyle="--",
    color="gray",
    linewidth=1.5,
    label="$p_{physical} = p_{logical}$",  # Label for legend
)
plt.xlabel("Physical Error Rate ($p$)", fontsize=16)
plt.ylabel("Logical Error Rate ($P_L$)", fontsize=16)
plt.xscale("log", base=2)
plt.yscale("log", base=2)
plt.title("Logical vs. Physical Error Rate", fontsize=18, pad=20)
plt.legend(fontsize=14)
plt.grid(True, which="both", ls="--", c="lightgray", alpha=0.7)  # 'both' for major and minor ticks
sns.despine()
plt.tight_layout()
plt.show()
Logical vs. Physical Error Rate

Conclusion and Limitations

In this tutorial, we successfully built, simulated, and decoded a simple quantum error correction cycle using the Steane code. We demonstrated encoding a logical qubit, introduced errors through noise simulation, and performed error correction using stabilizer measurements combined with classical decoding. Performance was benchmarked by measuring the logical versus physical error rates.

However, our approach relied on a significant simplifying assumption known as the code capacity model, where errors are assumed to occur at only one stage of the circuit, with otherwise perfect encoding and syndrome extraction. A more realistic approach—called circuit-level noise—accounts for potential errors at every gate and measurement within the circuit. This model significantly complicates decoding, as it requires mapping every possible error location not only in space but also across multiple syndrome measurement rounds, forming a complex space-time hypergraph. Decoding then involves interpreting error events over both spatial and temporal dimensions.

Nevertheless, the fundamental decoding principles explored here, particularly the Belief Propagation algorithm, remain highly relevant. BP is flexible enough to operate effectively on more comprehensive circuit-level decoding hypergraphs.

References

1

Pesah, Arthur. “The stabilizer trilogy I — Stabilizer codes.” Arthur Pesah, 31 Jan. 2023, https://arthurpesah.me/blog/2023-01-31-stabilizer-formalism-1/.

2

Wiberg, Niclas. (2001). Codes and Decoding on General Graphs. https://www.essrl.wustl.edu/~jao/itrg/wiberg.pdf

3

Panteleev, Pavel. “Degenerate Quantum LDPC Codes With Good Finite Length Performance.” arXiv.org, 04 Apr. 2019, https://arxiv.org/abs/1904.02703v3.

4

Hillmann, Timo. “Localized statistics decoding: A parallel decoding algorithm for quantum low-density parity-check codes.” arXiv.org, 26 Jun. 2024, https://arxiv.org/abs/2406.18655v1.

5

Wolanski, Stasiu. “Ambiguity Clustering: an accurate and efficient decoder for qLDPC codes.” arXiv.org, 20 Jun. 2024, https://arxiv.org/abs/2406.14527v2.

6

Loeliger, Hans-Andrea. “An introduction to factor graphs” in IEEE Signal Processing Magazine, vol. 21, no. 1, pp. 28-41, Jan. 2004, https://www.isiweb.ee.ethz.ch/papers/arch/aloe-2004-spmagffg.pdf.

7

Barber, David. “Bayesian Reasoning and Machine Learning”. Cambridge University Press, USA. 2012, http://web4.cs.ucl.ac.uk/staff/D.Barber/textbook/180325.pdf#page=107.50

About the author

Tom Ginsberg
Tom Ginsberg

Tom Ginsberg

Tom leads BEIT's efforts in quantum error correction & fault tolerant compilation.

Total running time of the script: (0 minutes 32.253 seconds)

Share demo

Ask a question on the forum

Related Demos

Magic state distillation

Introduction to mid-circuit measurements

How to quantum just-in-time (QJIT) compile Grover's algorithm with Catalyst

Loom x Catalyst: designing, orchestrating, and automating quantum error correction experiments

Stabilizer codes for quantum error correction

Efficient Simulation of Clifford Circuits

Quantum just-in-time compiling Shor's algorithm with Catalyst

Measurement-based quantum computation

Constant-depth preparation of matrix product states with dynamic circuits

Building a quantum lock using phase kickback

PennyLane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Built by researchers, for research. Created with ❤️ by Xanadu.

Research

  • Research
  • Performance
  • Hardware & Simulators
  • Demos
  • Quantum Compilation
  • Quantum Datasets

Education

  • Teach
  • Learn
  • Codebook
  • Coding Challenges
  • Videos
  • Glossary

Software

  • Install PennyLane
  • Features
  • Documentation
  • Catalyst Compilation Docs
  • Development Guide
  • API
  • GitHub
Stay updated with our newsletter

© Copyright 2025 | Xanadu | All rights reserved

TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.

Privacy Policy|Terms of Service|Cookie Policy|Code of Conduct