Skip to content

Commit d403783

Browse files
committed
WIP/FIX(PIN): PARALLEL DELs decide on PRUNED-dag (not full)...
- WIP: x4 TCs FAIL and still not discovered th bug :-( + BUT ALL+AUGMENTED PARALLEL TCs pass (yahoo#26 were failing some) + refact: net stores also `pruned_dag` (not only `steps`). + refact: _solve_dag() --> _prune_dag(). + doc: +a lot. + TODO: store pruned_dag in own ExePlan class.
1 parent 1cc733e commit d403783

File tree

2 files changed

+112
-78
lines changed

2 files changed

+112
-78
lines changed

graphkit/network.py

Lines changed: 86 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -54,54 +54,65 @@ def __repr__(self):
5454

5555
class Network(object):
5656
"""
57-
Assemble operations & data into a directed-acyclic-graph (DAG) and run them
57+
Assemble operations & data into a directed-acyclic-graph (DAG) to run them.
5858
59-
based on the given input values and requested outputs.
59+
The execution of the contained *operations* in the dag (the computation)
60+
is splitted in 2 phases:
6061
61-
The execution of *operations* (a computation) is splitted in 2 phases:
62-
63-
- COMPILE: prune, sort topologically the nodes in the dag, solve it, and
62+
- COMPILE: prune unsatisfied nodes, sort dag topologically & solve it, and
6463
derive the *execution plan* (see below) based on the given *inputs*
6564
and asked *outputs*.
6665
6766
- EXECUTE: sequential or parallel invocation of the underlying functions
68-
of the operations.
69-
70-
is based on 4 data-structures:
71-
72-
- the ``networkx`` :attr:`graph` DAG, containing interchanging layers of
73-
:class:`Operation` and :class:`DataPlaceholderNode` nodes.
74-
They are layed out and connected by repeated calls of :meth:`add_OP`.
75-
76-
The computation starts with :meth:`_solve_dag()` extracting
77-
a *DAG subgraph* by *pruning* nodes based on given inputs and
78-
requested outputs.
79-
This subgraph is used to decide the `execution_plan` (see below), and
80-
and is cached in :attr:`_cached_execution_plans` across runs with
81-
inputs/outputs as key.
82-
83-
- the :attr:`execution_plan` is the list of the operation-nodes only
84-
from the dag (above), topologically sorted, and interspersed with
85-
*instructions steps* needed to complete the run.
86-
It is built by :meth:`_build_execution_plan()` based on the subgraph dag
87-
extracted above. The *instructions* items achieve the following:
88-
89-
- :class:`DeleteInstruction`: delete items from values-cache as soon as
90-
they are not needed further down the dag, to reduce memory footprint
91-
while computing.
92-
93-
- :class:`PinInstruction`: avoid overwritting any given intermediate
94-
inputs, and still allow their providing operations to run
95-
(because they are needed for their other outputs).
96-
97-
- the :var:`cache` local-var in :meth:`compute()`, initialized on each run
98-
to hold the values of the given inputs, generated (aka intermediate) data,
99-
and output values.
100-
101-
- the :var:`overwrites` local-var, initialized on each run of both
102-
``_compute_xxx`` methods (for parallel or sequential executions), to
103-
hold values calculated but overwritten (aka "pinned") by intermediate
104-
input-values.
67+
of the operations with arguments from the ``cache``.
68+
69+
is based on 5 data-structures:
70+
71+
:ivar graph:
72+
A ``networkx`` DAG containing interchanging layers of
73+
:class:`Operation` and :class:`DataPlaceholderNode` nodes.
74+
They are layed out and connected by repeated calls of :meth:`add_OP`.
75+
76+
The computation starts with :meth:`_prune_dag()` extracting
77+
a *DAG subgraph* by *pruning* its nodes based on given inputs and
78+
requested outputs in :meth:`compute()`.
79+
:ivar execution_dag:
80+
It contains the nodes of the *pruned dag* from the last call to
81+
:meth:`compile()`. This pruned subgraph is used to decide
82+
the :attr:`execution_plan` (below).
83+
It is cached in :attr:`_cached_compilations` across runs with
84+
inputs/outputs as key.
85+
86+
:ivar execution_plan:
87+
It is the list of the operation-nodes only
88+
from the dag (above), topologically sorted, and interspersed with
89+
*instructions steps* needed to complete the run.
90+
It is built by :meth:`_build_execution_plan()` based on the subgraph dag
91+
extracted above.
92+
It is cached in :attr:`_cached_compilations` across runs with
93+
inputs/outputs as key.
94+
95+
The *instructions* items achieve the following:
96+
97+
- :class:`DeleteInstruction`: delete items from values-cache as soon as
98+
they are not needed further down the dag, to reduce memory footprint
99+
while computing.
100+
101+
- :class:`PinInstruction`: avoid overwritting any given intermediate
102+
inputs, and still allow their providing operations to run
103+
(because they are needed for their other outputs).
104+
105+
:var cache:
106+
a local-var in :meth:`compute()`, initialized on each run
107+
to hold the values of the given inputs, generated (intermediate) data,
108+
and output values.
109+
It is returned as is if no specific outputs requested; no data-eviction
110+
happens then.
111+
112+
:arg overwrites:
113+
The optional argument given to :meth:`compute()` to colect the
114+
intermediate *calculated* values that are overwritten by intermediate
115+
(aka "pinned") input-values.
105116
106117
"""
107118

@@ -119,11 +130,14 @@ def __init__(self, **kwargs):
119130
#: The list of operation-nodes & *instructions* needed to evaluate
120131
#: the given inputs & asked outputs, free memory and avoid overwritting
121132
#: any given intermediate inputs.
122-
self.execution_plan = []
133+
self.execution_plan = ()
134+
135+
#: Pruned graph of the last compilation.
136+
self.execution_dag = ()
123137

124138
#: Speed up :meth:`compile()` call and avoid a multithreading issue(?)
125139
#: that is occuring when accessing the dag in networkx.
126-
self._cached_execution_plans = {}
140+
self._cached_compilations = {}
127141

128142

129143
def add_op(self, operation):
@@ -143,8 +157,9 @@ def add_op(self, operation):
143157
# assert layer is only added once to graph
144158
assert operation not in self.graph.nodes, "Operation may only be added once"
145159

146-
## Invalidate old plans.
147-
self._cached_execution_plans = {}
160+
self.execution_dag = None
161+
self.execution_plan = None
162+
self._cached_compilations = {}
148163

149164
# add nodes and edges to graph describing the data needs for this layer
150165
for n in operation.needs:
@@ -246,11 +261,11 @@ def _collect_unsatisfied_operations(self, dag, inputs):
246261
all its needs have been accounted, so we can get its satisfaction.
247262
248263
- Their provided outputs are not linked to any data in the dag.
249-
An operation might not have any output link when :meth:`_solve_dag()`
264+
An operation might not have any output link when :meth:`_prune_dag()`
250265
has broken them, due to given intermediate inputs.
251266
252267
:param dag:
253-
the graph to consider
268+
a graph with broken edges those arriving to existing inputs
254269
:param inputs:
255270
an iterable of the names of the input values
256271
return:
@@ -288,13 +303,12 @@ def _collect_unsatisfied_operations(self, dag, inputs):
288303
return unsatisfied
289304

290305

291-
def _solve_dag(self, outputs, inputs):
306+
def _prune_dag(self, outputs, inputs):
292307
"""
293308
Determines what graph steps need to run to get to the requested
294-
outputs from the provided inputs. Eliminates steps that come before
295-
(in topological order) any inputs that have been provided. Also
296-
eliminates steps that are not on a path from the provided inputs to
297-
the requested outputs.
309+
outputs from the provided inputs. :
310+
- Eliminate steps that are not on a path arriving to requested outputs.
311+
- Eliminate unsatisfied operations: partial inputs or no outputs needed.
298312
299313
:param iterable outputs:
300314
A list of desired output names. This can also be ``None``, in which
@@ -305,7 +319,7 @@ def _solve_dag(self, outputs, inputs):
305319
The inputs names of all given inputs.
306320
307321
:return:
308-
the *execution plan*
322+
the *pruned_dag*
309323
"""
310324
dag = self.graph
311325

@@ -341,18 +355,16 @@ def _solve_dag(self, outputs, inputs):
341355

342356
# Prune unsatisfied operations (those with partial inputs or no outputs).
343357
unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs)
344-
pruned_dag = dag.subgraph(broken_dag.nodes - unsatisfied)
358+
pruned_dag = dag.subgraph(self.graph.nodes - unsatisfied)
345359

346-
plan = self._build_execution_plan(pruned_dag, inputs, outputs)
347-
348-
return plan
360+
return pruned_dag.copy() # clone so that it is picklable
349361

350362

351363
def compile(self, outputs=(), inputs=()):
352364
"""
353365
Solve dag, set the :attr:`execution_plan`, and cache it.
354366
355-
See :meth:`_solve_dag()` for detailed description.
367+
See :meth:`_prune_dag()` for detailed description.
356368
357369
:param iterable outputs:
358370
A list of desired output names. This can also be ``None``, in which
@@ -368,12 +380,20 @@ def compile(self, outputs=(), inputs=()):
368380
outputs = tuple(sorted(outputs))
369381
inputs_keys = tuple(sorted(inputs))
370382
cache_key = (inputs_keys, outputs)
371-
if cache_key in self._cached_execution_plans:
372-
self.execution_plan = self._cached_execution_plans[cache_key]
383+
384+
if cache_key in self._cached_compilations:
385+
dag, plan = self._cached_compilations[cache_key]
373386
else:
374-
plan = self._solve_dag(outputs, inputs)
375-
# save this result in a precomputed cache for future lookup
376-
self.execution_plan = self._cached_execution_plans[cache_key] = plan
387+
dag = self._prune_dag(outputs, inputs)
388+
plan = self._build_execution_plan(dag, inputs, outputs)
389+
390+
# Cache compilation results to speed up future runs
391+
# with different values (but same number of inputs/outputs).
392+
self._cached_compilations[cache_key] = dag, plan
393+
394+
## TODO: Extract into Solution class
395+
self.execution_dag = dag
396+
self.execution_plan = plan
377397

378398

379399

@@ -494,7 +514,6 @@ def _execute_thread_pool_barrier_method(
494514
self._pin_data_in_cache(node, cache, inputs, overwrites)
495515

496516

497-
498517
# stop if no nodes left to schedule, exit out of the loop
499518
if len(upnext) == 0:
500519
break
@@ -636,7 +655,7 @@ def _can_schedule_operation(self, op, executed_nodes):
636655
execution based on what has already been executed.
637656
"""
638657
# unordered, not iterated
639-
dependencies = set(n for n in nx.ancestors(self.graph, op)
658+
dependencies = set(n for n in nx.ancestors(self.execution_dag, op)
640659
if isinstance(n, Operation))
641660
return dependencies.issubset(executed_nodes)
642661

@@ -654,7 +673,7 @@ def _can_evict_value(self, name, executed_nodes):
654673
"""
655674
data_node = self.get_data_node(name)
656675
return data_node and set(
657-
self.graph.successors(data_node)).issubset(executed_nodes)
676+
self.execution_dag.successors(data_node)).issubset(executed_nodes)
658677

659678
def get_data_node(self, name):
660679
"""

test/test_graphkit.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,19 @@ def test_pruning_not_overrides_given_intermediate():
233233
assert pipeline({"a": 5, "overriden": 1, "c": 2}) == exp
234234
assert overwrites == {} # unjust must have been pruned
235235

236+
## Test Parallel
237+
#
238+
pipeline.set_execution_method("parallel")
239+
overwrites = {}
240+
pipeline.set_overwrites_collector(overwrites)
241+
#assert pipeline({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")
242+
assert overwrites == {} # unjust must have been pruned
243+
244+
overwrites = {}
245+
pipeline.set_overwrites_collector(overwrites)
246+
assert pipeline({"a": 5, "overriden": 1, "c": 2}) == exp
247+
assert overwrites == {} # unjust must have been pruned
248+
236249

237250
def test_pruning_multiouts_not_override_intermediates1():
238251
# Test #25: v.1.2.4 overwrites intermediate data when a previous operation
@@ -348,9 +361,9 @@ def test_pruning_with_given_intermediate_and_asked_out():
348361
## Test parallel
349362
# FAIL! in #26!
350363
#
351-
# pipeline.set_execution_method("parallel")
352-
# assert pipeline({"given-1": 5, "b": 2, "given-2": 2}) == exp
353-
# assert pipeline({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
364+
pipeline.set_execution_method("parallel")
365+
assert pipeline({"given-1": 5, "b": 2, "given-2": 2}) == exp
366+
assert pipeline({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")
354367

355368
def test_unsatisfied_operations():
356369
# Test that operations with partial inputs are culled and not failing.
@@ -395,16 +408,17 @@ def test_unsatisfied_operations_same_out():
395408
assert pipeline({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c")
396409

397410
## Test parallel
411+
#
398412
# FAIL! in #26
413+
pipeline.set_execution_method("parallel")
414+
exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21}
415+
assert pipeline({"a": 10, "b1": 2, "c": 1}) == exp
416+
assert pipeline({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c")
399417
#
400-
# pipeline.set_execution_method("parallel")
401-
# exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21}
402-
# assert pipeline({"a": 10, "b1": 2, "c": 1}) == exp
403-
# assert pipeline({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c")
404-
405-
# exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6}
406-
# assert pipeline({"a": 10, "b2": 2, "c": 1}) == exp
407-
# assert pipeline({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c")
418+
# FAIL! in #26
419+
exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6}
420+
assert pipeline({"a": 10, "b2": 2, "c": 1}) == exp
421+
assert pipeline({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == filtdict(exp, "ab_plus_c")
408422

409423

410424
def test_optional():
@@ -624,6 +638,7 @@ def compute(self, inputs):
624638
outputs.append(p)
625639
return outputs
626640

641+
627642
def test_backwards_compatibility():
628643

629644
sum_op1 = Sum(

0 commit comments

Comments
 (0)