Skip to content

Commit 0dc1293

Browse files
committed
WIP/FIX(prune,#26): PIN intermediate inputs if operation before must run
- WIP: PARALLEL execution not adding PINS! + Insert "PinInstructions" in the execution-plan to avoid overwritting. + Add `_overwrite_collector` in `compose()` to collect re-calculated values. + FIX the last TC in #25.
1 parent 0830b7c commit 0dc1293

File tree

3 files changed

+156
-28
lines changed

3 files changed

+156
-28
lines changed

graphkit/base.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Copyright 2016, Yahoo Inc.
22
# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms.
3+
try:
4+
from collections import abc
5+
except ImportError:
6+
import collections as abc
7+
38

49
class Data(object):
510
"""
@@ -151,9 +156,12 @@ def __init__(self, **kwargs):
151156

152157
# set execution mode to single-threaded sequential by default
153158
self._execution_method = "sequential"
159+
self._overwrites_collector = None
154160

155161
def _compute(self, named_inputs, outputs=None):
156-
return self.net.compute(outputs, named_inputs, method=self._execution_method)
162+
return self.net.compute(
163+
outputs, named_inputs, method=self._execution_method,
164+
overwrites_collector=self._overwrites_collector)
157165

158166
def __call__(self, *args, **kwargs):
159167
return self._compute(*args, **kwargs)
@@ -162,15 +170,35 @@ def set_execution_method(self, method):
162170
"""
163171
Determine how the network will be executed.
164172
165-
Args:
166-
method: str
167-
If "parallel", execute graph operations concurrently
168-
using a threadpool.
173+
:param str method:
174+
If "parallel", execute graph operations concurrently
175+
using a threadpool.
169176
"""
170-
options = ['parallel', 'sequential']
171-
assert method in options
177+
choices = ['parallel', 'sequential']
178+
if method not in choices:
179+
raise ValueError(
180+
"Invalid computation method %r! Must be one of %s"
181+
(method, choices))
172182
self._execution_method = method
173183

184+
def set_overwrites_collector(self, collector):
185+
"""
186+
Asks to put all *overwrites* into the `collector` after computing
187+
188+
An "overwrites" is intermediate value calculated but NOT stored
189+
into the results, becaues it has been given also as an intemediate
190+
input value, and the operation that would overwrite it MUST run for
191+
its other results.
192+
193+
:param collector:
194+
a mutable dict to be fillwed with named values
195+
"""
196+
if collector is not None and not isinstance(collector, abc.MutableMapping):
197+
raise ValueError(
198+
"Overwrites collector was not a MutableMapping, but: %r"
199+
% collector)
200+
self._overwrites_collector = collector
201+
174202
def plot(self, filename=None, show=False):
175203
self.net.plot(filename=filename, show=show)
176204

graphkit/network.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,26 @@ def __repr__(self):
3434
return 'DeleteInstruction("%s")' % self
3535

3636

37+
class PinInstruction(str):
38+
"""
39+
An instruction in the *execution plan* not to store the newly compute value
40+
into network's values-cache but to pin it instead to some given value.
41+
It is used ensure that given intermediate values are not overwritten when
42+
their providing functions could not be avoided, because their other outputs
43+
are needed elesewhere.
44+
"""
45+
def __repr__(self):
46+
return 'PinInstruction("%s")' % self
47+
48+
3749
class Network(object):
3850
"""
3951
This is the main network implementation. The class contains all of the
4052
code necessary to weave together operations into a directed-acyclic-graph (DAG)
4153
and pass data through.
4254
4355
The computation, ie the execution of the *operations* for given *inputs*
44-
and asked *outputs* is based on 3 data-structures:
56+
and asked *outputs* is based on 4 data-structures:
4557
4658
- The ``networkx`` :attr:`graph` DAG, containing interchanging layers of
4759
:class:`Operation` and :class:`DataPlaceholderNode` nodes.
@@ -68,6 +80,12 @@ class Network(object):
6880
- the :var:`cache` local-var, initialized on each run of both
6981
``_compute_xxx`` methods (for parallel or sequential executions), to
7082
hold all given input & generated (aka intermediate) data values.
83+
84+
- the :var:`overwrites` local-var, initialized on each run of both
85+
``_compute_xxx`` methods (for parallel or sequential executions), to
86+
hold values calculated but overwritten (aka "pinned") by intermediate
87+
input-values.
88+
7189
"""
7290

7391
def __init__(self, **kwargs):
@@ -122,7 +140,7 @@ def add_op(self, operation):
122140

123141
def list_layers(self, debug=False):
124142
# Make a generic plan.
125-
plan = self._build_execution_plan(self.graph)
143+
plan = self._build_execution_plan(self.graph, ())
126144
return [n for n in plan if debug or isinstance(n, Operation)]
127145

128146

@@ -134,15 +152,15 @@ def show_layers(self, debug=False, ret=False):
134152
else:
135153
print(s)
136154

137-
def _build_execution_plan(self, dag):
155+
def _build_execution_plan(self, dag, inputs):
138156
"""
139157
Create the list of operation-nodes & *instructions* evaluating all
140158
141159
operations & instructions needed a) to free memory and b) avoid
142160
overwritting given intermediate inputs.
143161
144162
:param dag:
145-
the original dag but "shrinked", not "broken"
163+
The original dag, pruned; not broken.
146164
147165
In the list :class:`DeleteInstructions` steps (DA) are inserted between
148166
operation nodes to reduce the memory footprint of cached results.
@@ -158,11 +176,15 @@ def _build_execution_plan(self, dag):
158176
# create an execution order such that each layer's needs are provided.
159177
ordered_nodes = iset(nx.topological_sort(dag))
160178

