@@ -375,6 +375,17 @@ def _compute_sequential_method(self, named_inputs, outputs):
375375 return {k : cache [k ] for k in iter (cache ) if k in outputs }
376376
377377
378+ @staticmethod
379+ def supported_plot_writers ():
380+ return {
381+ ".png" : lambda gplot : gplot .create_png (),
382+ ".dot" : lambda gplot : gplot .to_string (),
383+ ".jpg" : lambda gplot : gplot .create_jpeg (),
384+ ".jpeg" : lambda gplot : gplot .create_jpeg (),
385+ ".pdf" : lambda gplot : gplot .create_pdf (),
386+ ".svg" : lambda gplot : gplot .create_svg (),
387+ }
388+
378389 def plot (self , filename = None , show = False ):
379390 """
380391 Plot the graph.
@@ -422,23 +433,16 @@ def get_node_name(a):
422433
423434 # save plot
424435 if filename :
425- supported_plot_formaters = {
426- ".png" : g .create_png ,
427- ".dot" : g .to_string ,
428- ".jpg" : g .create_jpeg ,
429- ".jpeg" : g .create_jpeg ,
430- ".pdf" : g .create_pdf ,
431- ".svg" : g .create_svg ,
432- }
433436 _basename , ext = os .path .splitext (filename )
434- plot_formater = supported_plot_formaters .get (ext .lower ())
435- if not plot_formater :
436- raise Exception (
437+ writers = Network .supported_plot_writers ()
438+ plot_writer = Network .supported_plot_writers ().get (ext .lower ())
439+ if not plot_writer :
440+ raise ValueError (
437441 "Unknown file format for saving graph: %s"
438- " File extensions must be one of: .png .dot .jpg .jpeg .pdf .svg "
439- % ext )
442+ " File extensions must be one of: %s "
443+ % ( ext , ' ' . join ( writers )) )
440444 with open (filename , "wb" ) as fh :
441- fh .write (plot_formater ( ))
445+ fh .write (plot_writer ( g ))
442446
443447 # display graph via matplotlib
444448 if show :
0 commit comments