How-to use Dyson and Magnus based solvers#

Warning

This is an advanced topic — utilizing perturbation-theory based solvers requires detailed knowledge of the structure of the differential equations involved, as well as manual tuning of the solver parameters. See the DysonSolver and MagnusSolver documentation for API details. Also, see [1] for a detailed explanation of the solvers, which varies and builds on the core idea introduced in [2].

Note

The circumstances under which perturbative solvers outperform traditional solvers, and which parameter sets to use, is nuanced. Perturbative solvers executed with JAX are setup to use more parallelization within a single solver run than typical solvers, and thus it is circumstance-specific whether the trade-off between speed of a single run and resource consumption is advantageous. Due to the parallelized nature, the comparison of execution times demonstrated in this userguide are highly hardware-dependent.

In this tutorial we walk through how to use perturbation-theory based solvers. For information on how these solvers work, see the DysonSolver and MagnusSolver class documentation, as well as the perturbative expansion background information provided in Time-dependent perturbation theory and multi-variable series expansions review.

We use a simple transmon model:

\[H(t) = 2 \pi \nu N + \pi \alpha N(N-I) + s(t) \times 2 \pi r (a + a^\dagger)\]

where:

  • \(N\), \(a\), and \(a^\dagger\) are, respectively, the number, annihilation, and creation operators.

  • \(\nu\) is the qubit frequency and \(r\) is the drive strength.

  • \(s(t)\) is the drive signal, which we will take to be on resonance with envelope \(f(t) = A \frac{4t (T - t)}{T^2}\) for a given amplitude \(A\) and total time \(T\).

We will walk through the following steps:

  1. Configure JAX.

  2. Construct the model.

  3. How-to construct and simulate using the Dyson-based perturbative solver.

  4. Simulate using a traditional ODE solver, comparing speed.

  5. How-to construct and simulate using the Magnus-based perturbative solver.

1. Configure JAX#

First, configure JAX to run on CPU in 64 bit mode. See the userguide on using JAX for a more detailed explanation of how to work with JAX in Qiskit Dynamics.

# configure jax to use 64 bit mode
import jax
jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU if using a system without a GPU
jax.config.update('jax_platform_name', 'cpu')

2. Construct the model#

First, we construct the model described in the introduction. We use a relatively high dimension for the oscillator system state space to accentuate the speed difference between the perturbative solvers and the traditional ODE solver. The higher dimensionality introduces higher frequencies into the model, which will slow down both the ODE solver and the initial construction of the perturbative solver. However after the initial construction, the higher frequencies in the model have no impact on the perturbative solver speed.

import numpy as np

dim = 10  # Oscillator dimension

v = 5.  # Transmon frequency in GHz
anharm = -0.33  # Transmon anharmonicity in GHz
r = 0.02  # Transmon drive coupling in GHz

# Construct cavity operators
a = np.diag(np.sqrt(np.arange(1, dim)), 1)
adag = np.diag(np.sqrt(np.arange(1, dim)), -1)
N = np.diag(np.arange(dim))

# Static part of Hamiltonian
static_hamiltonian = 2 * np.pi * v * N + np.pi * anharm * N * (N - np.eye(dim))
# Drive term of Hamiltonian
drive_hamiltonian = 2 * np.pi * r * (a + adag)

# total simulation time
T = 1. / r

# Drive envelope function
envelope_func = lambda t: t * (T - t) / (T**2 / 4)

3. How-to construct and simulate using the Dyson-based perturbative solver#

Setting up a DysonSolver requires more setup than the standard Solver, as the user must specify several configuration parameters, along with the structure of the differential equation:

  • The DysonSolver requires direct specification of the LMDE to the solver. If we are simulating the Schrodinger equation, we need to multiply the Hamiltonian terms by -1j when describing the LMDE operators.

  • The DysonSolver is a fixed step solver, with the step size being fixed at instantiation. This step size must be chosen in conjunction with the expansion_order to ensure that a suitable accuracy is attained.

  • Over each fixed time-step the DysonSolver solves by computing a truncated perturbative expansion.

    • To compute the truncated perturbative expansion, the signal envelopes are approximated as a linear combination of Chebyshev polynomials.

    • The order of the Chebyshev approximations, along with central carrier frequencies for defining the “envelope” of each Signal, must be provided at instantiation.

See the DysonSolver API docs for more details.

For our example Hamiltonian we configure the DysonSolver as follows:

%%time

from qiskit_dynamics import DysonSolver

dt = 0.1
dyson_solver = DysonSolver(
    operators=[-1j * drive_hamiltonian],
    rotating_frame=-1j * static_hamiltonian,
    dt=dt,
    carrier_freqs=[v],
    chebyshev_orders=[1],
    expansion_order=7,
    integration_method='jax_odeint',
    atol=1e-12,
    rtol=1e-12
)
CPU times: user 2.95 s, sys: 470 ms, total: 3.42 s
Wall time: 3.02 s

