Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
75dd1ef
Bump actions/download-artifact from 4 to 5 (#8590)
dependabot[bot] Oct 1, 2025
42901bb
Bump actions/checkout from 4 to 5 (#8588)
dependabot[bot] Oct 1, 2025
bb5b425
Bump actions/setup-python from 5 to 6 (#8589)
dependabot[bot] Oct 10, 2025
e465d66
Replace `pyupgrade` with builtin Ruff's UP rule
Borda Oct 28, 2025
cf427fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2025
fc10d4d
--unsafe-fixes
Borda Oct 28, 2025
12ec249
Apply suggestions from code review
Borda Oct 28, 2025
30c85d9
8564 fourier positional encoding (#8570)
NabJa Oct 31, 2025
5e66c73
Include more-itertools in build env (#8611)
KumoLiu Oct 31, 2025
1e5a492
Bump peter-evans/create-or-update-comment from 4 to 5 (#8612)
dependabot[bot] Nov 1, 2025
9e9b5d0
Bump actions/download-artifact from 5 to 6 (#8614)
dependabot[bot] Nov 1, 2025
d0b78af
Bump actions/upload-artifact from 4 to 5 (#8615)
dependabot[bot] Nov 3, 2025
bdea238
Bump github/codeql-action from 3 to 4 (#8616)
dependabot[bot] Nov 3, 2025
f6b72c2
added ReduceTrait and FlattenSequence (#8531)
lukas-folle-snkeos Nov 3, 2025
0f5a188
Fix box_iou returning 0 for floating-point results less than 1. #8369…
reworld223 Nov 3, 2025
b3ccf8d
8620 ModuleNotFoundError: No module named \'onnxscript\' in test-py3x…
garciadias Nov 6, 2025
af4580b
Update monai/losses/perceptual.py
Borda Nov 7, 2025
502c402
Update monai/losses/sure_loss.py
Borda Nov 7, 2025
dbac655
Update monai/losses/sure_loss.py
Borda Nov 7, 2025
22bcd9b
Update monai/losses/sure_loss.py
Borda Nov 7, 2025
0334fc0
Update monai/losses/adversarial_loss.py
Borda Nov 7, 2025
83aa7d5
Apply suggestions from code review
Borda Nov 9, 2025
9a92dd3
--unsafe-fixes
Borda Nov 9, 2025
bcd3471
./runtests.sh --autofix
Borda Nov 9, 2025
0e7a2b5
timestep scheduling with np.linspace (#8623)
ytl0623 Nov 11, 2025
b93479c
Correct H&E stain ordering heuristic in ExtractHEStains (#8551)
iyassou Nov 13, 2025
5dae9c7
feat: add activation checkpointing to unet (#8554)
ferreirafabio80 Nov 14, 2025
3291f48
Fix index using tuple for image cropping operation (#8633)
johnzielke Nov 18, 2025
e6f2e1f
Fix #8599: Add track_meta and weights_only arguments to PersistentDat…
mccle Nov 19, 2025
6cf36b6
Update documentation links (#8637)
KumoLiu Nov 21, 2025
a0e4889
Fix #8350: Clarify LocalNormalizedCrossCorrelationLoss docstring (#8639)
engmohamedsalah Nov 24, 2025
935f1cc
8620 modulenotfounderror no module named onnxscript in test py3x 311 …
garciadias Nov 25, 2025
34f93b7
Generate heatmap transforms (#8579)
eclipse0922 Nov 25, 2025
02c5e0d
Added an optional_import check for onnxruntime and applied the @unitt…
ytl0623 Nov 27, 2025
3bfcc43
Bump peter-evans/slash-command-dispatch from 4.0.0 to 5.0.0 (#8650)
dependabot[bot] Dec 1, 2025
1bb3e63
Bump al-cheb/configure-pagefile-action from 1.4 to 1.5 (#8648)
dependabot[bot] Dec 2, 2025
3beefbd
Bump actions/checkout from 4 to 6 (#8649)
dependabot[bot] Dec 2, 2025
c8bb696
Update monai/losses/perceptual.py
Borda Dec 30, 2025
9b04647
resolves CI's drive-out-of-space by prune caching and use torch fro C…
Borda Jan 5, 2026
21d2289
fix linter
Borda Jan 5, 2026
93497a4
linting
Borda Jan 5, 2026
8c08d3e
linting
Borda Jan 5, 2026
03a0a4a
linting
Borda Jan 5, 2026
02f2b1a
linting
Borda Jan 5, 2026
4b759ab
B902
Borda Jan 5, 2026
0ba2bf6
Consolidating Version Bumps (#8681)
ericspod Jan 5, 2026
8af4f6d
Update monai/losses/sure_loss.py
Borda Jan 5, 2026
7bf51b4
Fix Zip Slip vulnerability in NGC private bundle download (#8682)
yueyueL Jan 5, 2026
5e23203
lru_cache(maxsize=None)
Borda Jan 6, 2026
ce7cfe5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2026
3392955
yesqa
Borda Jan 6, 2026
ecd6bb2
linting
Borda Jan 6, 2026
11f6d09
Fix channel-first indices buffer for distance_transform_edt (return_…
alexanderjaus Jan 6, 2026
f64575e
linting
Borda Jan 6, 2026
93861bd
linting
Borda Jan 6, 2026
8372011
linting
Borda Jan 6, 2026
2d0014a
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
Borda Jan 6, 2026
6e7cf2c
functools.cache (added in Python 3.9) is a convenience alias for func…
Borda Jan 6, 2026
dab1c5c
Merge branch 'dev' into ruff/UP
Borda Jan 6, 2026
5387c5e
linting
Borda Jan 6, 2026
ad5f385
fix
Borda Jan 6, 2026
83f351b
linting
Borda Jan 6, 2026
c248641
Merge branch 'dev' into ruff/UP
KumoLiu Jan 7, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,11 @@ repos:
rev: v0.7.0
hooks:
- id: ruff
args:
- --fix

- repo: https://github.com/asottile/pyupgrade
rev: v3.19.0
hooks:
- id: pyupgrade
args: [--py39-plus, --keep-runtime-typing]
name: Upgrade code with exceptions
args: ["--fix"]
exclude: |
(?x)(
^versioneer.py|
^monai/_version.py|
^monai/networks/| # avoid typing rewrites
^monai/apps/detection/utils/anchor_utils.py| # avoid typing rewrites
^tests/test_compute_panoptic_quality.py # avoid typing rewrites
^monai/_version.py
)

- repo: https://github.com/asottile/yesqa
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _download_algos_url(url: str, at_path: str) -> dict[str, dict[str, str]]:
try:
download_and_extract(url=url, filepath=algo_compressed_file, output_dir=os.path.dirname(at_path))
except Exception as e:
msg = f"Download and extract of {url} failed, attempt {i+1}/{download_attempts}."
msg = f"Download and extract of {url} failed, attempt {i + 1}/{download_attempts}."
if i < download_attempts - 1:
warnings.warn(msg)
time.sleep(i)
Expand Down
22 changes: 6 additions & 16 deletions monai/apps/deepgrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,9 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
logging.warning(f"Unique labels {unique_labels_count} exceeds 20. Please check if this is correct.")

logging.info(
"{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format(
vol_idx,
vol_image.shape,
image_count,
vol_label.shape if vol_label is not None else None,
label_count,
unique_labels_count,
)
f"{vol_idx} => Image Shape: {vol_image.shape} => {image_count};"
f" Label Shape: {vol_label.shape if vol_label is not None else None} => {label_count};"
f" Unique Labels: {unique_labels_count}"
)
return data_list

Expand Down Expand Up @@ -259,13 +254,8 @@ def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
logging.warning(f"Unique labels {unique_labels_count} exceeds 20. Please check if this is correct.")

logging.info(
"{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format(
vol_idx,
vol_image.shape,
image_count,
vol_label.shape if vol_label is not None else None,
label_count,
unique_labels_count,
)
f"{vol_idx} => Image Shape: {vol_image.shape} => {image_count};"
f" Label Shape: {vol_label.shape if vol_label is not None else None} => {label_count};"
f" Unique Labels: {unique_labels_count}"
)
return data_list
2 changes: 1 addition & 1 deletion monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def compute_anchor_matched_idxs(
)

if self.debug:
print(f"Max box overlap between anchors and gt boxes: {torch.max(match_quality_matrix,dim=1)[0]}.")
print(f"Max box overlap between anchors and gt boxes: {torch.max(match_quality_matrix, dim=1)[0]}.")

if torch.max(matched_idxs_per_image) < 0:
warnings.warn(
Expand Down
10 changes: 5 additions & 5 deletions monai/apps/detection/utils/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from __future__ import annotations

from typing import List, Sequence
from collections.abc import Sequence

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -106,7 +106,7 @@ class AnchorGenerator(nn.Module):
anchor_generator = AnchorGenerator(sizes, aspect_ratios)
"""

__annotations__ = {"cell_anchors": List[torch.Tensor]}
__annotations__ = {"cell_anchors": list[torch.Tensor]}

def __init__(
self,
Expand Down Expand Up @@ -174,13 +174,13 @@ def generate_anchors(
if (self.spatial_dims >= 3) and (len(aspect_ratios_t.shape) != 2):
raise ValueError(
f"In {self.spatial_dims}-D image, aspect_ratios for each level should be \
{len(aspect_ratios_t.shape)-1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}."
{len(aspect_ratios_t.shape) - 1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}."
)

if (self.spatial_dims >= 3) and (aspect_ratios_t.shape[1] != self.spatial_dims - 1):
raise ValueError(
f"In {self.spatial_dims}-D image, aspect_ratios for each level should has \
shape (_,{self.spatial_dims-1}). But got aspect_ratios with shape {aspect_ratios_t.shape}."
shape (_,{self.spatial_dims - 1}). But got aspect_ratios with shape {aspect_ratios_t.shape}."
)

# if 2d, w:h = 1:aspect_ratios
Expand Down Expand Up @@ -364,7 +364,7 @@ class AnchorGeneratorWithAnchorShape(AnchorGenerator):
anchor_generator = AnchorGeneratorWithAnchorShape(feature_map_scales, base_anchor_shapes)
"""

__annotations__ = {"cell_anchors": List[torch.Tensor]}
__annotations__ = {"cell_anchors": list[torch.Tensor]}

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/detection/utils/detector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def check_training_targets(
if boxes.numel() == 0:
warnings.warn(
f"Warning: Given target boxes has shape of {boxes.shape}. "
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2* spatial_dims}])."
f"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2 * spatial_dims}])."
)
else:
raise ValueError(
f"Expected target boxes to be a tensor of shape [N, {2* spatial_dims}], got {boxes.shape}.)."
f"Expected target boxes to be a tensor of shape [N, {2 * spatial_dims}], got {boxes.shape}.)."
)
if not torch.is_floating_point(boxes):
raise ValueError(f"Expected target boxes to be a float tensor, got {boxes.dtype}.")
Expand Down
18 changes: 9 additions & 9 deletions monai/apps/nnunet/nnunet_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import shutil
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any

import numpy as np
import torch
Expand All @@ -36,17 +36,17 @@


def get_nnunet_trainer(
dataset_name_or_id: Union[str, int],
dataset_name_or_id: str | int,
configuration: str,
fold: Union[int, str],
fold: int | str,
trainer_class_name: str = "nnUNetTrainer",
plans_identifier: str = "nnUNetPlans",
use_compressed_data: bool = False,
continue_training: bool = False,
only_run_validation: bool = False,
disable_checkpointing: bool = False,
device: str = "cuda",
pretrained_model: Optional[str] = None,
pretrained_model: str | None = None,
) -> Any: # type: ignore
"""
Get the nnUNet trainer instance based on the provided configuration.
Expand Down Expand Up @@ -166,7 +166,7 @@ class ModelnnUNetWrapper(torch.nn.Module):
restoring network architecture, and setting up the predictor for inference.
"""

def __init__(self, predictor: object, model_folder: Union[str, Path], model_name: str = "model.pt"): # type: ignore
def __init__(self, predictor: object, model_folder: str | Path, model_name: str = "model.pt"): # type: ignore
super().__init__()
self.predictor = predictor

Expand Down Expand Up @@ -294,7 +294,7 @@ def forward(self, x: MetaTensor) -> MetaTensor:
return MetaTensor(out_tensor, meta=x.meta)


def get_nnunet_monai_predictor(model_folder: Union[str, Path], model_name: str = "model.pt") -> ModelnnUNetWrapper:
def get_nnunet_monai_predictor(model_folder: str | Path, model_name: str = "model.pt") -> ModelnnUNetWrapper:
"""
Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
The model folder should contain the following files, created during training:
Expand Down Expand Up @@ -426,9 +426,9 @@ def get_network_from_nnunet_plans(
plans_file: str,
dataset_file: str,
configuration: str,
model_ckpt: Optional[str] = None,
model_ckpt: str | None = None,
model_key_in_ckpt: str = "model",
) -> Union[torch.nn.Module, Any]:
) -> torch.nn.Module | Any:
"""
Load and initialize a nnUNet network based on nnUNet plans and configuration.

Expand Down Expand Up @@ -518,7 +518,7 @@ def convert_monai_bundle_to_nnunet(nnunet_config: dict, bundle_root_folder: str,
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name

def subfiles(
folder: Union[str, Path], prefix: Optional[str] = None, suffix: Optional[str] = None, sort: bool = True
folder: str | Path, prefix: str | None = None, suffix: str | None = None, sort: bool = True
) -> list[str]:
res = [
i.name
Expand Down
12 changes: 7 additions & 5 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __call__(self, data):
d[self.stats_name] = report

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get image stats spent {time.time()-start}")
logger.debug(f"Get image stats spent {time.time() - start}")
return d


Expand Down Expand Up @@ -366,7 +366,7 @@ def __call__(self, data: Mapping) -> dict:
d[self.stats_name] = report

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get foreground image stats spent {time.time()-start}")
logger.debug(f"Get foreground image stats spent {time.time() - start}")
return d


Expand Down Expand Up @@ -535,7 +535,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
d[self.stats_name] = report # type: ignore[assignment]

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get label stats spent {time.time()-start}")
logger.debug(f"Get label stats spent {time.time() - start}")
return d # type: ignore[return-value]


Expand Down Expand Up @@ -913,9 +913,11 @@ def __init__(
for i, hist_params in enumerate(zip(self.hist_bins, self.hist_range)):
_hist_bins, _hist_range = hist_params
if not isinstance(_hist_bins, int) or _hist_bins < 0:
raise ValueError(f"Expected {i+1}. hist_bins value to be positive integer but got {_hist_bins}")
raise ValueError(f"Expected {i + 1}. hist_bins value to be positive integer but got {_hist_bins}")
if not isinstance(_hist_range, list) or len(_hist_range) != 2:
raise ValueError(f"Expected {i+1}. hist_range values to be list of length 2 but received {_hist_range}")
raise ValueError(
f"Expected {i + 1}. hist_range values to be list of length 2 but received {_hist_range}"
)

def __call__(self, data: dict) -> dict:
"""
Expand Down
4 changes: 3 additions & 1 deletion monai/data/wsi_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def get_valid_level(
# Set the default value if no resolution parameter is provided.
level = 0
if level >= n_levels:
raise ValueError(f"The maximum level of this image is {n_levels-1} while level={level} is requested)!")
raise ValueError(
f"The maximum level of this image is {n_levels - 1} while level={level} is requested)!"
)

return level

Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _default_iteration_print(self, engine: Engine) -> None:
"ignoring non-scalar output in StatsHandler,"
" make sure `output_transform(engine.state.output)` returns"
" a scalar or dictionary of key and scalar pairs to avoid this warning."
" {}:{}".format(name, type(value))
f" {name}:{type(value)}"
)
continue # not printing multi dimensional output
out_str += self.key_var_format.format(name, value.item() if isinstance(value, torch.Tensor) else value)
Expand All @@ -273,7 +273,7 @@ def _default_iteration_print(self, engine: Engine) -> None:
"ignoring non-scalar output in StatsHandler,"
" make sure `output_transform(engine.state.output)` returns"
" a scalar or a dictionary of key and scalar pairs to avoid this warning."
" {}".format(type(loss))
f" {type(loss)}"
)

if not out_str:
Expand Down
4 changes: 2 additions & 2 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter | Summ
"ignoring non-scalar output in TensorBoardStatsHandler,"
" make sure `output_transform(engine.state.output)` returns"
" a scalar or dictionary of key and scalar pairs to avoid this warning."
" {}:{}".format(name, type(value))
f" {name}:{type(value)}"
)
continue # not plot multi dimensional output
self._write_scalar(
Expand All @@ -280,7 +280,7 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter | Summ
"ignoring non-scalar output in TensorBoardStatsHandler,"
" make sure `output_transform(engine.state.output)` returns"
" a scalar or a dictionary of key and scalar pairs to avoid this warning."
" {}".format(type(loss))
f" {type(loss)}"
)
writer.flush()

Expand Down
3 changes: 1 addition & 2 deletions monai/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def __init__(

if criterion.lower() not in list(AdversarialCriterions):
raise ValueError(
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
% ", ".join(AdversarialCriterions)
f"Unrecognised criterion entered for Adversarial Loss. Must be one in: {', '.join(AdversarialCriterions)}"
)

# Depending on the criterion, a different activation layer is used.
Expand Down
2 changes: 1 addition & 1 deletion monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def __init__(
raise ValueError(f"dist_matrix must be C x C, got {dist_matrix.shape[0]} x {dist_matrix.shape[1]}.")

if weighting_mode not in ["default", "GDL"]:
raise ValueError("weighting_mode must be either 'default' or 'GDL, got %s." % weighting_mode)
raise ValueError(f"weighting_mode must be either 'default' or 'GDL', got {weighting_mode}.")

self.m = dist_matrix
if isinstance(self.m, np.ndarray):
Expand Down
4 changes: 1 addition & 3 deletions monai/losses/ds_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from __future__ import annotations

from typing import Union

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
Expand Down Expand Up @@ -70,7 +68,7 @@ def get_loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target = F.interpolate(target, size=input.shape[2:], mode=self.interp_mode)
return self.loss(input, target) # type: ignore[no-any-return]

def forward(self, input: Union[None, torch.Tensor, list[torch.Tensor]], target: torch.Tensor) -> torch.Tensor:
def forward(self, input: None | torch.Tensor | list[torch.Tensor], target: torch.Tensor) -> torch.Tensor:
if isinstance(input, (list, tuple)):
weights = self.get_weights(levels=len(input))
loss = torch.tensor(0, dtype=torch.float, device=target.device)
Expand Down
7 changes: 3 additions & 4 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import warnings
from collections.abc import Sequence
from typing import Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -153,7 +152,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

loss: Optional[torch.Tensor] = None
loss: torch.Tensor | None = None
input = input.float()
target = target.float()
if self.use_softmax:
Expand Down Expand Up @@ -203,7 +202,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:


def softmax_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | None = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
Expand All @@ -225,7 +224,7 @@ def softmax_focal_loss(


def sigmoid_focal_loss(
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = None
input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | None = None
) -> torch.Tensor:
"""
FL(pt) = -alpha * (1 - pt)**gamma * log(pt)
Expand Down
4 changes: 1 addition & 3 deletions monai/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,8 @@ def __init__(

if network_type.lower() not in list(PercetualNetworkType):
raise ValueError(
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
% ", ".join(PercetualNetworkType)
f"Unrecognised criterion entered for Perceptual Loss. Must be one in: {', '.join(PercetualNetworkType)}"
)

if cache_dir:
torch.hub.set_dir(cache_dir)
# raise a warning that this may change the default cache dir for all torch.hub calls
Expand Down
4 changes: 2 additions & 2 deletions monai/losses/spatial_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import inspect
import warnings
from collections.abc import Callable
from typing import Any, Optional
from typing import Any

import torch
from torch.nn.modules.loss import _Loss
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
if not callable(self.loss):
raise ValueError("The loss function is not callable.")

def forward(self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD].
Expand Down
Loading