Skip to content

Commit 4924aa6

Browse files
Fix TrainableBilateralFilter 3D input validation (#7444)
- Fix dimension comparison to use spatial dims instead of total dims - Add validation for minimum input dimensions - Fix typo in error message (ken_spatial_sigma -> len_spatial_sigma) - Move spatial dimension validation before unsqueeze operations The forward() method was incorrectly comparing self.len_spatial_sigma (number of spatial dimensions) with len(input_tensor.shape) (total dimensions including batch and channel), causing valid 3D inputs to be rejected. Fixes #7444 Signed-off-by: Abdoulaye Diallo <abdoulayediallo338@gmail.com>
1 parent b106a4c commit 4924aa6

1 file changed

Lines changed: 13 additions & 8 deletions

File tree

monai/networks/layers/filtering.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __init__(self, spatial_sigma, color_sigma):
221221
self.len_spatial_sigma = 3
222222
else:
223223
raise ValueError(
224-
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
224+
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
225225
)
226226

227227
# Register sigmas as trainable parameters.
@@ -231,6 +231,10 @@ def __init__(self, spatial_sigma, color_sigma):
231231
self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))
232232

233233
def forward(self, input_tensor):
234+
if len(input_tensor.shape) < 3:
235+
raise ValueError(
236+
f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
237+
)
234238
if input_tensor.shape[1] != 1:
235239
raise ValueError(
236240
f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
@@ -239,24 +243,25 @@ def forward(self, input_tensor):
239243
)
240244

241245
len_input = len(input_tensor.shape)
246+
spatial_dims = len_input - 2
242247

243248
# C++ extension so far only supports 5-dim inputs.
244-
if len_input == 3:
249+
if spatial_dims == 1:
245250
input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
246-
elif len_input == 4:
251+
elif spatial_dims == 2:
247252
input_tensor = input_tensor.unsqueeze(4)
248253

249-
if self.len_spatial_sigma != len_input:
250-
raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
254+
if self.len_spatial_sigma != spatial_dims:
255+
raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).")
251256

252257
prediction = TrainableBilateralFilterFunction.apply(
253258
input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
254259
)
255260

256261
# Make sure to return tensor of the same shape as the input.
257-
if len_input == 3:
262+
if spatial_dims == 1:
258263
prediction = prediction.squeeze(4).squeeze(3)
259-
elif len_input == 4:
264+
elif spatial_dims == 2:
260265
prediction = prediction.squeeze(4)
261266

262267
return prediction
@@ -389,7 +394,7 @@ def __init__(self, spatial_sigma, color_sigma):
389394
self.len_spatial_sigma = 3
390395
else:
391396
raise ValueError(
392-
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
397+
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
393398
)
394399

395400
# Register sigmas as trainable parameters.

0 commit comments

Comments
 (0)