Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@
AsDiscrete,
DistanceTransformEDT,
FillHoles,
GenerateHeatmap,
Invert,
KeepLargestConnectedComponent,
LabelFilter,
Expand All @@ -319,6 +320,9 @@
FillHolesD,
FillHolesd,
FillHolesDict,
GenerateHeatmapd,
GenerateHeatmapD,
GenerateHeatmapDict,
InvertD,
Invertd,
InvertDict,
Expand Down
158 changes: 157 additions & 1 deletion monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,14 @@
remove_small_objects,
)
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option
from monai.utils import (
TransformBackends,
convert_data_type,
convert_to_tensor,
ensure_tuple,
get_equivalent_dtype,
look_up_option,
)
from monai.utils.type_conversion import convert_to_dst_type

__all__ = [
Expand All @@ -54,6 +61,7 @@
"SobelGradients",
"VoteEnsemble",
"Invert",
"GenerateHeatmap",
"DistanceTransformEDT",
]

Expand Down Expand Up @@ -742,6 +750,154 @@ def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayO
return self.post_convert(out_pt, img)


class GenerateHeatmap(Transform):
"""
Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.

Notes:
- Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.
- Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D.
- Output layout uses channel-first convention with one channel per landmark.
- Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions (2 or 3).
- Output heatmap shape: (N, Y, X) for 2D or (N, Z, Y, X) for 3D.
- Each channel index corresponds to one landmark.

Args:
sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.
spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.
truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.
normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.
dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).

Raises:
ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved.

"""

backend = [TransformBackends.NUMPY, TransformBackends.TORCH]

def __init__(
self,
sigma: Sequence[float] | float = 5.0,
spatial_shape: Sequence[int] | None = None,
truncated: float = 4.0,
normalize: bool = True,
dtype: np.dtype | torch.dtype | type = np.float32,
) -> None:
if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)):
if any(s <= 0 for s in sigma):
raise ValueError("Argument `sigma` values must be positive.")
self._sigma = tuple(float(s) for s in sigma)
else:
if float(sigma) <= 0:
raise ValueError("Argument `sigma` must be positive.")
self._sigma = (float(sigma),)
if truncated <= 0:
raise ValueError("Argument `truncated` must be positive.")
self.truncated = float(truncated)
self.normalize = normalize
self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)
self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)
# Validate that dtype is floating-point for meaningful Gaussian values
if not self.torch_dtype.is_floating_point:
raise ValueError(f"Argument `dtype` must be a floating-point type, got {self.torch_dtype}")
self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)

def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:
"""
Args:
points: landmark coordinates as ndarray/Tensor with shape (N, D),
ordered as (Y, X) for 2D or (Z, Y, X) for 3D, where N is the number
of landmarks and D is the spatial dimensionality.
spatial_shape: spatial size as a sequence. If None, uses the value provided at construction.

Returns:
Heatmaps with shape (N, *spatial), one channel per landmark.

Raises:
ValueError: if points shape/dimension or spatial_shape is invalid.
"""
original_points = points
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)

if points_t.ndim != 2:
raise ValueError(
f"Argument `points` must be a 2D array with shape (num_points, spatial_dims), got shape {points_t.shape}."
)

if points_t.shape[-1] not in (2, 3):
raise ValueError("GenerateHeatmap only supports 2D or 3D landmarks.")

device = points_t.device
num_points, spatial_dims = points_t.shape

target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims)
sigma = self._resolve_sigma(spatial_dims)

# Create sparse image with impulses at landmark locations
heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device)
bounds_t = torch.as_tensor(target_shape, device=device, dtype=points_t.dtype)

for idx, center in enumerate(points_t):
if not torch.isfinite(center).all():
continue
if not ((center >= 0).all() and (center < bounds_t).all()):
continue
# Round to nearest integer for impulse placement, then clamp to valid index range
center_int = center.round().long()
# Clamp indices to [0, size-1] to avoid out-of-bounds (e.g., 9.7 rounds to 10 in size-10 array)
bounds_max = (bounds_t - 1).long()
center_int = torch.minimum(torch.maximum(center_int, torch.zeros_like(center_int)), bounds_max)
# Place impulse (use maximum in case of overlapping landmarks)
current_val = heatmap[idx][tuple(center_int)]
heatmap[idx][tuple(center_int)] = torch.maximum(
current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device)
)

# Apply Gaussian blur using GaussianFilter
# Reshape to (num_points, 1, *spatial) for per-channel filtering
heatmap_input = heatmap.unsqueeze(1) # Add channel dimension

gaussian_filter = GaussianFilter(
spatial_dims=spatial_dims, sigma=sigma, truncated=self.truncated, approx="erf", requires_grad=False
).to(device=device, dtype=self.torch_dtype)

heatmap_blurred = gaussian_filter(heatmap_input)
heatmap = heatmap_blurred.squeeze(1) # Remove channel dimension

# Normalize per channel if requested
if self.normalize:
for idx in range(num_points):
peak = heatmap[idx].amax()
if peak > 0:
heatmap[idx].div_(peak)

target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype
converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype)
return converted

def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]:
shape = call_shape if call_shape is not None else self.spatial_shape
if shape is None:
raise ValueError("Argument `spatial_shape` must be provided either at construction time or call time.")
shape_tuple = ensure_tuple(shape)
if len(shape_tuple) != spatial_dims:
if len(shape_tuple) == 1:
shape_tuple = shape_tuple * spatial_dims # type: ignore
else:
raise ValueError(
"Argument `spatial_shape` length must match the landmarks' spatial dims (or pass a single int to broadcast)."
)
return tuple(int(s) for s in shape_tuple)

def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:
if len(self._sigma) == spatial_dims:
return self._sigma
if len(self._sigma) == 1:
return self._sigma * spatial_dims
raise ValueError("Argument `sigma` sequence length must equal the number of spatial dimensions.")


class ProbNMS(Transform):
"""
Performs probability based non-maximum suppression (NMS) on the probabilities map via
Expand Down
Loading
Loading