@@ -34,14 +34,26 @@ def __repr__(self):
3434 return 'DeleteInstruction("%s")' % self
3535
3636
37+ class PinInstruction (str ):
38+ """
39+ An instruction in the *execution plan* not to store the newly compute value
40+ into network's values-cache but to pin it instead to some given value.
41+ It is used ensure that given intermediate values are not overwritten when
42+ their providing functions could not be avoided, because their other outputs
43+ are needed elesewhere.
44+ """
45+ def __repr__ (self ):
46+ return 'PinInstruction("%s")' % self
47+
48+
3749class Network (object ):
3850 """
3951 This is the main network implementation. The class contains all of the
4052 code necessary to weave together operations into a directed-acyclic-graph (DAG)
4153 and pass data through.
4254
4355 The computation, ie the execution of the *operations* for given *inputs*
44- and asked *outputs* is based on 3 data-structures:
56+ and asked *outputs* is based on 4 data-structures:
4557
4658 - The ``networkx`` :attr:`graph` DAG, containing interchanging layers of
4759 :class:`Operation` and :class:`DataPlaceholderNode` nodes.
@@ -68,6 +80,12 @@ class Network(object):
6880 - the :var:`cache` local-var, initialized on each run of both
6981 ``_compute_xxx`` methods (for parallel or sequential executions), to
7082 hold all given input & generated (aka intermediate) data values.
83+
84+ - the :var:`overwrites` local-var, initialized on each run of both
85+ ``_compute_xxx`` methods (for parallel or sequential executions), to
86+ hold values calculated but overwritten (aka "pinned") by intermediate
87+ input-values.
88+
7189 """
7290
7391 def __init__ (self , ** kwargs ):
@@ -122,7 +140,7 @@ def add_op(self, operation):
122140
123141 def list_layers (self , debug = False ):
124142 # Make a generic plan.
125- plan = self ._build_execution_plan (self .graph )
143+ plan = self ._build_execution_plan (self .graph , () )
126144 return [n for n in plan if debug or isinstance (n , Operation )]
127145
128146
@@ -134,15 +152,15 @@ def show_layers(self, debug=False, ret=False):
134152 else :
135153 print (s )
136154
137- def _build_execution_plan (self , dag ):
155+ def _build_execution_plan (self , dag , inputs ):
138156 """
139157 Create the list of operation-nodes & *instructions* evaluating all
140158
141159 operations & instructions needed a) to free memory and b) avoid
142160 overwritting given intermediate inputs.
143161
144162 :param dag:
145- the original dag but "shrinked", not " broken"
163+ The original dag, pruned; not broken.
146164
147165 In the list :class:`DeleteInstructions` steps (DA) are inserted between
148166 operation nodes to reduce the memory footprint of cached results.
@@ -158,11 +176,15 @@ def _build_execution_plan(self, dag):
158176 # create an execution order such that each layer's needs are provided.
159177 ordered_nodes = iset (nx .topological_sort (dag ))
160178
161- # add Operations evaluation steps, and instructions to free data.
179+ # Add Operations evaluation steps, and instructions to free and "pin"
180+ # data.
162181 for i , node in enumerate (ordered_nodes ):
163182
164183 if isinstance (node , DataPlaceholderNode ):
165- continue
184+ if node in inputs and dag .pred [node ]:
185+ # Command pinning only when there is another operation
186+ # generating this data as output.
187+ plan .append (PinInstruction (node ))
166188
167189 elif isinstance (node , Operation ):
168190
@@ -291,13 +313,11 @@ def _solve_dag(self, outputs, inputs):
291313 broken_dag = broken_dag .subgraph (ending_in_outputs | set (outputs ))
292314
293315
294- # Prune (un-satifiable) operations with partial inputs.
295- # See yahoo/graphkit#18
296- #
316+ # Prune unsatisfied operations (those with partial inputs or no outputs).
297317 unsatisfied = self ._collect_unsatisfied_operations (broken_dag , inputs )
298- shrinked_dag = dag .subgraph (broken_dag .nodes - unsatisfied )
318+ pruned_dag = dag .subgraph (broken_dag .nodes - unsatisfied )
299319
300- plan = self ._build_execution_plan (shrinked_dag )
320+ plan = self ._build_execution_plan (pruned_dag , inputs )
301321
302322 return plan
303323
@@ -331,7 +351,8 @@ def compile(self, outputs=(), inputs=()):
331351
332352
333353
334- def compute (self , outputs , named_inputs , method = None ):
354+ def compute (
355+ self , outputs , named_inputs , method = None , overwrites_collector = None ):
335356 """
336357 Run the graph. Any inputs to the network must be passed in by name.
337358
@@ -350,6 +371,10 @@ def compute(self, outputs, named_inputs, method=None):
350371 Set when invoking a composed graph or by
351372 :meth:`~NetworkOperation.set_execution_method()`.
352373
374+ :param overwrites_collector:
375+ (optional) a mutable dict to be fillwed with named values.
376+ If missing, values are simply discarded.
377+
353378 :returns: a dictionary of output data objects, keyed by name.
354379 """
355380
@@ -364,23 +389,34 @@ def compute(self, outputs, named_inputs, method=None):
364389
365390 # choose a method of execution
366391 if method == "parallel" :
367- self ._compute_thread_pool_barrier_method (cache )
392+ self ._compute_thread_pool_barrier_method (
393+ cache , overwrites_collector , named_inputs )
368394 else :
369- self ._compute_sequential_method (cache , outputs )
395+ self ._compute_sequential_method (
396+ cache , overwrites_collector , named_inputs , outputs )
370397
371398 if not outputs :
372399 # Return the whole cache as output, including input and
373400 # intermediate data nodes.
374- return cache
401+ result = cache
375402
376403 else :
377404 # Filter outputs to just return what's needed.
378405 # Note: list comprehensions exist in python 2.7+
379- return dict (i for i in cache .items () if i [0 ] in outputs )
406+ result = dict (i for i in cache .items () if i [0 ] in outputs )
407+
408+ return result
409+
410+
411+ def _pin_data_in_cache (self , value_name , cache , inputs , overwrites ):
412+ value_name = str (value_name )
413+ if overwrites is not None :
414+ overwrites [value_name ] = cache [value_name ]
415+ cache [value_name ] = inputs [value_name ]
380416
381417
382418 def _compute_thread_pool_barrier_method (
383- self , cache , thread_pool_size = 10
419+ self , cache , overwrites , inputs , thread_pool_size = 10
384420 ):
385421 """
386422 This method runs the graph using a parallel pool of thread executors.
@@ -436,7 +472,7 @@ def _compute_thread_pool_barrier_method(
436472 has_executed .add (op )
437473
438474
439- def _compute_sequential_method (self , cache , outputs ):
475+ def _compute_sequential_method (self , cache , overwrites , inputs , outputs ):
440476 """
441477 This method runs the graph one operation at a time in a single thread
442478 """
@@ -477,6 +513,8 @@ def _compute_sequential_method(self, cache, outputs):
477513 print ("removing data '%s' from cache." % step )
478514 cache .pop (step )
479515
516+ elif isinstance (step , PinInstruction ):
517+ self ._pin_data_in_cache (step , cache , inputs , overwrites )
480518 else :
481519 raise AssertionError ("Unrecognized instruction.%r" % step )
482520
0 commit comments