Code source de qiskit.transpiler.passes.basis.basis_translator

# This code is part of Qiskit.
# (C) Copyright IBM 2017, 2020.
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Translates gates to a target basis using a given equivalence library."""

import sys
import time
import logging

from itertools import zip_longest
from collections import defaultdict

import rustworkx

from qiskit.circuit import Gate, ParameterVector, QuantumRegister, ControlFlowOp, QuantumCircuit
from qiskit.dagcircuit import DAGCircuit
from qiskit.converters import circuit_to_dag, dag_to_circuit
from qiskit.circuit.equivalence import Key, NodeData
from qiskit.transpiler.basepasses import TransformationPass
from qiskit.transpiler.exceptions import TranspilerError

if sys.version_info >= (3, 8):
    from functools import singledispatchmethod
    from singledispatchmethod import singledispatchmethod

logger = logging.getLogger(__name__)

[docs]class BasisTranslator(TransformationPass): """Translates gates to a target basis by searching for a set of translations from a given EquivalenceLibrary. This pass operates in several steps: * Determine the source basis from the input circuit. * Perform a Dijkstra search over basis sets, starting from the device's target_basis new gates are being generated using the rules from the provided EquivalenceLibrary and the search stops if all gates in the source basis have been generated. * The found path, as a set of rules from the EquivalenceLibrary, is composed into a set of gate replacement rules. * The composed replacement rules are applied in-place to each op node which is not already in the target_basis. If the target keyword argument is specified and that :class:`~qiskit.transpiler.Target` objects contains operations which are non-global (i.e. they are defined only for a subset of qubits), as calculated by :meth:`~qiskit.transpiler.Target.get_non_global_operation_names`, this pass will attempt to match the output translation to those constraints. For 1 qubit operations this is straightforward, the pass will perform a search using the union of the set of global operations with the set of operations defined solely on that qubit. For multi-qubit gates this is a bit more involved, while the behavior is initially similar to the single qubit case, just using all the qubits the operation is run on (where order is not significant) isn't sufficient. We also need to consider any potential local qubits defined on subsets of the quantum arguments for the multi-qubit operation. This means the target used for the search of a non-global multi-qubit gate is the union of global operations, non-global multi-qubit gates sharing the same qubits, and any non-global gates defined on any subset of the qubits used. .. note:: In the case of non-global operations it is possible for a single execution of this pass to output an incomplete translation if any non-global gates are defined on qubits that are a subset of a larger multi-qubit gate. For example, if you have a ``u`` gate only defined on qubit 0 and an ``x`` gate only on qubit 1 it is possible when translating a 2 qubit operation on qubit 0 and 1 that the output might have ``u`` on qubit 1 and ``x`` on qubit 0. Typically running this pass a second time will correct these issues. .. _translation_errors: Translation Errors ------------------ This pass will error if there is no path to translate an input gate to the specified basis. However, during a typical/default preset passmanager this pass gets run multiple times at different stages of the compilation pipeline. This means that potentially the input gates that are getting translated were not in the input circuit to :func:`~.transpile` as they were generated by an intermediate transform in the circuit. When this error occurs it typically means that either the target basis is not universal or there are additional equivalence rules needed in the :clas:~.EquivalenceLibrary` instance being used by the :class:~.BasisTranslator` pass. You can refer to :ref:`custom_basis_gates` for details on adding custom equivalence rules. """ def __init__(self, equivalence_library, target_basis, target=None): """Initialize a BasisTranslator instance. Args: equivalence_library (EquivalenceLibrary): The equivalence library which will be used by the BasisTranslator pass. (Instructions in this library will not be unrolled by this pass.) target_basis (list[str]): Target basis names to unroll to, e.g. `['u3', 'cx']`. target (Target): The backend compilation target """ super().__init__() self._equiv_lib = equivalence_library self._target_basis = target_basis self._target = target self._non_global_operations = None self._qargs_with_non_global_operation = {} if target is not None: self._non_global_operations = self._target.get_non_global_operation_names() self._qargs_with_non_global_operation = defaultdict(set) for gate in self._non_global_operations: for qarg in self._target[gate]: self._qargs_with_non_global_operation[qarg].add(gate)
[docs] def run(self, dag): """Translate an input DAGCircuit to the target basis. Args: dag (DAGCircuit): input dag Raises: TranspilerError: if the target basis cannot be reached Returns: DAGCircuit: translated circuit. """ if self._target_basis is None and self._target is None: return dag qarg_indices = {qubit: index for index, qubit in enumerate(dag.qubits)} # Names of instructions assumed to supported by any backend. if self._target is None: basic_instrs = ["measure", "reset", "barrier", "snapshot", "delay"] target_basis = set(self._target_basis) source_basis = set(self._extract_basis(dag)) qargs_local_source_basis = {} else: basic_instrs = ["barrier", "snapshot"] target_basis = self._target.keys() - set(self._non_global_operations) source_basis, qargs_local_source_basis = self._extract_basis_target(dag, qarg_indices) target_basis = set(target_basis).union(basic_instrs) "Begin BasisTranslator from source basis %s to target basis %s.", source_basis, target_basis, ) # Search for a path from source to target basis. search_start_time = time.time() basis_transforms = _basis_search(self._equiv_lib, source_basis, target_basis) qarg_local_basis_transforms = {} for qarg, local_source_basis in qargs_local_source_basis.items(): expanded_target = set(target_basis) # For any multiqubit operation that contains a subset of qubits that # has a non-local operation, include that non-local operation in the # search. This matches with the check we did above to include those # subset non-local operations in the check here. if len(qarg) > 1: for non_local_qarg, local_basis in self._qargs_with_non_global_operation.items(): if qarg.issuperset(non_local_qarg): expanded_target |= local_basis else: expanded_target |= self._qargs_with_non_global_operation[tuple(qarg)] "Performing BasisTranslator search from source basis %s to target " "basis %s on qarg %s.", local_source_basis, expanded_target, qarg, ) local_basis_transforms = _basis_search( self._equiv_lib, local_source_basis, expanded_target ) if local_basis_transforms is None: raise TranspilerError( "Unable to translate the operations in the circuit: " f"{[x[0] for x in local_source_basis]} to the backend's (or manually " f"specified) target basis: {list(expanded_target)}. This likely means the " "target basis is not universal or there are additional equivalence rules " "needed in the EquivalenceLibrary being used. For more details on this " "error see: " "" "BasisTranslator.html#translation_errors" ) qarg_local_basis_transforms[qarg] = local_basis_transforms search_end_time = time.time() "Basis translation path search completed in %.3fs.", search_end_time - search_start_time ) if basis_transforms is None: raise TranspilerError( "Unable to translate the operations in the circuit: " f"{[x[0] for x in source_basis]} to the backend's (or manually specified) target " f"basis: {list(target_basis)}. This likely means the target basis is not universal " "or there are additional equivalence rules needed in the EquivalenceLibrary being " "used. For more details on this error see: " "" "html#translation_errors" ) # Compose found path into a set of instruction substitution rules. compose_start_time = time.time() instr_map = _compose_transforms(basis_transforms, source_basis, dag) extra_instr_map = { qarg: _compose_transforms(transforms, qargs_local_source_basis[qarg], dag) for qarg, transforms in qarg_local_basis_transforms.items() } compose_end_time = time.time() "Basis translation paths composed in %.3fs.", compose_end_time - compose_start_time ) # Replace source instructions with target translations. replace_start_time = time.time() def apply_translation(dag, wire_map): dag_updated = False for node in dag.op_nodes(): node_qargs = tuple(wire_map[bit] for bit in node.qargs) qubit_set = frozenset(node_qargs) if in target_basis: if isinstance(node.op, ControlFlowOp): flow_blocks = [] for block in node.op.blocks: dag_block = circuit_to_dag(block) dag_updated = apply_translation( dag_block, { inner: wire_map[outer] for inner, outer in zip(block.qubits, node.qargs) }, ) if dag_updated: flow_circ_block = dag_to_circuit(dag_block) else: flow_circ_block = block flow_blocks.append(flow_circ_block) node.op = node.op.replace_blocks(flow_blocks) continue if ( node_qargs in self._qargs_with_non_global_operation and in self._qargs_with_non_global_operation[node_qargs] ): continue if dag.has_calibration_for(node): continue if qubit_set in extra_instr_map: self._replace_node(dag, node, extra_instr_map[qubit_set]) elif (, node.op.num_qubits) in instr_map: self._replace_node(dag, node, instr_map) else: raise TranspilerError(f"BasisTranslator did not map {}.") dag_updated = True return dag_updated apply_translation(dag, qarg_indices) replace_end_time = time.time() "Basis translation instructions replaced in %.3fs.", replace_end_time - replace_start_time, ) return dag
def _replace_node(self, dag, node, instr_map): target_params, target_dag = instr_map[, node.op.num_qubits] if len(node.op.params) != len(target_params): raise TranspilerError( "Translation num_params not equal to op num_params." "Op: {} {} Translation: {}\n{}".format( node.op.params,, target_params, target_dag ) ) if node.op.params: # Convert target to circ and back to assign_parameters, since # DAGCircuits won't have a ParameterTable. target_circuit = dag_to_circuit(target_dag) target_circuit.assign_parameters( dict(zip_longest(target_params, node.op.params)), inplace=True ) bound_target_dag = circuit_to_dag(target_circuit) else: bound_target_dag = target_dag if len(bound_target_dag.op_nodes()) == 1 and len( bound_target_dag.op_nodes()[0].qargs ) == len(node.qargs): dag_op = bound_target_dag.op_nodes()[0].op # dag_op may be the same instance as other ops in the dag, # so if there is a condition, need to copy if getattr(node.op, "condition", None): dag_op = dag_op.copy() dag.substitute_node(node, dag_op, inplace=True) if bound_target_dag.global_phase: dag.global_phase += bound_target_dag.global_phase else: dag.substitute_node_with_dag(node, bound_target_dag) @singledispatchmethod def _extract_basis(self, circuit): return circuit @_extract_basis.register def _(self, dag: DAGCircuit): for node in dag.op_nodes(): if not dag.has_calibration_for(node): yield (, node.op.num_qubits) if isinstance(node.op, ControlFlowOp): for block in node.op.blocks: yield from self._extract_basis(block) @_extract_basis.register def _(self, circ: QuantumCircuit): for instr_context in instr, _, _ = instr_context if not circ.has_calibration_for(instr_context): yield (, instr.num_qubits) if isinstance(instr, ControlFlowOp): for block in instr.blocks: yield from self._extract_basis(block) def _extract_basis_target( self, dag, qarg_indices, source_basis=None, qargs_local_source_basis=None ): if source_basis is None: source_basis = set() if qargs_local_source_basis is None: qargs_local_source_basis = defaultdict(set) for node in dag.op_nodes(): qargs = tuple(qarg_indices[bit] for bit in node.qargs) if dag.has_calibration_for(node): continue # Treat the instruction as on an incomplete basis if the qargs are in the # qargs_with_non_global_operation dictionary or if any of the qubits in qargs # are a superset for a non-local operation. For example, if the qargs # are (0, 1) and that's a global (ie no non-local operations on (0, 1) # operation but there is a non-local operation on (1,) we need to # do an extra non-local search for this op to ensure we include any # single qubit operation for (1,) as valid. This pattern also holds # true for > 2q ops too (so for 4q operations we need to check for 3q, 2q, # and 1q operations in the same manner) if qargs in self._qargs_with_non_global_operation or any( frozenset(qargs).issuperset(incomplete_qargs) for incomplete_qargs in self._qargs_with_non_global_operation ): qargs_local_source_basis[frozenset(qargs)].add((, node.op.num_qubits)) else: source_basis.add((, node.op.num_qubits)) if isinstance(node.op, ControlFlowOp): for block in node.op.blocks: block_dag = circuit_to_dag(block) source_basis, qargs_local_source_basis = self._extract_basis_target( block_dag, { inner: qarg_indices[outer] for inner, outer in zip(block.qubits, node.qargs) }, source_basis=source_basis, qargs_local_source_basis=qargs_local_source_basis, ) return source_basis, qargs_local_source_basis
class StopIfBasisRewritable(Exception): """Custom exception that signals `rustworkx.dijkstra_search` to stop.""" class BasisSearchVisitor(rustworkx.visit.DijkstraVisitor): """Handles events emitted during `rustworkx.dijkstra_search`.""" def __init__(self, graph, source_basis, target_basis): self.graph = graph self.target_basis = set(target_basis) self._source_gates_remain = set(source_basis) self._num_gates_remain_for_rule = {} save_index = -1 for edata in self.graph.edges(): if save_index == edata.index: continue self._num_gates_remain_for_rule[edata.index] = edata.num_gates save_index = edata.index self._basis_transforms = [] self._predecessors = {} self._opt_cost_map = {} def discover_vertex(self, v, score): gate = self.graph[v].key self._source_gates_remain.discard(gate) self._opt_cost_map[gate] = score rule = self._predecessors.get(gate, None) if rule is not None: logger.debug( "Gate %s generated using rule \n%s\n with total cost of %s.",, rule.circuit, score, ) self._basis_transforms.append((, gate.num_qubits, rule.params, rule.circuit)) # we can stop the search if we have found all gates in the original ciruit. if not self._source_gates_remain: # if we start from source gates and apply `basis_transforms` in reverse order, we'll end # up with gates in the target basis. Note though that `basis_transforms` may include # additional transformations that are not required to map our source gates to the given # target basis. self._basis_transforms.reverse() raise StopIfBasisRewritable def examine_edge(self, edge): _, target, edata = edge if edata is None: return self._num_gates_remain_for_rule[edata.index] -= 1 target = self.graph[target].key # if there are gates in this `rule` that we have not yet generated, we can't apply # this `rule`. if `target` is already in basis, it's not beneficial to use this rule. if self._num_gates_remain_for_rule[edata.index] > 0 or target in self.target_basis: raise rustworkx.visit.PruneSearch def edge_relaxed(self, edge): _, target, edata = edge if edata is not None: gate = self.graph[target].key self._predecessors[gate] = edata.rule def edge_cost(self, edge_data): """Returns the cost of an edge. This function computes the cost of this edge rule by summing the costs of all gates in the rule equivalence circuit. In the end, we need to subtract the cost of the source since `dijkstra` will later add it. """ if edge_data is None: # the target of the edge is a gate in the target basis, # so we return a default value of 1. return 1 cost_tot = 0 for instruction in edge_data.rule.circuit: key = Key(, num_qubits=len(instruction.qubits)) cost_tot += self._opt_cost_map[key] return cost_tot - self._opt_cost_map[edge_data.source] @property def basis_transforms(self): """Returns the gate basis transforms.""" return self._basis_transforms def _basis_search(equiv_lib, source_basis, target_basis): """Search for a set of transformations from source_basis to target_basis. Args: equiv_lib (EquivalenceLibrary): Source of valid translations source_basis (Set[Tuple[gate_name: str, gate_num_qubits: int]]): Starting basis. target_basis (Set[gate_name: str]): Target basis. Returns: Optional[List[Tuple[gate, equiv_params, equiv_circuit]]]: List of (gate, equiv_params, equiv_circuit) tuples tuples which, if applied in order will map from source_basis to target_basis. Returns None if no path was found. """ logger.debug("Begining basis search from %s to %s.", source_basis, target_basis) source_basis = { (gate_name, gate_num_qubits) for gate_name, gate_num_qubits in source_basis if gate_name not in target_basis } # if source basis is empty, no work to be done. if not source_basis: return [] # This is only neccessary since gates in target basis are currently reported by # their names and we need to have in addition the number of qubits they act on. target_basis_keys = [key for key in equiv_lib.keys() if in target_basis] graph = equiv_lib.graph vis = BasisSearchVisitor(graph, source_basis, target_basis_keys) # we add a dummy node and connect it with gates in the target basis. # we'll start the search from this dummy node. dummy = graph.add_node(NodeData(key="key", equivs=[("dummy starting node", 0)])) try: graph.add_edges_from_no_data( [(dummy, equiv_lib.node_index(key)) for key in target_basis_keys] ) rtn = None try: rustworkx.digraph_dijkstra_search(graph, [dummy], vis.edge_cost, vis) except StopIfBasisRewritable: rtn = vis.basis_transforms logger.debug("Transformation path:") for gate_name, gate_num_qubits, params, equiv in rtn: logger.debug("%s/%s => %s\n%s", gate_name, gate_num_qubits, params, equiv) finally: # Remove dummy node in order to return graph to original state graph.remove_node(dummy) return rtn def _compose_transforms(basis_transforms, source_basis, source_dag): """Compose a set of basis transforms into a set of replacements. Args: basis_transforms (List[Tuple[gate_name, params, equiv]]): List of transforms to compose. source_basis (Set[Tuple[gate_name: str, gate_num_qubits: int]]): Names of gates which need to be translated. source_dag (DAGCircuit): DAG with example gates from source_basis. (Used to determine num_params for gate in source_basis.) Returns: Dict[gate_name, Tuple(params, dag)]: Dictionary mapping between each gate in source_basis and a DAGCircuit instance to replace it. Gates in source_basis but not affected by basis_transforms will be included as a key mapping to itself. """ example_gates = _get_example_gates(source_dag) mapped_instrs = {} for gate_name, gate_num_qubits in source_basis: # Need to grab a gate instance to find num_qubits and num_params. # Can be removed following . example_gate = example_gates[gate_name, gate_num_qubits] num_params = len(example_gate.params) placeholder_params = ParameterVector(gate_name, num_params) placeholder_gate = Gate(gate_name, gate_num_qubits, list(placeholder_params)) placeholder_gate.params = list(placeholder_params) dag = DAGCircuit() qr = QuantumRegister(gate_num_qubits) dag.add_qreg(qr) dag.apply_operation_back(placeholder_gate, qr[:], []) mapped_instrs[gate_name, gate_num_qubits] = placeholder_params, dag for gate_name, gate_num_qubits, equiv_params, equiv in basis_transforms: logger.debug( "Composing transform step: %s/%s %s =>\n%s", gate_name, gate_num_qubits, equiv_params, equiv, ) for mapped_instr_name, (dag_params, dag) in mapped_instrs.items(): doomed_nodes = [ node for node in dag.op_nodes() if (, node.op.num_qubits) == (gate_name, gate_num_qubits) ] if doomed_nodes and logger.isEnabledFor(logging.DEBUG): logger.debug( "Updating transform for mapped instr %s %s from \n%s", mapped_instr_name, dag_params, dag_to_circuit(dag, copy_operations=False), ) for node in doomed_nodes: replacement = equiv.assign_parameters( dict(zip_longest(equiv_params, node.op.params)) ) replacement_dag = circuit_to_dag(replacement) dag.substitute_node_with_dag(node, replacement_dag) if doomed_nodes and logger.isEnabledFor(logging.DEBUG): logger.debug( "Updated transform for mapped instr %s %s to\n%s", mapped_instr_name, dag_params, dag_to_circuit(dag, copy_operations=False), ) return mapped_instrs def _get_example_gates(source_dag): def recurse(dag, example_gates=None): example_gates = example_gates or {} for node in dag.op_nodes(): example_gates[(, node.op.num_qubits)] = node.op if isinstance(node.op, ControlFlowOp): for block in node.op.blocks: example_gates = recurse(circuit_to_dag(block), example_gates) return example_gates return recurse(source_dag)