diff --git a/web_api/app/alignment.py b/web_api/app/alignment.py index e3d52a9da..ec3350cfb 100644 --- a/web_api/app/alignment.py +++ b/web_api/app/alignment.py @@ -1,9 +1,11 @@ # pylint: disable=import-error +import asyncio import base64 import gzip import io import json import struct +import threading import time import numpy as np @@ -37,6 +39,9 @@ def _get_misd_detector(): api = FastAPI() +# Limits concurrent GPU computations to 1; additional requests queue asynchronously. +_gpu_semaphore = asyncio.Semaphore(1) + @api.exception_handler(Exception) async def generic_handler(request: Request, exc: Exception): @@ -455,23 +460,47 @@ async def apply_correspondences(request: Request): ) print(f"[apply_correspondences] params: {params}") - t0 = time.time() - relaxed_field, warped_image = apply_correspondences_to_image( - correspondences_dict=correspondences_dict, - image=image_tensor, - src_mask=src_mask_tensor, - tgt_mask=tgt_mask_tensor, - tgt_image=tgt_image_tensor, - **params, - ) - t_corr = time.time() - t0 - print(f"[apply_correspondences] correspondence relaxation: {t_corr:.2f}s") - - misd_as_input = None - if tgt_image_tensor is not None: - misd_as_input = _run_misd_detection(warped_image, tgt_image_tensor, image_tensor.dtype) - - print(f"[apply_correspondences] total: {time.time() - t0:.2f}s") + cancel_event = threading.Event() + + def _run_computation(): + t0 = time.time() + relaxed_field, warped_image = apply_correspondences_to_image( + correspondences_dict=correspondences_dict, + image=image_tensor, + src_mask=src_mask_tensor, + tgt_mask=tgt_mask_tensor, + tgt_image=tgt_image_tensor, + cancel_event=cancel_event, + **params, + ) + t_corr = time.time() - t0 + print(f"[apply_correspondences] correspondence relaxation: {t_corr:.2f}s") + + misd_as_input = None + if tgt_image_tensor is not None: + misd_as_input = _run_misd_detection(warped_image, tgt_image_tensor, image_tensor.dtype) + + print(f"[apply_correspondences] total: {time.time() - t0:.2f}s") + return relaxed_field, warped_image, misd_as_input + + async with _gpu_semaphore: + if await request.is_disconnected(): + print("[apply_correspondences] Client disconnected while waiting in queue") + return Response(status_code=499) + + compute_task = asyncio.get_event_loop().run_in_executor(None, _run_computation) + + # Poll for client disconnect while computation runs in thread pool. + while not compute_task.done(): + if await request.is_disconnected(): + print("[apply_correspondences] Client disconnected, cancelling computation") + cancel_event.set() + # Wait for the thread to finish (it will exit early on next iteration check) + await compute_task + return Response(status_code=499) + await asyncio.sleep(0.5) + + relaxed_field, warped_image, misd_as_input = await compute_task relaxed_field_np = relaxed_field.cpu().numpy().astype(np.float32) warped_image_np = warped_image.cpu().numpy().astype(np.float32) diff --git a/zetta_utils/internal b/zetta_utils/internal index 169abd527..bb0e1032c 160000 --- a/zetta_utils/internal +++ b/zetta_utils/internal @@ -1 +1 @@ -Subproject commit 169abd52771d8a8a58c06d3adc30847d5d45067b +Subproject commit bb0e1032ce81467db17d786d7fcefbeb41434cc5