Skip to content

Commit 36e2623

Browse files
committed
lint
Signed-off-by: R. Garcia-Dias <rafaelagd@gmail.com>
1 parent 4f6df07 commit 36e2623

File tree

9 files changed

+164
-66
lines changed

9 files changed

+164
-66
lines changed

monai/apps/auto3dseg/bundle_gen.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,15 @@ def _run_cmd(self, cmd: str, devices_info: str = "") -> subprocess.CompletedProc
269269

270270
return _run_cmd_bcprun(cmd, n=self.device_setting["NUM_NODES"], p=self.device_setting["n_devices"])
271271
elif int(self.device_setting["n_devices"]) > 1:
272-
return _run_cmd_torchrun(cmd, nnodes=1, nproc_per_node=self.device_setting["n_devices"], env=ps_environ, check=True)
272+
return _run_cmd_torchrun(
273+
cmd, nnodes=1, nproc_per_node=self.device_setting["n_devices"], env=ps_environ, check=True
274+
)
273275
else:
274276
return run_cmd(cmd.split(), run_cmd_verbose=True, env=ps_environ, check=True)
275277

276-
def train(self, train_params: None | dict = None, device_setting: None | dict = None) -> subprocess.CompletedProcess:
278+
def train(
279+
self, train_params: None | dict = None, device_setting: None | dict = None
280+
) -> subprocess.CompletedProcess:
277281
"""
278282
Load the run function in the training script of each model. Training parameter is predefined by the
279283
algo_config.yaml file, which is pre-filled by the fill_template_config function in the same instance.
@@ -364,7 +368,9 @@ def get_output_path(self):
364368

365369

366370
# path to download the algo_templates
367-
default_algo_zip = f"https://github.com/Project-MONAI/research-contributions/releases/download/algo_templates/{ALGO_HASH}.tar.gz"
371+
default_algo_zip = (
372+
f"https://github.com/Project-MONAI/research-contributions/releases/download/algo_templates/{ALGO_HASH}.tar.gz"
373+
)
368374

369375
# default algorithms
370376
default_algos = {
@@ -653,7 +659,6 @@ def generate(
653659
gen_algo.export_to_disk(output_folder, name, fold=f_id)
654660

655661
algo_to_pickle(gen_algo, template_path=algo.template_path)
656-
self.history.append({
657-
AlgoKeys.ID: name,
658-
AlgoKeys.ALGO: gen_algo,
659-
}) # track the previous, may create a persistent history
662+
self.history.append(
663+
{AlgoKeys.ID: name, AlgoKeys.ALGO: gen_algo}
664+
) # track the previous, may create a persistent history

monai/apps/detection/networks/retinanet_detector.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@
5959
from monai.networks.nets import resnet
6060
from monai.utils import BlendMode, PytorchPadMode, ensure_tuple_rep, optional_import
6161

62-
BalancedPositiveNegativeSampler, _ = optional_import("torchvision.models.detection._utils", name="BalancedPositiveNegativeSampler")
62+
BalancedPositiveNegativeSampler, _ = optional_import(
63+
"torchvision.models.detection._utils", name="BalancedPositiveNegativeSampler"
64+
)
6365
Matcher, _ = optional_import("torchvision.models.detection._utils", name="Matcher")
6466

6567

@@ -326,7 +328,9 @@ def set_box_regression_loss(self, box_loss: nn.Module, encode_gt: bool, decode_p
326328
self.encode_gt = encode_gt
327329
self.decode_pred = decode_pred
328330

329-
def set_regular_matcher(self, fg_iou_thresh: float, bg_iou_thresh: float, allow_low_quality_matches: bool = True) -> None:
331+
def set_regular_matcher(
332+
self, fg_iou_thresh: float, bg_iou_thresh: float, allow_low_quality_matches: bool = True
333+
) -> None:
330334
"""
331335
Using for training. Set torchvision matcher that matches anchors with ground truth boxes.
332336
@@ -340,7 +344,9 @@ def set_regular_matcher(self, fg_iou_thresh: float, bg_iou_thresh: float, allow_
340344
raise ValueError(
341345
f"Require fg_iou_thresh >= bg_iou_thresh. Got fg_iou_thresh={fg_iou_thresh}, bg_iou_thresh={bg_iou_thresh}."
342346
)
343-
self.proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=allow_low_quality_matches)
347+
self.proposal_matcher = Matcher(
348+
fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=allow_low_quality_matches
349+
)
344350

345351
def set_atss_matcher(self, num_candidates: int = 4, center_in_gt: bool = False) -> None:
346352
"""
@@ -489,7 +495,9 @@ def forward(
489495
"""
490496
# 1. Check if input arguments are valid
491497
if self.training:
492-
targets = check_training_targets(input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key)
498+
targets = check_training_targets(
499+
input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key
500+
)
493501
self._check_detector_training_components()
494502

