@@ -59,6 +59,9 @@ def __init__(self, name, base_dir=None):
5959 super (Workflow , self ).__init__ (name , base_dir )
6060 self ._graph = nx .DiGraph ()
6161
62+ self ._nodes_cache = set ()
63+ self ._nested_workflows_cache = set ()
64+
6265 # PUBLIC API
6366 def clone (self , name ):
6467 """Clone a workflow
@@ -141,7 +144,7 @@ def connect(self, *args, **kwargs):
141144 self .disconnect (connection_list )
142145 return
143146
144- newnodes = []
147+ newnodes = set ()
145148 for srcnode , destnode , _ in connection_list :
146149 if self in [srcnode , destnode ]:
147150 msg = (
@@ -151,9 +154,9 @@ def connect(self, *args, **kwargs):
151154
152155 raise IOError (msg )
153156 if (srcnode not in newnodes ) and not self ._has_node (srcnode ):
154- newnodes .append (srcnode )
157+ newnodes .add (srcnode )
155158 if (destnode not in newnodes ) and not self ._has_node (destnode ):
156- newnodes .append (destnode )
159+ newnodes .add (destnode )
157160 if newnodes :
158161 self ._check_nodes (newnodes )
159162 for node in newnodes :
@@ -163,15 +166,16 @@ def connect(self, *args, **kwargs):
163166 connected_ports = {}
164167 for srcnode , destnode , connects in connection_list :
165168 if destnode not in connected_ports :
166- connected_ports [destnode ] = []
169+ connected_ports [destnode ] = set ()
167170 # check to see which ports of destnode are already
168171 # connected.
169172 if not disconnect and (destnode in self ._graph .nodes ()):
170173 for edge in self ._graph .in_edges (destnode ):
171174 data = self ._graph .get_edge_data (* edge )
172- for sourceinfo , destname in data ["connect" ]:
173- if destname not in connected_ports [destnode ]:
174- connected_ports [destnode ] += [destname ]
175+ connected_ports [destnode ].update (
176+ destname
177+ for _ , destname in data ["connect" ]
178+ )
175179 for source , dest in connects :
176180 # Currently datasource/sink/grabber.io modules
177181 # determine their inputs/outputs depending on
@@ -226,7 +230,7 @@ def connect(self, *args, **kwargs):
226230 )
227231 if sourcename and not srcnode ._check_outputs (sourcename ):
228232 not_found .append (["out" , srcnode .name , sourcename ])
229- connected_ports [destnode ] += [ dest ]
233+ connected_ports [destnode ]. add ( dest )
230234 infostr = []
231235 for info in not_found :
232236 infostr += [
@@ -269,6 +273,9 @@ def connect(self, *args, **kwargs):
269273 "(%s, %s): new edge data: %s" , srcnode , destnode , str (edge_data )
270274 )
271275
276+ if newnodes :
277+ self ._update_node_cache ()
278+
272279 def disconnect (self , * args ):
273280 """Disconnect nodes
274281 See the docstring for connect for format.
@@ -325,7 +332,7 @@ def add_nodes(self, nodes):
325332 newnodes = []
326333 all_nodes = self ._get_all_nodes ()
327334 for node in nodes :
328- if self . _has_node ( node ) :
335+ if node in all_nodes :
329336 raise IOError ("Node %s already exists in the workflow" % node )
330337 if isinstance (node , Workflow ):
331338 for subnode in node ._get_all_nodes ():
@@ -346,6 +353,7 @@ def add_nodes(self, nodes):
346353 if node ._hierarchy is None :
347354 node ._hierarchy = self .name
348355 self ._graph .add_nodes_from (newnodes )
356+ self ._update_node_cache ()
349357
350358 def remove_nodes (self , nodes ):
351359 """Remove nodes from a workflow
@@ -356,6 +364,7 @@ def remove_nodes(self, nodes):
356364 A list of EngineBase-based objects
357365 """
358366 self ._graph .remove_nodes_from (nodes )
367+ self ._update_node_cache ()
359368
360369 # Input-Output access
361370 @property
@@ -895,22 +904,32 @@ def _set_node_input(self, node, param, source, sourceinfo):
895904 node .set_input (param , deepcopy (newval ))
896905
897906 def _get_all_nodes (self ):
898- allnodes = []
899- for node in self ._graph .nodes ():
900- if isinstance (node , Workflow ):
901- allnodes .extend (node ._get_all_nodes ())
902- else :
903- allnodes .append (node )
907+ allnodes = self ._nodes_cache - self ._nested_workflows_cache
908+ for node in self ._nested_workflows_cache :
909+ allnodes |= node ._get_all_nodes ()
904910 return allnodes
905911
906- def _has_node (self , wanted_node ):
907- for node in self ._graph .nodes ():
908- if wanted_node == node :
909- return True
912+ def _update_node_cache (self ):
913+ nodes = set (self ._graph )
914+
915+ added_nodes = nodes .difference (self ._nodes_cache )
916+ removed_nodes = self ._nodes_cache .difference (nodes )
917+
918+ self ._nodes_cache = nodes
919+ self ._nested_workflows_cache .difference_update (removed_nodes )
920+
921+ for node in added_nodes :
910922 if isinstance (node , Workflow ):
911- if node ._has_node (wanted_node ):
912- return True
913- return False
923+ self ._nested_workflows_cache .add (node )
924+
925+ def _has_node (self , wanted_node ):
926+ return (
927+ wanted_node in self ._nodes_cache or
928+ any (
929+ wf ._has_node (wanted_node )
930+ for wf in self ._nested_workflows_cache
931+ )
932+ )
914933
915934 def _create_flat_graph (self ):
916935 """Make a simple DAG where no node is a workflow."""
@@ -939,7 +958,7 @@ def _generate_flatgraph(self):
939958 raise Exception (
940959 ("Workflow: %s is not a directed acyclic graph " "(DAG)" ) % self .name
941960 )
942- nodes = list (nx . topological_sort ( self ._graph ) )
961+ nodes = list (self ._graph . nodes )
943962 for node in nodes :
944963 logger .debug ("processing node: %s" , node )
945964 if isinstance (node , Workflow ):
0 commit comments