From 7d8e0b7c8e0434f6a30604815d1ef837d01cc31f Mon Sep 17 00:00:00 2001 From: "sewon.jeon" Date: Wed, 6 Aug 2025 18:43:06 +0900 Subject: [PATCH] Fix Invertd confusion with postprocessing transforms - Modified Invertd to only invert preprocessing transforms - When orig_key == key, limits transform_info to preprocessing transforms only - Counts invertible transforms in preprocessing and uses only that many from applied_operations - Fixes issue #8396 where Lambdad before Invertd caused errors - Added test case to verify the fix The issue occurred because Invertd would try to invert all transforms in applied_operations, including those from postprocessing. When Lambdad was applied in postprocessing before Invertd, it would be at the top of the stack, causing ID mismatch errors when preprocessing transforms tried to pop themselves. Fix Invertd confusion with postprocessing transforms - Modified Invertd to only invert preprocessing transforms - When orig_key == key, limits transform_info to preprocessing transforms only - Counts invertible transforms in preprocessing and uses only that many from applied_operations - Fixes issue #8396 where Lambdad before Invertd caused errors - Added test case to verify the fix The issue occurred because Invertd would try to invert all transforms in applied_operations, including those from postprocessing. When Lambdad was applied in postprocessing before Invertd, it would be at the top of the stack, causing ID mismatch errors when preprocessing transforms tried to pop themselves. Signed-off-by: sewon.jeon --- monai/transforms/post/dictionary.py | 28 +++++++++++++-- tests/transforms/inverse/test_invertd.py | 43 ++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 7e1e074f71..201d5dcb98 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -687,9 +687,33 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]: orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}" if orig_key in d and isinstance(d[orig_key], MetaTensor): - transform_info = d[orig_key].applied_operations + all_transforms = d[orig_key].applied_operations meta_info = d[orig_key].meta - else: + + # If orig_key == key, the data at d[orig_key] may have been modified by + # postprocessing transforms. We need to exclude any transforms that were + # added after the preprocessing pipeline completed. + # When orig_key == key, filter out postprocessing transforms to prevent + # confusion during inversion (see issue #8396) + if orig_key == key: + num_preproc_transforms = 0 + try: + if hasattr(self.transform, 'transforms'): + for t in self.transform.flatten().transforms: + if isinstance(t, InvertibleTransform): + num_preproc_transforms += 1 + elif isinstance(self.transform, InvertibleTransform): + num_preproc_transforms = 1 + except AttributeError: + # Fallback: use all transforms if flatten fails + num_preproc_transforms = len(all_transforms) + + if num_preproc_transforms > 0: + transform_info = all_transforms[:num_preproc_transforms] + else: + transform_info = all_transforms + else: + transform_info = all_transforms transform_info = d[InvertibleTransform.trace_key(orig_key)] meta_info = d.get(orig_meta_key, {}) if nearest_interp: diff --git a/tests/transforms/inverse/test_invertd.py b/tests/transforms/inverse/test_invertd.py index 2b5e9da85d..e80f105f6a 100644 --- a/tests/transforms/inverse/test_invertd.py +++ b/tests/transforms/inverse/test_invertd.py @@ -137,6 +137,49 @@ def test_invert(self): set_determinism(seed=None) + def test_invert_with_postproc_lambdad(self): + """Test that Invertd works correctly when postprocessing contains invertible transforms like Lambdad.""" + set_determinism(seed=0) + + # Create test images + im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) + + # Define preprocessing transforms + preproc = Compose([ + LoadImaged(KEYS, image_only=True), + EnsureChannelFirstd(KEYS), + Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), + ScaleIntensityd("image", minv=1, maxv=10), + ResizeWithPadOrCropd(KEYS, 100), + ]) + + # Define postprocessing with Lambdad before Invertd (the problematic case) + from monai.transforms import Lambdad + postproc = Compose([ + # This Lambdad should not interfere with Invertd + Lambdad(["pred"], lambda x: x), # Identity transform + # Invertd should only invert the preprocessing transforms + Invertd(["pred"], preproc, orig_keys=["image"], nearest_interp=True), + ]) + + # Apply preprocessing + data = {"image": im_fname, "label": seg_fname} + preprocessed = preproc(data) + + # Create prediction (copy from preprocessed image) + preprocessed["pred"] = preprocessed["image"].clone() + + # Apply postprocessing with Lambdad before Invertd + # This should work without errors - the main issue was that it would fail + result = postproc(preprocessed) + # Check that the inversion was successful + self.assertIn("pred", result) + # Check that the shape was correctly inverted + self.assertTupleEqual(result["pred"].shape[1:], (101, 100, 107)) + # The fact that we got here without an exception means the fix is working + + set_determinism(seed=None) + if __name__ == "__main__": unittest.main()