Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ def _run_cmd(self, cmd: str, devices_info: str = "") -> subprocess.CompletedProc
look_up_option(self.device_setting["MN_START_METHOD"], ["bcprun"])
except ValueError as err:
raise NotImplementedError(
f"{self.device_setting['MN_START_METHOD']} is not supported yet."
"Try modify BundleAlgo._run_cmd for your cluster."
f"{self.device_setting['MN_START_METHOD']} is not supported yet. Try modify BundleAlgo._run_cmd for your cluster."
) from err

return _run_cmd_bcprun(cmd, n=self.device_setting["NUM_NODES"], p=self.device_setting["n_devices"])
Expand Down
5 changes: 2 additions & 3 deletions monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,7 @@ def set_regular_matcher(
"""
if fg_iou_thresh < bg_iou_thresh:
raise ValueError(
"Require fg_iou_thresh >= bg_iou_thresh. "
f"Got fg_iou_thresh={fg_iou_thresh}, bg_iou_thresh={bg_iou_thresh}."
f"Require fg_iou_thresh >= bg_iou_thresh. Got fg_iou_thresh={fg_iou_thresh}, bg_iou_thresh={bg_iou_thresh}."
)
self.proposal_matcher = Matcher(
fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=allow_low_quality_matches
Expand Down Expand Up @@ -519,7 +518,7 @@ def forward(
else:
if self.inferer is None:
raise ValueError(
"`self.inferer` is not defined." "Please refer to function self.set_sliding_window_inferer(*)."
"`self.inferer` is not defined. Please refer to function self.set_sliding_window_inferer(*)."
)
head_outputs = predict_with_inferer(
images, self.network, keys=[self.cls_key, self.box_reg_key], inferer=self.inferer
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/detection/transforms/box_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def convert_box_to_mask(
boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b])
# apply to global mask
slicing = [b]
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type:ignore
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type: ignore
boxes_mask_np[tuple(slicing)] = boxes_only_mask
return convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16)[0]

Expand Down
5 changes: 2 additions & 3 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def update_ops_nested_label(self, nested_key: str, op: Operations) -> None:
raise ValueError("Nested_key input format is wrong. Please ensure it is like key1#0#key2")
root: str
child_key: str
(root, _, child_key) = keys
root, _, child_key = keys
if root not in self.ops:
self.ops[root] = [{}]
self.ops[root][0].update({child_key: None})
Expand Down Expand Up @@ -952,8 +952,7 @@ def __call__(self, data: dict) -> dict:
self.hist_range = nr_channels * self.hist_range
if len(self.hist_range) != nr_channels:
raise ValueError(
f"There is a mismatch between the number of channels ({nr_channels}) "
f"and histogram ranges ({len(self.hist_range)})."
f"There is a mismatch between the number of channels ({nr_channels}) and histogram ranges ({len(self.hist_range)})."
)

# perform calculation
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,7 +1948,7 @@ def create_workflow(

"""
_args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
(workflow_name, config_file) = _pop_args(
workflow_name, config_file = _pop_args(
_args, workflow_name=ConfigWorkflow, config_file=None
) # the default workflow name is "ConfigWorkflow"
if isinstance(workflow_name, str):
Expand Down
4 changes: 2 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class DatasetFunc(Dataset):
"""

def __init__(self, data: Any, func: Callable, **kwargs) -> None:
super().__init__(data=None, transform=None) # type:ignore
super().__init__(data=None, transform=None) # type: ignore
self.src = data
self.func = func
self.kwargs = kwargs
Expand Down Expand Up @@ -1635,7 +1635,7 @@ def _cachecheck(self, item_transformed):
return (_data, _meta)
return _data
else:
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type: ignore
for i, _item in enumerate(item_transformed):
for k in _item:
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")
Expand Down
3 changes: 1 addition & 2 deletions monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,7 @@ def get_data(
# Check if there are three color channels for RGB
elif mode in "RGB" and patch.shape[self.channel_dim] != 3:
raise ValueError(
f"The image is expected to have three color channels in '{mode}' mode but has "
f"{patch.shape[self.channel_dim]}. "
f"The image is expected to have three color channels in '{mode}' mode but has {patch.shape[self.channel_dim]}. "
)
# Get patch-related metadata
metadata: dict = self._get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level)
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def stopping_fn_from_loss() -> Callable[[Engine], Any]:
"""

def stopping_fn(engine: Engine) -> Any:
return -engine.state.output # type:ignore
return -engine.state.output # type: ignore

return stopping_fn

Expand Down
2 changes: 1 addition & 1 deletion monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def get_edge_surface_distance(
edges_spacing = None
if use_subvoxels:
edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape))
(edges_pred, edges_gt, *areas) = get_mask_edges(
edges_pred, edges_gt, *areas = get_mask_edges(
y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False
)
if not edges_gt.any():
Expand Down
5 changes: 4 additions & 1 deletion monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from monai.config.deviceconfig import USE_COMPILED
from monai.networks.layers.spatial_transforms import grid_pull
from monai.networks.utils import meshgrid_ij
from monai.transforms.spatial.functional import _compiled_unsupported
from monai.utils import GridSampleMode, GridSamplePadMode, optional_import

_C, _ = optional_import("monai._C")
Expand Down Expand Up @@ -138,7 +139,9 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor):
grid = self.get_reference_grid(ddf, jitter=self.jitter) + ddf
grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims)

