Skip to content

Commit 168dd7d

Browse files
committed
enh(plot.TC): expose supported writers and TC on them
1 parent c02e211 commit 168dd7d

File tree

2 files changed

+28
-15
lines changed

2 files changed

+28
-15
lines changed

graphkit/network.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,17 @@ def _compute_sequential_method(self, named_inputs, outputs):
375375
return {k: cache[k] for k in iter(cache) if k in outputs}
376376

377377

378+
@staticmethod
379+
def supported_plot_writers():
380+
return {
381+
".png": lambda gplot: gplot.create_png(),
382+
".dot": lambda gplot: gplot.to_string(),
383+
".jpg": lambda gplot: gplot.create_jpeg(),
384+
".jpeg": lambda gplot: gplot.create_jpeg(),
385+
".pdf": lambda gplot: gplot.create_pdf(),
386+
".svg": lambda gplot: gplot.create_svg(),
387+
}
388+
378389
def plot(self, filename=None, show=False):
379390
"""
380391
Plot the graph.
@@ -422,23 +433,16 @@ def get_node_name(a):
422433

423434
# save plot
424435
if filename:
425-
supported_plot_formaters = {
426-
".png": g.create_png,
427-
".dot": g.to_string,
428-
".jpg": g.create_jpeg,
429-
".jpeg": g.create_jpeg,
430-
".pdf": g.create_pdf,
431-
".svg": g.create_svg,
432-
}
433436
_basename, ext = os.path.splitext(filename)
434-
plot_formater = supported_plot_formaters.get(ext.lower())
435-
if not plot_formater:
436-
raise Exception(
437+
writers = Network.supported_plot_writers()
438+
plot_writer = Network.supported_plot_writers().get(ext.lower())
439+
if not plot_writer:
440+
raise ValueError(
437441
"Unknown file format for saving graph: %s"
438-
" File extensions must be one of: .png .dot .jpg .jpeg .pdf .svg"
439-
% ext)
442+
" File extensions must be one of: %s"
443+
% (ext, ' '.join(writers)))
440444
with open(filename, "wb") as fh:
441-
fh.write(plot_formater())
445+
fh.write(plot_writer(g))
442446

443447
# display graph via matplotlib
444448
if show:

test/test_graphkit.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,23 @@ def test_plotting():
327327
sum_op3 = operation(name='sum_op3', needs=['sum1', 'c'], provides='sum3')(add)
328328
net1 = compose(name='my network 1')(sum_op1, sum_op2, sum_op3)
329329

330-
for ext in ".png .dot .jpg .jpeg .pdf .svg".split():
330+
for ext in network.Network.supported_plot_writers():
331331
tdir = tempfile.mkdtemp(suffix=ext)
332332
png_file = osp.join(tdir, "workflow.png")
333333
net1.net.plot(png_file)
334334
try:
335335
assert osp.exists(png_file)
336336
finally:
337337
shutil.rmtree(tdir, ignore_errors=True)
338+
try:
339+
net1.net.plot('bad.format')
340+
assert False, "Should had failed writting arbitrary file format!"
341+
except ValueError as ex:
342+
assert "Unknown file format" in str(ex)
343+
344+
## Check help msg lists all siupported formats
345+
for ext in network.Network.supported_plot_writers():
346+
assert ext in str(ex)
338347

339348

340349
####################################

0 commit comments

Comments
 (0)