Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions web_api/app/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from scipy.ndimage import binary_closing
from scipy.ndimage import label as scipy_label

from zetta_utils.internal.alignment.manual_correspondence import (
apply_correspondences_to_image,
Expand Down Expand Up @@ -94,6 +96,11 @@ class ApplyCorrespondencesRequest(BaseModel):
description="Weight for MSE loss between warped source and target images. "
"Only used when tgt_image is provided.",
)
downsample_factor: int = Field(
16,
description="Multi-scale downsampling factor for field propagation. "
"1 = no downsampling. 8 or 16 for faster initialization.",
)


class ApplyCorrespondencesResponse(BaseModel):
Expand Down Expand Up @@ -154,6 +161,7 @@ def _parse_json_request(body: dict, device: torch.device):
"lr": req.lr,
"optimizer_type": req.optimizer_type,
"mse_weight": req.mse_weight,
"downsample_factor": req.downsample_factor,
}
return (
correspondences_dict,
Expand Down Expand Up @@ -277,6 +285,7 @@ async def _parse_multipart_request(request: Request, device: torch.device):
"lr": metadata.get("lr", 1e-3),
"optimizer_type": metadata.get("optimizer_type", "adam"),
"mse_weight": metadata.get("mse_weight", 1.0),
"downsample_factor": metadata.get("downsample_factor", 16),
}
return (
correspondences_dict,
Expand Down Expand Up @@ -370,6 +379,26 @@ def _run_misd_detection(warped_image, tgt_image_tensor, input_dtype):
dim=0, keepdim=True
)
misd_mask[zero_mask] = 0

structure = np.ones((3, 3), dtype=bool)
misd_np = misd_mask.cpu().numpy() # (1, X, Y, Z)
out_np = np.zeros_like(misd_np)
for z in range(misd_np.shape[3]):
sl = misd_np[0, :, :, z] > 128
if not sl.any():
continue
sl = binary_closing(sl, structure=structure, iterations=1)
if not sl.any():
continue
labels, _ = scipy_label(sl)
ids, counts = np.unique(labels, return_counts=True)
keep_mask = np.zeros_like(sl)
for seg_id, count in zip(ids, counts):
if seg_id != 0 and count >= 600:
keep_mask |= labels == seg_id
out_np[0, :, :, z][keep_mask] = 255
misd_mask = torch.from_numpy(out_np).to(misd_mask.device)

print(
f"[apply_correspondences] misd uint8: "
f"min={misd_mask.min().item()} "
Expand Down Expand Up @@ -515,6 +544,11 @@ def _run_computation():
return _build_json_response(relaxed_field_np, warped_image_np, misd_image_np)


class ExistingCorrespondenceItem(BaseModel):
start: list[float] = Field(..., description="Start point [y, x]")
end: list[float] = Field(..., description="End point [y, x]")


class ComputeSiftCorrespondencesRequest(BaseModel):
src_image: str = Field(..., description="Base64-encoded uint8 image bytes")
tgt_image: str = Field(..., description="Base64-encoded uint8 image bytes")
Expand All @@ -536,6 +570,11 @@ class ComputeSiftCorrespondencesRequest(BaseModel):
True,
description="If true, output [y, x] (Portal convention). If false, output [x, y].",
)
existing_correspondences: list[ExistingCorrespondenceItem] | None = Field(
None,
description="Existing correspondence arrows for deduplication/refiltering. "
"Each item has start [y, x] and end [y, x] in pixel coords.",
)


class ComputeSiftCorrespondencesResponse(BaseModel):
Expand Down Expand Up @@ -564,18 +603,19 @@ class ComputeSiftCorrespondencesResponse(BaseModel):
"edge_threshold": 10,
"sigma": 1.6,
"ratio_test_fraction": 0.7,
"ransac_threshold": 3.0,
"use_ransac": False,
"ransac_threshold": 5.0,
"use_ransac": True,
"spatial_weight": 0.7,
"swap_xy": True,
}


async def _parse_sift_request(
request: Request, use_ransac_default: bool
) -> tuple[np.ndarray, np.ndarray, dict]:
) -> tuple[np.ndarray, np.ndarray, dict, list[dict] | None]:
content_type = request.headers.get("content-type", "")
defaults = {**_SIFT_PARAM_DEFAULTS, "use_ransac": use_ransac_default}
existing_correspondences = None

if "multipart/form-data" in content_type:
form = await request.form()
Expand Down Expand Up @@ -607,6 +647,7 @@ async def _parse_sift_request(
tgt_image = np.frombuffer(tgt_bytes, dtype=np.uint8).reshape(metadata["tgt_image_shape"])

sift_params = {k: metadata.get(k, defaults[k]) for k in _SIFT_PARAM_KEYS}
existing_correspondences = metadata.get("existing_correspondences")
else:
body = await request.json()
req = ComputeSiftCorrespondencesRequest(
Expand All @@ -621,16 +662,27 @@ async def _parse_sift_request(
)

sift_params = {k: getattr(req, k) for k in _SIFT_PARAM_KEYS}
if req.existing_correspondences is not None:
existing_correspondences = [
{"start": ec.start, "end": ec.end} for ec in req.existing_correspondences
]

return src_image, tgt_image, sift_params
return src_image, tgt_image, sift_params, existing_correspondences


async def _run_sift_correspondences(
request: Request, use_ransac_default: bool
) -> ComputeSiftCorrespondencesResponse:
src_image, tgt_image, sift_params = await _parse_sift_request(request, use_ransac_default)
src_image, tgt_image, sift_params, existing_correspondences = await _parse_sift_request(
request, use_ransac_default
)

result = compute_sift_correspondences(src=src_image, tgt=tgt_image, **sift_params)
result = compute_sift_correspondences(
src=src_image,
tgt=tgt_image,
existing_correspondences=existing_correspondences,
**sift_params,
)

return ComputeSiftCorrespondencesResponse(
lines=[CorrespondenceLine(**line) for line in result["lines"]],
Expand Down
2 changes: 1 addition & 1 deletion zetta_utils/internal
Loading