Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,8 +1161,7 @@ def __call__(
d = dict(data)
first_key: Hashable = self.first_key(d)
if first_key == ():
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
return d

self.randomize(None)
# all the keys share the same random Affine factor
Expand Down Expand Up @@ -1322,8 +1321,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
first_key: Hashable = self.first_key(d)

if first_key == ():
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
return d

self.randomize(None)
device = self.rand_2d_elastic.device
Expand Down Expand Up @@ -1473,8 +1471,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
first_key: Hashable = self.first_key(d)

if first_key == ():
out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
return d

self.randomize(None)
if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations: # type: ignore
Expand Down Expand Up @@ -2134,8 +2131,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = No
d = dict(data)
first_key: Hashable = self.first_key(d)
if first_key == ():
out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
return d

self.randomize(None)

Expand Down Expand Up @@ -2305,13 +2301,13 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
d = dict(data)
self.randomize(None)
if not self._do_transform:
out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d

first_key: Hashable = self.first_key(d)
if first_key == ():
out = convert_to_tensor(d, track_meta=get_track_meta())
return out
return d
if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations: # type: ignore
warnings.warn(f"data['{first_key}'] has pending operations, transform may return incorrect results.")
self.rand_grid_distortion.randomize(d[first_key].shape[1:])
Expand Down Expand Up @@ -2633,8 +2629,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
d = dict(data)
first_key: Hashable = self.first_key(d)
if first_key == ():
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
return d

self.randomize(None)

Expand Down
14 changes: 14 additions & 0 deletions tests/transforms/test_rand_grid_distortiond.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,28 @@


class TestRandGridDistortiond(unittest.TestCase):
"""Test cases for RandGridDistortiond dictionary transform."""

@parameterized.expand(TESTS)
def test_rand_grid_distortiond(self, input_param, seed, input_data, expected_val_img, expected_val_mask):
"""Verify distortion produces expected output for image and mask keys."""
g = RandGridDistortiond(**input_param)
g.set_random_state(seed=seed)
result = g(input_data)
assert_allclose(result["img"], expected_val_img, type_test=False, rtol=1e-4, atol=1e-4)
assert_allclose(result["mask"], expected_val_mask, type_test=False, rtol=1e-4, atol=1e-4)

def test_no_transform_with_non_tensor_metadata(self):
"""When _do_transform is False, non-tensor values in the dict should not cause an error."""
img = np.indices([6, 6]).astype(np.float32)
data = {"img": img, "extra_info": 42, "label_name": "tumor"}
g = RandGridDistortiond(keys=["img"], prob=0.0) # prob=0 ensures _do_transform is False
result = g(data)
# non-tensor metadata should pass through unchanged
self.assertEqual(result["extra_info"], 42)
self.assertEqual(result["label_name"], "tumor")
assert_allclose(result["img"], img, type_test=False)


if __name__ == "__main__":
unittest.main()
Loading