How-to use pulse schedules generated by Qiskit Pulse with JAX transformations#

Qiskit Pulse enables specification of time-dependence in quantum systems as pulse schedules, built from sequences of a variety of instructions, including the specification of shaped pulses (see the detailed API information about Qiskit pulse API Reference). As of Qiskit 0.40.0, JAX support was added for the ScalableSymbolicPulse class. This user guide entry demonstrates the technical elements of utilizing this class within JAX-transformable functions.

Note

At present, only the ScalableSymbolicPulse class is supported by JAX, as the validation present in other pulse types, such as Gaussian, is not JAX-compatible.

This guide addresses the following topics. See the userguide on using JAX for a more detailed explanation of how to work with JAX in Qiskit Dynamics.

  1. Configure JAX.

  2. How to define a Gaussian pulse using ScalableSymbolicPulse.

  3. JAX transforming Pulse to Signal conversion involving ScalableSymbolicPulse.

1. Configure JAX#

First, configure JAX to run on CPU in 64 bit mode.

# configure jax to use 64 bit mode
import jax
jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update('jax_platform_name', 'cpu')

2. How to define a Gaussian pulse using ScalableSymbolicPulse#

As the standard Gaussian pulse is not JAX-compatible, to define a Gaussian pulse to use in optimization, we need to instantiate a ScalableSymbolicPulse with a Gaussian parameterization. First, define the symbolic representation in sympy.

from qiskit import pulse
from qiskit_dynamics.pulse import InstructionToSignals
import sympy as sym

dt = 0.222
w = 5.

# Helper function that returns a lifted Gaussian symbolic equation.
def lifted_gaussian(
    t: sym.Symbol,
    center,
    t_zero,
    sigma,
) -> sym.Expr:
    t_shifted = (t - center).expand()
    t_offset = (t_zero - center).expand()

    gauss = sym.exp(-((t_shifted / sigma) ** 2) / 2)
    offset = sym.exp(-((t_offset / sigma) ** 2) / 2)

    return (gauss - offset) / (1 - offset)

Next, define the ScalableSymbolicPulse using the above expression.

_t, _duration, _amp, _sigma, _angle = sym.symbols("t, duration, amp, sigma, angle")
_center = _duration / 2

envelope_expr = (
    _amp * sym.exp(sym.I * _angle) * lifted_gaussian(_t, _center, _duration + 1, _sigma)
)

gaussian_pulse = pulse.ScalableSymbolicPulse(
        pulse_type="Gaussian",
        duration=160,
        amp=0.3,
        angle=0,
        parameters={"sigma": 40},
        envelope=envelope_expr,
        constraints=_sigma > 0,
        valid_amp_conditions=sym.Abs(_amp) <= 1.0,
    )

gaussian_pulse.draw()
../_images/how_to_use_pulse_schedule_for_jax_jit_2_0.png

3. JAX transforming Pulse to Signal conversion involving ScalableSymbolicPulse#

Using a Gaussian pulse as an example, we show that a function involving ScalableSymbolicPulse and the pulse to signal converter can be JAX-compiled (or more generally, JAX-transformed).

# use amplitude as the function argument
def jit_func(amp):
    _t, _duration, _amp, _sigma, _angle = sym.symbols("t, duration, amp, sigma, angle")
    _center = _duration / 2

    envelope_expr = (
        _amp * sym.exp(sym.I * _angle) * lifted_gaussian(_t, _center, _duration + 1, _sigma)
    )

    # we need to set disable_validation True to enable jax-jitting.
    pulse.ScalableSymbolicPulse.disable_validation = True

    gaussian_pulse = pulse.ScalableSymbolicPulse(
            pulse_type="Gaussian",
            duration=160,
            amp=amp,
            angle=0,
            parameters={"sigma": 40},
            envelope=envelope_expr,
            constraints=_sigma > 0,
            valid_amp_conditions=sym.Abs(_amp) <= 1.0,
        )

    # build a pulse schedule
    with pulse.build() as schedule:
        pulse.play(gaussian_pulse, pulse.DriveChannel(0))

    # convert from a pulse schedule to a list of signals
    converter = InstructionToSignals(dt, carriers={"d0": w})

    return converter.get_signals(schedule)[0].samples

