Skip to content

Commit 3a87959

Browse files
committed
refact(plot): separate graphviz building from IO
1 parent b08a363 commit 3a87959

File tree

4 files changed

+94
-87
lines changed

4 files changed

+94
-87
lines changed

graphkit/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def plot(self, filename=None, show=False, jupyter=None,
177177
:param str filename:
178178
Write diagram into a file.
179179
Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
180-
call :func:`network.supported_plot_formats()` for more.
180+
call :func:`plot.supported_plot_formats()` for more.
181181
:param show:
182182
If it evaluates to true, opens the diagram in a matplotlib window.
183183
If it equals `-1`, it plots but does not open the Window.

graphkit/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def plot(self, filename=None, show=False, jupyter=None,
382382
:param str filename:
383383
Write diagram into a file.
384384
Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
385-
call :func:`network.supported_plot_formats()` for more.
385+
call :func:`plot.supported_plot_formats()` for more.
386386
:param show:
387387
If it evaluates to true, opens the diagram in a matplotlib window.
388388
If it equals `-1``, it plots but does not open the Window.

graphkit/plot.py

Lines changed: 90 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,90 @@ def supported_plot_formats():
1515
return [".%s" % f for f in pydot.Dot().formats]
1616

1717

18+
def build_pydot(graph, steps=None, inputs=None, outputs=None, solution=None):
19+
""" Build a Graphviz graph """
20+
import pydot
21+
22+
assert graph is not None
23+
24+
def get_node_name(a):
25+
if isinstance(a, Operation):
26+
return a.name
27+
return a
28+
29+
dot = pydot.Dot(graph_type="digraph")
30+
31+
# draw nodes
32+
for nx_node in graph.nodes:
33+
kw = {}
34+
if isinstance(nx_node, str):
35+
# Only DeleteInstructions data in steps.
36+
if nx_node in steps:
37+
kw = {"color": "red", "penwidth": 2}
38+
39+
# SHAPE change if in inputs/outputs.
40+
# tip: https://graphviz.gitlab.io/_pages/doc/info/shapes.html
41+
shape = "rect"
42+
if inputs and outputs and nx_node in inputs and nx_node in outputs:
43+
shape = "hexagon"
44+
else:
45+
if inputs and nx_node in inputs:
46+
shape = "invhouse"
47+
if outputs and nx_node in outputs:
48+
shape = "house"
49+
50+
# LABEL change from solution.
51+
if solution and nx_node in solution:
52+
kw["style"] = "filled"
53+
kw["fillcolor"] = "gray"
54+
# kw["tooltip"] = nx_node, solution.get(nx_node)
55+
node = pydot.Node(name=nx_node, shape=shape, URL="fdgfdf", **kw)
56+
else: # Operation
57+
kw = {}
58+
shape = "oval" if isinstance(nx_node, NetworkOperation) else "circle"
59+
if nx_node in steps:
60+
kw["style"] = "bold"
61+
node = pydot.Node(name=nx_node.name, shape=shape, **kw)
62+
63+
dot.add_node(node)
64+
65+
# draw edges
66+
for src, dst in graph.edges:
67+
src_name = get_node_name(src)
68+
dst_name = get_node_name(dst)
69+
kw = {}
70+
if isinstance(dst, Operation) and any(
71+
n == src and isinstance(n, optional) for n in dst.needs
72+
):
73+
kw["style"] = "dashed"
74+
edge = pydot.Edge(src=src_name, dst=dst_name, **kw)
75+
dot.add_edge(edge)
76+
77+
# draw steps sequence
78+
if steps and len(steps) > 1:
79+
it1 = iter(steps)
80+
it2 = iter(steps)
81+
next(it2)
82+
for i, (src, dst) in enumerate(zip(it1, it2), 1):
83+
src_name = get_node_name(src)
84+
dst_name = get_node_name(dst)
85+
edge = pydot.Edge(
86+
src=src_name,
87+
dst=dst_name,
88+
label=str(i),
89+
style="dotted",
90+
color="green",
91+
fontcolor="green",
92+
fontname="bold",
93+
fontsize=18,
94+
penwidth=3,
95+
arrowhead="vee",
96+
)
97+
dot.add_edge(edge)
98+
99+
return dot
100+
101+
18102
def plot_graph(
19103
graph,
20104
filename=None,
@@ -55,7 +139,7 @@ def plot_graph(
55139
:param str filename:
56140
Write diagram into a file.
57141
Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
58-
call :func:`network.supported_plot_formats()` for more.
142+
call :func:`plot.supported_plot_formats()` for more.
59143
:param show:
60144
If it evaluates to true, opens the diagram in a matplotlib window.
61145
If it equals `-1``, it plots but does not open the Window.
@@ -93,84 +177,7 @@ def plot_graph(
93177
>>> pipeline.plot('plot.svg', inputs=inputs, solution=solution, outputs=['asked', 'b1']);
94178
95179
"""
96-
import pydot
97-
98-
assert graph is not None
99-
100-
def get_node_name(a):
101-
if isinstance(a, Operation):
102-
return a.name
103-
return a
104-
105-
g = pydot.Dot(graph_type="digraph")
106-
107-
# draw nodes
108-
for nx_node in graph.nodes:
109-
kw = {}
110-
if isinstance(nx_node, str):
111-
# Only DeleteInstructions data in steps.
112-
if nx_node in steps:
113-
kw = {"color": "red", "penwidth": 2}
114-
115-
# SHAPE change if in inputs/outputs.
116-
# tip: https://graphviz.gitlab.io/_pages/doc/info/shapes.html
117-
shape = "rect"
118-
if inputs and outputs and nx_node in inputs and nx_node in outputs:
119-
shape = "hexagon"
120-
else:
121-
if inputs and nx_node in inputs:
122-
shape = "invhouse"
123-
if outputs and nx_node in outputs:
124-
shape = "house"
125-
126-
# LABEL change from solution.
127-
if solution and nx_node in solution:
128-
kw["style"] = "filled"
129-
kw["fillcolor"] = "gray"
130-
# kw["tooltip"] = nx_node, solution.get(nx_node)
131-
node = pydot.Node(name=nx_node, shape=shape, URL="fdgfdf", **kw)
132-
else: # Operation
133-
kw = {}
134-
shape = "oval" if isinstance(nx_node, NetworkOperation) else "circle"
135-
if nx_node in steps:
136-
kw["style"] = "bold"
137-
node = pydot.Node(name=nx_node.name, shape=shape, **kw)
138-
139-
g.add_node(node)
140-
141-
# draw edges
142-
for src, dst in graph.edges:
143-
src_name = get_node_name(src)
144-
dst_name = get_node_name(dst)
145-
kw = {}
146-
if isinstance(dst, Operation) and any(
147-
n == src and isinstance(n, optional) for n in dst.needs
148-
):
149-
kw["style"] = "dashed"
150-
edge = pydot.Edge(src=src_name, dst=dst_name, **kw)
151-
g.add_edge(edge)
152-
153-
# draw steps sequence
154-
if steps and len(steps) > 1:
155-
it1 = iter(steps)
156-
it2 = iter(steps)
157-
next(it2)
158-
for i, (src, dst) in enumerate(zip(it1, it2), 1):
159-
src_name = get_node_name(src)
160-
dst_name = get_node_name(dst)
161-
edge = pydot.Edge(
162-
src=src_name,
163-
dst=dst_name,
164-
label=str(i),
165-
style="dotted",
166-
color="green",
167-
fontcolor="green",
168-
fontname="bold",
169-
fontsize=18,
170-
penwidth=3,
171-
arrowhead="vee",
172-
)
173-
g.add_edge(edge)
180+
dot = build_pydot(graph, steps, inputs, outputs, solution)
174181

175182
# Save plot
176183
#
@@ -183,26 +190,26 @@ def get_node_name(a):
183190
" File extensions must be one of: %s" % (ext, " ".join(formats))
184191
)
185192

186-
g.write(filename, format=ext.lower()[1:])
193+
dot.write(filename, format=ext.lower()[1:])
187194

188195
## Return an SVG renderable in jupyter.
189196
#
190197
if jupyter:
191198
from IPython.display import SVG
192199

193-
g = SVG(data=g.create_svg())
200+
dot = SVG(data=dot.create_svg())
194201

195202
## Display graph via matplotlib
196203
#
197204
if show:
198205
import matplotlib.pyplot as plt
199206
import matplotlib.image as mpimg
200207

201-
png = g.create_png()
208+
png = dot.create_png()
202209
sio = io.BytesIO(png)
203210
img = mpimg.imread(sio)
204211
plt.imshow(img, aspect="equal")
205212
if show != -1:
206213
plt.show()
207214

208-
return g
215+
return dot

test/test_plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_plot_formats(pipeline, input_names, outputs, solution, tmp_path):
5959
# ...these are not working on my PC, or travis.
6060
forbidden_formats = ".dia .hpgl .mif .mp .pcl .pic .vtx .xlib".split()
6161
prev_dot = None
62-
for ext in network.supported_plot_formats():
62+
for ext in plot.supported_plot_formats():
6363
if ext not in forbidden_formats:
6464
dot = pipeline.plot(inputs=input_names, outputs=outputs, solution=solution)
6565
assert dot
@@ -72,7 +72,7 @@ def test_plot_bad_format(pipeline, tmp_path):
7272
pipeline.plot(filename="bad.format")
7373

7474
## Check help msg lists all siupported formats
75-
for ext in network.supported_plot_formats():
75+
for ext in plot.supported_plot_formats():
7676
assert exinfo.match(ext)
7777

7878

0 commit comments

Comments
 (0)