@@ -76,6 +76,9 @@ def add_op(self, operation):
7676 for n in operation .needs :
7777 self .graph .add_edge (DataPlaceholderNode (n ), operation )
7878
79+ if operation .color :
80+ self .graph .nodes [operation ]['color' ] = operation .color
81+
7982 # add nodes and edges to graph describing what this layer provides
8083 for p in operation .provides :
8184 self .graph .add_edge (operation , DataPlaceholderNode (p ))
@@ -93,6 +96,7 @@ def show_layers(self):
9396 print ("layer_name: " , name )
9497 print ("\t " , "needs: " , step .needs )
9598 print ("\t " , "provides: " , step .provides )
99+ print ("\t " , "color: " , step .color )
96100 print ("" )
97101
98102 def compile (self ):
@@ -136,7 +140,7 @@ def compile(self):
136140 else :
137141 raise TypeError ("Unrecognized network graph node" )
138142
139- def _find_necessary_steps (self , outputs , inputs ):
143+ def _find_necessary_steps (self , outputs , inputs , color = None ):
140144 """
141145 Determines what graph steps need to be run to get to the requested
142146 outputs from the provided inputs. Eliminates steps that come before
@@ -152,6 +156,9 @@ def _find_necessary_steps(self, outputs, inputs):
152156 :param dict inputs:
153157 A dictionary mapping names to values for all provided inputs.
154158
159+ :param str color:
160+ A color to filter nodes by.
161+
155162 :returns:
156163 Returns a list of all the steps that need to be run for the
157164 provided inputs and requested outputs.
@@ -160,7 +167,7 @@ def _find_necessary_steps(self, outputs, inputs):
160167 # return steps if it has already been computed before for this set of inputs and outputs
161168 outputs = tuple (sorted (outputs )) if isinstance (outputs , (list , set )) else outputs
162169 inputs_keys = tuple (sorted (inputs .keys ()))
163- cache_key = (inputs_keys , outputs )
170+ cache_key = (inputs_keys , outputs , color )
164171 if cache_key in self ._necessary_steps_cache :
165172 return self ._necessary_steps_cache [cache_key ]
166173
@@ -199,15 +206,23 @@ def _find_necessary_steps(self, outputs, inputs):
199206 # Get rid of the unnecessary nodes from the set of necessary ones.
200207 necessary_nodes -= unnecessary_nodes
201208
202- necessary_steps = [step for step in self .steps if step in necessary_nodes ]
209+ necessary_steps = []
210+
211+ for step in self .steps :
212+ if isinstance (step , Operation ):
213+ if step .color == color and step in necessary_nodes :
214+ necessary_steps .append (step )
215+ else :
216+ if step in necessary_nodes :
217+ necessary_steps .append (step )
203218
204219 # save this result in a precomputed cache for future lookup
205220 self ._necessary_steps_cache [cache_key ] = necessary_steps
206221
207222 # Return an ordered list of the needed steps.
208223 return necessary_steps
209224
210- def compute (self , outputs , named_inputs ):
225+ def compute (self , outputs , named_inputs , color = None ):
211226 """
212227 This method runs the graph one operation at a time in a single thread
213228 Any inputs to the network must be passed in by name.
@@ -222,6 +237,8 @@ def compute(self, outputs, named_inputs):
222237 and the values are the concrete values you
223238 want to set for the data node.
224239
240+ :param str color: Only the subgraph of nodes with color will be evaluted.
241+
225242 :returns: a dictionary of output data objects, keyed by name.
226243 """
227244
@@ -238,7 +255,7 @@ def compute(self, outputs, named_inputs):
238255
239256 # Find the subset of steps we need to run to get to the requested
240257 # outputs from the provided inputs.
241- all_steps = self ._find_necessary_steps (outputs , named_inputs )
258+ all_steps = self ._find_necessary_steps (outputs , named_inputs , color )
242259
243260 self .times = {}
244261 for step in all_steps :
@@ -281,9 +298,9 @@ def compute(self, outputs, named_inputs):
281298 raise TypeError ("Unrecognized instruction." )
282299
283300 if not outputs :
284- # Return the whole cache as output, including input and
285- # intermediate data nodes .
286- return cache
301+ # Return cache as output including intermediate data nodes,
302+ # but excluding input .
303+ return { k : cache [ k ] for k in set ( cache ) - set ( named_inputs )}
287304
288305 else :
289306 # Filter outputs to just return what's needed.
0 commit comments