The above parameters are chosen so that the DysonSolver is fast and produces high accuracy solutions (measured and confirmed after the fact). The relatively large step size dt = 0.1 is chosen for speed: the larger the step size, the fewer steps required. To ensure high accuracy given the large step size, we choose a high expansion order, and utilize a linear envelope approximation scheme by setting the chebyshev_order to 1 for the single drive signal.

Similar to the Solver interface, the DysonSolver.solve() method can be called to simulate the system for a given list of signals, initial state, start time, and number of time steps of length dt.

To properly compare the speed of DysonSolver to a traditional ODE solver, we write JAX-compilable functions wrapping each that, given an amplitude value, returns the final unitary over the interval [0, (T // dt) * dt] for an on-resonance drive with envelope shape given by envelope_func above. Running compiled versions of these functions gives a sense of the speeds attainable by these solvers.

from qiskit_dynamics import Signal
from jax import jit

# Jit the function to improve performance for repeated calls
@jit
def dyson_sim(amp):
    """For a given envelope amplitude, simulate the final unitary using the
    Dyson solver.
    """
    drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
    return dyson_solver.solve(
        signals=[drive_signal],
        y0=np.eye(dim, dtype=complex),
        t0=0.,
        n_steps=int(T // dt)
    ).y[-1]

First run includes compile time.

%time yf_dyson = dyson_sim(1.).block_until_ready()
CPU times: user 647 ms, sys: 29.2 ms, total: 676 ms
Wall time: 660 ms

Once JIT compilation has been performance we can benchmark the performance of the JIT-compiled solver:

%time yf_dyson = dyson_sim(1.).block_until_ready()
CPU times: user 12.3 ms, sys: 0 ns, total: 12.3 ms
Wall time: 4.62 ms

4. Comparison to traditional ODE solver#

We now construct the same simulation using a standard solver to compare accuracy and simulation speed.

from qiskit_dynamics import Solver

solver = Solver(
    static_hamiltonian=static_hamiltonian,
    hamiltonian_operators=[drive_hamiltonian],
    rotating_frame=static_hamiltonian
)

# specify tolerance as an argument to run the simulation at different tolerances
def ode_sim(amp, tol):
    drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
    res = solver.solve(
        t_span=[0., int(T // dt) * dt],
        y0=np.eye(dim, dtype=complex),
        signals=[drive_signal],
        method='jax_odeint',
        atol=tol,
        rtol=tol
    )
    return res.y[-1]

Simulate with low tolerance for comparison to high accuracy solution.

yf_low_tol = ode_sim(1., 1e-13)
np.linalg.norm(yf_low_tol - yf_dyson)
6.529550206930476e-07

For speed comparison, compile at a tolerance with similar accuracy.

jit_ode_sim = jit(lambda amp: ode_sim(amp, 1e-8))

%time yf_ode = jit_ode_sim(1.).block_until_ready()
CPU times: user 447 ms, sys: 16.1 ms, total: 463 ms
Wall time: 457 ms

Measure compiled time.

%time yf_ode = jit_ode_sim(1.).block_until_ready()
CPU times: user 46.4 ms, sys: 0 ns, total: 46.4 ms
Wall time: 46.2 ms

Confirm similar accuracy solution.

np.linalg.norm(yf_low_tol - yf_ode)
8.67211035081537e-07

Here we see that, once compiled, the Dyson-based solver has a significant speed advantage over the traditional solver, at the expense of the initial compilation time and the technical aspect of using the solver.

5. How-to construct and simulate using the Magnus-based perturbation solver#

Next, we repeat our example using the Magnus-based perturbative solver. Setup of the MagnusSolver is similar to the DysonSolver, but it uses the Magnus expansion and matrix exponentiation to simulate over each fixed time step.

%%time

from qiskit_dynamics import MagnusSolver

dt = 0.1
magnus_solver = MagnusSolver(
    operators=[-1j * drive_hamiltonian],
    rotating_frame=-1j * static_hamiltonian,
    dt=dt,
    carrier_freqs=[v],
    chebyshev_orders=[1],
    expansion_order=3,
    integration_method='jax_odeint',
    atol=1e-12,
    rtol=1e-12
)
CPU times: user 1.53 s, sys: 45 ms, total: 1.58 s
Wall time: 1.58 s

Setup simulation function.

@jit
def magnus_sim(amp):
    drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
    return magnus_solver.solve(
        signals=[drive_signal],
        y0=np.eye(dim, dtype=complex),
        t0=0.,
        n_steps=int(T // dt)
    ).y[-1]

First run includes compile time.

%time yf_magnus = magnus_sim(1.).block_until_ready()
CPU times: user 1.28 s, sys: 63.7 ms, total: 1.34 s
Wall time: 1.33 s

Second run demonstrates speed of the simulation.

%time yf_magnus = magnus_sim(1.).block_until_ready()
CPU times: user 24.8 ms, sys: 0 ns, total: 24.8 ms
Wall time: 20 ms
np.linalg.norm(yf_magnus - yf_low_tol)
6.678901371612617e-07

Observe comparable accuracy at a lower order in the expansion, albeit with a modest speed up as compared to the Dyson-based solver.

References