Skip to content

Commit 230bb64

Browse files
vdusekclaude
andauthored
refactor: Adopt asyncio.TaskGroup for structured concurrency (#643)
## Summary - Replace manual task management (`create_task` + `cancel` + `gather`) with `asyncio.TaskGroup` in `RequestQueueClientAsync.batch_add_requests` for better error propagation and automatic cleanup - Update docs example (`02_tasks_async.py`) to use `asyncio.TaskGroup` instead of `asyncio.gather` - `StreamedLogAsync` and `StatusMessageWatcherAsync` left unchanged — `TaskGroup` doesn't fit their `start()`/`stop()` lifecycle pattern and adds no benefit for single-task cases ## Issue - Closes #598 Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a293601 commit 230bb64

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

docs/03_examples/code/02_tasks_async.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ async def main() -> None:
3030
print('Task clients created:', apify_task_clients)
3131

3232
# Execute Apify tasks
33-
task_run_results = await asyncio.gather(
34-
*[client.call() for client in apify_task_clients]
35-
)
33+
async with asyncio.TaskGroup() as tg:
34+
tasks = [tg.create_task(client.call()) for client in apify_task_clients]
35+
36+
task_run_results = [task.result() for task in tasks]
3637

3738
# Filter out None results (tasks that failed to return a run)
3839
successful_runs = [run for run in task_run_results if run is not None]

src/apify_client/_resource_clients/request_queue.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,6 @@ async def batch_add_requests(
798798
if min_delay_between_unprocessed_requests_retries:
799799
logger.warning('`min_delay_between_unprocessed_requests_retries` is deprecated and not used anymore.')
800800

801-
tasks = set[asyncio.Task]()
802801
asyncio_queue: asyncio.Queue[Iterable[dict]] = asyncio.Queue()
803802
request_params = self._build_params(clientKey=self.client_key, forefront=forefront)
804803

@@ -815,29 +814,31 @@ async def batch_add_requests(
815814
for batch in batches:
816815
await asyncio_queue.put(batch)
817816

818-
# Start a required number of worker tasks to process the batches.
819-
for i in range(max_parallel):
820-
coro = self._batch_add_requests_worker(
821-
asyncio_queue,
822-
request_params,
823-
)
824-
task = asyncio.create_task(coro, name=f'batch_add_requests_worker_{i}')
825-
tasks.add(task)
826-
827-
# Wait for all batches to be processed.
828-
await asyncio_queue.join()
829-
830-
# Send cancellation signals to all worker tasks and wait for them to finish.
831-
for task in tasks:
832-
task.cancel()
833-
834-
results: list[BatchAddResponse] = await asyncio.gather(*tasks)
817+
# Use TaskGroup for structured concurrency — automatic cleanup and error propagation.
818+
try:
819+
async with asyncio.TaskGroup() as tg:
820+
workers = [
821+
tg.create_task(
822+
self._batch_add_requests_worker(asyncio_queue, request_params),
823+
name=f'batch_add_requests_worker_{i}',
824+
)
825+
for i in range(max_parallel)
826+
]
827+
828+
# Wait for all batches to be processed, then cancel idle workers.
829+
await asyncio_queue.join()
830+
for worker in workers:
831+
worker.cancel()
832+
except ExceptionGroup as eg:
833+
# Re-raise the first worker exception directly to maintain backward-compatible error types.
834+
raise eg.exceptions[0] from None
835835

836836
# Combine the results from all workers and return them.
837837
processed_requests = list[AddedRequest]()
838838
unprocessed_requests = list[RequestDraft]()
839839

840-
for result in results:
840+
for worker in workers:
841+
result = worker.result()
841842
processed_requests.extend(result.data.processed_requests)
842843
unprocessed_requests.extend(result.data.unprocessed_requests)
843844

0 commit comments

Comments
 (0)