/
pass_manager_visualization.py
319 lines (257 loc) · 11 KB
/
pass_manager_visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# This code is part of Qiskit.
#
# (C) Copyright IBM 2019.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
Visualization function for a pass manager. Passes are grouped based on their
flow controller, and coloured based on the type of pass.
"""
from __future__ import annotations
import os
import inspect
import tempfile
from qiskit.utils import optionals as _optionals
from qiskit.passmanager.base_tasks import BaseController, GenericPass
from qiskit.passmanager.flow_controllers import FlowControllerLinear
from qiskit.transpiler.basepasses import AnalysisPass, TransformationPass
from .exceptions import VisualizationError
DEFAULT_STYLE = {AnalysisPass: "red", TransformationPass: "blue"}
@_optionals.HAS_GRAPHVIZ.require_in_call
@_optionals.HAS_PYDOT.require_in_call
def pass_manager_drawer(pass_manager, filename=None, style=None, raw=False):
"""
Draws the pass manager.
This function needs `pydot <https://github.com/pydot/pydot>`__, which in turn needs
`Graphviz <https://www.graphviz.org/>`__ to be installed.
Args:
pass_manager (PassManager): the pass manager to be drawn
filename (str): file path to save image to
style (dict or OrderedDict): keys are the pass classes and the values are
the colors to make them. An example can be seen in the DEFAULT_STYLE. An ordered
dict can be used to ensure a priority coloring when pass falls into multiple
categories. Any values not included in the provided dict will be filled in from
the default dict
raw (Bool) : True if you want to save the raw Dot output not an image. The
default is False.
Returns:
PIL.Image or None: an in-memory representation of the pass manager. Or None if
no image was generated or PIL is not installed.
Raises:
MissingOptionalLibraryError: when nxpd or pydot not installed.
VisualizationError: If raw=True and filename=None.
Example:
.. code-block::
%matplotlib inline
from qiskit import QuantumCircuit
from qiskit.compiler import transpile
from qiskit.transpiler import PassManager
from qiskit.visualization import pass_manager_drawer
from qiskit.transpiler.passes import Unroller
circ = QuantumCircuit(3)
circ.ccx(0, 1, 2)
circ.draw()
pass_ = Unroller(['u1', 'u2', 'u3', 'cx'])
pm = PassManager(pass_)
new_circ = pm.run(circ)
new_circ.draw(output='mpl')
pass_manager_drawer(pm, "passmanager.jpg")
"""
import pydot
if not style:
style = DEFAULT_STYLE
# create the overall graph
graph = pydot.Dot()
# identifiers for nodes need to be unique, so assign an id
# can't just use python's id in case the exact same pass was
# appended more than once
component_id = 0
prev_node = None
for index, controller_group in enumerate(pass_manager.to_flow_controller().tasks):
subgraph, component_id, prev_node = draw_subgraph(
controller_group, component_id, style, prev_node, index
)
graph.add_subgraph(subgraph)
output = make_output(graph, raw, filename)
return output
def _get_node_color(pss, style):
# look in the user provided dict first
for typ, color in style.items():
if isinstance(pss, typ):
return color
# failing that, look in the default
for typ, color in DEFAULT_STYLE.items():
if isinstance(pss, typ):
return color
return "black"
@_optionals.HAS_GRAPHVIZ.require_in_call
@_optionals.HAS_PYDOT.require_in_call
def staged_pass_manager_drawer(pass_manager, filename=None, style=None, raw=False):
"""
Draws the staged pass manager.
This function needs `pydot <https://github.com/erocarrera/pydot>`__, which in turn needs
`Graphviz <https://www.graphviz.org/>`__ to be installed.
Args:
pass_manager (StagedPassManager): the staged pass manager to be drawn
filename (str): file path to save image to
style (dict or OrderedDict): keys are the pass classes and the values are
the colors to make them. An example can be seen in the DEFAULT_STYLE. An ordered
dict can be used to ensure a priority coloring when pass falls into multiple
categories. Any values not included in the provided dict will be filled in from
the default dict
raw (Bool) : True if you want to save the raw Dot output not an image. The
default is False.
Returns:
PIL.Image or None: an in-memory representation of the pass manager. Or None if
no image was generated or PIL is not installed.
Raises:
MissingOptionalLibraryError: when nxpd or pydot not installed.
VisualizationError: If raw=True and filename=None.
Example:
.. code-block::
%matplotlib inline
from qiskit.providers.fake_provider import GenericBackendV2
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
pass_manager = generate_preset_pass_manager(3, GenericBackendV2(num_qubits=5))
pass_manager.draw()
"""
import pydot
# only include stages that have passes
stages = list(filter(lambda s: s is not None, pass_manager.expanded_stages))
if not style:
style = DEFAULT_STYLE
# create the overall graph
graph = pydot.Dot()
# identifiers for nodes need to be unique, so assign an id
# can't just use python's id in case the exact same pass was
# appended more than once
component_id = 0
# keep a running count of indexes across stages
idx = 0
prev_node = None
for st in stages:
stage = getattr(pass_manager, st)
if stage is not None:
stagegraph = pydot.Cluster(str(st), label=str(st), fontname="helvetica", labeljust="l")
for controller_group in stage.to_flow_controller().tasks:
subgraph, component_id, prev_node = draw_subgraph(
controller_group, component_id, style, prev_node, idx
)
stagegraph.add_subgraph(subgraph)
idx += 1
graph.add_subgraph(stagegraph)
output = make_output(graph, raw, filename)
return output
def draw_subgraph(controller_group, component_id, style, prev_node, idx):
"""Draw subgraph."""
import pydot
# label is the name of the flow controller parameter
label = f"[{idx}] "
if isinstance(controller_group, BaseController) and not isinstance(
controller_group, FlowControllerLinear
):
label += f"{controller_group.__class__.__name__}"
# create the subgraph for this controller
subgraph = pydot.Cluster(str(component_id), label=label, fontname="helvetica", labeljust="l")
component_id += 1
if isinstance(controller_group, BaseController):
# Assume linear pipeline
# TODO: support pipeline branching when such controller is introduced
tasks = getattr(controller_group, "tasks", [])
elif isinstance(controller_group, GenericPass):
tasks = [controller_group]
elif isinstance(controller_group, (list, tuple)):
tasks = controller_group
else:
# Invalid data
return subgraph, component_id, prev_node
flatten_tasks = []
for task in tasks:
# Flatten nested linear flow controller.
# This situation often occurs in the builtin pass managers because it constructs
# some stages by appending other pass manager instance converted into a linear controller.
# Flattening inner linear controller tasks doesn't change the execution.
if isinstance(task, FlowControllerLinear):
flatten_tasks.extend(task.tasks)
else:
flatten_tasks.append(task)
for task in flatten_tasks:
if isinstance(task, BaseController):
# Partly nested flow controller
# TODO recursively inject subgraph into subgraph
node = pydot.Node(
str(component_id),
label="Nested flow controller",
color="k",
shape="rectangle",
fontname="helvetica",
)
else:
# label is the name of the pass
node = pydot.Node(
str(component_id),
label=str(type(task).__name__),
color=_get_node_color(task, style),
shape="rectangle",
fontname="helvetica",
)
subgraph.add_node(node)
component_id += 1
# the arguments that were provided to the pass when it was created
arg_spec = inspect.getfullargspec(task.__init__)
# 0 is the args, 1: to remove the self arg
args = arg_spec[0][1:]
num_optional = len(arg_spec[3]) if arg_spec[3] else 0
# add in the inputs to the pass
for arg_index, arg in enumerate(args):
nd_style = "solid"
# any optional args are dashed
# the num of optional counts from the end towards the start of the list
if arg_index >= (len(args) - num_optional):
nd_style = "dashed"
input_node = pydot.Node(
component_id,
label=arg,
color="black",
shape="ellipse",
fontsize=10,
style=nd_style,
fontname="helvetica",
)
subgraph.add_node(input_node)
component_id += 1
subgraph.add_edge(pydot.Edge(input_node, node))
# if there is a previous node, add an edge between them
if prev_node:
subgraph.add_edge(pydot.Edge(prev_node, node))
prev_node = node
return subgraph, component_id, prev_node
def make_output(graph, raw, filename):
"""Produce output for pass_manager."""
if raw:
if filename:
graph.write(filename, format="raw")
return None
else:
raise VisualizationError("if format=raw, then a filename is required.")
if not _optionals.HAS_PIL and filename:
# pylint says this isn't a method - it is
graph.write_png(filename)
return None
_optionals.HAS_PIL.require_now("pass manager drawer")
with tempfile.TemporaryDirectory() as tmpdirname:
from PIL import Image
tmppath = os.path.join(tmpdirname, "pass_manager.png")
# pylint says this isn't a method - it is
graph.write_png(tmppath)
image = Image.open(tmppath)
os.remove(tmppath)
if filename:
image.save(filename, "PNG")
return image