161-
# add Operations evaluation steps, and instructions to free data.
179+
# Add Operations evaluation steps, and instructions to free and "pin"
180+
# data.
162181
for i, node in enumerate(ordered_nodes):
163182

164183
if isinstance(node, DataPlaceholderNode):
165-
continue
184+
if node in inputs and dag.pred[node]:
185+
# Command pinning only when there is another operation
186+
# generating this data as output.
187+
plan.append(PinInstruction(node))
166188

167189
elif isinstance(node, Operation):
168190

@@ -291,13 +313,11 @@ def _solve_dag(self, outputs, inputs):
291313
broken_dag = broken_dag.subgraph(ending_in_outputs | set(outputs))
292314

293315

294-
# Prune (un-satifiable) operations with partial inputs.
295-
# See yahoo/graphkit#18
296-
#
316+
# Prune unsatisfied operations (those with partial inputs or no outputs).
297317
unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs)
298-
shrinked_dag = dag.subgraph(broken_dag.nodes - unsatisfied)
318+
pruned_dag = dag.subgraph(broken_dag.nodes - unsatisfied)
299319

300-
plan = self._build_execution_plan(shrinked_dag)
320+
plan = self._build_execution_plan(pruned_dag, inputs)
301321

302322
return plan
303323

@@ -331,7 +351,8 @@ def compile(self, outputs=(), inputs=()):
331351

332352

333353

334-
def compute(self, outputs, named_inputs, method=None):
354+
def compute(
355+
self, outputs, named_inputs, method=None, overwrites_collector=None):
335356
"""
336357
Run the graph. Any inputs to the network must be passed in by name.
337358
@@ -350,6 +371,10 @@ def compute(self, outputs, named_inputs, method=None):
350371
Set when invoking a composed graph or by
351372
:meth:`~NetworkOperation.set_execution_method()`.
352373
374+
:param overwrites_collector:
375+
(optional) a mutable dict to be fillwed with named values.
376+
If missing, values are simply discarded.
377+
353378
:returns: a dictionary of output data objects, keyed by name.
354379
"""
355380

@@ -364,23 +389,34 @@ def compute(self, outputs, named_inputs, method=None):
364389

365390
# choose a method of execution
366391
if method == "parallel":
367-
self._compute_thread_pool_barrier_method(cache)
392+
self._compute_thread_pool_barrier_method(
393+
cache, overwrites_collector, named_inputs)
368394
else:
369-
self._compute_sequential_method(cache, outputs)
395+
self._compute_sequential_method(
396+
cache, overwrites_collector, named_inputs, outputs)
370397

371398
if not outputs:
372399
# Return the whole cache as output, including input and
373400
# intermediate data nodes.
374-
return cache
401+
result = cache
375402

376403
else:
377404
# Filter outputs to just return what's needed.
378405
# Note: list comprehensions exist in python 2.7+
379-
return dict(i for i in cache.items() if i[0] in outputs)
406+
result = dict(i for i in cache.items() if i[0] in outputs)
407+
408+
return result
409+
410+
411+
def _pin_data_in_cache(self, value_name, cache, inputs, overwrites):
412+
value_name = str(value_name)
413+
if overwrites is not None:
414+
overwrites[value_name] = cache[value_name]
415+
cache[value_name] = inputs[value_name]
380416

381417

