From 6f0cf5b7394397bcb4313f6139b39cc4dbf01b60 Mon Sep 17 00:00:00 2001 From: teo-zetta Date: Mon, 23 Mar 2026 16:52:30 -0500 Subject: [PATCH] fix collection annotation input where a single z layer results in bbox with z size 0 --- zetta_utils/geometry/bbox.py | 15 ++++++++++++--- .../training/datasets/collection_dataset.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/zetta_utils/geometry/bbox.py b/zetta_utils/geometry/bbox.py index f776f9d5f..b64fe7048 100644 --- a/zetta_utils/geometry/bbox.py +++ b/zetta_utils/geometry/bbox.py @@ -515,7 +515,7 @@ def snapped( self, grid_offset: Sequence[float], grid_size: Sequence[float], - mode: Literal["shrink", "expand"], + mode: Literal["shrink", "expand", "floor"], ) -> BBox3D: """Returns a BoundingBox snapped to a grid with the given offset and size. @@ -536,8 +536,7 @@ def snapped( floor(round((b[1] - o) / s, VEC3D_PRECISION)) * s + o for b, o, s in zip(self.bounds, grid_offset, grid_size) ) - else: - assert mode == "expand", "Typechecking error" + elif mode == "expand": start_final = tuple( floor(round((b[0] - o) / s, VEC3D_PRECISION)) * s + o for b, o, s in zip(self.bounds, grid_offset, grid_size) @@ -546,6 +545,16 @@ def snapped( floor(round((b[1] - o) / s + 1, VEC3D_PRECISION) - EPS) * s + o for b, o, s in zip(self.bounds, grid_offset, grid_size) ) + else: + assert mode == "floor", "Typechecking error" + start_final = tuple( + floor(round((b[0] - o) / s, VEC3D_PRECISION)) * s + o + for b, o, s in zip(self.bounds, grid_offset, grid_size) + ) + end_final = tuple( + floor(round((b[1] - o) / s, VEC3D_PRECISION)) * s + o + for b, o, s in zip(self.bounds, grid_offset, grid_size) + ) return BBox3D.from_coords( start_coord=cast(tuple[float, float, float], start_final), end_coord=cast(tuple[float, float, float], end_final), diff --git a/zetta_utils/training/datasets/collection_dataset.py b/zetta_utils/training/datasets/collection_dataset.py index 68da6c163..606a54754 100644 --- a/zetta_utils/training/datasets/collection_dataset.py +++ b/zetta_utils/training/datasets/collection_dataset.py @@ -80,7 +80,7 @@ def build_collection_dataset( # pylint: disable=too-many-locals this_resolution = [resolution[0], resolution[1], z_resolution] if isinstance(annotation.ng_annotation, AxisAlignedBoundingBoxAnnotation): bbox = BBox3D.from_ng_bbox(annotation.ng_annotation, (1, 1, 1)).snapped( - (0, 0, 0), this_resolution, "shrink" + (0, 0, 0), this_resolution, "floor" ) this_dset = LayerDataset(