|
5 | 5 | import os |
6 | 6 | import networkx as nx |
7 | 7 |
|
| 8 | +from collections import defaultdict |
8 | 9 | from io import StringIO |
| 10 | +from itertools import chain |
| 11 | + |
9 | 12 |
|
10 | 13 | from boltons.setutils import IndexedSet as iset |
11 | 14 |
|
@@ -138,66 +141,45 @@ def compile(self): |
138 | 141 | self.steps.append(DeleteInstruction(need)) |
139 | 142 |
|
140 | 143 | else: |
141 | | - raise TypeError("Unrecognized network graph node") |
| 144 | + raise TypeError("Unrecognized network graph node %s" % type(node)) |
142 | 145 |
|
143 | 146 |
|
144 | | - def _collect_satisfiable_needs(self, operation, inputs, satisfiables, visited): |
| 147 | + def _collect_unsatisfiable_operations(self, necessary_nodes, inputs): |
145 | 148 | """ |
146 | | - Recusrively check if operation inputs are given/calculated (satisfied), or not. |
147 | | -
|
148 | | - :param satisfiables: |
149 | | - the set to populate with satisfiable operations |
150 | | -
|
151 | | - :param visited: |
152 | | - a cache of operations & needs, not to visit them again |
153 | | - :return: |
154 | | - true if opearation is satisfiable |
| 149 | + Traverse ordered graph and mark satisfied needs on each operation, |
| 150 | +
|
| 151 | + collecting those missing at least one. |
| 152 | + Since the graph is ordered, as soon as we're on an operation, |
| 153 | + all its needs have been accounted, so we can get its satisfaction. |
| 154 | +
|
| 155 | + :param necessary_nodes: |
| 156 | + the subset of the graph to consider but WITHOUT the initial data |
| 157 | + (because that is what :meth:`_find_necessary_steps()` can gives us...) |
| 158 | + :param inputs: |
| 159 | + an iterable of the names of the input values |
| 160 | + return: |
| 161 | + a list of unsatisfiable operations |
155 | 162 | """ |
156 | | - assert isinstance(operation, Operation), ( |
157 | | - "Expected Operation, got:", |
158 | | - type(operation), |
159 | | - ) |
160 | | - |
161 | | - if operation in visited: |
162 | | - return visited[operation] |
163 | | - |
164 | | - |
165 | | - def is_need_satisfiable(need): |
166 | | - if need in visited: |
167 | | - return visited[need] |
168 | | - |
169 | | - if need in inputs: |
170 | | - satisfied = True |
171 | | - else: |
172 | | - need_providers = list(self.graph.predecessors(need)) |
173 | | - satisfied = bool(need_providers) and any( |
174 | | - self._collect_satisfiable_needs(op, inputs, satisfiables, visited) |
175 | | - for op in need_providers |
176 | | - ) |
177 | | - visited[need] = satisfied |
178 | | - |
179 | | - return satisfied |
180 | | - |
181 | | - satisfied = all( |
182 | | - is_need_satisfiable(need) |
183 | | - for need in operation.needs |
184 | | - if not isinstance(need, optional) |
185 | | - ) |
186 | | - if satisfied: |
187 | | - satisfiables.add(operation) |
188 | | - visited[operation] = satisfied |
189 | | - |
190 | | - return satisfied |
191 | | - |
192 | | - |
193 | | - def _collect_satisfiable_operations(self, nodes, inputs): |
194 | | - satisfiables = set() # unordered, not iterated |
195 | | - visited = {} |
196 | | - for node in nodes: |
197 | | - if node not in visited and isinstance(node, Operation): |
198 | | - self._collect_satisfiable_needs(node, inputs, satisfiables, visited) |
| 163 | + G = self.graph # shortcut |
| 164 | + ok_data = set(inputs) # to collect producible data |
| 165 | + op_satisfaction = defaultdict(set) # to collect operation satisfiable needs |
| 166 | + unsatisfiables = [] # to collect operations with partial needs |
| 167 | + # We also need inputs to mark op_satisfaction. |
| 168 | + nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings |
| 169 | + for node in nx.topological_sort(G.subgraph(nodes)): |
| 170 | + if isinstance(node, Operation): |
| 171 | + real_needs = set(n for n in node.needs if not isinstance(n, optional)) |
| 172 | + if real_needs.issubset(op_satisfaction[node]): |
| 173 | + # mark all future data-provides as ok |
| 174 | + ok_data.update(G.adj[node]) |
| 175 | + else: |
| 176 | + unsatisfiables.append(node) |
| 177 | + elif isinstance(node, (DataPlaceholderNode, str)) and node in ok_data: |
| 178 | + # mark satisfied-needs on all future operations |
| 179 | + for future_op in G.adj[node]: |
| 180 | + op_satisfaction[future_op].add(node) |
199 | 181 |
|
200 | | - return satisfiables |
| 182 | + return unsatisfiables |
201 | 183 |
|
202 | 184 |
|
203 | 185 | def _find_necessary_steps(self, outputs, inputs): |
@@ -264,12 +246,10 @@ def _find_necessary_steps(self, outputs, inputs): |
264 | 246 | necessary_nodes -= unnecessary_nodes |
265 | 247 |
|
266 | 248 | # Drop (un-satifiable) operations with partial inputs. |
267 | | - # See https://github.com/yahoo/graphkit/pull/18 |
| 249 | + # See yahoo/graphkit#18 |
268 | 250 | # |
269 | | - satisfiables = self._collect_satisfiable_operations(necessary_nodes, inputs) |
270 | | - for node in list(necessary_nodes): |
271 | | - if isinstance(node, Operation) and node not in satisfiables: |
272 | | - necessary_nodes.remove(node) |
| 251 | + unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs) |
| 252 | + necessary_nodes -= set(unsatisfiables) |
273 | 253 |
|
274 | 254 | necessary_steps = [step for step in self.steps if step in necessary_nodes] |
275 | 255 |
|
|
0 commit comments