@@ -15,6 +15,90 @@ def supported_plot_formats():
1515 return [".%s" % f for f in pydot .Dot ().formats ]
1616
1717
18+ def build_pydot (graph , steps = None , inputs = None , outputs = None , solution = None ):
19+ """ Build a Graphviz graph """
20+ import pydot
21+
22+ assert graph is not None
23+
24+ def get_node_name (a ):
25+ if isinstance (a , Operation ):
26+ return a .name
27+ return a
28+
29+ dot = pydot .Dot (graph_type = "digraph" )
30+
31+ # draw nodes
32+ for nx_node in graph .nodes :
33+ kw = {}
34+ if isinstance (nx_node , str ):
35+ # Only DeleteInstructions data in steps.
36+ if nx_node in steps :
37+ kw = {"color" : "red" , "penwidth" : 2 }
38+
39+ # SHAPE change if in inputs/outputs.
40+ # tip: https://graphviz.gitlab.io/_pages/doc/info/shapes.html
41+ shape = "rect"
42+ if inputs and outputs and nx_node in inputs and nx_node in outputs :
43+ shape = "hexagon"
44+ else :
45+ if inputs and nx_node in inputs :
46+ shape = "invhouse"
47+ if outputs and nx_node in outputs :
48+ shape = "house"
49+
50+ # LABEL change from solution.
51+ if solution and nx_node in solution :
52+ kw ["style" ] = "filled"
53+ kw ["fillcolor" ] = "gray"
54+ # kw["tooltip"] = nx_node, solution.get(nx_node)
55+ node = pydot .Node (name = nx_node , shape = shape , URL = "fdgfdf" , ** kw )
56+ else : # Operation
57+ kw = {}
58+ shape = "oval" if isinstance (nx_node , NetworkOperation ) else "circle"
59+ if nx_node in steps :
60+ kw ["style" ] = "bold"
61+ node = pydot .Node (name = nx_node .name , shape = shape , ** kw )
62+
63+ dot .add_node (node )
64+
65+ # draw edges
66+ for src , dst in graph .edges :
67+ src_name = get_node_name (src )
68+ dst_name = get_node_name (dst )
69+ kw = {}
70+ if isinstance (dst , Operation ) and any (
71+ n == src and isinstance (n , optional ) for n in dst .needs
72+ ):
73+ kw ["style" ] = "dashed"
74+ edge = pydot .Edge (src = src_name , dst = dst_name , ** kw )
75+ dot .add_edge (edge )
76+
77+ # draw steps sequence
78+ if steps and len (steps ) > 1 :
79+ it1 = iter (steps )
80+ it2 = iter (steps )
81+ next (it2 )
82+ for i , (src , dst ) in enumerate (zip (it1 , it2 ), 1 ):
83+ src_name = get_node_name (src )
84+ dst_name = get_node_name (dst )
85+ edge = pydot .Edge (
86+ src = src_name ,
87+ dst = dst_name ,
88+ label = str (i ),
89+ style = "dotted" ,
90+ color = "green" ,
91+ fontcolor = "green" ,
92+ fontname = "bold" ,
93+ fontsize = 18 ,
94+ penwidth = 3 ,
95+ arrowhead = "vee" ,
96+ )
97+ dot .add_edge (edge )
98+
99+ return dot
100+
101+
18102def plot_graph (
19103 graph ,
20104 filename = None ,
@@ -55,7 +139,7 @@ def plot_graph(
55139 :param str filename:
56140 Write diagram into a file.
57141 Common extensions are ``.png .dot .jpg .jpeg .pdf .svg``
58- call :func:`network .supported_plot_formats()` for more.
142+ call :func:`plot .supported_plot_formats()` for more.
59143 :param show:
60144 If it evaluates to true, opens the diagram in a matplotlib window.
61145 If it equals `-1``, it plots but does not open the Window.
@@ -93,84 +177,7 @@ def plot_graph(
93177 >>> pipeline.plot('plot.svg', inputs=inputs, solution=solution, outputs=['asked', 'b1']);
94178
95179 """
96- import pydot
97-
98- assert graph is not None
99-
100- def get_node_name (a ):
101- if isinstance (a , Operation ):
102- return a .name
103- return a
104-
105- g = pydot .Dot (graph_type = "digraph" )
106-
107- # draw nodes
108- for nx_node in graph .nodes :
109- kw = {}
110- if isinstance (nx_node , str ):
111- # Only DeleteInstructions data in steps.
112- if nx_node in steps :
113- kw = {"color" : "red" , "penwidth" : 2 }
114-
115- # SHAPE change if in inputs/outputs.
116- # tip: https://graphviz.gitlab.io/_pages/doc/info/shapes.html
117- shape = "rect"
118- if inputs and outputs and nx_node in inputs and nx_node in outputs :
119- shape = "hexagon"
120- else :
121- if inputs and nx_node in inputs :
122- shape = "invhouse"
123- if outputs and nx_node in outputs :
124- shape = "house"
125-
126- # LABEL change from solution.
127- if solution and nx_node in solution :
128- kw ["style" ] = "filled"
129- kw ["fillcolor" ] = "gray"
130- # kw["tooltip"] = nx_node, solution.get(nx_node)
131- node = pydot .Node (name = nx_node , shape = shape , URL = "fdgfdf" , ** kw )
132- else : # Operation
133- kw = {}
134- shape = "oval" if isinstance (nx_node , NetworkOperation ) else "circle"
135- if nx_node in steps :
136- kw ["style" ] = "bold"
137- node = pydot .Node (name = nx_node .name , shape = shape , ** kw )
138-
139- g .add_node (node )
140-
141- # draw edges
142- for src , dst in graph .edges :
143- src_name = get_node_name (src )
144- dst_name = get_node_name (dst )
145- kw = {}
146- if isinstance (dst , Operation ) and any (
147- n == src and isinstance (n , optional ) for n in dst .needs
148- ):
149- kw ["style" ] = "dashed"
150- edge = pydot .Edge (src = src_name , dst = dst_name , ** kw )
151- g .add_edge (edge )
152-
153- # draw steps sequence
154- if steps and len (steps ) > 1 :
155- it1 = iter (steps )
156- it2 = iter (steps )
157- next (it2 )
158- for i , (src , dst ) in enumerate (zip (it1 , it2 ), 1 ):
159- src_name = get_node_name (src )
160- dst_name = get_node_name (dst )
161- edge = pydot .Edge (
162- src = src_name ,
163- dst = dst_name ,
164- label = str (i ),
165- style = "dotted" ,
166- color = "green" ,
167- fontcolor = "green" ,
168- fontname = "bold" ,
169- fontsize = 18 ,
170- penwidth = 3 ,
171- arrowhead = "vee" ,
172- )
173- g .add_edge (edge )
180+ dot = build_pydot (graph , steps , inputs , outputs , solution )
174181
175182 # Save plot
176183 #
@@ -183,26 +190,26 @@ def get_node_name(a):
183190 " File extensions must be one of: %s" % (ext , " " .join (formats ))
184191 )
185192
186- g .write (filename , format = ext .lower ()[1 :])
193+ dot .write (filename , format = ext .lower ()[1 :])
187194
188195 ## Return an SVG renderable in jupyter.
189196 #
190197 if jupyter :
191198 from IPython .display import SVG
192199
193- g = SVG (data = g .create_svg ())
200+ dot = SVG (data = dot .create_svg ())
194201
195202 ## Display graph via matplotlib
196203 #
197204 if show :
198205 import matplotlib .pyplot as plt
199206 import matplotlib .image as mpimg
200207
201- png = g .create_png ()
208+ png = dot .create_png ()
202209 sio = io .BytesIO (png )
203210 img = mpimg .imread (sio )
204211 plt .imshow (img , aspect = "equal" )
205212 if show != - 1 :
206213 plt .show ()
207214
208- return g
215+ return dot
0 commit comments