5959from monai .networks .nets import resnet
6060from monai .utils import BlendMode , PytorchPadMode , ensure_tuple_rep , optional_import
6161
62- BalancedPositiveNegativeSampler , _ = optional_import ("torchvision.models.detection._utils" , name = "BalancedPositiveNegativeSampler" )
62+ BalancedPositiveNegativeSampler , _ = optional_import (
63+ "torchvision.models.detection._utils" , name = "BalancedPositiveNegativeSampler"
64+ )
6365Matcher , _ = optional_import ("torchvision.models.detection._utils" , name = "Matcher" )
6466
6567
@@ -326,7 +328,9 @@ def set_box_regression_loss(self, box_loss: nn.Module, encode_gt: bool, decode_p
326328 self .encode_gt = encode_gt
327329 self .decode_pred = decode_pred
328330
329- def set_regular_matcher (self , fg_iou_thresh : float , bg_iou_thresh : float , allow_low_quality_matches : bool = True ) -> None :
331+ def set_regular_matcher (
332+ self , fg_iou_thresh : float , bg_iou_thresh : float , allow_low_quality_matches : bool = True
333+ ) -> None :
330334 """
331335 Using for training. Set torchvision matcher that matches anchors with ground truth boxes.
332336
@@ -340,7 +344,9 @@ def set_regular_matcher(self, fg_iou_thresh: float, bg_iou_thresh: float, allow_
340344 raise ValueError (
341345 f"Require fg_iou_thresh >= bg_iou_thresh. Got fg_iou_thresh={ fg_iou_thresh } , bg_iou_thresh={ bg_iou_thresh } ."
342346 )
343- self .proposal_matcher = Matcher (fg_iou_thresh , bg_iou_thresh , allow_low_quality_matches = allow_low_quality_matches )
347+ self .proposal_matcher = Matcher (
348+ fg_iou_thresh , bg_iou_thresh , allow_low_quality_matches = allow_low_quality_matches
349+ )
344350
345351 def set_atss_matcher (self , num_candidates : int = 4 , center_in_gt : bool = False ) -> None :
346352 """
@@ -489,7 +495,9 @@ def forward(
489495 """
490496 # 1. Check if input arguments are valid
491497 if self .training :
492- targets = check_training_targets (input_images , targets , self .spatial_dims , self .target_label_key , self .target_box_key )
498+ targets = check_training_targets (
499+ input_images , targets , self .spatial_dims , self .target_label_key , self .target_box_key
500+ )
493501 self ._check_detector_training_components ()
494502
495503 # 2. Pad list of images to a single Tensor `images` with spatial size divisible by self.size_divisible.
@@ -509,8 +517,12 @@ def forward(
509517 ensure_dict_value_to_list_ (head_outputs )
510518 else :
511519 if self .inferer is None :
512- raise ValueError ("`self.inferer` is not defined.Please refer to function self.set_sliding_window_inferer(*)." )
513- head_outputs = predict_with_inferer (images , self .network , keys = [self .cls_key , self .box_reg_key ], inferer = self .inferer )
520+ raise ValueError (
521+ "`self.inferer` is not defined.Please refer to function self.set_sliding_window_inferer(*)."
522+ )
523+ head_outputs = predict_with_inferer (
524+ images , self .network , keys = [self .cls_key , self .box_reg_key ], inferer = self .inferer
525+ )
514526
515527 # 4. Generate anchors and store it in self.anchors: List[Tensor]
516528 self .generate_anchors (images , head_outputs )
@@ -532,10 +544,7 @@ def forward(
532544
533545 # 6(2). If during inference, return detection results
534546 detections = self .postprocess_detections (
535- head_outputs ,
536- self .anchors ,
537- image_sizes ,
538- num_anchor_locs_per_level , # type: ignore
547+ head_outputs , self .anchors , image_sizes , num_anchor_locs_per_level # type: ignore
539548 )
540549 return detections
541550
@@ -544,7 +553,9 @@ def _check_detector_training_components(self):
544553 Check if self.proposal_matcher and self.fg_bg_sampler have been set for training.
545554 """
546555 if not hasattr (self , "proposal_matcher" ):
547- raise AttributeError ("Matcher is not set. Please refer to self.set_regular_matcher(*) or self.set_atss_matcher(*)." )
556+ raise AttributeError (
557+ "Matcher is not set. Please refer to self.set_regular_matcher(*) or self.set_atss_matcher(*)."
558+ )
548559 if self .fg_bg_sampler is None and self .debug :
549560 warnings .warn (
550561 "No balanced sampler is used. Negative samples are likely to "
@@ -641,7 +652,9 @@ def postprocess_detections(
641652 """
642653
643654 # recover level sizes, HWA or HWDA for each level
644- num_anchors_per_level = [num_anchor_locs * self .num_anchors_per_loc for num_anchor_locs in num_anchor_locs_per_level ]
655+ num_anchors_per_level = [
656+ num_anchor_locs * self .num_anchors_per_loc for num_anchor_locs in num_anchor_locs_per_level
657+ ]
645658
646659 # split outputs per level
647660 split_head_outputs : dict [str , list [Tensor ]] = {}
@@ -658,7 +671,9 @@ def postprocess_detections(
658671 detections : list [dict [str , Tensor ]] = []
659672
660673 for index in range (num_images ):
661- box_regression_per_image = [br [index ] for br in box_regression ] # List[Tensor], each sized (HWA, 2*spatial_dims)
674+ box_regression_per_image = [
675+ br [index ] for br in box_regression
676+ ] # List[Tensor], each sized (HWA, 2*spatial_dims)
662677 logits_per_image = [cl [index ] for cl in class_logits ] # List[Tensor], each sized (HWA, self.num_classes)
663678 anchors_per_image , img_spatial_size = split_anchors [index ], image_sizes [index ]
664679 # decode box regression into boxes
@@ -671,11 +686,13 @@ def postprocess_detections(
671686 boxes_per_image , logits_per_image , img_spatial_size
672687 )
673688
674- detections .append ({
675- self .target_box_key : selected_boxes , # Tensor, sized (N, 2*spatial_dims)
676- self .pred_score_key : selected_scores , # Tensor, sized (N, )
677- self .target_label_key : selected_labels , # Tensor, sized (N, )
678- })
689+ detections .append (
690+ {
691+ self .target_box_key : selected_boxes , # Tensor, sized (N, 2*spatial_dims)
692+ self .pred_score_key : selected_scores , # Tensor, sized (N, )
693+ self .target_label_key : selected_labels , # Tensor, sized (N, )
694+ }
695+ )
679696
680697 return detections
681698
@@ -704,7 +721,9 @@ def compute_loss(
704721 """
705722 matched_idxs = self .compute_anchor_matched_idxs (anchors , targets , num_anchor_locs_per_level )
706723 losses_cls = self .compute_cls_loss (head_outputs_reshape [self .cls_key ], targets , matched_idxs )
707- losses_box_regression = self .compute_box_loss (head_outputs_reshape [self .box_reg_key ], targets , anchors , matched_idxs )
724+ losses_box_regression = self .compute_box_loss (
725+ head_outputs_reshape [self .box_reg_key ], targets , anchors , matched_idxs
726+ )
708727 return {self .cls_key : losses_cls , self .box_reg_key : losses_box_regression }
709728
710729 def compute_anchor_matched_idxs (
@@ -737,7 +756,9 @@ def compute_anchor_matched_idxs(
737756 # anchors_per_image: Tensor, targets_per_image: Dice[str, Tensor]
738757 if targets_per_image [self .target_box_key ].numel () == 0 :
739758 # if no GT boxes
740- matched_idxs .append (torch .full ((anchors_per_image .size (0 ),), - 1 , dtype = torch .int64 , device = anchors_per_image .device ))
759+ matched_idxs .append (
760+ torch .full ((anchors_per_image .size (0 ),), - 1 , dtype = torch .int64 , device = anchors_per_image .device )
761+ )
741762 continue
742763
743764 # matched_idxs_per_image (Tensor[int64]): Tensor sized (sum(HWA),) or (sum(HWDA),)
@@ -777,7 +798,9 @@ def compute_anchor_matched_idxs(
777798 matched_idxs .append (matched_idxs_per_image )
778799 return matched_idxs
779800
780- def compute_cls_loss (self , cls_logits : Tensor , targets : list [dict [str , Tensor ]], matched_idxs : list [Tensor ]) -> Tensor :
801+ def compute_cls_loss (
802+ self , cls_logits : Tensor , targets : list [dict [str , Tensor ]], matched_idxs : list [Tensor ]
803+ ) -> Tensor :
781804 """
782805 Compute classification losses.
783806
@@ -895,7 +918,9 @@ def get_cls_train_sample_per_image(
895918 gt_classes_target = torch .zeros_like (cls_logits_per_image ) # (sum(HW(D)A), self.num_classes)
896919 gt_classes_target [
897920 foreground_idxs_per_image , # fg anchor idx in
898- targets_per_image [self .target_label_key ][matched_idxs_per_image [foreground_idxs_per_image ]], # fg class label
921+ targets_per_image [self .target_label_key ][
922+ matched_idxs_per_image [foreground_idxs_per_image ]
923+ ], # fg class label
899924 ] = 1.0
900925
901926 if self .fg_bg_sampler is None :
@@ -967,9 +992,9 @@ def get_box_train_sample_per_image(
967992
968993 # select only the foreground boxes
969994 # matched GT boxes for foreground anchors
970- matched_gt_boxes_per_image = targets_per_image [self .target_box_key ][matched_idxs_per_image [ foreground_idxs_per_image ]]. to (
971- box_regression_per_image . device
972- )
995+ matched_gt_boxes_per_image = targets_per_image [self .target_box_key ][
996+ matched_idxs_per_image [ foreground_idxs_per_image ]
997+ ]. to ( box_regression_per_image . device )
973998 # predicted box regression for foreground anchors
974999 box_regression_per_image = box_regression_per_image [foreground_idxs_per_image , :]
9751000 # foreground anchors
0 commit comments