77
88from io import StringIO
99
10- from .base import Operation
10+ from .base import Operation , Control
1111
1212
1313class DataPlaceholderNode (str ):
@@ -83,6 +83,10 @@ def add_op(self, operation):
8383 for p in operation .provides :
8484 self .graph .add_edge (operation , DataPlaceholderNode (p ))
8585
86+ if isinstance (operation , Control ) and hasattr (operation , 'condition_needs' ):
87+ for n in operation .condition_needs :
88+ self .graph .add_edge (DataPlaceholderNode (n ), operation )
89+
8690 # clear compiled steps (must recompile after adding new layers)
8791 self .steps = []
8892
@@ -97,6 +101,8 @@ def show_layers(self):
97101 print ("\t " , "needs: " , step .needs )
98102 print ("\t " , "provides: " , step .provides )
99103 print ("\t " , "color: " , step .color )
104+ if hasattr (step , 'condition_needs' ):
105+ print ("\t " , "condition needs: " , step .condition_needs )
100106 print ("" )
101107
102108 def compile (self ):
@@ -107,14 +113,37 @@ def compile(self):
107113 self .steps = []
108114
109115 # create an execution order such that each layer's needs are provided.
110- ordered_nodes = list (nx .dag .topological_sort (self .graph ))
116+ try :
117+ def key (node ):
118+
119+ if hasattr (node , 'order' ):
120+ return node .order
121+ elif isinstance (node , DataPlaceholderNode ):
122+ return float ('-inf' )
123+ else :
124+ return 0
125+
126+ ordered_nodes = list (nx .dag .lexicographical_topological_sort (self .graph ,
127+ key = key ))
128+ except TypeError as e :
129+ if self ._debug :
130+ print ("Lexicographical topological sort failed! Falling back to topological sort." )
131+
132+ if not any (map (lambda node : isinstance (node , Control ), self .graph .nodes )):
133+ ordered_nodes = list (nx .dag .topological_sort (self .graph ))
134+ else :
135+ print ("Topological sort failed!" )
136+ raise e
111137
112138 # add Operations evaluation steps, and instructions to free data.
113139 for i , node in enumerate (ordered_nodes ):
114140
115141 if isinstance (node , DataPlaceholderNode ):
116142 continue
117143
144+ elif isinstance (node , Control ):
145+ self .steps .append (node )
146+
118147 elif isinstance (node , Operation ):
119148
120149 # add layer to list of steps
@@ -256,11 +285,24 @@ def compute(self, outputs, named_inputs, color=None):
256285 # Find the subset of steps we need to run to get to the requested
257286 # outputs from the provided inputs.
258287 all_steps = self ._find_necessary_steps (outputs , named_inputs , color )
259-
288+ # import pdb
260289 self .times = {}
290+ if_true = False
261291 for step in all_steps :
262292
263- if isinstance (step , Operation ):
293+ if isinstance (step , Control ):
294+ # pdb.set_trace()
295+ if hasattr (step , 'condition' ):
296+ if_true = step ._compute_condition (cache )
297+ if if_true :
298+ layer_outputs = step ._compute (cache )
299+ cache .update (layer_outputs )
300+ elif not if_true :
301+ layer_outputs = step ._compute (cache )
302+ cache .update (layer_outputs )
303+ if_true = False
304+
305+ elif isinstance (step , Operation ):
264306
265307 if self ._debug :
266308 print ("-" * 32 )
0 commit comments