/
commutation_checker.py
438 lines (372 loc) · 16.1 KB
/
commutation_checker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
# This code is part of Qiskit.
#
# (C) Copyright IBM 2017, 2021.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Code from commutative_analysis pass that checks commutation relations between DAG nodes."""
from functools import lru_cache
from typing import List, Union
import numpy as np
from qiskit import QiskitError
from qiskit.circuit import Qubit
from qiskit.circuit.operation import Operation
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES
from qiskit.quantum_info.operators import Operator
_skipped_op_names = {"measure", "reset", "delay", "initialize"}
_no_cache_op_names = {"annotated"}
@lru_cache(maxsize=None)
def _identity_op(num_qubits):
"""Cached identity matrix"""
return Operator(
np.eye(2**num_qubits), input_dims=(2,) * num_qubits, output_dims=(2,) * num_qubits
)
class CommutationChecker:
"""This code is essentially copy-pasted from commutative_analysis.py.
This code cleverly hashes commutativity and non-commutativity results between DAG nodes and seems
quite efficient for large Clifford circuits.
They may be other possible efficiency improvements: using rule-based commutativity analysis,
evicting from the cache less useful entries, etc.
"""
def __init__(self, standard_gate_commutations: dict = None, cache_max_entries: int = 10**6):
super().__init__()
if standard_gate_commutations is None:
self._standard_commutations = {}
else:
self._standard_commutations = standard_gate_commutations
self._cache_max_entries = cache_max_entries
# self._cached_commutation has the same structure as standard_gate_commutations, i.e. a
# dict[pair of gate names][relative placement][tuple of gate parameters] := True/False
self._cached_commutations = {}
self._current_cache_entries = 0
self._cache_miss = 0
self._cache_hit = 0
def commute(
self,
op1: Operation,
qargs1: List,
cargs1: List,
op2: Operation,
qargs2: List,
cargs2: List,
max_num_qubits: int = 3,
) -> bool:
"""
Checks if two Operations commute. The return value of `True` means that the operations
truly commute, and the return value of `False` means that either the operations do not
commute or that the commutation check was skipped (for example, when the operations
have conditions or have too many qubits).
Args:
op1: first operation.
qargs1: first operation's qubits.
cargs1: first operation's clbits.
op2: second operation.
qargs2: second operation's qubits.
cargs2: second operation's clbits.
max_num_qubits: the maximum number of qubits to consider, the check may be skipped if
the number of qubits for either operation exceeds this amount.
Returns:
bool: whether two operations commute.
"""
structural_commutation = _commutation_precheck(
op1, qargs1, cargs1, op2, qargs2, cargs2, max_num_qubits
)
if structural_commutation is not None:
return structural_commutation
first_op_tuple, second_op_tuple = _order_operations(
op1, qargs1, cargs1, op2, qargs2, cargs2
)
first_op, first_qargs, _ = first_op_tuple
second_op, second_qargs, _ = second_op_tuple
skip_cache = first_op.name in _no_cache_op_names or second_op.name in _no_cache_op_names
if skip_cache:
return _commute_matmul(first_op, first_qargs, second_op, second_qargs)
commutation_lookup = self.check_commutation_entries(
first_op, first_qargs, second_op, second_qargs
)
if commutation_lookup is not None:
return commutation_lookup
# Compute commutation via matrix multiplication
is_commuting = _commute_matmul(first_op, first_qargs, second_op, second_qargs)
# Store result in this session's commutation_library
# TODO implement LRU cache or similar
# Rebuild cache if current cache exceeded max size
if self._current_cache_entries >= self._cache_max_entries:
self.clear_cached_commutations()
first_params = getattr(first_op, "params", [])
second_params = getattr(second_op, "params", [])
if len(first_params) > 0 or len(second_params) > 0:
self._cached_commutations.setdefault((first_op.name, second_op.name), {}).setdefault(
_get_relative_placement(first_qargs, second_qargs), {}
)[
(_hashable_parameters(first_params), _hashable_parameters(second_params))
] = is_commuting
else:
self._cached_commutations.setdefault((first_op.name, second_op.name), {})[
_get_relative_placement(first_qargs, second_qargs)
] = is_commuting
self._current_cache_entries += 1
return is_commuting
def num_cached_entries(self):
"""Returns number of cached entries"""
return self._current_cache_entries
def clear_cached_commutations(self):
"""Clears the dictionary holding cached commutations"""
self._current_cache_entries = 0
self._cache_miss = 0
self._cache_hit = 0
self._cached_commutations = {}
def check_commutation_entries(
self,
first_op: Operation,
first_qargs: List,
second_op: Operation,
second_qargs: List,
) -> Union[bool, None]:
"""Returns stored commutation relation if any
Args:
first_op: first operation.
first_qargs: first operation's qubits.
second_op: second operation.
second_qargs: second operation's qubits.
Return:
bool: True if the gates commute and false if it is not the case.
"""
# We don't precompute commutations for parameterized gates, yet
commutation = _query_commutation(
first_op,
first_qargs,
second_op,
second_qargs,
self._standard_commutations,
)
if commutation is not None:
return commutation
commutation = _query_commutation(
first_op,
first_qargs,
second_op,
second_qargs,
self._cached_commutations,
)
if commutation is None:
self._cache_miss += 1
else:
self._cache_hit += 1
return commutation
def _hashable_parameters(params):
"""Convert the parameters of a gate into a hashable format for lookup in a dictionary.
This aims to be fast in common cases, and is not intended to work outside of the lifetime of a
single commutation pass; it does not handle mutable state correctly if the state is actually
changed."""
try:
hash(params)
return params
except TypeError:
pass
if isinstance(params, (list, tuple)):
return tuple(_hashable_parameters(x) for x in params)
if isinstance(params, np.ndarray):
# Using the bytes of the matrix as key is runtime efficient but requires more space: 128 bits
# times the number of parameters instead of a single 64 bit id. However, by using the bytes as
# an id, we can reuse the cached commutations between different passes.
return (np.ndarray, params.tobytes())
# Catch anything else with a slow conversion.
return ("fallback", str(params))
def is_commutation_supported(op):
"""
Filter operations whose commutation is not supported due to bugs in transpiler passes invoking
commutation analysis.
Args:
op (Operation): operation to be checked for commutation relation
Return:
True if determining the commutation of op is currently supported
"""
# Bug in CommutativeCancellation, e.g. see gh-8553
if getattr(op, "condition", False):
return False
# Commutation of ControlFlow gates also not supported yet. This may be pending a control flow graph.
if op.name in CONTROL_FLOW_OP_NAMES:
return False
return True
def is_commutation_skipped(op, qargs, max_num_qubits):
"""
Filter operations whose commutation will not be determined.
Args:
op (Operation): operation to be checked for commutation relation
qargs (List): operation qubits
max_num_qubits (int): the maximum number of qubits to consider, the check may be skipped if
the number of qubits for either operation exceeds this amount.
Return:
True if determining the commutation of op is currently not supported
"""
if (
len(qargs) > max_num_qubits
or getattr(op, "_directive", False)
or op.name in _skipped_op_names
):
return True
if getattr(op, "is_parameterized", False) and op.is_parameterized():
return True
# we can proceed if op has defined: to_operator, to_matrix and __array__, or if its definition can be
# recursively resolved by operations that have a matrix. We check this by constructing an Operator.
if (hasattr(op, "to_matrix") and hasattr(op, "__array__")) or hasattr(op, "to_operator"):
return False
return False
def _commutation_precheck(
op1: Operation,
qargs1: List,
cargs1: List,
op2: Operation,
qargs2: List,
cargs2: List,
max_num_qubits,
):
if not is_commutation_supported(op1) or not is_commutation_supported(op2):
return False
if set(qargs1).isdisjoint(qargs2) and set(cargs1).isdisjoint(cargs2):
return True
if is_commutation_skipped(op1, qargs1, max_num_qubits) or is_commutation_skipped(
op2, qargs2, max_num_qubits
):
return False
return None
def _get_relative_placement(first_qargs: List[Qubit], second_qargs: List[Qubit]) -> tuple:
"""Determines the relative qubit placement of two gates. Note: this is NOT symmetric.
Args:
first_qargs (DAGOpNode): first gate
second_qargs (DAGOpNode): second gate
Return:
A tuple that describes the relative qubit placement: E.g.
_get_relative_placement(CX(0, 1), CX(1, 2)) would return (None, 0) as there is no overlap on
the first qubit of the first gate but there is an overlap on the second qubit of the first gate,
i.e. qubit 0 of the second gate. Likewise,
_get_relative_placement(CX(1, 2), CX(0, 1)) would return (1, None)
"""
qubits_g2 = {q_g1: i_g1 for i_g1, q_g1 in enumerate(second_qargs)}
return tuple(qubits_g2.get(q_g0, None) for q_g0 in first_qargs)
@lru_cache(maxsize=10**3)
def _persistent_id(op_name: str) -> int:
"""Returns an integer id of a string that is persistent over different python executions (note that
hash() can not be used, i.e. its value can change over two python executions)
Args:
op_name (str): The string whose integer id should be determined.
Return:
The integer id of the input string.
"""
return int.from_bytes(bytes(op_name, encoding="utf-8"), byteorder="big", signed=True)
def _order_operations(
op1: Operation, qargs1: List, cargs1: List, op2: Operation, qargs2: List, cargs2: List
):
"""Orders two operations in a canonical way that is persistent over
@different python versions and executions
Args:
op1: first operation.
qargs1: first operation's qubits.
cargs1: first operation's clbits.
op2: second operation.
qargs2: second operation's qubits.
cargs2: second operation's clbits.
Return:
The input operations in a persistent, canonical order.
"""
op1_tuple = (op1, qargs1, cargs1)
op2_tuple = (op2, qargs2, cargs2)
least_qubits_op, most_qubits_op = (
(op1_tuple, op2_tuple) if op1.num_qubits < op2.num_qubits else (op2_tuple, op1_tuple)
)
# prefer operation with the least number of qubits as first key as this results in shorter keys
if op1.num_qubits != op2.num_qubits:
return least_qubits_op, most_qubits_op
else:
return (
(op1_tuple, op2_tuple)
if _persistent_id(op1.name) < _persistent_id(op2.name)
else (op2_tuple, op1_tuple)
)
def _query_commutation(
first_op: Operation,
first_qargs: List,
second_op: Operation,
second_qargs: List,
_commutation_lib: dict,
) -> Union[bool, None]:
"""Queries and returns the commutation of a pair of operations from a provided commutation library
Args:
first_op: first operation.
first_qargs: first operation's qubits.
first_cargs: first operation's clbits.
second_op: second operation.
second_qargs: second operation's qubits.
second_cargs: second operation's clbits.
_commutation_lib (dict): dictionary of commutation relations
Return:
True if first_op and second_op commute, False if they do not commute and
None if the commutation is not in the library
"""
commutation = _commutation_lib.get((first_op.name, second_op.name), None)
# Return here if the commutation is constant over all relative placements of the operations
if commutation is None or isinstance(commutation, bool):
return commutation
# If we arrive here, there is an entry in the commutation library but it depends on the
# placement of the operations and also possibly on operation parameters
if isinstance(commutation, dict):
commutation_after_placement = commutation.get(
_get_relative_placement(first_qargs, second_qargs), None
)
# if we have another dict in commutation_after_placement, commutation depends on params
if isinstance(commutation_after_placement, dict):
# Param commutation entry exists and must be a dict
first_params = getattr(first_op, "params", [])
second_params = getattr(second_op, "params", [])
return commutation_after_placement.get(
(_hashable_parameters(first_params), _hashable_parameters(second_params)),
None,
)
else:
# queried commutation is True, False or None
return commutation_after_placement
else:
raise ValueError("Expected commutation to be None, bool or a dict")
def _commute_matmul(
first_ops: Operation, first_qargs: List, second_op: Operation, second_qargs: List
):
qarg = {q: i for i, q in enumerate(first_qargs)}
num_qubits = len(qarg)
for q in second_qargs:
if q not in qarg:
qarg[q] = num_qubits
num_qubits += 1
first_qarg = tuple(qarg[q] for q in first_qargs)
second_qarg = tuple(qarg[q] for q in second_qargs)
# try to generate an Operator out of op, if this succeeds we can determine commutativity, otherwise
# return false
try:
operator_1 = Operator(
first_ops, input_dims=(2,) * len(first_qarg), output_dims=(2,) * len(first_qarg)
)
operator_2 = Operator(
second_op, input_dims=(2,) * len(second_qarg), output_dims=(2,) * len(second_qarg)
)
except QiskitError:
return False
if first_qarg == second_qarg:
# Use full composition if possible to get the fastest matmul paths.
op12 = operator_1.compose(operator_2)
op21 = operator_2.compose(operator_1)
else:
# Expand operator_1 to be large enough to contain operator_2 as well; this relies on qargs1
# being the lowest possible indices so the identity can be tensored before it.
extra_qarg2 = num_qubits - len(first_qarg)
if extra_qarg2:
id_op = _identity_op(extra_qarg2)
operator_1 = id_op.tensor(operator_1)
op12 = operator_1.compose(operator_2, qargs=second_qarg, front=False)
op21 = operator_1.compose(operator_2, qargs=second_qarg, front=True)
ret = op12 == op21
return ret