77import threading
88import time
99from abc import ABC , abstractmethod
10- from collections import Counter
1110from concurrent .futures import Future , ThreadPoolExecutor
1211from dataclasses import dataclass
1312from enum import Enum
2221from aws_durable_execution_sdk_python .identifier import OperationIdentifier
2322from aws_durable_execution_sdk_python .lambda_service import ErrorObject
2423from aws_durable_execution_sdk_python .operation .child import child_handler
25- from aws_durable_execution_sdk_python .types import BatchResult as BatchResultProtocol
24+ from aws_durable_execution_sdk_python .types import (
25+ BatchItem ,
26+ BatchItemStatus ,
27+ BatchResult ,
28+ )
2629
2730if TYPE_CHECKING :
2831 from collections .abc import Callable
4548
4649
4750# region Result models
48- class BatchItemStatus (Enum ):
49- SUCCEEDED = "SUCCEEDED"
50- FAILED = "FAILED"
51- STARTED = "STARTED"
52-
53-
54- class CompletionReason (Enum ):
55- ALL_COMPLETED = "ALL_COMPLETED"
56- MIN_SUCCESSFUL_REACHED = "MIN_SUCCESSFUL_REACHED"
57- FAILURE_TOLERANCE_EXCEEDED = "FAILURE_TOLERANCE_EXCEEDED"
58-
59-
6051@dataclass (frozen = True )
6152class SuspendResult :
6253 should_suspend : bool
@@ -71,173 +62,6 @@ def suspend(exception: SuspendExecution) -> SuspendResult:
7162 return SuspendResult (should_suspend = True , exception = exception )
7263
7364
74- @dataclass (frozen = True )
75- class BatchItem (Generic [R ]):
76- index : int
77- status : BatchItemStatus
78- result : R | None = None
79- error : ErrorObject | None = None
80-
81- def to_dict (self ) -> dict :
82- return {
83- "index" : self .index ,
84- "status" : self .status .value ,
85- "result" : self .result ,
86- "error" : self .error .to_dict () if self .error else None ,
87- }
88-
89- @classmethod
90- def from_dict (cls , data : dict ) -> BatchItem [R ]:
91- return cls (
92- index = data ["index" ],
93- status = BatchItemStatus (data ["status" ]),
94- result = data .get ("result" ),
95- error = ErrorObject .from_dict (data ["error" ]) if data .get ("error" ) else None ,
96- )
97-
98-
99- @dataclass (frozen = True )
100- class BatchResult (Generic [R ], BatchResultProtocol [R ]): # noqa: PYI059
101- all : list [BatchItem [R ]]
102- completion_reason : CompletionReason
103-
104- @classmethod
105- def from_dict (
106- cls , data : dict , completion_config : CompletionConfig | None = None
107- ) -> BatchResult [R ]:
108- batch_items : list [BatchItem [R ]] = [
109- BatchItem .from_dict (item ) for item in data ["all" ]
110- ]
111-
112- completion_reason_value = data .get ("completionReason" )
113- if completion_reason_value is None :
114- # Infer completion reason from batch item statuses and completion config
115- # This aligns with the TypeScript implementation that uses completion config
116- # to accurately reconstruct the completion reason during replay
117- result = cls .from_items (batch_items , completion_config )
118- logger .warning (
119- "Missing completionReason in BatchResult deserialization, "
120- "inferred '%s' from batch item statuses. "
121- "This may indicate incomplete serialization data." ,
122- result .completion_reason .value ,
123- )
124- return result
125-
126- completion_reason = CompletionReason (completion_reason_value )
127- return cls (batch_items , completion_reason )
128-
129- @classmethod
130- def from_items (
131- cls ,
132- items : list [BatchItem [R ]],
133- completion_config : CompletionConfig | None = None ,
134- ):
135- """
136- Infer completion reason based on batch item statuses and completion config.
137-
138- This follows the same logic as the TypeScript implementation:
139- - If all items completed: ALL_COMPLETED
140- - If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED
141- - Otherwise: FAILURE_TOLERANCE_EXCEEDED
142- """
143-
144- statuses = (item .status for item in items )
145- counts = Counter (statuses )
146- succeeded_count = counts .get (BatchItemStatus .SUCCEEDED , 0 )
147- failed_count = counts .get (BatchItemStatus .FAILED , 0 )
148- started_count = counts .get (BatchItemStatus .STARTED , 0 )
149-
150- completed_count = succeeded_count + failed_count
151- total_count = started_count + completed_count
152-
153- # If all items completed (no started items), it's ALL_COMPLETED
154- if completed_count == total_count :
155- completion_reason = CompletionReason .ALL_COMPLETED
156- elif ( # If we have completion config and minSuccessful threshold is met
157- completion_config
158- and (min_successful := completion_config .min_successful ) is not None
159- and succeeded_count >= min_successful
160- ):
161- completion_reason = CompletionReason .MIN_SUCCESSFUL_REACHED
162- else :
163- # Otherwise, assume failure tolerance was exceeded
164- completion_reason = CompletionReason .FAILURE_TOLERANCE_EXCEEDED
165-
166- return cls (items , completion_reason )
167-
168- def to_dict (self ) -> dict :
169- return {
170- "all" : [item .to_dict () for item in self .all ],
171- "completionReason" : self .completion_reason .value ,
172- }
173-
174- def succeeded (self ) -> list [BatchItem [R ]]:
175- return [
176- item
177- for item in self .all
178- if item .status is BatchItemStatus .SUCCEEDED and item .result is not None
179- ]
180-
181- def failed (self ) -> list [BatchItem [R ]]:
182- return [
183- item
184- for item in self .all
185- if item .status is BatchItemStatus .FAILED and item .error is not None
186- ]
187-
188- def started (self ) -> list [BatchItem [R ]]:
189- return [item for item in self .all if item .status is BatchItemStatus .STARTED ]
190-
191- @property
192- def status (self ) -> BatchItemStatus :
193- return BatchItemStatus .FAILED if self .has_failure else BatchItemStatus .SUCCEEDED
194-
195- @property
196- def has_failure (self ) -> bool :
197- return any (item .status is BatchItemStatus .FAILED for item in self .all )
198-
199- def throw_if_error (self ) -> None :
200- first_error = next (
201- (item .error for item in self .all if item .status is BatchItemStatus .FAILED ),
202- None ,
203- )
204- if first_error :
205- raise first_error .to_callable_runtime_error ()
206-
207- def get_results (self ) -> list [R ]:
208- return [
209- item .result
210- for item in self .all
211- if item .status is BatchItemStatus .SUCCEEDED and item .result is not None
212- ]
213-
214- def get_errors (self ) -> list [ErrorObject ]:
215- return [
216- item .error
217- for item in self .all
218- if item .status is BatchItemStatus .FAILED and item .error is not None
219- ]
220-
221- @property
222- def success_count (self ) -> int :
223- return sum (1 for item in self .all if item .status is BatchItemStatus .SUCCEEDED )
224-
225- @property
226- def failure_count (self ) -> int :
227- return sum (1 for item in self .all if item .status is BatchItemStatus .FAILED )
228-
229- @property
230- def started_count (self ) -> int :
231- return sum (1 for item in self .all if item .status is BatchItemStatus .STARTED )
232-
233- @property
234- def total_count (self ) -> int :
235- return len (self .all )
236-
237-
238- # endregion Result models
239-
240-
24165# region concurrency models
24266@dataclass (frozen = True )
24367class Executable (Generic [CallableType ]):
@@ -367,12 +191,12 @@ class ExecutionCounters:
367191 def __init__ (
368192 self ,
369193 total_tasks : int ,
370- min_successful : int ,
194+ min_successful : int | None ,
371195 tolerated_failure_count : int | None ,
372196 tolerated_failure_percentage : float | None ,
373197 ):
374198 self .total_tasks : int = total_tasks
375- self .min_successful : int = min_successful
199+ self .min_successful : int | None = min_successful
376200 self .tolerated_failure_count : int | None = tolerated_failure_count
377201 self .tolerated_failure_percentage : float | None = tolerated_failure_percentage
378202 self .success_count : int = 0
@@ -421,24 +245,26 @@ def is_complete(self) -> bool:
421245 """
422246 Check if execution should complete (based on completion criteria).
423247 Matches TypeScript isComplete() logic.
248+
249+ Note: This method only checks completion criteria (all done, or min_successful met).
250+ Failure tolerance is enforced separately by should_continue() and combined in should_complete().
424251 """
425252 with self ._lock :
426253 completed_count = self .success_count + self .failure_count
427254
428255 # All tasks completed
429256 if completed_count == self .total_tasks :
430- # Complete if no failure tolerance OR no failures OR min successful reached
431- return (
432- (
433- self .tolerated_failure_count is None
434- and self .tolerated_failure_percentage is None
435- )
436- or self .failure_count == 0
437- or self .success_count >= self .min_successful
438- )
257+ # If min_successful is explicitly set, check if we met it
258+ # Otherwise, complete when all tasks are done
259+ if self .min_successful is not None :
260+ return self .success_count >= self .min_successful
261+ return True
439262
440- # when we breach min successful, we've completed
441- return self .success_count >= self .min_successful
263+ # Early completion: when we breach min_successful (only if explicitly set)
264+ return (
265+ self .min_successful is not None
266+ and self .success_count >= self .min_successful
267+ )
442268
443269 def should_complete (self ) -> bool :
444270 """
@@ -453,9 +279,12 @@ def is_all_completed(self) -> bool:
453279 return self .success_count == self .total_tasks
454280
455281 def is_min_successful_reached (self ) -> bool :
456- """True if minimum successful tasks reached."""
282+ """True if minimum successful task is both set and reached."""
457283 with self ._lock :
458- return self .success_count >= self .min_successful
284+ return (
285+ self .min_successful is not None
286+ and self .success_count >= self .min_successful
287+ )
459288
460289 def is_failure_tolerance_exceeded (self ) -> bool :
461290 """True if failure tolerance was exceeded."""
@@ -594,7 +423,9 @@ def __init__(
594423 self ._suspend_exception : SuspendExecution | None = None
595424
596425 # ExecutionCounters will keep track of completion criteria and on-going counters
597- min_successful = self .completion_config .min_successful or len (self .executables )
426+ # Note: min_successful should remain None if not explicitly set
427+ # When None, the operation completes when all tasks finish (respecting failure tolerance)
428+ min_successful = self .completion_config .min_successful
598429 tolerated_failure_count = self .completion_config .tolerated_failure_count
599430 tolerated_failure_percentage = (
600431 self .completion_config .tolerated_failure_percentage
0 commit comments