jax.jit(jit_func)(0.4)
Array([0.00461643+0.j, 0.00784044+0.j, 0.01118371+0.j, 0.0146479 +0.j,
       0.01823455+0.j, 0.02194501+0.j, 0.02578049+0.j, 0.029742  +0.j,
       0.03383034+0.j, 0.03804615+0.j, 0.0423898 +0.j, 0.04686147+0.j,
       0.05146109+0.j, 0.05618834+0.j, 0.06104264+0.j, 0.06602316+0.j,
       0.07112877+0.j, 0.07635807+0.j, 0.08170936+0.j, 0.08718063+0.j,
       0.0927696 +0.j, 0.09847362+0.j, 0.10428977+0.j, 0.11021477+0.j,
       0.11624505+0.j, 0.12237668+0.j, 0.12860541+0.j, 0.13492665+0.j,
       0.14133549+0.j, 0.14782668+0.j, 0.15439464+0.j, 0.16103348+0.j,
       0.16773697+0.j, 0.17449859+0.j, 0.18131147+0.j, 0.1881685 +0.j,
       0.19506222+0.j, 0.20198494+0.j, 0.20892866+0.j, 0.21588517+0.j,
       0.22284598+0.j, 0.22980239+0.j, 0.2367455 +0.j, 0.24366621+0.j,
       0.25055524+0.j, 0.25740317+0.j, 0.26420043+0.j, 0.27093735+0.j,
       0.27760417+0.j, 0.28419106+0.j, 0.29068813+0.j, 0.29708551+0.j,
       0.30337328+0.j, 0.3095416 +0.j, 0.31558066+0.j, 0.32148073+0.j,
       0.32723219+0.j, 0.33282555+0.j, 0.33825149+0.j, 0.34350085+0.j,
       0.34856471+0.j, 0.35343437+0.j, 0.35810137+0.j, 0.36255757+0.j,
       0.36679511+0.j, 0.37080648+0.j, 0.3745845 +0.j, 0.37812239+0.j,
       0.38141374+0.j, 0.38445258+0.j, 0.38723335+0.j, 0.38975094+0.j,
       0.39200072+0.j, 0.39397853+0.j, 0.39568069+0.j, 0.39710405+0.j,
       0.39824594+0.j, 0.39910423+0.j, 0.39967732+0.j, 0.39996414+0.j,
       0.39996414+0.j, 0.39967732+0.j, 0.39910423+0.j, 0.39824594+0.j,
       0.39710405+0.j, 0.39568069+0.j, 0.39397853+0.j, 0.39200072+0.j,
       0.38975094+0.j, 0.38723335+0.j, 0.38445258+0.j, 0.38141374+0.j,
       0.37812239+0.j, 0.3745845 +0.j, 0.37080648+0.j, 0.36679511+0.j,
       0.36255757+0.j, 0.35810137+0.j, 0.35343437+0.j, 0.34856471+0.j,
       0.34350085+0.j, 0.33825149+0.j, 0.33282555+0.j, 0.32723219+0.j,
       0.32148073+0.j, 0.31558066+0.j, 0.3095416 +0.j, 0.30337328+0.j,
       0.29708551+0.j, 0.29068813+0.j, 0.28419106+0.j, 0.27760417+0.j,
       0.27093735+0.j, 0.26420043+0.j, 0.25740317+0.j, 0.25055524+0.j,
       0.24366621+0.j, 0.2367455 +0.j, 0.22980239+0.j, 0.22284598+0.j,
       0.21588517+0.j, 0.20892866+0.j, 0.20198494+0.j, 0.19506222+0.j,
       0.1881685 +0.j, 0.18131147+0.j, 0.17449859+0.j, 0.16773697+0.j,
       0.16103348+0.j, 0.15439464+0.j, 0.14782668+0.j, 0.14133549+0.j,
       0.13492665+0.j, 0.12860541+0.j, 0.12237668+0.j, 0.11624505+0.j,
       0.11021477+0.j, 0.10428977+0.j, 0.09847362+0.j, 0.0927696 +0.j,
       0.08718063+0.j, 0.08170936+0.j, 0.07635807+0.j, 0.07112877+0.j,
       0.06602316+0.j, 0.06104264+0.j, 0.05618834+0.j, 0.05146109+0.j,
       0.04686147+0.j, 0.0423898 +0.j, 0.03804615+0.j, 0.03383034+0.j,
       0.029742  +0.j, 0.02578049+0.j, 0.02194501+0.j, 0.01823455+0.j,
       0.0146479 +0.j, 0.01118371+0.j, 0.00784044+0.j, 0.00461643+0.j],      dtype=complex128)