diff --git a/web_api/Dockerfile.gpu b/web_api/Dockerfile.gpu index 0425fa9f0..25d7f96f8 100644 --- a/web_api/Dockerfile.gpu +++ b/web_api/Dockerfile.gpu @@ -22,7 +22,7 @@ g++ \ python3-dev \ && rm -rf /var/lib/apt/lists/* - + # ---- Copy metadata --------------------------------------------------- COPY pyproject.toml web_api/requirements.txt /opt/http/ @@ -44,6 +44,12 @@ # ---- 6. Install project + modules ------------------------------------ RUN pip install --no-cache-dir '.[modules]' + # ---- Install cue CLI (needed for builder.build with spec JSON files) -- + RUN apt-get update && apt-get install -y --no-install-recommends curl \ + && curl -fsSL https://github.com/cue-lang/cue/releases/download/v0.11.1/cue_v0.11.1_linux_amd64.tar.gz \ + | tar xz -C /usr/local/bin cue \ + && apt-get purge -y curl && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* + # ---- 7. Copy full source --------------------------------------------- COPY . /opt/http diff --git a/web_api/app/alignment.py b/web_api/app/alignment.py index 1c368018a..e3d52a9da 100644 --- a/web_api/app/alignment.py +++ b/web_api/app/alignment.py @@ -4,6 +4,7 @@ import io import json import struct +import time import numpy as np import torch @@ -14,10 +15,26 @@ from zetta_utils.internal.alignment.manual_correspondence import ( apply_correspondences_to_image, ) +from zetta_utils.internal.alignment.misalignment_detector import MisalignmentDetector from zetta_utils.internal.alignment.sift import compute_sift_correspondences from .utils import generic_exception_handler +MISD_MODEL_PATH = ( + "gs://zetta-research-nico/training_artifacts/aced_misd_general/" + "4.0.1_dsfactor2_thr2.0_lr0.0001_z2/last.ckpt.model.spec.json" +) + +_misd_detector = None + + +def _get_misd_detector(): + global _misd_detector # pylint: disable=global-statement + if _misd_detector is None: + _misd_detector = MisalignmentDetector(model_path=MISD_MODEL_PATH) + return _misd_detector + + api = FastAPI() @@ -83,6 +100,12 @@ class ApplyCorrespondencesResponse(BaseModel): warped_image_shape: list[int] = Field( ..., description="Shape of warped_image array [C, H, W, 1]" ) + misd_image: str | None = Field( + None, description="Base64-encoded misalignment mask (float32 binary)" + ) + misd_image_shape: list[int] | None = Field( + None, description="Shape of misd_image array [C, H, W, 1]" + ) def _decompress_if_gzipped(data: bytes) -> bytes: @@ -260,32 +283,48 @@ async def _parse_multipart_request(request: Request, device: torch.device): ) -def _build_json_response(relaxed_field_np: np.ndarray, warped_image_np: np.ndarray): +def _build_json_response( + relaxed_field_np: np.ndarray, + warped_image_np: np.ndarray, + misd_image_np: np.ndarray | None = None, +): relaxed_field_b64 = base64.b64encode(relaxed_field_np.tobytes()).decode() warped_image_b64 = base64.b64encode(warped_image_np.tobytes()).decode() + misd_b64 = ( + base64.b64encode(misd_image_np.tobytes()).decode() if misd_image_np is not None else None + ) + misd_shape = list(misd_image_np.shape) if misd_image_np is not None else None return ApplyCorrespondencesResponse( relaxed_field=relaxed_field_b64, relaxed_field_shape=list(relaxed_field_np.shape), warped_image=warped_image_b64, warped_image_shape=list(warped_image_np.shape), + misd_image=misd_b64, + misd_image_shape=misd_shape, ) def _build_binary_response( - relaxed_field_np: np.ndarray, warped_image_np: np.ndarray, compress: bool + relaxed_field_np: np.ndarray, + warped_image_np: np.ndarray, + compress: bool, + misd_image_np: np.ndarray | None = None, ): - header_json = json.dumps( - { - "relaxed_field_shape": list(relaxed_field_np.shape), - "warped_image_shape": list(warped_image_np.shape), - } - ).encode() + header = { + "relaxed_field_shape": list(relaxed_field_np.shape), + "warped_image_shape": list(warped_image_np.shape), + } + if misd_image_np is not None: + header["misd_image_shape"] = list(misd_image_np.shape) + header_json = json.dumps(header).encode() buf = io.BytesIO() buf.write(struct.pack(" bool: return "application/octet-stream" in accept or fmt == "binary-v1" +def _run_misd_detection(warped_image, tgt_image_tensor, input_dtype): + t1 = time.time() + misd = _get_misd_detector() + t_model = time.time() - t1 + print(f"[apply_correspondences] misd detector init: {t_model:.2f}s") + + t2 = time.time() + warped_int8 = (warped_image * 127).clamp(-128, 127).round().to(torch.int8) + tgt_int8 = (tgt_image_tensor * 127).clamp(-128, 127).round().to(torch.int8) + misd_mask = misd(warped_int8, tgt_int8) + zero_mask = (warped_int8 == 0).all(dim=0, keepdim=True) | (tgt_int8 == 0).all( + dim=0, keepdim=True + ) + misd_mask[zero_mask] = 0 + print( + f"[apply_correspondences] misd uint8: " + f"min={misd_mask.min().item()} " + f"max={misd_mask.max().item()} " + f"mean={misd_mask.float().mean().item():.1f}" + ) + t_misd = time.time() - t2 + print(f"[apply_correspondences] misd inference: {t_misd:.2f}s") + + misd_float = misd_mask.float() / 255.0 + if input_dtype == torch.int8: + misd_as_input = (misd_float * 255 - 128).clamp(-128, 127).to(torch.int8) + elif input_dtype == torch.uint8: + misd_as_input = misd_mask + else: + misd_as_input = misd_float * 2.0 - 1.0 + print( + f"[apply_correspondences] misd output " + f"(dtype={misd_as_input.dtype}): " + f"min={misd_as_input.min().item()} " + f"max={misd_as_input.max().item()} " + f"mean={misd_as_input.float().mean().item():.1f}" + ) + return misd_as_input + + @api.post("/apply_correspondences") async def apply_correspondences(request: Request): """Apply correspondences to image using relaxation and warping. @@ -376,6 +455,7 @@ async def apply_correspondences(request: Request): ) print(f"[apply_correspondences] params: {params}") + t0 = time.time() relaxed_field, warped_image = apply_correspondences_to_image( correspondences_dict=correspondences_dict, image=image_tensor, @@ -384,15 +464,26 @@ async def apply_correspondences(request: Request): tgt_image=tgt_image_tensor, **params, ) + t_corr = time.time() - t0 + print(f"[apply_correspondences] correspondence relaxation: {t_corr:.2f}s") + + misd_as_input = None + if tgt_image_tensor is not None: + misd_as_input = _run_misd_detection(warped_image, tgt_image_tensor, image_tensor.dtype) + + print(f"[apply_correspondences] total: {time.time() - t0:.2f}s") relaxed_field_np = relaxed_field.cpu().numpy().astype(np.float32) warped_image_np = warped_image.cpu().numpy().astype(np.float32) + misd_image_np = ( + misd_as_input.cpu().numpy().astype(np.float32) if misd_as_input is not None else None + ) if _wants_binary_response(request): compress = "gzip" in request.headers.get("accept-encoding", "") - return _build_binary_response(relaxed_field_np, warped_image_np, compress) + return _build_binary_response(relaxed_field_np, warped_image_np, compress, misd_image_np) - return _build_json_response(relaxed_field_np, warped_image_np) + return _build_json_response(relaxed_field_np, warped_image_np, misd_image_np) class ComputeSiftCorrespondencesRequest(BaseModel):