Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 46 additions & 17 deletions web_api/app/alignment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# pylint: disable=import-error
import asyncio
import base64
import gzip
import io
import json
import struct
import threading
import time

import numpy as np
Expand Down Expand Up @@ -37,6 +39,9 @@ def _get_misd_detector():

api = FastAPI()

# Limits concurrent GPU computations to 1; additional requests queue asynchronously.
_gpu_semaphore = asyncio.Semaphore(1)


@api.exception_handler(Exception)
async def generic_handler(request: Request, exc: Exception):
Expand Down Expand Up @@ -455,23 +460,47 @@ async def apply_correspondences(request: Request):
)
print(f"[apply_correspondences] params: {params}")

t0 = time.time()
relaxed_field, warped_image = apply_correspondences_to_image(
correspondences_dict=correspondences_dict,
image=image_tensor,
src_mask=src_mask_tensor,
tgt_mask=tgt_mask_tensor,
tgt_image=tgt_image_tensor,
**params,
)
t_corr = time.time() - t0
print(f"[apply_correspondences] correspondence relaxation: {t_corr:.2f}s")

misd_as_input = None
if tgt_image_tensor is not None:
misd_as_input = _run_misd_detection(warped_image, tgt_image_tensor, image_tensor.dtype)

print(f"[apply_correspondences] total: {time.time() - t0:.2f}s")
cancel_event = threading.Event()

def _run_computation():
t0 = time.time()
relaxed_field, warped_image = apply_correspondences_to_image(
correspondences_dict=correspondences_dict,
image=image_tensor,
src_mask=src_mask_tensor,
tgt_mask=tgt_mask_tensor,
tgt_image=tgt_image_tensor,
cancel_event=cancel_event,
**params,
)
t_corr = time.time() - t0
print(f"[apply_correspondences] correspondence relaxation: {t_corr:.2f}s")

misd_as_input = None
if tgt_image_tensor is not None:
misd_as_input = _run_misd_detection(warped_image, tgt_image_tensor, image_tensor.dtype)

print(f"[apply_correspondences] total: {time.time() - t0:.2f}s")
return relaxed_field, warped_image, misd_as_input

async with _gpu_semaphore:
if await request.is_disconnected():
print("[apply_correspondences] Client disconnected while waiting in queue")
return Response(status_code=499)

compute_task = asyncio.get_event_loop().run_in_executor(None, _run_computation)

# Poll for client disconnect while computation runs in thread pool.
while not compute_task.done():
if await request.is_disconnected():
print("[apply_correspondences] Client disconnected, cancelling computation")
cancel_event.set()
# Wait for the thread to finish (it will exit early on next iteration check)
await compute_task
return Response(status_code=499)
await asyncio.sleep(0.5)

relaxed_field, warped_image, misd_as_input = await compute_task

relaxed_field_np = relaxed_field.cpu().numpy().astype(np.float32)
warped_image_np = warped_image.cpu().numpy().astype(np.float32)
Expand Down
2 changes: 1 addition & 1 deletion zetta_utils/internal
Loading