495503
# 2. Pad list of images to a single Tensor `images` with spatial size divisible by self.size_divisible.
@@ -509,8 +517,12 @@ def forward(
509517
ensure_dict_value_to_list_(head_outputs)
510518
else:
511519
if self.inferer is None:
512-
raise ValueError("`self.inferer` is not defined.Please refer to function self.set_sliding_window_inferer(*).")
513-
head_outputs = predict_with_inferer(images, self.network, keys=[self.cls_key, self.box_reg_key], inferer=self.inferer)
520+
raise ValueError(
521+
"`self.inferer` is not defined.Please refer to function self.set_sliding_window_inferer(*)."
522+
)
523+
head_outputs = predict_with_inferer(
524+
images, self.network, keys=[self.cls_key, self.box_reg_key], inferer=self.inferer
525+
)
514526

515527
# 4. Generate anchors and store it in self.anchors: List[Tensor]
516528
self.generate_anchors(images, head_outputs)
@@ -532,10 +544,7 @@ def forward(
532544

533545
# 6(2). If during inference, return detection results
534546
detections = self.postprocess_detections(
535-
head_outputs,
536-
self.anchors,
537-
image_sizes,
538-
num_anchor_locs_per_level, # type: ignore
547+
head_outputs, self.anchors, image_sizes, num_anchor_locs_per_level # type: ignore
539548
)
540549
return detections
541550

@@ -544,7 +553,9 @@ def _check_detector_training_components(self):
544553
Check if self.proposal_matcher and self.fg_bg_sampler have been set for training.
545554
"""
546555
if not hasattr(self, "proposal_matcher"):
547-
raise AttributeError("Matcher is not set. Please refer to self.set_regular_matcher(*) or self.set_atss_matcher(*).")
556+
raise AttributeError(
557+
"Matcher is not set. Please refer to self.set_regular_matcher(*) or self.set_atss_matcher(*)."
558+
)
548559
if self.fg_bg_sampler is None and self.debug:
549560
warnings.warn(
550561
"No balanced sampler is used. Negative samples are likely to "
@@ -641,7 +652,9 @@ def postprocess_detections(
641652
"""
642653

643654
# recover level sizes, HWA or HWDA for each level
644-
num_anchors_per_level = [num_anchor_locs * self.num_anchors_per_loc for num_anchor_locs in num_anchor_locs_per_level]
655+
num_anchors_per_level = [
656+
num_anchor_locs * self.num_anchors_per_loc for num_anchor_locs in num_anchor_locs_per_level
657+
]
645658

646659
# split outputs per level
647660
split_head_outputs: dict[str, list[Tensor]] = {}
@@ -658,7 +671,9 @@ def postprocess_detections(
658671
detections: list[dict[str, Tensor]] = []
659672

660673
for index in range(num_images):
661-
box_regression_per_image = [br[index] for br in box_regression] # List[Tensor], each sized (HWA, 2*spatial_dims)
674+
box_regression_per_image = [
675+
br[index] for br in box_regression
676+
] # List[Tensor], each sized (HWA, 2*spatial_dims)
662677
logits_per_image = [cl[index] for cl in class_logits] # List[Tensor], each sized (HWA, self.num_classes)
663678
anchors_per_image, img_spatial_size = split_anchors[index], image_sizes[index]
664679
# decode box regression into boxes
@@ -671,11 +686,13 @@ def postprocess_detections(
671686
boxes_per_image, logits_per_image, img_spatial_size
672687
)
673688

674-
detections.append({
675-
self.target_box_key: selected_boxes, # Tensor, sized (N, 2*spatial_dims)
676-
self.pred_score_key: selected_scores, # Tensor, sized (N, )
677-
self.target_label_key: selected_labels, # Tensor, sized (N, )
678-
})
689+
detections.append(
690+
{
691+
self.target_box_key: selected_boxes, # Tensor, sized (N, 2*spatial_dims)
692+
self.pred_score_key: selected_scores, # Tensor, sized (N, )
693+
self.target_label_key: selected_labels, # Tensor, sized (N, )
694+
}
695+
)
679696

680697
return detections
681698

@@ -704,7 +721,9 @@ def compute_loss(
704721
"""
705722
matched_idxs = self.compute_anchor_matched_idxs(anchors, targets, num_anchor_locs_per_level)
706723
losses_cls = self.compute_cls_loss(head_outputs_reshape[self.cls_key], targets, matched_idxs)
707-
losses_box_regression = self.compute_box_loss(head_outputs_reshape[self.box_reg_key], targets, anchors, matched_idxs)
724+
losses_box_regression = self.compute_box_loss(
725+
head_outputs_reshape[self.box_reg_key], targets, anchors, matched_idxs
726+
)
708727
return {self.cls_key: losses_cls, self.box_reg_key: losses_box_regression}
709728

710729
def compute_anchor_matched_idxs(
@@ -737,7 +756,9 @@ def compute_anchor_matched_idxs(
737756
# anchors_per_image: Tensor, targets_per_image: Dice[str, Tensor]
738757
if targets_per_image[self.target_box_key].numel() == 0:
739758
# if no GT boxes
740-
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device))
759+
matched_idxs.append(
760+
torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
761+
)
741762
continue
742763

743764
# matched_idxs_per_image (Tensor[int64]): Tensor sized (sum(HWA),) or (sum(HWDA),)
@@ -777,7 +798,9 @@ def compute_anchor_matched_idxs(
777798
matched_idxs.append(matched_idxs_per_image)
778799
return matched_idxs
779800

780-
def compute_cls_loss(self, cls_logits: Tensor, targets: list[dict[str, Tensor]], matched_idxs: list[Tensor]) -> Tensor:
801+
def compute_cls_loss(
802+
self, cls_logits: Tensor, targets: list[dict[str, Tensor]], matched_idxs: list[Tensor]
803+
) -> Tensor:
781804
"""
782805
Compute classification losses.
783806
@@ -895,7 +918,9 @@ def get_cls_train_sample_per_image(
895918
gt_classes_target = torch.zeros_like(cls_logits_per_image) # (sum(HW(D)A), self.num_classes)
896919
gt_classes_target[
897920
foreground_idxs_per_image, # fg anchor idx in
898-
targets_per_image[self.target_label_key][matched_idxs_per_image[foreground_idxs_per_image]], # fg class label
921+
targets_per_image[self.target_label_key][
922+
matched_idxs_per_image[foreground_idxs_per_image]
923+
], # fg class label
899924
] = 1.0
900925

901926
if self.fg_bg_sampler is None:
@@ -967,9 +992,9 @@ def get_box_train_sample_per_image(
967992

968993
# select only the foreground boxes
969994
# matched GT boxes for foreground anchors
970-
matched_gt_boxes_per_image = targets_per_image[self.target_box_key][matched_idxs_per_image[foreground_idxs_per_image]].to(
971-
box_regression_per_image.device
972-
)
995+
matched_gt_boxes_per_image = targets_per_image[self.target_box_key][
996+
matched_idxs_per_image[foreground_idxs_per_image]
997+
].to(box_regression_per_image.device)
973998
# predicted box regression for foreground anchors
974999
box_regression_per_image = box_regression_per_image[foreground_idxs_per_image, :]
9751000
# foreground anchors

monai/apps/detection/utils/anchor_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def __init__(
136136
self.indexing = look_up_option(indexing, ["ij", "xy"])
137137

138138
self.aspect_ratios = aspect_ratios
139-
self.cell_anchors = [self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(self.sizes, aspect_ratios)]
139+
self.cell_anchors = [
140+
self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(self.sizes, aspect_ratios)
141+
]
140142

141143
# This comment comes from torchvision.
142144
# TODO: https://github.com/pytorch/pytorch/issues/26792
@@ -251,7 +253,8 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]])
251253
# compute anchor centers regarding to the image.
252254
# shifts_centers is [x_center, y_center] or [x_center, y_center, z_center]
253255
shifts_centers = [
254-
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] for axis in range(self.spatial_dims)
256+
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis]
257+
for axis in range(self.spatial_dims)
255258
]
256259

257260
# to support torchscript, cannot directly use torch.meshgrid(shifts_centers).
@@ -304,7 +307,10 @@ def forward(self, images: Tensor, feature_maps: list[Tensor]) -> list[Tensor]:
304307
batchsize = images.shape[0]
305308
dtype, device = feature_maps[0].dtype, feature_maps[0].device
306309
strides = [
307-
[torch.tensor(image_size[axis] // g[axis], dtype=torch.int64, device=device) for axis in range(self.spatial_dims)]
310+
[
311+
torch.tensor(image_size[axis] // g[axis], dtype=torch.int64, device=device)
312+
for axis in range(self.spatial_dims)
313+
]
308314
for g in grid_sizes
309315
]
310316

monai/apps/detection/utils/detector_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def check_training_targets(
8080
for i in range(len(targets)):
8181
target = targets[i]
8282
if (target_label_key not in target.keys()) or (target_box_key not in target.keys()):
83-
raise ValueError(f"{target_label_key} and {target_box_key} are expected keys in targets. Got {target.keys()}.")
83+
raise ValueError(
84+
f"{target_label_key} and {target_box_key} are expected keys in targets. Got {target.keys()}."
85+
)
8486

8587
boxes = target[target_box_key]
8688
if not isinstance(boxes, torch.Tensor):
@@ -92,7 +94,9 @@ def check_training_targets(
9294
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2 * spatial_dims}])."
9395
)
9496
else:
95-
raise ValueError(f"Expected target boxes to be a tensor of shape [N, {2 * spatial_dims}], got {boxes.shape}.).")
97+
raise ValueError(
98+
f"Expected target boxes to be a tensor of shape [N, {2 * spatial_dims}], got {boxes.shape}.)."
99+
)
96100
if not torch.is_floating_point(boxes):
97101
raise ValueError(f"Expected target boxes to be a float tensor, got {boxes.dtype}.")
98102
targets[i][target_box_key] = standardize_empty_box(boxes, spatial_dims=spatial_dims) # type: ignore

0 commit comments

Comments
 (0)