Skip to content

Commit 39c06c3

Browse files
authored
Merge branch 'dev' into optimize/vectorize-dice-metric
2 parents 3c27d3f + 583d5ca commit 39c06c3

File tree

2 files changed

+59
-6
lines changed

2 files changed

+59
-6
lines changed

monai/losses/dice.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import warnings
1515
from collections.abc import Callable, Sequence
16-
from typing import Any
1716

1817
import numpy as np
1918
import torch
@@ -239,11 +238,52 @@ class MaskedDiceLoss(DiceLoss):
239238
240239
"""
241240

242-
def __init__(self, *args: Any, **kwargs: Any) -> None:
241+
def __init__(
242+
self,
243+
include_background: bool = True,
244+
to_onehot_y: bool = False,
245+
sigmoid: bool = False,
246+
softmax: bool = False,
247+
other_act: Callable | None = None,
248+
squared_pred: bool = False,
249+
jaccard: bool = False,
250+
reduction: LossReduction | str = LossReduction.MEAN,
251+
smooth_nr: float = 1e-5,
252+
smooth_dr: float = 1e-5,
253+
batch: bool = False,
254+
weight: Sequence[float] | float | int | torch.Tensor | None = None,
255+
soft_label: bool = False,
256+
) -> None:
243257
"""
244258
Args follow :py:class:`monai.losses.DiceLoss`.
245259
"""
246-
super().__init__(*args, **kwargs)
260+
if other_act is not None and not callable(other_act):
261+
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
262+
if sigmoid and softmax:
263+
raise ValueError("Incompatible values: sigmoid=True and softmax=True.")
264+
if other_act is not None and (sigmoid or softmax):
265+
raise ValueError("Incompatible values: other_act is not None and sigmoid=True or softmax=True.")
266+
267+
self.pre_sigmoid = sigmoid
268+
self.pre_softmax = softmax
269+
self.pre_other_act = other_act
270+
271+
super().__init__(
272+
include_background=include_background,
273+
to_onehot_y=to_onehot_y,
274+
sigmoid=False,
275+
softmax=False,
276+
other_act=None,
277+
squared_pred=squared_pred,
278+
jaccard=jaccard,
279+
reduction=reduction,
280+
smooth_nr=smooth_nr,
281+
smooth_dr=smooth_dr,
282+
batch=batch,
283+
weight=weight,
284+
soft_label=soft_label,
285+
)
286+
247287
self.spatial_weighted = MaskedLoss(loss=super().forward)
248288

249289
def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
@@ -253,6 +293,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
253293
target: the shape should be BNH[WD].
254294
mask: the shape should B1H[WD] or 11H[WD].
255295
"""
296+
297+
if self.pre_sigmoid:
298+
input = torch.sigmoid(input)
299+
300+
n_pred_ch = input.shape[1]
301+
if self.pre_softmax:
302+
if n_pred_ch == 1:
303+
warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
304+
else:
305+
input = torch.softmax(input, 1)
306+
307+
if self.pre_other_act is not None:
308+
input = self.pre_other_act(input)
256309
return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return]
257310

258311

tests/losses/test_masked_dice_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
2828
"mask": torch.tensor([[[[0.0, 0.0], [1.0, 1.0]]]]),
2929
},
30-
0.500,
30+
0.333333,
3131
],
3232
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
3333
{"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
@@ -36,7 +36,7 @@
3636
"target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),
3737
"mask": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 1.0], [0.0, 0.0]]]]),
3838
},
39-
0.422969,
39+
0.301128,
4040
],
4141
[ # shape: (2, 2, 3), (2, 1, 3)
4242
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
@@ -54,7 +54,7 @@
5454
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
5555
"mask": torch.tensor([[[1.0, 1.0, 0.0]]]),
5656
},
57-
0.47033,
57+
0.579184,
5858
],
5959
[ # shape: (2, 2, 3), (2, 1, 3)
6060
{

0 commit comments

Comments
 (0)