PennyLane
Install
Install

Related materials

  • Related contentHow to optimize a QML model using JAX and JAXopt
  • Related contentUsing JAX with PennyLane
  • Related contentVariational classifier

Contents

  1. Set up your model, data, and cost
  2. Initialize your parameters
  3. Create the optimizer
  4. Jitting the optimization loop
  5. Appendix: Timing the two approaches
  6. About the authors

Downloads

  • Download Python script
  • Download Notebook
  • View on GitHub
  1. Demos/
  2. Quantum Machine Learning/
  3. How to optimize a QML model using JAX and Optax

How to optimize a QML model using JAX and Optax

Josh Izaac

Josh Izaac

Maria Schuld

Maria Schuld

Published: January 17, 2024. Last updated: May 04, 2026.

Once you have set up a quantum machine learning model, data to train with and cost function to minimize as an objective, the next step is to perform the optimization. That is, setting up a classical optimization loop to find a minimal value of your cost function.

In this example, we’ll show you how to use JAX, an autodifferentiable machine learning framework, and Optax, a suite of JAX-compatible gradient-based optimizers, to optimize a PennyLane quantum machine learning model.

../../_images/socialsthumbnail_large_How_to_optimize_QML_model_using_JAX_and_Optax_2024-01-16.png

Set up your model, data, and cost

Here, we will create a simple QML model for our optimization. In particular:

  • We will embed our data through a series of rotation gates.

  • We will then have an ansatz of trainable rotation gates with parameters weights; it is these values we will train to minimize our cost function.

  • We will train the QML model on data, a (5, 4) array, and optimize the model to match target predictions given by target.

import pennylane as qp
import jax
from jax import numpy as jnp
import optax

jax.config.update("jax_enable_x64", True)

n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev = qp.device("default.qubit", wires=n_wires)

@qp.qnode(dev)
def circuit(data, weights):
    """Quantum circuit ansatz"""

    # data embedding
    for i in range(n_wires):
        # data[i] will be of shape (4,); we are
        # taking advantage of operation vectorization here
        qp.RY(data[i], wires=i)

    # trainable ansatz
    for i in range(n_wires):
        qp.RX(weights[i, 0], wires=i)
        qp.RY(weights[i, 1], wires=i)
        qp.RX(weights[i, 2], wires=i)
        qp.CNOT(wires=[i, (i + 1) % n_wires])

    # we use a sum of local Z's as an observable since a
    # local Z would only be affected by params on that qubit.
    return qp.expval(qp.sum(*[qp.PauliZ(i) for i in range(n_wires)]))

def my_model(data, weights, bias):
    return circuit(data, weights) + bias

We will define a simple cost function that computes the overlap between model output and target data, and just-in-time (JIT) compile it:

@jax.jit
def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss

Note that the model above is just an example for demonstration – there are important considerations that must be taken into account when performing QML research, including methods for data embedding, circuit architecture, and cost function, in order to build models that may have use. This is still an active area of research; see our demonstrations for details.

Initialize your parameters

Now, we can generate our trainable parameters weights and bias that will be used to train our QML model.

weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

Plugging the trainable parameters, data, and target labels into our cost function, we can see the current loss as well as the parameter gradients:

print(loss_fn(params, data, targets))

print(jax.grad(loss_fn)(params, data, targets))
0.29232612378972245
{'bias': Array(-0.75432075, dtype=float64, weak_type=True), 'weights': Array([[-1.95077284e-01,  5.28546714e-02, -4.89252103e-01],
       [-1.99687796e-02, -5.32871598e-02,  9.22904778e-02],
       [-2.71755508e-03, -9.64678693e-05, -4.79570903e-03],
       [-6.35443899e-02,  3.61110059e-02, -2.05196880e-01],
       [-9.02635457e-02,  1.63759376e-01, -5.64262637e-01]],      dtype=float64)}

Create the optimizer

We can now use Optax to create an optimizer, and train our circuit. Here, we choose the Adam optimizer, however other available optimizers may be used here.

opt = optax.adam(learning_rate=0.3)
opt_state = opt.init(params)

We first define our update_step function, which needs to do a couple of things:

  • Compute the loss function (so we can track training) and the gradients (so we can apply an optimization step). We can do this in one execution via the jax.value_and_grad function.

  • Apply the update step of our optimizer via opt.update

  • Update the parameters via optax.apply_updates

def update_step(opt, params, opt_state, data, targets):
    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val

loss_history = []

for i in range(100):
    params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets)

    if i % 5 == 0:
        print(f"Step: {i} Loss: {loss_val}")

    loss_history.append(loss_val)
Step: 0 Loss: 0.29232612378972245
Step: 5 Loss: 0.04476625195615539
Step: 10 Loss: 0.03190229898101476
Step: 15 Loss: 0.03623733712884947
Step: 20 Loss: 0.03370063713175687
Step: 25 Loss: 0.028724009131698508
Step: 30 Loss: 0.023012036779565558
Step: 35 Loss: 0.01871610331250138
Step: 40 Loss: 0.014776715787275367
Step: 45 Loss: 0.010427800672797093
Step: 50 Loss: 0.009646036809129437
Step: 55 Loss: 0.024105871070985618
Step: 60 Loss: 0.008081601338350786
Step: 65 Loss: 0.007607466509199649
Step: 70 Loss: 0.007097244525043754
Step: 75 Loss: 0.006785825964850794
Step: 80 Loss: 0.006898310538816934
Step: 85 Loss: 0.00658510644469507
Step: 90 Loss: 0.006029456810315623
Step: 95 Loss: 0.004975663971947959

