Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion web_api/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
g++ \
python3-dev \
&& rm -rf /var/lib/apt/lists/*

# ---- Copy metadata ---------------------------------------------------
COPY pyproject.toml web_api/requirements.txt /opt/http/

Expand All @@ -44,6 +44,12 @@
# ---- 6. Install project + modules ------------------------------------
RUN pip install --no-cache-dir '.[modules]'

# ---- Install cue CLI (needed for builder.build with spec JSON files) --
RUN apt-get update && apt-get install -y --no-install-recommends curl \
&& curl -fsSL https://github.com/cue-lang/cue/releases/download/v0.11.1/cue_v0.11.1_linux_amd64.tar.gz \
| tar xz -C /usr/local/bin cue \
&& apt-get purge -y curl && apt-get autoremove -y && rm -rf /var/lib/apt/lists/*

# ---- 7. Copy full source ---------------------------------------------
COPY . /opt/http

Expand Down
111 changes: 101 additions & 10 deletions web_api/app/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io
import json
import struct
import time

import numpy as np
import torch
Expand All @@ -14,10 +15,26 @@
from zetta_utils.internal.alignment.manual_correspondence import (
apply_correspondences_to_image,
)
from zetta_utils.internal.alignment.misalignment_detector import MisalignmentDetector
from zetta_utils.internal.alignment.sift import compute_sift_correspondences

from .utils import generic_exception_handler

MISD_MODEL_PATH = (
"gs://zetta-research-nico/training_artifacts/aced_misd_general/"
"4.0.1_dsfactor2_thr2.0_lr0.0001_z2/last.ckpt.model.spec.json"
)

_misd_detector = None


def _get_misd_detector():
global _misd_detector # pylint: disable=global-statement
if _misd_detector is None:
_misd_detector = MisalignmentDetector(model_path=MISD_MODEL_PATH)
return _misd_detector


api = FastAPI()


Expand Down Expand Up @@ -83,6 +100,12 @@ class ApplyCorrespondencesResponse(BaseModel):
warped_image_shape: list[int] = Field(
..., description="Shape of warped_image array [C, H, W, 1]"
)
misd_image: str | None = Field(
None, description="Base64-encoded misalignment mask (float32 binary)"
)
misd_image_shape: list[int] | None = Field(
None, description="Shape of misd_image array [C, H, W, 1]"
)


def _decompress_if_gzipped(data: bytes) -> bytes:
Expand Down Expand Up @@ -260,32 +283,48 @@ async def _parse_multipart_request(request: Request, device: torch.device):
)


def _build_json_response(relaxed_field_np: np.ndarray, warped_image_np: np.ndarray):
def _build_json_response(
relaxed_field_np: np.ndarray,
warped_image_np: np.ndarray,
misd_image_np: np.ndarray | None = None,
):
relaxed_field_b64 = base64.b64encode(relaxed_field_np.tobytes()).decode()
warped_image_b64 = base64.b64encode(warped_image_np.tobytes()).decode()
misd_b64 = (
base64.b64encode(misd_image_np.tobytes()).decode() if misd_image_np is not None else None
)
misd_shape = list(misd_image_np.shape) if misd_image_np is not None else None
return ApplyCorrespondencesResponse(
relaxed_field=relaxed_field_b64,
relaxed_field_shape=list(relaxed_field_np.shape),
warped_image=warped_image_b64,
warped_image_shape=list(warped_image_np.shape),
misd_image=misd_b64,
misd_image_shape=misd_shape,
)


def _build_binary_response(
relaxed_field_np: np.ndarray, warped_image_np: np.ndarray, compress: bool
relaxed_field_np: np.ndarray,
warped_image_np: np.ndarray,
compress: bool,
misd_image_np: np.ndarray | None = None,
):
header_json = json.dumps(
{
"relaxed_field_shape": list(relaxed_field_np.shape),
"warped_image_shape": list(warped_image_np.shape),
}
).encode()
header = {
"relaxed_field_shape": list(relaxed_field_np.shape),
"warped_image_shape": list(warped_image_np.shape),
}
if misd_image_np is not None:
header["misd_image_shape"] = list(misd_image_np.shape)
header_json = json.dumps(header).encode()

buf = io.BytesIO()
buf.write(struct.pack("<I", len(header_json)))
buf.write(header_json)
buf.write(relaxed_field_np.tobytes())
buf.write(warped_image_np.tobytes())
if misd_image_np is not None:
buf.write(misd_image_np.tobytes())
payload = buf.getvalue()

if compress:
Expand All @@ -312,6 +351,46 @@ def _wants_binary_response(request: Request) -> bool:
return "application/octet-stream" in accept or fmt == "binary-v1"


def _run_misd_detection(warped_image, tgt_image_tensor, input_dtype):
t1 = time.time()
misd = _get_misd_detector()
t_model = time.time() - t1
print(f"[apply_correspondences] misd detector init: {t_model:.2f}s")

t2 = time.time()
warped_int8 = (warped_image * 127).clamp(-128, 127).round().to(torch.int8)
tgt_int8 = (tgt_image_tensor * 127).clamp(-128, 127).round().to(torch.int8)
misd_mask = misd(warped_int8, tgt_int8)
zero_mask = (warped_int8 == 0).all(dim=0, keepdim=True) | (tgt_int8 == 0).all(
dim=0, keepdim=True
)
misd_mask[zero_mask] = 0
print(
f"[apply_correspondences] misd uint8: "
f"min={misd_mask.min().item()} "
f"max={misd_mask.max().item()} "
f"mean={misd_mask.float().mean().item():.1f}"
)
t_misd = time.time() - t2
print(f"[apply_correspondences] misd inference: {t_misd:.2f}s")

misd_float = misd_mask.float() / 255.0
if input_dtype == torch.int8:
misd_as_input = (misd_float * 255 - 128).clamp(-128, 127).to(torch.int8)
elif input_dtype == torch.uint8:
misd_as_input = misd_mask
else:
misd_as_input = misd_float * 2.0 - 1.0
print(
f"[apply_correspondences] misd output "
f"(dtype={misd_as_input.dtype}): "
f"min={misd_as_input.min().item()} "
f"max={misd_as_input.max().item()} "
f"mean={misd_as_input.float().mean().item():.1f}"
)
return misd_as_input


@api.post("/apply_correspondences")
async def apply_correspondences(request: Request):
"""Apply correspondences to image using relaxation and warping.
Expand Down Expand Up @@ -376,6 +455,7 @@ async def apply_correspondences(request: Request):
)
print(f"[apply_correspondences] params: {params}")

t0 = time.time()
relaxed_field, warped_image = apply_correspondences_to_image(
correspondences_dict=correspondences_dict,
image=image_tensor,
Expand All @@ -384,15 +464,26 @@ async def apply_correspondences(request: Request):
tgt_image=tgt_image_tensor,
**params,
)
t_corr = time.time() - t0
print(f"[apply_correspondences] correspondence relaxation: {t_corr:.2f}s")

misd_as_input = None
if tgt_image_tensor is not None:
misd_as_input = _run_misd_detection(warped_image, tgt_image_tensor, image_tensor.dtype)

print(f"[apply_correspondences] total: {time.time() - t0:.2f}s")

relaxed_field_np = relaxed_field.cpu().numpy().astype(np.float32)
warped_image_np = warped_image.cpu().numpy().astype(np.float32)
misd_image_np = (
misd_as_input.cpu().numpy().astype(np.float32) if misd_as_input is not None else None
)

if _wants_binary_response(request):
compress = "gzip" in request.headers.get("accept-encoding", "")
return _build_binary_response(relaxed_field_np, warped_image_np, compress)
return _build_binary_response(relaxed_field_np, warped_image_np, compress, misd_image_np)

return _build_json_response(relaxed_field_np, warped_image_np)
return _build_json_response(relaxed_field_np, warped_image_np, misd_image_np)


class ComputeSiftCorrespondencesRequest(BaseModel):
Expand Down
Loading