Skip to content

Commit 3d8b6dd

Browse files
committed
Add support for colors.
1 parent 87edead commit 3d8b6dd

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,5 @@ docs/_build/
5252

5353
# PyBuilder
5454
target/
55+
56+
.pytest_cache

graphkit/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self, **kwargs):
5353
self.needs = kwargs.get('needs')
5454
self.provides = kwargs.get('provides')
5555
self.params = kwargs.get('params', {})
56+
self.color = kwargs.get('color', None)
5657

5758
# call _after_init as final step of initialization
5859
self._after_init()
@@ -151,8 +152,8 @@ def __init__(self, **kwargs):
151152
self.net = kwargs.pop('net')
152153
Operation.__init__(self, **kwargs)
153154

154-
def _compute(self, named_inputs, outputs=None):
155-
return self.net.compute(outputs, named_inputs)
155+
def _compute(self, named_inputs, outputs=None, color=None):
156+
return self.net.compute(outputs, named_inputs, color)
156157

157158
def __call__(self, *args, **kwargs):
158159
return self._compute(*args, **kwargs)

graphkit/functional.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ class operation(Operation):
6767
A dict of key/value pairs representing constant parameters
6868
associated with your operation. These can correspond to either
6969
``args`` or ``kwargs`` of ``fn`.
70+
71+
:param str color:
72+
A color for the node in the computation graph.
7073
"""
7174

7275
def __init__(self, fn=None, **kwargs):
@@ -93,6 +96,9 @@ def _normalize_kwargs(self, kwargs):
9396
if type(kwargs['params']) is not dict:
9497
kwargs['params'] = {}
9598

99+
if 'color' in kwargs and type(kwargs['color']) == str:
100+
assert kwargs['color'], "empty string provided for `color` parameters"
101+
96102
return kwargs
97103

98104
def __call__(self, fn=None, **kwargs):

graphkit/network.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def add_op(self, operation):
7676
for n in operation.needs:
7777
self.graph.add_edge(DataPlaceholderNode(n), operation)
7878

79+
if operation.color:
80+
self.graph.nodes[operation]['color'] = operation.color
81+
7982
# add nodes and edges to graph describing what this layer provides
8083
for p in operation.provides:
8184
self.graph.add_edge(operation, DataPlaceholderNode(p))
@@ -93,6 +96,7 @@ def show_layers(self):
9396
print("layer_name: ", name)
9497
print("\t", "needs: ", step.needs)
9598
print("\t", "provides: ", step.provides)
99+
print("\t", "color: ", step.color)
96100
print("")
97101

98102
def compile(self):
@@ -136,7 +140,7 @@ def compile(self):
136140
else:
137141
raise TypeError("Unrecognized network graph node")
138142

139-
def _find_necessary_steps(self, outputs, inputs):
143+
def _find_necessary_steps(self, outputs, inputs, color=None):
140144
"""
141145
Determines what graph steps need to be run to get to the requested
142146
outputs from the provided inputs. Eliminates steps that come before
@@ -152,6 +156,9 @@ def _find_necessary_steps(self, outputs, inputs):
152156
:param dict inputs:
153157
A dictionary mapping names to values for all provided inputs.
154158
159+
:param str color:
160+
A color to filter nodes by.
161+
155162
:returns:
156163
Returns a list of all the steps that need to be run for the
157164
provided inputs and requested outputs.
@@ -160,7 +167,7 @@ def _find_necessary_steps(self, outputs, inputs):
160167
# return steps if it has already been computed before for this set of inputs and outputs
161168
outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs
162169
inputs_keys = tuple(sorted(inputs.keys()))
163-
cache_key = (inputs_keys, outputs)
170+
cache_key = (inputs_keys, outputs, color)
164171
if cache_key in self._necessary_steps_cache:
165172
return self._necessary_steps_cache[cache_key]
166173

@@ -199,15 +206,23 @@ def _find_necessary_steps(self, outputs, inputs):
199206
# Get rid of the unnecessary nodes from the set of necessary ones.
200207
necessary_nodes -= unnecessary_nodes
201208

202-
necessary_steps = [step for step in self.steps if step in necessary_nodes]
209+
necessary_steps = []
210+
211+
for step in self.steps:
212+
if isinstance(step, Operation):
213+
if step.color == color and step in necessary_nodes:
214+
necessary_steps.append(step)
215+
else:
216+
if step in necessary_nodes:
217+
necessary_steps.append(step)
203218

204219
# save this result in a precomputed cache for future lookup
205220
self._necessary_steps_cache[cache_key] = necessary_steps
206221

207222
# Return an ordered list of the needed steps.
208223
return necessary_steps
209224

210-
def compute(self, outputs, named_inputs):
225+
def compute(self, outputs, named_inputs, color=None):
211226
"""
212227
This method runs the graph one operation at a time in a single thread
213228
Any inputs to the network must be passed in by name.
@@ -222,6 +237,8 @@ def compute(self, outputs, named_inputs):
222237
and the values are the concrete values you
223238
want to set for the data node.
224239
240+
:param str color: Only the subgraph of nodes with color will be evaluted.
241+
225242
:returns: a dictionary of output data objects, keyed by name.
226243
"""
227244

@@ -238,7 +255,7 @@ def compute(self, outputs, named_inputs):
238255

239256
# Find the subset of steps we need to run to get to the requested
240257
# outputs from the provided inputs.
241-
all_steps = self._find_necessary_steps(outputs, named_inputs)
258+
all_steps = self._find_necessary_steps(outputs, named_inputs, color)
242259

243260
self.times = {}
244261
for step in all_steps:
@@ -281,9 +298,9 @@ def compute(self, outputs, named_inputs):
281298
raise TypeError("Unrecognized instruction.")
282299

283300
if not outputs:
284-
# Return the whole cache as output, including input and
285-
# intermediate data nodes.
286-
return cache
301+
# Return cache as output including intermediate data nodes,
302+
# but excluding input.
303+
return {k: cache[k] for k in set(cache) - set(named_inputs)}
287304

288305
else:
289306
# Filter outputs to just return what's needed.

0 commit comments

Comments
 (0)