1313
1414import warnings
1515from collections .abc import Callable , Sequence
16- from typing import Any
1716
1817import numpy as np
1918import torch
@@ -239,11 +238,52 @@ class MaskedDiceLoss(DiceLoss):
239238
240239 """
241240
242- def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
241+ def __init__ (
242+ self ,
243+ include_background : bool = True ,
244+ to_onehot_y : bool = False ,
245+ sigmoid : bool = False ,
246+ softmax : bool = False ,
247+ other_act : Callable | None = None ,
248+ squared_pred : bool = False ,
249+ jaccard : bool = False ,
250+ reduction : LossReduction | str = LossReduction .MEAN ,
251+ smooth_nr : float = 1e-5 ,
252+ smooth_dr : float = 1e-5 ,
253+ batch : bool = False ,
254+ weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
255+ soft_label : bool = False ,
256+ ) -> None :
243257 """
244258 Args follow :py:class:`monai.losses.DiceLoss`.
245259 """
246- super ().__init__ (* args , ** kwargs )
260+ if other_act is not None and not callable (other_act ):
261+ raise TypeError (f"other_act must be None or callable but is { type (other_act ).__name__ } ." )
262+ if sigmoid and softmax :
263+ raise ValueError ("Incompatible values: sigmoid=True and softmax=True." )
264+ if other_act is not None and (sigmoid or softmax ):
265+ raise ValueError ("Incompatible values: other_act is not None and sigmoid=True or softmax=True." )
266+
267+ self .pre_sigmoid = sigmoid
268+ self .pre_softmax = softmax
269+ self .pre_other_act = other_act
270+
271+ super ().__init__ (
272+ include_background = include_background ,
273+ to_onehot_y = to_onehot_y ,
274+ sigmoid = False ,
275+ softmax = False ,
276+ other_act = None ,
277+ squared_pred = squared_pred ,
278+ jaccard = jaccard ,
279+ reduction = reduction ,
280+ smooth_nr = smooth_nr ,
281+ smooth_dr = smooth_dr ,
282+ batch = batch ,
283+ weight = weight ,
284+ soft_label = soft_label ,
285+ )
286+
247287 self .spatial_weighted = MaskedLoss (loss = super ().forward )
248288
249289 def forward (self , input : torch .Tensor , target : torch .Tensor , mask : torch .Tensor | None = None ) -> torch .Tensor :
@@ -253,6 +293,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
253293 target: the shape should be BNH[WD].
254294 mask: the shape should B1H[WD] or 11H[WD].
255295 """
296+
297+ if self .pre_sigmoid :
298+ input = torch .sigmoid (input )
299+
300+ n_pred_ch = input .shape [1 ]
301+ if self .pre_softmax :
302+ if n_pred_ch == 1 :
303+ warnings .warn ("single channel prediction, `softmax=True` ignored." , stacklevel = 2 )
304+ else :
305+ input = torch .softmax (input , 1 )
306+
307+ if self .pre_other_act is not None :
308+ input = self .pre_other_act (input )
256309 return self .spatial_weighted (input = input , target = target , mask = mask ) # type: ignore[no-any-return]
257310
258311
0 commit comments