From 19a0e8997676dfdfc209527c5984527c4efaedb9 Mon Sep 17 00:00:00 2001 From: Sergiy Popovich Date: Sun, 15 Mar 2026 10:51:25 -0400 Subject: [PATCH 1/4] Add misalignment detection to apply_correspondences endpoint --- web_api/Dockerfile.gpu | 8 +++- web_api/app/alignment.py | 81 +++++++++++++++++++++++++++++++++++----- 2 files changed, 78 insertions(+), 11 deletions(-) diff --git a/web_api/Dockerfile.gpu b/web_api/Dockerfile.gpu index 0425fa9f0..25d7f96f8 100644 --- a/web_api/Dockerfile.gpu +++ b/web_api/Dockerfile.gpu @@ -22,7 +22,7 @@ g++ \ python3-dev \ && rm -rf /var/lib/apt/lists/* - + # ---- Copy metadata --------------------------------------------------- COPY pyproject.toml web_api/requirements.txt /opt/http/ @@ -44,6 +44,12 @@ # ---- 6. Install project + modules ------------------------------------ RUN pip install --no-cache-dir '.[modules]' + # ---- Install cue CLI (needed for builder.build with spec JSON files) -- + RUN apt-get update && apt-get install -y --no-install-recommends curl \ + && curl -fsSL https://github.com/cue-lang/cue/releases/download/v0.11.1/cue_v0.11.1_linux_amd64.tar.gz \ + | tar xz -C /usr/local/bin cue \ + && apt-get purge -y curl && apt-get autoremove -y && rm -rf /var/lib/apt/lists/* + # ---- 7. Copy full source --------------------------------------------- COPY . /opt/http diff --git a/web_api/app/alignment.py b/web_api/app/alignment.py index 1c368018a..907091cd0 100644 --- a/web_api/app/alignment.py +++ b/web_api/app/alignment.py @@ -4,6 +4,7 @@ import io import json import struct +import time import numpy as np import torch @@ -14,10 +15,24 @@ from zetta_utils.internal.alignment.manual_correspondence import ( apply_correspondences_to_image, ) +from zetta_utils.internal.alignment.misalignment_detector import MisalignmentDetector from zetta_utils.internal.alignment.sift import compute_sift_correspondences from .utils import generic_exception_handler +MISD_MODEL_PATH = ( + "gs://zetta-research-nico/training_artifacts/aced_misd_general/" + "4.0.1_dsfactor2_thr2.0_lr0.0001_z2/last.ckpt.model.spec.json" +) + +_misd_detector = None + +def _get_misd_detector(): + global _misd_detector + if _misd_detector is None: + _misd_detector = MisalignmentDetector(model_path=MISD_MODEL_PATH) + return _misd_detector + api = FastAPI() @@ -83,6 +98,8 @@ class ApplyCorrespondencesResponse(BaseModel): warped_image_shape: list[int] = Field( ..., description="Shape of warped_image array [C, H, W, 1]" ) + misd_image: str | None = Field(None, description="Base64-encoded misalignment mask (float32 binary)") + misd_image_shape: list[int] | None = Field(None, description="Shape of misd_image array [C, H, W, 1]") def _decompress_if_gzipped(data: bytes) -> bytes: @@ -260,32 +277,39 @@ async def _parse_multipart_request(request: Request, device: torch.device): ) -def _build_json_response(relaxed_field_np: np.ndarray, warped_image_np: np.ndarray): +def _build_json_response(relaxed_field_np: np.ndarray, warped_image_np: np.ndarray, misd_image_np: np.ndarray | None = None): relaxed_field_b64 = base64.b64encode(relaxed_field_np.tobytes()).decode() warped_image_b64 = base64.b64encode(warped_image_np.tobytes()).decode() + misd_b64 = base64.b64encode(misd_image_np.tobytes()).decode() if misd_image_np is not None else None + misd_shape = list(misd_image_np.shape) if misd_image_np is not None else None return ApplyCorrespondencesResponse( relaxed_field=relaxed_field_b64, relaxed_field_shape=list(relaxed_field_np.shape), warped_image=warped_image_b64, warped_image_shape=list(warped_image_np.shape), + misd_image=misd_b64, + misd_image_shape=misd_shape, ) def _build_binary_response( - relaxed_field_np: np.ndarray, warped_image_np: np.ndarray, compress: bool + relaxed_field_np: np.ndarray, warped_image_np: np.ndarray, compress: bool, misd_image_np: np.ndarray | None = None ): - header_json = json.dumps( - { - "relaxed_field_shape": list(relaxed_field_np.shape), - "warped_image_shape": list(warped_image_np.shape), - } - ).encode() + header = { + "relaxed_field_shape": list(relaxed_field_np.shape), + "warped_image_shape": list(warped_image_np.shape), + } + if misd_image_np is not None: + header["misd_image_shape"] = list(misd_image_np.shape) + header_json = json.dumps(header).encode() buf = io.BytesIO() buf.write(struct.pack(" [-1, 1] + misd_as_input = misd_float * 2.0 - 1.0 + print(f"[apply_correspondences] misd output (dtype={misd_as_input.dtype}): min={misd_as_input.min().item()} max={misd_as_input.max().item()} mean={misd_as_input.float().mean().item():.1f}") + + print(f"[apply_correspondences] total: {time.time() - t0:.2f}s") relaxed_field_np = relaxed_field.cpu().numpy().astype(np.float32) warped_image_np = warped_image.cpu().numpy().astype(np.float32) + misd_image_np = misd_as_input.cpu().numpy().astype(np.float32) if tgt_image_tensor is not None else None if _wants_binary_response(request): compress = "gzip" in request.headers.get("accept-encoding", "") - return _build_binary_response(relaxed_field_np, warped_image_np, compress) + return _build_binary_response(relaxed_field_np, warped_image_np, compress, misd_image_np) - return _build_json_response(relaxed_field_np, warped_image_np) + return _build_json_response(relaxed_field_np, warped_image_np, misd_image_np) class ComputeSiftCorrespondencesRequest(BaseModel): From 85d6c4c91a0be93cc4008637bee3e569fd5d500d Mon Sep 17 00:00:00 2001 From: dmytroprokopenko-techmagic Date: Tue, 17 Mar 2026 19:49:33 +0200 Subject: [PATCH 2/4] feat: fix build issues --- web_api/app/alignment.py | 117 ++++++++++++++++++++++++++------------- 1 file changed, 78 insertions(+), 39 deletions(-) diff --git a/web_api/app/alignment.py b/web_api/app/alignment.py index 907091cd0..bc9b6bb8a 100644 --- a/web_api/app/alignment.py +++ b/web_api/app/alignment.py @@ -3,15 +3,13 @@ import gzip import io import json +import numpy as np import struct import time - -import numpy as np import torch from fastapi import FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from pydantic import BaseModel, Field - from zetta_utils.internal.alignment.manual_correspondence import ( apply_correspondences_to_image, ) @@ -28,7 +26,7 @@ _misd_detector = None def _get_misd_detector(): - global _misd_detector + global _misd_detector # pylint: disable=global-statement if _misd_detector is None: _misd_detector = MisalignmentDetector(model_path=MISD_MODEL_PATH) return _misd_detector @@ -98,8 +96,12 @@ class ApplyCorrespondencesResponse(BaseModel): warped_image_shape: list[int] = Field( ..., description="Shape of warped_image array [C, H, W, 1]" ) - misd_image: str | None = Field(None, description="Base64-encoded misalignment mask (float32 binary)") - misd_image_shape: list[int] | None = Field(None, description="Shape of misd_image array [C, H, W, 1]") + misd_image: str | None = Field( + None, description="Base64-encoded misalignment mask (float32 binary)" + ) + misd_image_shape: list[int] | None = Field( + None, description="Shape of misd_image array [C, H, W, 1]" + ) def _decompress_if_gzipped(data: bytes) -> bytes: @@ -277,10 +279,18 @@ async def _parse_multipart_request(request: Request, device: torch.device): ) -def _build_json_response(relaxed_field_np: np.ndarray, warped_image_np: np.ndarray, misd_image_np: np.ndarray | None = None): +def _build_json_response( + relaxed_field_np: np.ndarray, + warped_image_np: np.ndarray, + misd_image_np: np.ndarray | None = None, +): relaxed_field_b64 = base64.b64encode(relaxed_field_np.tobytes()).decode() warped_image_b64 = base64.b64encode(warped_image_np.tobytes()).decode() - misd_b64 = base64.b64encode(misd_image_np.tobytes()).decode() if misd_image_np is not None else None + misd_b64 = ( + base64.b64encode(misd_image_np.tobytes()).decode() + if misd_image_np is not None + else None + ) misd_shape = list(misd_image_np.shape) if misd_image_np is not None else None return ApplyCorrespondencesResponse( relaxed_field=relaxed_field_b64, @@ -293,7 +303,10 @@ def _build_json_response(relaxed_field_np: np.ndarray, warped_image_np: np.ndarr def _build_binary_response( - relaxed_field_np: np.ndarray, warped_image_np: np.ndarray, compress: bool, misd_image_np: np.ndarray | None = None + relaxed_field_np: np.ndarray, + warped_image_np: np.ndarray, + compress: bool, + misd_image_np: np.ndarray | None = None, ): header = { "relaxed_field_shape": list(relaxed_field_np.shape), @@ -336,6 +349,53 @@ def _wants_binary_response(request: Request) -> bool: return "application/octet-stream" in accept or fmt == "binary-v1" +def _run_misd_detection(warped_image, tgt_image_tensor, input_dtype): + t1 = time.time() + misd = _get_misd_detector() + t_model = time.time() - t1 + print(f"[apply_correspondences] misd detector init: {t_model:.2f}s") + + t2 = time.time() + warped_int8 = ( + (warped_image * 127).clamp(-128, 127).round().to(torch.int8) + ) + tgt_int8 = ( + (tgt_image_tensor * 127).clamp(-128, 127).round().to(torch.int8) + ) + misd_mask = misd(warped_int8, tgt_int8) + zero_mask = ( + (warped_int8 == 0).all(dim=0, keepdim=True) + | (tgt_int8 == 0).all(dim=0, keepdim=True) + ) + misd_mask[zero_mask] = 0 + print( + f"[apply_correspondences] misd uint8: " + f"min={misd_mask.min().item()} " + f"max={misd_mask.max().item()} " + f"mean={misd_mask.float().mean().item():.1f}" + ) + t_misd = time.time() - t2 + print(f"[apply_correspondences] misd inference: {t_misd:.2f}s") + + misd_float = misd_mask.float() / 255.0 + if input_dtype == torch.int8: + misd_as_input = ( + (misd_float * 255 - 128).clamp(-128, 127).to(torch.int8) + ) + elif input_dtype == torch.uint8: + misd_as_input = misd_mask + else: + misd_as_input = misd_float * 2.0 - 1.0 + print( + f"[apply_correspondences] misd output " + f"(dtype={misd_as_input.dtype}): " + f"min={misd_as_input.min().item()} " + f"max={misd_as_input.max().item()} " + f"mean={misd_as_input.float().mean().item():.1f}" + ) + return misd_as_input + + @api.post("/apply_correspondences") async def apply_correspondences(request: Request): """Apply correspondences to image using relaxation and warping. @@ -400,7 +460,6 @@ async def apply_correspondences(request: Request): ) print(f"[apply_correspondences] params: {params}") - import time t0 = time.time() relaxed_field, warped_image = apply_correspondences_to_image( correspondences_dict=correspondences_dict, @@ -413,41 +472,21 @@ async def apply_correspondences(request: Request): t_corr = time.time() - t0 print(f"[apply_correspondences] correspondence relaxation: {t_corr:.2f}s") - # Run misalignment detection between warped source and target + misd_as_input = None if tgt_image_tensor is not None: - t1 = time.time() - misd = _get_misd_detector() - t_model = time.time() - t1 - print(f"[apply_correspondences] misd detector init: {t_model:.2f}s") - - t2 = time.time() - # Input images are float [-1, 1] (int8 / 127). Convert to int8 for misd model. - warped_int8 = (warped_image * 127).clamp(-128, 127).round().to(torch.int8) - tgt_int8 = (tgt_image_tensor * 127).clamp(-128, 127).round().to(torch.int8) - misd_mask = misd(warped_int8, tgt_int8) # (1, H, W, 1) uint8 [0, 255] - # Zero out where warped source or target is zero (no data) - zero_mask = (warped_int8 == 0).all(dim=0, keepdim=True) | (tgt_int8 == 0).all(dim=0, keepdim=True) - misd_mask[zero_mask] = 0 - print(f"[apply_correspondences] misd uint8: min={misd_mask.min().item()} max={misd_mask.max().item()} mean={misd_mask.float().mean().item():.1f}") - t_misd = time.time() - t2 - print(f"[apply_correspondences] misd inference: {t_misd:.2f}s") - - # Remap misd [0, 255] uint8 to match input image range - misd_float = misd_mask.float() / 255.0 # [0, 1] - if image_tensor.dtype == torch.int8: - misd_as_input = (misd_float * 255 - 128).clamp(-128, 127).to(torch.int8) - elif image_tensor.dtype == torch.uint8: - misd_as_input = misd_mask - else: - # Input is float [-1, 1], remap [0, 1] -> [-1, 1] - misd_as_input = misd_float * 2.0 - 1.0 - print(f"[apply_correspondences] misd output (dtype={misd_as_input.dtype}): min={misd_as_input.min().item()} max={misd_as_input.max().item()} mean={misd_as_input.float().mean().item():.1f}") + misd_as_input = _run_misd_detection( + warped_image, tgt_image_tensor, image_tensor.dtype + ) print(f"[apply_correspondences] total: {time.time() - t0:.2f}s") relaxed_field_np = relaxed_field.cpu().numpy().astype(np.float32) warped_image_np = warped_image.cpu().numpy().astype(np.float32) - misd_image_np = misd_as_input.cpu().numpy().astype(np.float32) if tgt_image_tensor is not None else None + misd_image_np = ( + misd_as_input.cpu().numpy().astype(np.float32) + if misd_as_input is not None + else None + ) if _wants_binary_response(request): compress = "gzip" in request.headers.get("accept-encoding", "") From 8da463e5935f4efc8bcccd94643dd0c885626b0e Mon Sep 17 00:00:00 2001 From: dmytroprokopenko-techmagic Date: Tue, 17 Mar 2026 19:53:13 +0200 Subject: [PATCH 3/4] feat: fix isort issue --- web_api/app/alignment.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/web_api/app/alignment.py b/web_api/app/alignment.py index bc9b6bb8a..ca890834a 100644 --- a/web_api/app/alignment.py +++ b/web_api/app/alignment.py @@ -25,12 +25,14 @@ _misd_detector = None + def _get_misd_detector(): global _misd_detector # pylint: disable=global-statement if _misd_detector is None: _misd_detector = MisalignmentDetector(model_path=MISD_MODEL_PATH) return _misd_detector + api = FastAPI() @@ -287,9 +289,7 @@ def _build_json_response( relaxed_field_b64 = base64.b64encode(relaxed_field_np.tobytes()).decode() warped_image_b64 = base64.b64encode(warped_image_np.tobytes()).decode() misd_b64 = ( - base64.b64encode(misd_image_np.tobytes()).decode() - if misd_image_np is not None - else None + base64.b64encode(misd_image_np.tobytes()).decode() if misd_image_np is not None else None ) misd_shape = list(misd_image_np.shape) if misd_image_np is not None else None return ApplyCorrespondencesResponse( @@ -356,16 +356,11 @@ def _run_misd_detection(warped_image, tgt_image_tensor, input_dtype): print(f"[apply_correspondences] misd detector init: {t_model:.2f}s") t2 = time.time() - warped_int8 = ( - (warped_image * 127).clamp(-128, 127).round().to(torch.int8) - ) - tgt_int8 = ( - (tgt_image_tensor * 127).clamp(-128, 127).round().to(torch.int8) - ) + warped_int8 = (warped_image * 127).clamp(-128, 127).round().to(torch.int8) + tgt_int8 = (tgt_image_tensor * 127).clamp(-128, 127).round().to(torch.int8) misd_mask = misd(warped_int8, tgt_int8) - zero_mask = ( - (warped_int8 == 0).all(dim=0, keepdim=True) - | (tgt_int8 == 0).all(dim=0, keepdim=True) + zero_mask = (warped_int8 == 0).all(dim=0, keepdim=True) | (tgt_int8 == 0).all( + dim=0, keepdim=True ) misd_mask[zero_mask] = 0 print( @@ -379,9 +374,7 @@ def _run_misd_detection(warped_image, tgt_image_tensor, input_dtype): misd_float = misd_mask.float() / 255.0 if input_dtype == torch.int8: - misd_as_input = ( - (misd_float * 255 - 128).clamp(-128, 127).to(torch.int8) - ) + misd_as_input = (misd_float * 255 - 128).clamp(-128, 127).to(torch.int8) elif input_dtype == torch.uint8: misd_as_input = misd_mask else: @@ -474,18 +467,14 @@ async def apply_correspondences(request: Request): misd_as_input = None if tgt_image_tensor is not None: - misd_as_input = _run_misd_detection( - warped_image, tgt_image_tensor, image_tensor.dtype - ) + misd_as_input = _run_misd_detection(warped_image, tgt_image_tensor, image_tensor.dtype) print(f"[apply_correspondences] total: {time.time() - t0:.2f}s") relaxed_field_np = relaxed_field.cpu().numpy().astype(np.float32) warped_image_np = warped_image.cpu().numpy().astype(np.float32) misd_image_np = ( - misd_as_input.cpu().numpy().astype(np.float32) - if misd_as_input is not None - else None + misd_as_input.cpu().numpy().astype(np.float32) if misd_as_input is not None else None ) if _wants_binary_response(request): From 96a48c50c2a1e3d6a6fd5ef95264a2260395d78e Mon Sep 17 00:00:00 2001 From: dmytroprokopenko-techmagic Date: Tue, 17 Mar 2026 20:20:04 +0200 Subject: [PATCH 4/4] feat: fix isort issue --- web_api/app/alignment.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/web_api/app/alignment.py b/web_api/app/alignment.py index ca890834a..e3d52a9da 100644 --- a/web_api/app/alignment.py +++ b/web_api/app/alignment.py @@ -3,13 +3,15 @@ import gzip import io import json -import numpy as np import struct import time + +import numpy as np import torch from fastapi import FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from pydantic import BaseModel, Field + from zetta_utils.internal.alignment.manual_correspondence import ( apply_correspondences_to_image, ) @@ -282,9 +284,9 @@ async def _parse_multipart_request(request: Request, device: torch.device): def _build_json_response( - relaxed_field_np: np.ndarray, - warped_image_np: np.ndarray, - misd_image_np: np.ndarray | None = None, + relaxed_field_np: np.ndarray, + warped_image_np: np.ndarray, + misd_image_np: np.ndarray | None = None, ): relaxed_field_b64 = base64.b64encode(relaxed_field_np.tobytes()).decode() warped_image_b64 = base64.b64encode(warped_image_np.tobytes()).decode() @@ -303,10 +305,10 @@ def _build_json_response( def _build_binary_response( - relaxed_field_np: np.ndarray, - warped_image_np: np.ndarray, - compress: bool, - misd_image_np: np.ndarray | None = None, + relaxed_field_np: np.ndarray, + warped_image_np: np.ndarray, + compress: bool, + misd_image_np: np.ndarray | None = None, ): header = { "relaxed_field_shape": list(relaxed_field_np.shape),