Skip to content

Commit 7ffa49d

Browse files
authored
Merge pull request #130 from djarecka/checksum_resets
changes in checksums/results and some fixes (fixes#107)
2 parents f5c43ec + 2ec3b7f commit 7ffa49d

File tree

5 files changed

+169
-37
lines changed

5 files changed

+169
-37
lines changed

pydra/engine/core.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def __init__(
9595
for field in dc.fields(klass)
9696
if field.name not in ["_func", "_graph_checksums"]
9797
]
98+
# dictionary to save the connections with lazy fields
99+
self.inp_lf = {}
98100
self.state = None
99101
self._output = {}
100102
self._result = {}
@@ -124,8 +126,6 @@ def __init__(
124126
self.allow_cache_override = True
125127
self._checksum = None
126128

127-
# dictionary of results from tasks
128-
self.results_dict = {}
129129
self.plugin = None
130130
self.hooks = TaskHook()
131131

@@ -165,6 +165,10 @@ def version(self):
165165
def checksum(self):
166166
"""calculating checksum
167167
"""
168+
# if checksum is called before run the _graph_checksums is not ready
169+
if is_workflow(self) and self.inputs._graph_checksums is None:
170+
self.inputs._graph_checksums = [nd.checksum for nd in self.graph_sorted]
171+
168172
input_hash = self.inputs.hash
169173
if self.state is None:
170174
self._checksum = create_checksum(self.__class__.__name__, input_hash)
@@ -176,6 +180,30 @@ def checksum(self):
176180
)
177181
return self._checksum
178182

183+
def checksum_states(self, state_index=None):
184+
""" calculating checksum for the specific state or all of the states
185+
replace lists in the inputs fields with a specific values for states
186+
can be used only for tasks with a state
187+
"""
188+
if state_index is not None:
189+
if self.state is None:
190+
raise Exception("can't use state_index if no splitter is used")
191+
inputs_copy = deepcopy(self.inputs)
192+
for key, ind in self.state.inputs_ind[state_index].items():
193+
setattr(
194+
inputs_copy,
195+
key.split(".")[1],
196+
getattr(inputs_copy, key.split(".")[1])[ind],
197+
)
198+
input_hash = inputs_copy.hash
199+
checksum_ind = create_checksum(self.__class__.__name__, input_hash)
200+
return checksum_ind
201+
else:
202+
checksum_list = []
203+
for ind in range(len(self.state.inputs_ind)):
204+
checksum_list.append(self.checksum_states(state_index=ind))
205+
return checksum_list
206+
179207
def set_state(self, splitter, combiner=None):
180208
if splitter is not None:
181209
self.state = state.State(
@@ -226,14 +254,7 @@ def cache_locations(self, locations):
226254
@property
227255
def output_dir(self):
228256
if self.state:
229-
if self.results_dict:
230-
return [
231-
self._cache_dir / res[1] for (_, res) in self.results_dict.items()
232-
]
233-
else:
234-
raise Exception(
235-
f"output_dir not available, will be ready after running {self.name}"
236-
)
257+
return [self._cache_dir / checksum for checksum in self.checksum_states()]
237258
else:
238259
return self._cache_dir / self.checksum
239260

@@ -399,7 +420,7 @@ def _combined_output(self):
399420
for (gr, ind_l) in self.state.final_groups_mapping.items():
400421
combined_results.append([])
401422
for ind in ind_l:
402-
result = load_result(self.results_dict[ind][1], self.cache_locations)
423+
result = load_result(self.checksum_states(ind), self.cache_locations)
403424
if result is None:
404425
return None
405426
combined_results[gr].append(result)
@@ -419,10 +440,8 @@ def result(self, state_index=None):
419440
return self._combined_output()
420441
else:
421442
results = []
422-
for (ii, val) in enumerate(self.state.states_val):
423-
result = load_result(
424-
self.results_dict[ii][1], self.cache_locations
425-
)
443+
for checksum in self.checksum_states():
444+
result = load_result(checksum, self.cache_locations)
426445
if result is None:
427446
return None
428447
results.append(result)
@@ -431,19 +450,25 @@ def result(self, state_index=None):
431450
if self.state.combiner:
432451
return self._combined_output()[state_index]
433452
result = load_result(
434-
self.results_dict[state_index][1], self.cache_locations
453+
self.checksum_states(state_index), self.cache_locations
435454
)
436455
return result
437456
else:
438457
if state_index is not None:
439458
raise ValueError("Task does not have a state")
440-
if self.results_dict:
441-
checksum = self.results_dict[None][1]
442-
else:
443-
checksum = self.checksum
459+
checksum = self.checksum
444460
result = load_result(checksum, self.cache_locations)
445461
return result
446462

463+
def _reset(self):
464+
"""resetting the connections between inputs and LazyFields"""
465+
for field in dc.fields(self.inputs):
466+
if field.name in self.inp_lf:
467+
setattr(self.inputs, field.name, self.inp_lf[field.name])
468+
if is_workflow(self):
469+
for task in self.graph.nodes:
470+
task._reset()
471+
447472

448473
class Workflow(TaskBase):
449474
def __init__(
@@ -534,6 +559,8 @@ def create_connections(self, task):
534559
for field in dc.fields(task.inputs):
535560
val = getattr(task.inputs, field.name)
536561
if isinstance(val, LazyField):
562+
# saving all connections with LazyFields
563+
task.inp_lf[field.name] = val
537564
# adding an edge to the graph if task id expecting output from a different task
538565
if val.name != self.name:
539566
# checking if the connection is already in the graph
@@ -558,7 +585,7 @@ def create_connections(self, task):
558585
task.state = state.State(task.name, other_states=other_states)
559586

560587
async def _run(self, submitter=None, **kwargs):
561-
self.inputs = dc.replace(self.inputs, **kwargs)
588+
# self.inputs = dc.replace(self.inputs, **kwargs) don't need it?
562589
checksum = self.checksum
563590
lockfile = self.cache_dir / (checksum + ".lock")
564591
# Eagerly retrieve cached
@@ -610,7 +637,6 @@ async def _run_task(self, submitter):
610637
if not submitter:
611638
raise Exception("Submitter should already be set.")
612639
# at this point Workflow is stateless so this should be fine
613-
self.results_dict[None] = (None, self.checksum)
614640
await submitter._run_workflow(self)
615641

616642
def set_output(self, connections):

pydra/engine/state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(self, name, splitter=None, combiner=None, other_states=None):
2020
self.set_input_groups()
2121
self.set_splitter_final()
2222
self.states_val = []
23+
self.inputs_ind = []
2324
self.final_groups_mapping = {}
2425

2526
def __str__(self):

pydra/engine/submitter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ def __call__(self, runnable, cache_locations=None):
4141
self.loop.run_until_complete(self.submit_workflow(runnable))
4242
else:
4343
self.loop.run_until_complete(self.submit(runnable, wait=True))
44+
if is_workflow(runnable):
45+
# resetting all connections with LazyFields
46+
runnable._reset()
4447
return runnable.result()
4548

4649
async def submit_workflow(self, workflow):
4750
"""Distributes or initiates workflow execution"""
4851
if workflow.plugin and workflow.plugin != self.plugin:
52+
# dj: this is not tested!!!
4953
await self.worker.run_el(workflow)
5054
else:
5155
await workflow._run(self)
@@ -81,8 +85,6 @@ async def submit(self, runnable, wait=False):
8185
)
8286
for sidx in range(len(runnable.state.states_val)):
8387
job = runnable.to_job(sidx)
84-
job.results_dict[None] = (sidx, job.checksum)
85-
runnable.results_dict[sidx] = (None, job.checksum)
8688
logger.debug(
8789
f'Submitting runnable {job}{str(sidx) if sidx is not None else ""}'
8890
)

pydra/engine/tests/test_node_task.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,6 @@ def test_odir_init():
160160
assert nn.output_dir
161161

162162

163-
def test_odir_init_error():
164-
""" checking if output_dir raises an error for task with a state
165-
if the task doesn't have result (before running)
166-
"""
167-
nn = fun_addtwo(name="NA").split(splitter="a", a=[3, 5])
168-
169-
with pytest.raises(Exception) as excinfo:
170-
assert nn.output_dir
171-
assert "output_dir not available" in str(excinfo.value)
172-
173-
174163
# Tests for tasks without state (i.e. no splitter)
175164

176165

@@ -224,6 +213,19 @@ def test_task_nostate_1_call_plug(plugin):
224213
assert nn.output_dir.exists()
225214

226215

216+
def test_task_nostate_1_call_updateinp():
217+
""" task without splitter"""
218+
nn = fun_addtwo(name="NA", a=30)
219+
# updating input when calling the node
220+
nn(a=3)
221+
222+
# checking the results
223+
results = nn.result()
224+
assert results.output.out == 5
225+
# checking the output_dir
226+
assert nn.output_dir.exists()
227+
228+
227229
@pytest.mark.parametrize("plugin", Plugins)
228230
def test_task_nostate_2(plugin):
229231
""" task with a list as an input, but no splitter"""

0 commit comments

Comments
 (0)