if not USE_COMPILED: # pytorch native grid_sample
_use_compiled = USE_COMPILED and not _compiled_unsupported(image.device)

if not _use_compiled: # pytorch native grid_sample
for i, dim in enumerate(grid.shape[1:-1]):
grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1
index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1))
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""
A collection of "vanilla" transforms for IO functions.
"""

from __future__ import annotations

import inspect
Expand Down
1 change: 0 additions & 1 deletion monai/transforms/regularization/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


class Mixer(RandomizableTransform):

def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
"""
Mixer is a base class providing the basic logic for the mixup-class of
Expand Down
22 changes: 19 additions & 3 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.functional import (
_compiled_unsupported,
affine_func,
convert_box_to_points,
convert_points_to_box,
Expand Down Expand Up @@ -2104,14 +2105,15 @@ def __call__(
_align_corners = self.align_corners if align_corners is None else align_corners
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype, device=_device)
sr = min(len(img_t.peek_pending_shape() if isinstance(img_t, MetaTensor) else img_t.shape[1:]), 3)
_use_compiled = USE_COMPILED and not _compiled_unsupported(img_t.device)
backend, _interp_mode, _padding_mode, _ = resolves_modes(
self.mode if mode is None else mode,
self.padding_mode if padding_mode is None else padding_mode,
backend=None,
use_compiled=USE_COMPILED,
use_compiled=_use_compiled,
)

if USE_COMPILED or backend == TransformBackends.NUMPY:
if _use_compiled or backend == TransformBackends.NUMPY:
grid_t, *_ = convert_to_dst_type(grid[:sr], img_t, dtype=grid.dtype, wrap_sequence=True)
if isinstance(grid, torch.Tensor) and grid_t.data_ptr() == grid.data_ptr():
grid_t = grid_t.clone(memory_format=torch.contiguous_format)
Expand All @@ -2122,7 +2124,7 @@ def __call__(
grid_t[i] = ((_dim - 1) / _dim) * grid_t[i] + t if _align_corners else grid_t[i] + t
elif _align_corners:
grid_t[i] = ((_dim - 1) / _dim) * (grid_t[i] + 0.5)
if USE_COMPILED and backend == TransformBackends.TORCH: # compiled is using torch backend param name
if _use_compiled and backend == TransformBackends.TORCH: # compiled is using torch backend param name
grid_t = moveaxis(grid_t, 0, -1) # type: ignore
out = grid_pull(
img_t.unsqueeze(0),
Expand All @@ -2140,6 +2142,20 @@ def __call__(
[_map_coord(c, grid_np, order=_interp_mode, mode=_padding_mode) for c in img_np]
)
out = convert_to_dst_type(out, img_t)[0]
else:
# Fallback to PyTorch grid_sample when compiled extension is unsupported.
# Convert grid coordinates from compiled convention [0, size-1] to PyTorch [-1, 1]
for i, dim in enumerate(img_t.shape[1 : 1 + sr]):
_dim = max(2, dim)
grid_t[i] = (grid_t[i] * 2.0 / _dim) - 1.0
grid_t = moveaxis(grid_t, 0, -1) # type: ignore
out = torch.nn.functional.grid_sample(
img_t.unsqueeze(0),
grid_t.unsqueeze(0),
mode=_interp_mode,
padding_mode=_padding_mode,
align_corners=None if _align_corners == TraceKeys.NONE else _align_corners, # type: ignore
)[0]
else:
grid_t = moveaxis(grid[list(range(sr - 1, -1, -1))], 0, -1) # type: ignore
grid_t = convert_to_dst_type(grid_t, img_t, wrap_sequence=True)[0].unsqueeze(0)
Expand Down
32 changes: 31 additions & 1 deletion monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,35 @@
__all__ = ["spatial_resample", "orientation", "flip", "resize", "rotate", "zoom", "rotate90", "affine_func"]


def _compiled_unsupported(device: torch.device) -> bool:
"""
Return True if ``monai._C`` (the compiled C extension providing ``grid_pull``) is not
compiled with support for the given CUDA device's compute capability.

Args:
device: The torch device to check for compiled extension support.

Returns:
True if the device is CUDA with compute capability major >= 12 (Blackwell+),
False otherwise. Always returns False for CPU devices.

Note:
``monai._C`` is built at install time against a fixed set of CUDA architectures.
NVIDIA Blackwell GPUs (sm_120, compute capability 12.x) and newer were not included in
the default ``TORCH_CUDA_ARCH_LIST`` when the MONAI slim image was originally built,
so executing ``grid_pull`` on those devices produces incorrect results. Falling back to
the PyTorch-native ``affine_grid`` + ``grid_sample`` path (``USE_COMPILED=False``) gives
correct output on all architectures.

The threshold (``major >= 12``) matches the first architecture family (Blackwell, sm_120)
that shipped after the highest sm supported in the current default build list (sm_90,
Hopper). Adjust this constant when ``monai._C`` is rebuilt with sm_120+ support.
"""
if device.type != "cuda":
return False
return torch.cuda.get_device_properties(device).major >= 12


def _maybe_new_metatensor(img, dtype=None, device=None):
"""create a metatensor with fresh metadata if track_meta is True otherwise convert img into a torch tensor"""
return convert_to_tensor(
Expand Down Expand Up @@ -158,7 +187,8 @@ def spatial_resample(
xform_shape = [-1] + in_sp_size
img = img.reshape(xform_shape)
img = img.to(dtype_pt)
if isinstance(mode, int) or USE_COMPILED:
_use_compiled = USE_COMPILED and not _compiled_unsupported(img.device)
if isinstance(mode, int) or _use_compiled:
dst_xform = create_translate(spatial_rank, [float(d - 1) / 2 for d in spatial_size])
xform = xform @ convert_to_dst_type(dst_xform, xform)[0]
affine_xform = monai.transforms.Affine(
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def __init__(
# if the root log level is higher than INFO, set a separate stream handler to record
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
console.is_data_stats_handler = True # type:ignore[attr-defined]
console.is_data_stats_handler = True # type: ignore[attr-defined]
_logger.addHandler(console)

def __call__(
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_loader_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.
"""this test should not generate errors or
UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores"""

from __future__ import annotations

import multiprocessing as mp
Expand Down
1 change: 1 addition & 0 deletions tests/profile_subclass/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
"""

from __future__ import annotations

import argparse
Expand Down
1 change: 1 addition & 0 deletions tests/profile_subclass/pyspy_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
To be used with py-spy, comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
"""

from __future__ import annotations

import argparse
Expand Down
1 change: 1 addition & 0 deletions tests/transforms/croppad/test_pad_nd_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Tests for pad_nd dtype support and backend selection.
Validates PyTorch padding preference and NumPy fallback behavior.
"""

from __future__ import annotations

import unittest
Expand Down
83 changes: 83 additions & 0 deletions tests/transforms/test_spatial_gpu_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test GPU support detection and fallback paths for spatial transforms."""

from __future__ import annotations

import unittest

import torch

from monai.transforms.spatial.functional import _compiled_unsupported


class TestCompiledUnsupported(unittest.TestCase):
"""Test _compiled_unsupported device detection."""

def test_cpu_device_always_supported(self):
"""CPU devices should never be marked unsupported."""
device = torch.device("cpu")
self.assertFalse(_compiled_unsupported(device))

def test_non_cuda_device_always_supported(self):
"""Non-CUDA devices should always be supported."""
device = torch.device("cpu")
self.assertFalse(_compiled_unsupported(device))

@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_device_detection(self):
"""Verify CUDA compute capability detection."""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
cc_major = torch.cuda.get_device_properties(device).major
unsupported = _compiled_unsupported(device)
# Device is unsupported if cc_major >= 12
if cc_major >= 12:
self.assertTrue(unsupported)
else:
self.assertFalse(unsupported)

def test_compiled_unsupported_return_type(self):
"""Verify return type is bool."""
device = torch.device("cpu")
result = _compiled_unsupported(device)
self.assertIsInstance(result, bool)


class TestResampleFallback(unittest.TestCase):
"""Test Resample fallback behavior on unsupported devices."""

@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
def test_resample_compilation_flag_respected(self):
"""Verify Resample respects _compiled_unsupported check."""
# This would require internal inspection or output verification
# Could test with mock device properties or actual Blackwell GPU
Comment on lines +59 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Empty test method provides no coverage.

test_resample_compilation_flag_respected has no assertions or body—it will pass silently. Either implement the test (mock device properties or check output behavior) or remove the placeholder with a TODO issue.

♻️ Option: Skip explicitly with reason
     `@unittest.skipIf`(not torch.cuda.is_available(), reason="CUDA not available")
+    `@unittest.skip`("TODO: implement with mock device properties or Blackwell GPU")
     def test_resample_compilation_flag_respected(self):
         """Verify Resample respects _compiled_unsupported check."""
-        # This would require internal inspection or output verification
-        # Could test with mock device properties or actual Blackwell GPU
+        pass
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/transforms/test_spatial_gpu_support.py` around lines 59 - 63, The test
method test_resample_compilation_flag_respected is empty and will always pass;
either implement a real check or explicitly skip/remove it. Fix by implementing
a unit test that verifies Resample respects the _compiled_unsupported flag
(e.g., mock device properties or simulate a Blackwell GPU and assert Resample
raises/skips compilation) or replace the body with unittest.skip("TODO:
implement test for _compiled_unsupported") or raise unittest.SkipTest with a
clear reason; target the test method name
test_resample_compilation_flag_respected and any helpers used to construct
Resample/mocked device properties.


def test_compiled_unsupported_logic(self):
"""Test that unsupported devices are correctly detected."""
# CPU should be supported
cpu_device = torch.device("cpu")
self.assertFalse(_compiled_unsupported(cpu_device))

# Verify logic: return True if CUDA and cc_major >= 12
cuda_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if cuda_device.type == "cuda":
cc_major = torch.cuda.get_device_properties(cuda_device).major
expected = cc_major >= 12
actual = _compiled_unsupported(cuda_device)
self.assertEqual(actual, expected)


if __name__ == "__main__":
unittest.main()
if __name__ == "__main__":
unittest.main()
Comment on lines +80 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicated if __name__ == "__main__" block.

Copy-paste error—remove the duplicate.

🐛 Fix
 if __name__ == "__main__":
     unittest.main()
-if __name__ == "__main__":
-    unittest.main()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if __name__ == "__main__":
unittest.main()
if __name__ == "__main__":
unittest.main()
if __name__ == "__main__":
unittest.main()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/transforms/test_spatial_gpu_support.py` around lines 80 - 83, Remove
the duplicated module entry point: there are two identical "if __name__ ==
\"__main__\": unittest.main()" blocks in
tests/transforms/test_spatial_gpu_support.py—delete the extra one so only a
single main invocation remains at the end of the file.

1 change: 1 addition & 0 deletions versioneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@
[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer

"""

# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring
# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements
# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error
Expand Down
Loading