382418
def _compute_thread_pool_barrier_method(
383-
self, cache, thread_pool_size=10
419+
self, cache, overwrites, inputs, thread_pool_size=10
384420
):
385421
"""
386422
This method runs the graph using a parallel pool of thread executors.
@@ -436,7 +472,7 @@ def _compute_thread_pool_barrier_method(
436472
has_executed.add(op)
437473

438474

439-
def _compute_sequential_method(self, cache, outputs):
475+
def _compute_sequential_method(self, cache, overwrites, inputs, outputs):
440476
"""
441477
This method runs the graph one operation at a time in a single thread
442478
"""
@@ -477,6 +513,8 @@ def _compute_sequential_method(self, cache, outputs):
477513
print("removing data '%s' from cache." % step)
478514
cache.pop(step)
479515

516+
elif isinstance(step, PinInstruction):
517+
self._pin_data_in_cache(step, cache, inputs, overwrites)
480518
else:
481519
raise AssertionError("Unrecognized instruction.%r" % step)
482520

test/test_graphkit.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from graphkit import operation, compose, Operation
1414

1515

16+
def scream(*args, **kwargs):
17+
raise AssertionError(
18+
"Must not have run!\n args: %s\n kwargs: %s", (args, kwargs))
19+
20+
1621
def identity(x):
1722
return x
1823

@@ -200,9 +205,9 @@ def test_pruning_raises_for_bad_output():
200205

201206

202207
def test_pruning_not_overrides_given_intermediate():
203-
# Test #25: v1.2.4 overrides intermediate data when no output asked
208+
# Test #25: v1.2.4 overwrites intermediate data when no output asked
204209
netop = compose(name="netop")(
205-
operation(name="unjustly run", needs=["a"], provides=["overriden"])(identity),
210+
operation(name="unjustly run", needs=["a"], provides=["overriden"])(scream),
206211
operation(name="op", needs=["overriden", "c"], provides=["asked"])(add),
207212
)
208213

@@ -212,11 +217,24 @@ def test_pruning_not_overrides_given_intermediate():
212217
# FAILs
213218
# - on v1.2.4 with (overriden, asked): = (5, 7) instead of (1, 3)
214219
# - on #18(unsatisfied) + #23(ordered-sets) with (overriden, asked) = (5, 7) instead of (1, 3)
220+
# FIXED on #26
221+
assert netop({"a": 5, "overriden": 1, "c": 2}) == exp
222+
223+
## Test OVERWITES
224+
#
225+
overwrites = {}
226+
netop.set_overwrites_collector(overwrites)
227+
assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")
228+
assert overwrites == {} # unjust must have been pruned
229+
230+
overwrites = {}
231+
netop.set_overwrites_collector(overwrites)
215232
assert netop({"a": 5, "overriden": 1, "c": 2}) == exp
233+
assert overwrites == {} # unjust must have been pruned
216234

217235

218236
def test_pruning_multiouts_not_override_intermediates1():
219-
# Test #25: v.1.2.4 overrides intermediate data when a previous operation
237+
# Test #25: v.1.2.4 overwrites intermediate data when a previous operation
220238
# must run for its other outputs (outputs asked or not)
221239
netop = compose(name="netop")(
222240
operation(name="must run", needs=["a"], provides=["overriden", "calced"])
@@ -228,11 +246,30 @@ def test_pruning_multiouts_not_override_intermediates1():
228246
# FAILs
229247
# - on v1.2.4 with (overriden, asked) = (5, 15) instead of (1, 11)
230248
# - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4.
249+
# FIXED on #26
231250
assert netop({"a": 5, "overriden": 1}) == exp
232251
# FAILs
233252
# - on v1.2.4 with KeyError: 'e',
234253
# - on #18(unsatisfied) + #23(ordered-sets) with empty result.
254+
# FIXED on #26
255+
assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")
256+
257+
## Test OVERWITES
258+
#
259+
overwrites = {}
260+
netop.set_overwrites_collector(overwrites)
261+
assert netop({"a": 5, "overriden": 1}) == exp
262+
assert overwrites == {'overriden': 5}
263+
264+
overwrites = {}
265+
netop.set_overwrites_collector(overwrites)
235266
assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")
267+
assert overwrites == {'overriden': 5}
268+
269+
# ## Test parallel
270+
# netop.set_execution_method("parallel")
271+
# assert netop({"a": 5, "overriden": 1}) == exp
272+
# assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")
236273

237274

238275
def test_pruning_multiouts_not_override_intermediates2():
@@ -249,11 +286,25 @@ def test_pruning_multiouts_not_override_intermediates2():
249286
# FAILs
250287
# - on v1.2.4 with (overriden, asked) = (5, 70) instead of (1, 13)
251288
# - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4.
289+
# FIXED on #26
252290
assert netop({"a": 5, "overriden": 1, "c": 2}) == exp
253291
# FAILs
254292
# - on v1.2.4 with KeyError: 'e',
255293
# - on #18(unsatisfied) + #23(ordered-sets) with empty result.
256294
assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")
295+
# FIXED on #26
296+
297+
## Test OVERWITES
298+
#
299+
overwrites = {}
300+
netop.set_overwrites_collector(overwrites)
301+
assert netop({"a": 5, "overriden": 1, "c": 2}) == exp
302+
assert overwrites == {'overriden': 5}
303+
304+
overwrites = {}
305+
netop.set_overwrites_collector(overwrites)
306+
assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")
307+
assert overwrites == {'overriden': 5}
257308

258309

259310
def test_pruning_with_given_intermediate_and_asked_out():
@@ -274,6 +325,17 @@ def test_pruning_with_given_intermediate_and_asked_out():
274325
# FIXED on #18+#26 (new dag solver).
275326
assert netop({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
276327

328+
## Test OVERWITES
329+
#
330+
overwrites = {}
331+
netop.set_overwrites_collector(overwrites)
332+
assert netop({"given-1": 5, "b": 2, "given-2": 2}) == exp
333+
assert overwrites == {}
334+
335+
overwrites = {}
336+
netop.set_overwrites_collector(overwrites)
337+
assert netop({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
338+
assert overwrites == {}
277339

278340
def test_unsatisfied_operations():
279341
# Test that operations with partial inputs are culled and not failing.

0 commit comments

Comments
 (0)