@@ -95,6 +95,8 @@ def __init__(
9595 for field in dc .fields (klass )
9696 if field .name not in ["_func" , "_graph_checksums" ]
9797 ]
98+ # dictionary to save the connections with lazy fields
99+ self .inp_lf = {}
98100 self .state = None
99101 self ._output = {}
100102 self ._result = {}
@@ -124,8 +126,6 @@ def __init__(
124126 self .allow_cache_override = True
125127 self ._checksum = None
126128
127- # dictionary of results from tasks
128- self .results_dict = {}
129129 self .plugin = None
130130 self .hooks = TaskHook ()
131131
@@ -165,6 +165,10 @@ def version(self):
165165 def checksum (self ):
166166 """calculating checksum
167167 """
168+ # if checksum is called before run the _graph_checksums is not ready
169+ if is_workflow (self ) and self .inputs ._graph_checksums is None :
170+ self .inputs ._graph_checksums = [nd .checksum for nd in self .graph_sorted ]
171+
168172 input_hash = self .inputs .hash
169173 if self .state is None :
170174 self ._checksum = create_checksum (self .__class__ .__name__ , input_hash )
@@ -176,6 +180,30 @@ def checksum(self):
176180 )
177181 return self ._checksum
178182
183+ def checksum_states (self , state_index = None ):
184+ """ calculating checksum for the specific state or all of the states
185+ replace lists in the inputs fields with a specific values for states
186+ can be used only for tasks with a state
187+ """
188+ if state_index is not None :
189+ if self .state is None :
190+ raise Exception ("can't use state_index if no splitter is used" )
191+ inputs_copy = deepcopy (self .inputs )
192+ for key , ind in self .state .inputs_ind [state_index ].items ():
193+ setattr (
194+ inputs_copy ,
195+ key .split ("." )[1 ],
196+ getattr (inputs_copy , key .split ("." )[1 ])[ind ],
197+ )
198+ input_hash = inputs_copy .hash
199+ checksum_ind = create_checksum (self .__class__ .__name__ , input_hash )
200+ return checksum_ind
201+ else :
202+ checksum_list = []
203+ for ind in range (len (self .state .inputs_ind )):
204+ checksum_list .append (self .checksum_states (state_index = ind ))
205+ return checksum_list
206+
179207 def set_state (self , splitter , combiner = None ):
180208 if splitter is not None :
181209 self .state = state .State (
@@ -226,14 +254,7 @@ def cache_locations(self, locations):
226254 @property
227255 def output_dir (self ):
228256 if self .state :
229- if self .results_dict :
230- return [
231- self ._cache_dir / res [1 ] for (_ , res ) in self .results_dict .items ()
232- ]
233- else :
234- raise Exception (
235- f"output_dir not available, will be ready after running { self .name } "
236- )
257+ return [self ._cache_dir / checksum for checksum in self .checksum_states ()]
237258 else :
238259 return self ._cache_dir / self .checksum
239260
@@ -399,7 +420,7 @@ def _combined_output(self):
399420 for (gr , ind_l ) in self .state .final_groups_mapping .items ():
400421 combined_results .append ([])
401422 for ind in ind_l :
402- result = load_result (self .results_dict [ ind ][ 1 ] , self .cache_locations )
423+ result = load_result (self .checksum_states ( ind ) , self .cache_locations )
403424 if result is None :
404425 return None
405426 combined_results [gr ].append (result )
@@ -419,10 +440,8 @@ def result(self, state_index=None):
419440 return self ._combined_output ()
420441 else :
421442 results = []
422- for (ii , val ) in enumerate (self .state .states_val ):
423- result = load_result (
424- self .results_dict [ii ][1 ], self .cache_locations
425- )
443+ for checksum in self .checksum_states ():
444+ result = load_result (checksum , self .cache_locations )
426445 if result is None :
427446 return None
428447 results .append (result )
@@ -431,19 +450,25 @@ def result(self, state_index=None):
431450 if self .state .combiner :
432451 return self ._combined_output ()[state_index ]
433452 result = load_result (
434- self .results_dict [ state_index ][ 1 ] , self .cache_locations
453+ self .checksum_states ( state_index ) , self .cache_locations
435454 )
436455 return result
437456 else :
438457 if state_index is not None :
439458 raise ValueError ("Task does not have a state" )
440- if self .results_dict :
441- checksum = self .results_dict [None ][1 ]
442- else :
443- checksum = self .checksum
459+ checksum = self .checksum
444460 result = load_result (checksum , self .cache_locations )
445461 return result
446462
463+ def _reset (self ):
464+ """resetting the connections between inputs and LazyFields"""
465+ for field in dc .fields (self .inputs ):
466+ if field .name in self .inp_lf :
467+ setattr (self .inputs , field .name , self .inp_lf [field .name ])
468+ if is_workflow (self ):
469+ for task in self .graph .nodes :
470+ task ._reset ()
471+
447472
448473class Workflow (TaskBase ):
449474 def __init__ (
@@ -534,6 +559,8 @@ def create_connections(self, task):
534559 for field in dc .fields (task .inputs ):
535560 val = getattr (task .inputs , field .name )
536561 if isinstance (val , LazyField ):
562+ # saving all connections with LazyFields
563+ task .inp_lf [field .name ] = val
537564 # adding an edge to the graph if task id expecting output from a different task
538565 if val .name != self .name :
539566 # checking if the connection is already in the graph
@@ -558,7 +585,7 @@ def create_connections(self, task):
558585 task .state = state .State (task .name , other_states = other_states )
559586
560587 async def _run (self , submitter = None , ** kwargs ):
561- self .inputs = dc .replace (self .inputs , ** kwargs )
588+ # self.inputs = dc.replace(self.inputs, **kwargs) don't need it?
562589 checksum = self .checksum
563590 lockfile = self .cache_dir / (checksum + ".lock" )
564591 # Eagerly retrieve cached
@@ -610,7 +637,6 @@ async def _run_task(self, submitter):
610637 if not submitter :
611638 raise Exception ("Submitter should already be set." )
612639 # at this point Workflow is stateless so this should be fine
613- self .results_dict [None ] = (None , self .checksum )
614640 await submitter ._run_workflow (self )
615641
616642 def set_output (self , connections ):
0 commit comments