Jitting the optimization loop

In the above example, we JIT compiled our cost function loss_fn. However, we can also JIT compile the entire optimization loop; this means that the for-loop around optimization is not happening in Python, but is compiled and executed natively. This avoids (potentially costly) data transfer between Python and our JIT compiled cost function with each update step.

# Define the optimizer we want to work with
opt = optax.adam(learning_rate=0.3)

@jax.jit
def update_step_jit(i, args):
    params, opt_state, data, targets, print_training = args

    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)

    # if print_training=True, print the loss every 5 steps
    jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)

    return (params, opt_state, data, targets, print_training)

@jax.jit
def optimization_jit(params, data, targets, print_training=False):

    opt_state = opt.init(params)
    args = (params, opt_state, data, targets, print_training)
    (params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update_step_jit, args)

    return params

Note that we use jax.lax.fori_loop and jax.lax.cond, rather than a standard Python for loop and if statement, to allow the control flow to be JIT compatible. We also use jax.debug.print to allow printing to take place at function run-time, rather than compile-time.

params = {"weights": weights, "bias": bias}
optimization_jit(params, data, targets, print_training=True)
Step: 0  Loss: 0.29232612378972245
Step: 5  Loss: 0.04476625195615542
Step: 10  Loss: 0.031902298981014814
Step: 15  Loss: 0.036237337128849384
Step: 20  Loss: 0.03370063713175685
Step: 25  Loss: 0.02872400913169845
Step: 30  Loss: 0.023012036779565558
Step: 35  Loss: 0.01871610331250138
Step: 40  Loss: 0.014776715787275276
Step: 45  Loss: 0.010427800672796939
Step: 50  Loss: 0.009646036809129033
Step: 55  Loss: 0.024105871070991523
Step: 60  Loss: 0.008081601338351297
Step: 65  Loss: 0.0076074665091996445
Step: 70  Loss: 0.007097244525043479
Step: 75  Loss: 0.006785825964846236
Step: 80  Loss: 0.006898310538821387
Step: 85  Loss: 0.006585106444693048
Step: 90  Loss: 0.006029456810321116
Step: 95  Loss: 0.004975663971947308

{'bias': Array(-0.75291251, dtype=float64), 'weights': Array([[ 1.63087727,  1.55018804,  0.67214272],
       [ 0.72661427,  0.36422372, -0.7562584 ],
       [ 2.78387239,  0.62720641,  3.44997077],
       [-1.10119028, -0.12680061,  0.89282747],
       [ 1.27236226,  1.10632236,  2.22052939]], dtype=float64)}

Appendix: Timing the two approaches

We can time the two approaches (JIT compiling just the cost function, vs JIT compiling the entire optimization loop) to explore the differences in performance:

from timeit import repeat

def optimization(params, data, targets):
    opt = optax.adam(learning_rate=0.3)
    opt_state = opt.init(params)

    for i in range(100):
        params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets)

    return params

reps = 5
num = 2

times = repeat("optimization(params, data, targets)", globals=globals(), number=num, repeat=reps)
result = min(times) / num

print(f"Jitting just the cost (best of {reps}): {result} sec per loop")

times = repeat("optimization_jit(params, data, targets)", globals=globals(), number=num, repeat=reps)
result = min(times) / num

print(f"Jitting the entire optimization (best of {reps}): {result} sec per loop")
Jitting just the cost (best of 5): 0.37293681099993137 sec per loop
Jitting the entire optimization (best of 5): 0.006363348500030952 sec per loop

In this example, JIT compiling the entire optimization loop is significantly more performant.

About the authors

Josh Izaac
Josh Izaac

Josh Izaac

Josh is a theoretical physicist, software tinkerer, and occasional baker. At Xanadu, he contributes to the development and growth of Xanadu’s open-source quantum software products.

Maria Schuld
Maria Schuld

Maria Schuld

Dedicated to making quantum machine learning a reality one day.

Total running time of the script: (0 minutes 14.690 seconds)

Share demo

Ask a question on the forum

Related Demos

How to optimize a QML model using JAX and JAXopt

Using JAX with PennyLane

Variational classifier

How to optimize a QML model using Catalyst and quantum just-in-time (QJIT) compilation

Multidimensional regression with a variational quantum circuit

Quantum gradients with backpropagation

How to use Catalyst with Lightning-GPU

Post Variational Quantum Neural Networks

Quantum Circuit Born Machines

Basic tutorial: qubit rotation

PennyLane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Built by researchers, for research. Created with ❤️ by Xanadu.

Research

  • Research
  • Performance
  • Hardware & Simulators
  • Demos
  • Compilation Hub
  • Quantum Datasets

Education

  • Teach
  • Learn
  • Codebook
  • Coding Challenges
  • Videos
  • Glossary

Software

  • Install PennyLane
  • Features
  • Documentation
  • Catalyst Compilation Docs
  • Development Guide
  • API
  • GitHub
Stay updated with our newsletter

© Copyright 2026 | Xanadu | All rights reserved

TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.

Privacy Policy|Terms of Service|Cookie Policy|Code of Conduct