From b18b98fd3daa27d401bcfd44c382f51e62e43b52 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Sat, 25 Oct 2025 20:06:53 +0100 Subject: [PATCH 01/12] add grandqc tissue model --- requirements/requirements.txt | 1 + tests/models/test_arch_grandqc.py | 52 ++++++++++ tiatoolbox/data/pretrained_model.yaml | 31 ++++-- tiatoolbox/models/architecture/grandqc.py | 117 ++++++++++++++++++++++ tiatoolbox/wsicore/wsireader.py | 6 +- 5 files changed, 200 insertions(+), 7 deletions(-) create mode 100644 tests/models/test_arch_grandqc.py create mode 100644 tiatoolbox/models/architecture/grandqc.py diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 045a4ce4e..d99b62e10 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -28,6 +28,7 @@ requests>=2.28.1 scikit-image>=0.20 scikit-learn>=1.2.0 scipy>=1.8 +segmentation-models-pytorch>=0.5.0 shapely>=2.0.0 SimpleITK>=2.2.1 sphinx>=5.3.0 diff --git a/tests/models/test_arch_grandqc.py b/tests/models/test_arch_grandqc.py new file mode 100644 index 000000000..7b7c7ecc5 --- /dev/null +++ b/tests/models/test_arch_grandqc.py @@ -0,0 +1,52 @@ +"""Unit test package for GrandQC Tissue Model.""" + +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import torch + +from tiatoolbox.models.architecture import ( + fetch_pretrained_weights, + get_pretrained_model, +) +from tiatoolbox.models.architecture.grandqc import TissueDetectionModel +from tiatoolbox.models.engine.io_config import IOSegmentorConfig +from tiatoolbox.utils.misc import select_device +from tiatoolbox.wsicore.wsireader import WSIReader + +ON_GPU = False + + +def test_functional_grandqc(remote_sample: Callable) -> None: + """Test for GrandQC model.""" + # test fetch pretrained weights + pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection_mpp10") + assert pretrained_weights is not None + + # test creation + model = TissueDetectionModel() + assert model is not None + + # load pretrained weights + pretrained = torch.load(pretrained_weights, map_location="cpu") + model.load_state_dict(pretrained) + + # test get pretrained model + model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10") + assert isinstance(model, TissueDetectionModel) + assert isinstance(ioconfig, IOSegmentorConfig) + + # test inference + mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) + reader = WSIReader.open(mini_wsi_svs) + read_kwargs = {"resolution": 10.0, "units": "mpp", "coord_space": "resolution"} + batch = np.array( + [ + reader.read_bounds((0, 0, 512, 512), **read_kwargs), + reader.read_bounds((512, 512, 1024, 1024), **read_kwargs), + ], + ) + batch = torch.from_numpy(batch) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) + assert output.shape == (2, 512, 512, 2) diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 880c623fe..434ec936a 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -815,7 +815,7 @@ mapde-crchisto: threshold_abs: 250 num_classes: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -837,7 +837,7 @@ mapde-conic: threshold_abs: 205 num_classes: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -860,7 +860,7 @@ sccnn-crchisto: threshold_abs: 0.20 patch_output_shape: [ 13, 13 ] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -883,7 +883,7 @@ sccnn-conic: threshold_abs: 0.05 patch_output_shape: [ 13, 13 ] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -903,7 +903,7 @@ nuclick_original-pannuke: num_input_channels: 5 num_output_channels: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 0.25} @@ -925,7 +925,7 @@ nuclick_light-pannuke: decoder_block: [3,3] skip_type: "add" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 0.25} @@ -934,3 +934,22 @@ nuclick_light-pannuke: patch_input_shape: [128, 128] patch_output_shape: [128, 128] save_resolution: {'units': 'baseline', 'resolution': 1.0} + +grandqc_tissue_detection_mpp10: + hf_repo_id: TIACentre/GrandQC_Tissue_Detection + architecture: + class: grandqc.TissueDetectionModel + kwargs: + num_input_channels: 3 + num_output_channels: 2 + ioconfig: + class: io_config.IOSegmentorConfig + kwargs: + input_resolutions: + - {'units': 'mpp', 'resolution': 10.0} + output_resolutions: + - {'units': 'mpp', 'resolution': 10.0} + patch_input_shape: [512, 512] + patch_output_shape: [512, 512] + stride_shape: [256, 256] + save_resolution: {'units': 'mpp', 'resolution': 10.0} diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py new file mode 100644 index 000000000..8c0907c98 --- /dev/null +++ b/tiatoolbox/models/architecture/grandqc.py @@ -0,0 +1,117 @@ +"""Define GrandQC Tissue Detection Model architecture.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Mapping + +import cv2 +import numpy as np +import segmentation_models_pytorch as smp +import torch + +from tiatoolbox.models.models_abc import ModelABC + + +class TissueDetectionModel(ModelABC): + """GrandQC Tissue Detection Model.""" + + def __init__(self: TissueDetectionModel) -> None: + """Initialize TissueDetectionModel.""" + super().__init__() + self._postproc = self.postproc + self._preproc = self.preproc + self.tissue_detection_model = smp.UnetPlusPlus( + encoder_name="timm-efficientnet-b0", + encoder_weights=None, + classes=2, + activation=None, + ) + + @staticmethod + def preproc(image: np.ndarray) -> np.ndarray: + """Apply jpg compression then ImageNet normalise.""" + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80] + _, compressed_image = cv2.imencode(".jpg", image, encode_param) + compressed_image = np.array(cv2.imdecode(compressed_image, 1)) + + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + return (compressed_image / 255.0 - mean) / std + + @staticmethod + def postproc(image: np.ndarray) -> np.ndarray: + """Define post-processing of this class of model. + + This simply applies argmin along last axis of the input. + + """ + return image.argmin(axis=-1) + + def forward( + self: TissueDetectionModel, + imgs: torch.Tensor, + *args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002 + **kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002 + ) -> torch.Tensor: + """Forward function for model.""" + return self.tissue_detection_model(imgs) + + @staticmethod + def infer_batch( + model: torch.nn.Module, + batch_data: torch.Tensor, + *, + device: str, + ) -> np.ndarray: + """Run inference on an input batch. + + This contains logic for forward operation as well as i/o + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (:class:`torch.Tensor`): + A batch of data generated by + `torch.utils.data.DataLoader`. + device (str): + Transfers model to the specified device. Default is "cpu". + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (:class:`torch.Tensor`): + A batch of data generated by + `torch.utils.data.DataLoader`. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + np.ndarray: + The inference results as a numpy array. + + """ + model.eval() + + #### + imgs = batch_data + + imgs = imgs.to(device).type(torch.float32) + imgs = imgs.permute(0, 3, 1, 2) # to NCHW + + with torch.inference_mode(): + logits = model(imgs) + probs = torch.nn.functional.softmax(logits, 1) + probs = probs.permute(0, 2, 3, 1) # to NHWC + + return probs.cpu().numpy() + + def load_state_dict( + self: TissueDetectionModel, + state_dict: Mapping[str, Any], + **kwargs: bool, + ) -> torch.nn.modules.module._IncompatibleKeys: + """Load state dict for the TissueDetectionModel.""" + return self.tissue_detection_model.load_state_dict(state_dict, **kwargs) diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py index 0791bc4fa..a135b61b8 100644 --- a/tiatoolbox/wsicore/wsireader.py +++ b/tiatoolbox/wsicore/wsireader.py @@ -1572,7 +1572,11 @@ def tissue_mask( ) elif method == "otsu": masker = tissuemask.OtsuTissueMasker(**masker_kwargs) - mask_img = masker.fit_transform([thumbnail])[0] + else: # grandqc + masker = tissuemask.GrandQCMasker(**masker_kwargs) + # GrandQC model is trained on 10mpp images + thumbnail = self.slide_thumbnail(resolution=10, units="mpp") + mask_img = masker.fit_transform(np.array([thumbnail]))[0] return VirtualWSIReader(mask_img.astype(np.uint8), info=self.info, mode="bool") def save_tiles( From 899d6cb381e2da5196e6cd6b0c036b7111f3371b Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Sat, 25 Oct 2025 20:30:01 +0100 Subject: [PATCH 02/12] add example --- tiatoolbox/models/architecture/grandqc.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py index 8c0907c98..43071d57a 100644 --- a/tiatoolbox/models/architecture/grandqc.py +++ b/tiatoolbox/models/architecture/grandqc.py @@ -16,7 +16,23 @@ class TissueDetectionModel(ModelABC): - """GrandQC Tissue Detection Model.""" + """GrandQC Tissue Detection Model. + + Example: + >>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor + >>> semantic_segmentor = SemanticSegmentor( + ... model="grandqc_tissue_detection_mpp10", + ... ) + >>> results = semantic_segmentor.run( + ... ["/example_wsi.svs"], + ... masks=None, + ... auto_get_mask=False, + ... patch_mode=False, + ... save_dir=Path("/tissue_mask/"), + ... output_type="annotationstore", + ... ) + + """ def __init__(self: TissueDetectionModel) -> None: """Initialize TissueDetectionModel.""" From 8a7295d01b0235d733b819c5931a26aaa60c0e43 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Sat, 25 Oct 2025 20:33:44 +0100 Subject: [PATCH 03/12] fix tests --- tests/models/test_arch_grandqc.py | 4 +++- tiatoolbox/models/architecture/grandqc.py | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/models/test_arch_grandqc.py b/tests/models/test_arch_grandqc.py index 7b7c7ecc5..0c31d7f36 100644 --- a/tests/models/test_arch_grandqc.py +++ b/tests/models/test_arch_grandqc.py @@ -25,7 +25,7 @@ def test_functional_grandqc(remote_sample: Callable) -> None: assert pretrained_weights is not None # test creation - model = TissueDetectionModel() + model = TissueDetectionModel(num_input_channels=3, num_output_channels=2) assert model is not None # load pretrained weights @@ -36,6 +36,8 @@ def test_functional_grandqc(remote_sample: Callable) -> None: model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10") assert isinstance(model, TissueDetectionModel) assert isinstance(ioconfig, IOSegmentorConfig) + assert model.num_input_channels == 3 + assert model.num_output_channels == 2 # test inference mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py index 43071d57a..16f46aee9 100644 --- a/tiatoolbox/models/architecture/grandqc.py +++ b/tiatoolbox/models/architecture/grandqc.py @@ -34,15 +34,20 @@ class TissueDetectionModel(ModelABC): """ - def __init__(self: TissueDetectionModel) -> None: + def __init__( + self: TissueDetectionModel, num_input_channels: int, num_output_channels: int + ) -> None: """Initialize TissueDetectionModel.""" super().__init__() + self.num_input_channels = num_input_channels + self.num_output_channels = num_output_channels self._postproc = self.postproc self._preproc = self.preproc self.tissue_detection_model = smp.UnetPlusPlus( encoder_name="timm-efficientnet-b0", encoder_weights=None, - classes=2, + in_channels=self.num_input_channels, + classes=self.num_output_channels, activation=None, ) From 5c5bfc4efa89c6d4af417b96e6a7f168a364c2e5 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Sat, 25 Oct 2025 21:30:36 +0100 Subject: [PATCH 04/12] fix error --- tiatoolbox/models/architecture/grandqc.py | 14 +++----------- tiatoolbox/wsicore/wsireader.py | 6 +----- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py index 16f46aee9..738938da7 100644 --- a/tiatoolbox/models/architecture/grandqc.py +++ b/tiatoolbox/models/architecture/grandqc.py @@ -64,9 +64,10 @@ def preproc(image: np.ndarray) -> np.ndarray: @staticmethod def postproc(image: np.ndarray) -> np.ndarray: - """Define post-processing of this class of model. + """Define post-processing for this model. - This simply applies argmin along last axis of the input. + This returns the class index with the minimum probability. + In this model, this means selecting tissue class. """ return image.argmin(axis=-1) @@ -91,15 +92,6 @@ def infer_batch( This contains logic for forward operation as well as i/o - Args: - model (nn.Module): - PyTorch defined model. - batch_data (:class:`torch.Tensor`): - A batch of data generated by - `torch.utils.data.DataLoader`. - device (str): - Transfers model to the specified device. Default is "cpu". - Args: model (nn.Module): PyTorch defined model. diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py index a135b61b8..0791bc4fa 100644 --- a/tiatoolbox/wsicore/wsireader.py +++ b/tiatoolbox/wsicore/wsireader.py @@ -1572,11 +1572,7 @@ def tissue_mask( ) elif method == "otsu": masker = tissuemask.OtsuTissueMasker(**masker_kwargs) - else: # grandqc - masker = tissuemask.GrandQCMasker(**masker_kwargs) - # GrandQC model is trained on 10mpp images - thumbnail = self.slide_thumbnail(resolution=10, units="mpp") - mask_img = masker.fit_transform(np.array([thumbnail]))[0] + mask_img = masker.fit_transform([thumbnail])[0] return VirtualWSIReader(mask_img.astype(np.uint8), info=self.info, mode="bool") def save_tiles( From fd692daa554f277a1313fb7226bfdbe8bdfebcad Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Tue, 28 Oct 2025 15:07:33 +0000 Subject: [PATCH 05/12] update docstring --- tiatoolbox/models/architecture/grandqc.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py index 738938da7..e2f36254b 100644 --- a/tiatoolbox/models/architecture/grandqc.py +++ b/tiatoolbox/models/architecture/grandqc.py @@ -66,8 +66,8 @@ def preproc(image: np.ndarray) -> np.ndarray: def postproc(image: np.ndarray) -> np.ndarray: """Define post-processing for this model. - This returns the class index with the minimum probability. - In this model, this means selecting tissue class. + This simply applies argmin to obtain tissue class. + (Tissue = 0, Background = 1) """ return image.argmin(axis=-1) @@ -108,9 +108,7 @@ def infer_batch( """ model.eval() - #### imgs = batch_data - imgs = imgs.to(device).type(torch.float32) imgs = imgs.permute(0, 3, 1, 2) # to NCHW From d82cc3d91acb9be4c5402b6a28b22cdc5972e0e4 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Tue, 28 Oct 2025 15:22:20 +0000 Subject: [PATCH 06/12] improve test coverage --- tests/models/test_arch_grandqc.py | 32 +++++++++++++++++------ tiatoolbox/models/architecture/grandqc.py | 2 +- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/tests/models/test_arch_grandqc.py b/tests/models/test_arch_grandqc.py index 0c31d7f36..1c124fc09 100644 --- a/tests/models/test_arch_grandqc.py +++ b/tests/models/test_arch_grandqc.py @@ -1,8 +1,5 @@ """Unit test package for GrandQC Tissue Model.""" -from collections.abc import Callable -from pathlib import Path - import numpy as np import torch @@ -13,12 +10,12 @@ from tiatoolbox.models.architecture.grandqc import TissueDetectionModel from tiatoolbox.models.engine.io_config import IOSegmentorConfig from tiatoolbox.utils.misc import select_device -from tiatoolbox.wsicore.wsireader import WSIReader +from tiatoolbox.wsicore.wsireader import VirtualWSIReader ON_GPU = False -def test_functional_grandqc(remote_sample: Callable) -> None: +def test_functional_grandqc() -> None: """Test for GrandQC model.""" # test fetch pretrained weights pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection_mpp10") @@ -40,9 +37,10 @@ def test_functional_grandqc(remote_sample: Callable) -> None: assert model.num_output_channels == 2 # test inference - mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - reader = WSIReader.open(mini_wsi_svs) - read_kwargs = {"resolution": 10.0, "units": "mpp", "coord_space": "resolution"} + generator = np.random.default_rng(1337) + test_image = generator.integers(0, 256, size=(2048, 2048, 3), dtype=np.uint8) + reader = VirtualWSIReader.open(test_image) + read_kwargs = {"resolution": 0, "units": "level", "coord_space": "resolution"} batch = np.array( [ reader.read_bounds((0, 0, 512, 512), **read_kwargs), @@ -52,3 +50,21 @@ def test_functional_grandqc(remote_sample: Callable) -> None: batch = torch.from_numpy(batch) output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) assert output.shape == (2, 512, 512, 2) + + +def test_grandqc_preproc_postproc() -> None: + """Test GrandQC preproc and postproc functions.""" + model = TissueDetectionModel(num_input_channels=3, num_output_channels=2) + + generator = np.random.default_rng(1337) + # test preproc + dummy_image = generator.integers(0, 256, size=(512, 512, 3), dtype=np.uint8) + preproc_image = model.preproc(dummy_image) + assert preproc_image.shape == dummy_image.shape + assert preproc_image.dtype == np.float64 + + # test postproc + dummy_output = generator.random(size=(512, 512, 2), dtype=np.float32) + postproc_image = model.postproc(dummy_output) + assert postproc_image.shape == (512, 512) + assert postproc_image.dtype == np.int64 diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py index e2f36254b..a2ffdf2db 100644 --- a/tiatoolbox/models/architecture/grandqc.py +++ b/tiatoolbox/models/architecture/grandqc.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from collections.abc import Mapping import cv2 From 93a24a1ab3fb8f6205f112dbf063256224e9be1a Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Thu, 6 Nov 2025 16:54:39 +0000 Subject: [PATCH 07/12] add unet++ model --- .../models/architecture/timm_universal.py | 247 +++++++++ .../models/architecture/unetplusplus.py | 470 ++++++++++++++++++ 2 files changed, 717 insertions(+) create mode 100644 tiatoolbox/models/architecture/timm_universal.py create mode 100644 tiatoolbox/models/architecture/unetplusplus.py diff --git a/tiatoolbox/models/architecture/timm_universal.py b/tiatoolbox/models/architecture/timm_universal.py new file mode 100644 index 000000000..138b2ef8b --- /dev/null +++ b/tiatoolbox/models/architecture/timm_universal.py @@ -0,0 +1,247 @@ +""" +TimmUniversalEncoder provides a unified feature extraction interface built on the +`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style +models (e.g., Swin Transformer, ConvNeXt). + +This encoder produces consistent multi-level feature maps for semantic segmentation tasks. +It allows configuring the number of feature extraction stages (`depth`) and adjusting +`output_stride` when supported. + +Key Features: +- Flexible model selection using `timm.create_model`. +- Unified multi-level output across different model hierarchies. +- Automatic alignment for inconsistent feature scales: + - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale. + - VGG-style models (include scale-1 features): Align outputs for compatibility. +- Easy access to feature scale information via the `reduction` property. + +Feature Scale Differences: +- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32. +- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale. +- VGG-style models: Include scale-1 features (input resolution). + +Notes: +- `output_stride` is unsupported in some models, especially transformer-based architectures. +- Special handling for models like TResNet and DLA to ensure correct feature indexing. +- VGG-style models use `_is_vgg_style` to align scale-1 features with standard outputs. +""" + +from typing import Any + +import timm +import torch +import torch.nn as nn + + +class TimmUniversalEncoder(nn.Module): + """ + A universal encoder leveraging the `timm` library for feature extraction from + various model architectures, including traditional-style and transformer-style models. + + Features: + - Supports configurable depth and output stride. + - Ensures consistent multi-level feature extraction across diverse models. + - Compatible with convolutional and transformer-like backbones. + """ + + _is_torch_scriptable = True + _is_torch_exportable = True + _is_torch_compilable = True + + def __init__( + self, + name: str, + pretrained: bool = True, + in_channels: int = 3, + depth: int = 5, + output_stride: int = 32, + **kwargs: dict[str, Any], + ): + """ + Initialize the encoder. + + Args: + name (str): Model name to load from `timm`. + pretrained (bool): Load pretrained weights (default: True). + in_channels (int): Number of input channels (default: 3 for RGB). + depth (int): Number of feature stages to extract (default: 5). + output_stride (int): Desired output stride (default: 32). + **kwargs: Additional arguments passed to `timm.create_model`. + """ + # At the moment we do not support models with more than 5 stages, + # but can be reconfigured in the future. + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + + super().__init__() + self.name = name + + # Default model configuration for feature extraction + common_kwargs = dict( + in_chans=in_channels, + features_only=True, + output_stride=output_stride, + pretrained=pretrained, + out_indices=tuple(range(depth)), + ) + + # Not all models support output stride argument, drop it by default + if output_stride == 32: + common_kwargs.pop("output_stride") + + # Load a temporary model to analyze its feature hierarchy + try: + with torch.device("meta"): + tmp_model = timm.create_model(name, features_only=True) + except Exception: + tmp_model = timm.create_model(name, features_only=True) + + # Check if model output is in channel-last format (NHWC) + self._is_channel_last = getattr(tmp_model, "output_fmt", None) == "NHWC" + + # Determine the model's downsampling pattern and set hierarchy flags + encoder_stage = len(tmp_model.feature_info.reduction()) + reduction_scales = list(tmp_model.feature_info.reduction()) + + if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]: + # Transformer-style downsampling: scales (4, 8, 16, 32) + self._is_transformer_style = True + self._is_vgg_style = False + elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]: + # Traditional-style downsampling: scales (2, 4, 8, 16, 32) + self._is_transformer_style = False + self._is_vgg_style = False + elif reduction_scales == [2**i for i in range(encoder_stage)]: + # Vgg-style models including scale 1: scales (1, 2, 4, 8, 16, 32) + self._is_transformer_style = False + self._is_vgg_style = True + else: + raise ValueError("Unsupported model downsampling pattern.") + + if self._is_transformer_style: + # Transformer-like models (start at scale 4) + if "tresnet" in name: + # 'tresnet' models start feature extraction at stage 1, + # so out_indices=(1, 2, 3, 4) for depth=5. + common_kwargs["out_indices"] = tuple(range(1, depth)) + else: + # Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5. + common_kwargs["out_indices"] = tuple(range(depth - 1)) + + timm_model_kwargs = _merge_kwargs_no_duplicates(common_kwargs, kwargs) + self.model = timm.create_model(name, **timm_model_kwargs) + + # Add a dummy output channel (0) to align with traditional encoder structures. + self._out_channels = ( + [in_channels] + [0] + self.model.feature_info.channels() + ) + else: + if "dla" in name: + # For 'dla' models, out_indices starts at 0 and matches the input size. + common_kwargs["out_indices"] = tuple(range(1, depth + 1)) + if self._is_vgg_style: + common_kwargs["out_indices"] = tuple(range(depth + 1)) + + self.model = timm.create_model( + name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) + ) + + if self._is_vgg_style: + self._out_channels = self.model.feature_info.channels() + else: + self._out_channels = [in_channels] + self.model.feature_info.channels() + + self._in_channels = in_channels + self._depth = depth + self._output_stride = output_stride + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Forward pass to extract multi-stage features. + + Args: + x (torch.Tensor): Input tensor of shape (B, C, H, W). + + Returns: + list[torch.Tensor]: List of feature maps at different scales. + """ + features = self.model(x) + + # Convert NHWC to NCHW if needed + if self._is_channel_last: + features = [ + feature.permute(0, 3, 1, 2).contiguous() for feature in features + ] + + # Add dummy feature for scale 1/2 if missing (transformer-style models) + if self._is_transformer_style: + B, _, H, W = x.shape + dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device) + features = [dummy] + features + + # Add input tensor as scale 1 feature if `self._is_vgg_style` is False + if not self._is_vgg_style: + features = [x] + features + + return features + + @property + def out_channels(self) -> list[int]: + """ + Returns the number of output channels for each feature stage. + + Returns: + list[int]: A list of channel dimensions at each scale. + """ + return self._out_channels + + @property + def output_stride(self) -> int: + """ + Returns the effective output stride based on the model depth. + + Returns: + int: The effective output stride. + """ + return int(min(self._output_stride, 2**self._depth)) + + def load_state_dict(self, state_dict, **kwargs): + # for compatibility of weights for + # timm- ported encoders with TimmUniversalEncoder + patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] + + is_deprecated_encoder = any( + self.name.startswith(pattern) for pattern in patterns + ) + + if is_deprecated_encoder: + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if not key.startswith("model."): + new_key = "model." + key + if "gernet" in self.name: + new_key = new_key.replace(".stages.", ".stages_") + state_dict[new_key] = state_dict.pop(key) + + return super().load_state_dict(state_dict, **kwargs) + + +def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: + """ + Merge two dictionaries, ensuring no duplicate keys exist. + + Args: + a (dict): Base dictionary. + b (dict): Additional parameters to merge. + + Returns: + dict: A merged dictionary. + """ + duplicates = a.keys() & b.keys() + if duplicates: + raise ValueError(f"'{duplicates}' already specified internally") + + return a | b diff --git a/tiatoolbox/models/architecture/unetplusplus.py b/tiatoolbox/models/architecture/unetplusplus.py new file mode 100644 index 000000000..9861484cc --- /dev/null +++ b/tiatoolbox/models/architecture/unetplusplus.py @@ -0,0 +1,470 @@ +"""Define Unet++ architecture from Segmentation Models Pytorch.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union + +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Mapping + +import torch +import torch.nn as nn +import torch.nn.functional as F +import warnings + +from tiatoolbox.models.models_abc import ModelABC +from tiatoolbox.models.architecture.timm_universal import TimmUniversalEncoder + + +class ArgMax(nn.Module): + def __init__(self, dim=None): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.argmax(x, dim=self.dim) + + +class Clamp(nn.Module): + def __init__(self, min=0, max=1): + super().__init__() + self.min, self.max = min, max + + def forward(self, x): + return torch.clamp(x, self.min, self.max) + +class Activation(nn.Module): + def __init__(self, name, **params): + super().__init__() + self.activation: nn.Module + if name is None or name == "identity": + self.activation = nn.Identity(**params) + elif name == "sigmoid": + self.activation = nn.Sigmoid() + elif name == "softmax2d": + self.activation = nn.Softmax(dim=1, **params) + elif name == "softmax": + self.activation = nn.Softmax(**params) + elif name == "logsoftmax": + self.activation = nn.LogSoftmax(**params) + elif name == "tanh": + self.activation = nn.Tanh() + elif name == "argmax": + self.activation = ArgMax(**params) + elif name == "argmax2d": + self.activation = ArgMax(dim=1, **params) + elif name == "clamp": + self.activation = Clamp(**params) + else: + self.activation = nn.Identity(**params) + raise ValueError( + f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/" + f"argmax/argmax2d/clamp/None; got {name}" + ) + + def forward(self, x): + return self.activation(x) + + +class SegmentationHead(nn.Sequential): + def __init__( + self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1 + ): + conv2d = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + upsampling = ( + nn.UpsamplingBilinear2d(scale_factor=upsampling) + if upsampling > 1 + else nn.Identity() + ) + activation = Activation(activation) + super().__init__(conv2d, upsampling, activation) + + +class ClassificationHead(nn.Sequential): + def __init__( + self, in_channels, classes, pooling="avg", dropout=0.2, activation=None + ): + if pooling not in ("max", "avg"): + raise ValueError( + "Pooling should be one of ('max', 'avg'), got {}.".format(pooling) + ) + pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1) + flatten = nn.Flatten() + dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity() + linear = nn.Linear(in_channels, classes, bias=True) + activation = Activation(activation) + super().__init__(pool, flatten, dropout, linear, activation) + +def get_norm_layer( + use_norm: Union[bool, str, Dict[str, Any]], out_channels: int +) -> nn.Module: + supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm") + + # Step 1. Convert tot dict representation + + ## Check boolean + if use_norm is True: + norm_params = {"type": "batchnorm"} + elif use_norm is False: + norm_params = {"type": "identity"} + + ## Check string + elif isinstance(use_norm, str): + norm_str = use_norm.lower() + if norm_str == "inplace": + norm_params = { + "type": "inplace", + "activation": "leaky_relu", + "activation_param": 0.0, + } + elif norm_str in supported_norms: + norm_params = {"type": norm_str} + else: + raise ValueError( + f"Unrecognized normalization type string provided: {use_norm}. Should be in " + f"{supported_norms}" + ) + + ## Check dict + elif isinstance(use_norm, dict): + norm_params = use_norm + + else: + raise ValueError( + f"Invalid type for use_norm should either be a bool (batchnorm/identity), " + f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}" + ) + + # Step 2. Check if the dict is valid + if "type" not in norm_params: + raise ValueError( + f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'." + ) + if norm_params["type"] not in supported_norms: + raise ValueError( + f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}" + ) + + # Step 3. Initialize the norm layer + norm_type = norm_params["type"] + norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"} + + + if norm_type == "batchnorm": + norm = nn.BatchNorm2d(out_channels, **norm_kwargs) + elif norm_type == "identity": + norm = nn.Identity() + elif norm_type == "layernorm": + norm = nn.LayerNorm(out_channels, **norm_kwargs) + elif norm_type == "instancenorm": + norm = nn.InstanceNorm2d(out_channels, **norm_kwargs) + else: + raise ValueError(f"Unrecognized normalization type: {norm_type}") + + return norm + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int = 0, + stride: int = 1, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): + norm = get_norm_layer(use_norm, out_channels) + + is_identity = isinstance(norm, nn.Identity) + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=is_identity, + ) + + is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN) + activation = nn.Identity() if is_inplaceabn else nn.ReLU(inplace=True) + + super(Conv2dReLU, self).__init__(conv, norm, activation) + +class SCSEModule(nn.Module): + def __init__(self, in_channels, reduction=16): + super().__init__() + self.cSE = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, in_channels // reduction, 1), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels // reduction, in_channels, 1), + nn.Sigmoid(), + ) + self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) + + def forward(self, x): + return x * self.cSE(x) + x * self.sSE(x) + +class Attention(nn.Module): + def __init__(self, name, **params): + super().__init__() + + if name is None: + self.attention = nn.Identity(**params) + elif name == "scse": + self.attention = SCSEModule(**params) + else: + raise ValueError("Attention {} is not implemented".format(name)) + + def forward(self, x): + return self.attention(x) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + skip_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + attention_type: Optional[str] = None, + interpolation_mode: str = "nearest", + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_norm=use_norm, + ) + self.attention1 = Attention( + attention_type, in_channels=in_channels + skip_channels + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_norm=use_norm, + ) + self.attention2 = Attention(attention_type, in_channels=out_channels) + self.interpolation_mode = interpolation_mode + + def forward( + self, x: torch.Tensor, skip: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.attention1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.attention2(x) + return x + + +class CenterBlock(nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): + conv1 = Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + use_norm=use_norm, + ) + conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_norm=use_norm, + ) + super().__init__(conv1, conv2) + + +class UnetPlusPlusDecoder(nn.Module): + def __init__( + self, + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], + n_blocks: int = 5, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + attention_type: Optional[str] = None, + interpolation_mode: str = "nearest", + center: bool = False, + ): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + f"Model depth is {n_blocks}, but you provide `decoder_channels` for {len(decoder_channels)} blocks." + ) + + # remove first skip with same spatial resolution + encoder_channels = encoder_channels[1:] + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + self.in_channels = [head_channels] + list(decoder_channels[:-1]) + self.skip_channels = list(encoder_channels[1:]) + [0] + self.out_channels = decoder_channels + if center: + self.center = CenterBlock( + head_channels, + head_channels, + use_norm=use_norm, + ) + else: + self.center = nn.Identity() + + # combine decoder keyword arguments + kwargs = dict( + use_norm=use_norm, + attention_type=attention_type, + interpolation_mode=interpolation_mode, + ) + + blocks = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(layer_idx + 1): + if depth_idx == 0: + in_ch = self.in_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1) + out_ch = self.out_channels[layer_idx] + else: + out_ch = self.skip_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * ( + layer_idx + 1 - depth_idx + ) + in_ch = self.skip_channels[layer_idx - 1] + blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock( + in_ch, skip_ch, out_ch, **kwargs + ) + blocks[f"x_{0}_{len(self.in_channels) - 1}"] = DecoderBlock( + self.in_channels[-1], 0, self.out_channels[-1], **kwargs + ) + self.blocks = nn.ModuleDict(blocks) + self.depth = len(self.in_channels) - 1 + + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + # start building dense connections + dense_x = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(self.depth - layer_idx): + if layer_idx == 0: + output = self.blocks[f"x_{depth_idx}_{depth_idx}"]( + features[depth_idx], features[depth_idx + 1] + ) + dense_x[f"x_{depth_idx}_{depth_idx}"] = output + else: + dense_l_i = depth_idx + layer_idx + cat_features = [ + dense_x[f"x_{idx}_{dense_l_i}"] + for idx in range(depth_idx + 1, dense_l_i + 1) + ] + cat_features = torch.cat( + cat_features + [features[dense_l_i + 1]], dim=1 + ) + dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[ + f"x_{depth_idx}_{dense_l_i}" + ](dense_x[f"x_{depth_idx}_{dense_l_i - 1}"], cat_features) + dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"]( + dense_x[f"x_{0}_{self.depth - 1}"] + ) + return dense_x[f"x_{0}_{self.depth}"] + + +class UNetPlusPlusModel(ModelABC): + """UNet++ Model.""" + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: Optional[str] = "imagenet", + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), + decoder_attention_type: Optional[str] = None, + decoder_interpolation: str = "nearest", + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, Callable]] = None, + aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], + ): + super().__init__() + + if encoder_name.startswith("mit_b"): + raise ValueError( + "UnetPlusPlus is not support encoder_name={}".format(encoder_name) + ) + + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + + self.encoder = TimmUniversalEncoder( + name=encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + **kwargs, + ) + + self.decoder = UnetPlusPlusDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + n_blocks=encoder_depth, + use_norm=decoder_use_norm, + center=True if encoder_name.startswith("vgg") else False, + attention_type=decoder_attention_type, + interpolation_mode=decoder_interpolation, + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "unetplusplus-{}".format(encoder_name) + + def forward(self, x): + """Sequentially pass `x` trough model`s encoder, decoder and heads""" + + features = self.encoder(x) + decoder_output = self.decoder(features) + + masks = self.segmentation_head(decoder_output) + + if self.classification_head is not None: + labels = self.classification_head(features[-1]) + return masks, labels + + return masks \ No newline at end of file From 94c43eec7e50042c533568cc365a523e367dd076 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Nov 2025 16:53:36 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../models/architecture/timm_universal.py | 23 +++---- .../models/architecture/unetplusplus.py | 63 +++++++++---------- 2 files changed, 38 insertions(+), 48 deletions(-) diff --git a/tiatoolbox/models/architecture/timm_universal.py b/tiatoolbox/models/architecture/timm_universal.py index 138b2ef8b..fd8211825 100644 --- a/tiatoolbox/models/architecture/timm_universal.py +++ b/tiatoolbox/models/architecture/timm_universal.py @@ -1,5 +1,4 @@ -""" -TimmUniversalEncoder provides a unified feature extraction interface built on the +"""TimmUniversalEncoder provides a unified feature extraction interface built on the `timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style models (e.g., Swin Transformer, ConvNeXt). @@ -30,12 +29,11 @@ import timm import torch -import torch.nn as nn +from torch import nn class TimmUniversalEncoder(nn.Module): - """ - A universal encoder leveraging the `timm` library for feature extraction from + """A universal encoder leveraging the `timm` library for feature extraction from various model architectures, including traditional-style and transformer-style models. Features: @@ -57,8 +55,7 @@ def __init__( output_stride: int = 32, **kwargs: dict[str, Any], ): - """ - Initialize the encoder. + """Initialize the encoder. Args: name (str): Model name to load from `timm`. @@ -158,8 +155,7 @@ def __init__( self._output_stride = output_stride def forward(self, x: torch.Tensor) -> list[torch.Tensor]: - """ - Forward pass to extract multi-stage features. + """Forward pass to extract multi-stage features. Args: x (torch.Tensor): Input tensor of shape (B, C, H, W). @@ -189,8 +185,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: @property def out_channels(self) -> list[int]: - """ - Returns the number of output channels for each feature stage. + """Returns the number of output channels for each feature stage. Returns: list[int]: A list of channel dimensions at each scale. @@ -199,8 +194,7 @@ def out_channels(self) -> list[int]: @property def output_stride(self) -> int: - """ - Returns the effective output stride based on the model depth. + """Returns the effective output stride based on the model depth. Returns: int: The effective output stride. @@ -230,8 +224,7 @@ def load_state_dict(self, state_dict, **kwargs): def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: - """ - Merge two dictionaries, ensuring no duplicate keys exist. + """Merge two dictionaries, ensuring no duplicate keys exist. Args: a (dict): Base dictionary. diff --git a/tiatoolbox/models/architecture/unetplusplus.py b/tiatoolbox/models/architecture/unetplusplus.py index 9861484cc..4611ce56f 100644 --- a/tiatoolbox/models/architecture/unetplusplus.py +++ b/tiatoolbox/models/architecture/unetplusplus.py @@ -2,18 +2,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union - -if TYPE_CHECKING: # pragma: no cover - from collections.abc import Mapping +import warnings +from collections.abc import Callable, Sequence +from typing import Any import torch -import torch.nn as nn import torch.nn.functional as F -import warnings +from torch import nn -from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.models.architecture.timm_universal import TimmUniversalEncoder +from tiatoolbox.models.models_abc import ModelABC class ArgMax(nn.Module): @@ -33,6 +31,7 @@ def __init__(self, min=0, max=1): def forward(self, x): return torch.clamp(x, self.min, self.max) + class Activation(nn.Module): def __init__(self, name, **params): super().__init__() @@ -87,9 +86,7 @@ def __init__( self, in_channels, classes, pooling="avg", dropout=0.2, activation=None ): if pooling not in ("max", "avg"): - raise ValueError( - "Pooling should be one of ('max', 'avg'), got {}.".format(pooling) - ) + raise ValueError(f"Pooling should be one of ('max', 'avg'), got {pooling}.") pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1) flatten = nn.Flatten() dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity() @@ -97,8 +94,9 @@ def __init__( activation = Activation(activation) super().__init__(pool, flatten, dropout, linear, activation) + def get_norm_layer( - use_norm: Union[bool, str, Dict[str, Any]], out_channels: int + use_norm: bool | str | dict[str, Any], out_channels: int ) -> nn.Module: supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm") @@ -151,7 +149,6 @@ def get_norm_layer( norm_type = norm_params["type"] norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"} - if norm_type == "batchnorm": norm = nn.BatchNorm2d(out_channels, **norm_kwargs) elif norm_type == "identity": @@ -165,6 +162,7 @@ def get_norm_layer( return norm + class Conv2dReLU(nn.Sequential): def __init__( self, @@ -173,7 +171,7 @@ def __init__( kernel_size: int, padding: int = 0, stride: int = 1, - use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + use_norm: bool | str | dict[str, Any] = "batchnorm", ): norm = get_norm_layer(use_norm, out_channels) @@ -192,6 +190,7 @@ def __init__( super(Conv2dReLU, self).__init__(conv, norm, activation) + class SCSEModule(nn.Module): def __init__(self, in_channels, reduction=16): super().__init__() @@ -206,7 +205,8 @@ def __init__(self, in_channels, reduction=16): def forward(self, x): return x * self.cSE(x) + x * self.sSE(x) - + + class Attention(nn.Module): def __init__(self, name, **params): super().__init__() @@ -216,7 +216,7 @@ def __init__(self, name, **params): elif name == "scse": self.attention = SCSEModule(**params) else: - raise ValueError("Attention {} is not implemented".format(name)) + raise ValueError(f"Attention {name} is not implemented") def forward(self, x): return self.attention(x) @@ -228,8 +228,8 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", - attention_type: Optional[str] = None, + use_norm: bool | str | dict[str, Any] = "batchnorm", + attention_type: str | None = None, interpolation_mode: str = "nearest", ): super().__init__() @@ -254,7 +254,7 @@ def __init__( self.interpolation_mode = interpolation_mode def forward( - self, x: torch.Tensor, skip: Optional[torch.Tensor] = None + self, x: torch.Tensor, skip: torch.Tensor | None = None ) -> torch.Tensor: x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode) if skip is not None: @@ -271,7 +271,7 @@ def __init__( self, in_channels: int, out_channels: int, - use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + use_norm: bool | str | dict[str, Any] = "batchnorm", ): conv1 = Conv2dReLU( in_channels, @@ -296,8 +296,8 @@ def __init__( encoder_channels: Sequence[int], decoder_channels: Sequence[int], n_blocks: int = 5, - use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", - attention_type: Optional[str] = None, + use_norm: bool | str | dict[str, Any] = "batchnorm", + attention_type: str | None = None, interpolation_mode: str = "nearest", center: bool = False, ): @@ -356,7 +356,7 @@ def __init__( self.blocks = nn.ModuleDict(blocks) self.depth = len(self.in_channels) - 1 - def forward(self, features: List[torch.Tensor]) -> torch.Tensor: + def forward(self, features: list[torch.Tensor]) -> torch.Tensor: features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder @@ -394,23 +394,21 @@ def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, - encoder_weights: Optional[str] = "imagenet", - decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + encoder_weights: str | None = "imagenet", + decoder_use_norm: bool | str | dict[str, Any] = "batchnorm", decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), - decoder_attention_type: Optional[str] = None, + decoder_attention_type: str | None = None, decoder_interpolation: str = "nearest", in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, Callable]] = None, - aux_params: Optional[dict] = None, + activation: str | Callable | None = None, + aux_params: dict | None = None, **kwargs: dict[str, Any], ): super().__init__() if encoder_name.startswith("mit_b"): - raise ValueError( - "UnetPlusPlus is not support encoder_name={}".format(encoder_name) - ) + raise ValueError(f"UnetPlusPlus is not support encoder_name={encoder_name}") decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) if decoder_use_batchnorm is not None: @@ -453,11 +451,10 @@ def __init__( else: self.classification_head = None - self.name = "unetplusplus-{}".format(encoder_name) + self.name = f"unetplusplus-{encoder_name}" def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" - features = self.encoder(x) decoder_output = self.decoder(features) @@ -467,4 +464,4 @@ def forward(self, x): labels = self.classification_head(features[-1]) return masks, labels - return masks \ No newline at end of file + return masks From 98cef832e6442769250d75455fa12e93b02dd083 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Tue, 18 Nov 2025 18:22:24 +0000 Subject: [PATCH 09/12] remove smp dependency --- requirements/requirements.txt | 1 - tests/models/test_arch_grandqc.py | 8 +- tiatoolbox/data/pretrained_model.yaml | 2 +- tiatoolbox/models/architecture/grandqc.py | 45 +- .../models/architecture/timm_efficientnet.py | 416 ++++++++++++++++ .../models/architecture/timm_universal.py | 240 --------- .../models/architecture/unetplusplus.py | 458 +++++++----------- 7 files changed, 633 insertions(+), 537 deletions(-) create mode 100644 tiatoolbox/models/architecture/timm_efficientnet.py delete mode 100644 tiatoolbox/models/architecture/timm_universal.py diff --git a/requirements/requirements.txt b/requirements/requirements.txt index d99b62e10..045a4ce4e 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -28,7 +28,6 @@ requests>=2.28.1 scikit-image>=0.20 scikit-learn>=1.2.0 scipy>=1.8 -segmentation-models-pytorch>=0.5.0 shapely>=2.0.0 SimpleITK>=2.2.1 sphinx>=5.3.0 diff --git a/tests/models/test_arch_grandqc.py b/tests/models/test_arch_grandqc.py index 1c124fc09..27f6afdd4 100644 --- a/tests/models/test_arch_grandqc.py +++ b/tests/models/test_arch_grandqc.py @@ -7,7 +7,7 @@ fetch_pretrained_weights, get_pretrained_model, ) -from tiatoolbox.models.architecture.grandqc import TissueDetectionModel +from tiatoolbox.models.architecture.grandqc import GrandQCModel from tiatoolbox.models.engine.io_config import IOSegmentorConfig from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import VirtualWSIReader @@ -22,7 +22,7 @@ def test_functional_grandqc() -> None: assert pretrained_weights is not None # test creation - model = TissueDetectionModel(num_input_channels=3, num_output_channels=2) + model = GrandQCModel(num_input_channels=3, num_output_channels=2) assert model is not None # load pretrained weights @@ -31,7 +31,7 @@ def test_functional_grandqc() -> None: # test get pretrained model model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10") - assert isinstance(model, TissueDetectionModel) + assert isinstance(model, GrandQCModel) assert isinstance(ioconfig, IOSegmentorConfig) assert model.num_input_channels == 3 assert model.num_output_channels == 2 @@ -54,7 +54,7 @@ def test_functional_grandqc() -> None: def test_grandqc_preproc_postproc() -> None: """Test GrandQC preproc and postproc functions.""" - model = TissueDetectionModel(num_input_channels=3, num_output_channels=2) + model = GrandQCModel(num_input_channels=3, num_output_channels=2) generator = np.random.default_rng(1337) # test preproc diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index fca703391..f0916dbfc 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -938,7 +938,7 @@ nuclick_light-pannuke: grandqc_tissue_detection_mpp10: hf_repo_id: TIACentre/GrandQC_Tissue_Detection architecture: - class: grandqc.TissueDetectionModel + class: grandqc.GrandQCModel kwargs: num_input_channels: 3 num_output_channels: 2 diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py index a2ffdf2db..e7c3b12ce 100644 --- a/tiatoolbox/models/architecture/grandqc.py +++ b/tiatoolbox/models/architecture/grandqc.py @@ -9,14 +9,14 @@ import cv2 import numpy as np -import segmentation_models_pytorch as smp import torch +from tiatoolbox.models.architecture.unetplusplus import UNetPlusPlusModel from tiatoolbox.models.models_abc import ModelABC -class TissueDetectionModel(ModelABC): - """GrandQC Tissue Detection Model. +class GrandQCModel(ModelABC): + """GrandQC Tissue Detection Model [1]. Example: >>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor @@ -32,10 +32,15 @@ class TissueDetectionModel(ModelABC): ... output_type="annotationstore", ... ) + References: + [1] Weng Z. et al. "GrandQC: a comprehensive solution to quality control problem + in digital pathology". + Nature Communications 2024 + """ def __init__( - self: TissueDetectionModel, num_input_channels: int, num_output_channels: int + self: GrandQCModel, num_input_channels: int, num_output_channels: int ) -> None: """Initialize TissueDetectionModel.""" super().__init__() @@ -43,17 +48,23 @@ def __init__( self.num_output_channels = num_output_channels self._postproc = self.postproc self._preproc = self.preproc - self.tissue_detection_model = smp.UnetPlusPlus( - encoder_name="timm-efficientnet-b0", - encoder_weights=None, - in_channels=self.num_input_channels, + self.tissue_detection_model = UNetPlusPlusModel( classes=self.num_output_channels, - activation=None, ) @staticmethod def preproc(image: np.ndarray) -> np.ndarray: - """Apply jpg compression then ImageNet normalise.""" + """Apply JPEG compression and ImageNet normalization to the input image. + + Args: + image (np.ndarray): + Input image as a NumPy array (H, W, C) in uint8 format. + + Returns: + np.ndarray: + The preprocessed image. + + """ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80] _, compressed_image = cv2.imencode(".jpg", image, encode_param) compressed_image = np.array(cv2.imdecode(compressed_image, 1)) @@ -69,11 +80,19 @@ def postproc(image: np.ndarray) -> np.ndarray: This simply applies argmin to obtain tissue class. (Tissue = 0, Background = 1) + Args: + image (np.ndarray): + Input probability map as a NumPy array (H, W, C). + + Returns: + np.ndarray: + Tissue mask + """ return image.argmin(axis=-1) def forward( - self: TissueDetectionModel, + self: GrandQCModel, imgs: torch.Tensor, *args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002 **kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002 @@ -120,9 +139,9 @@ def infer_batch( return probs.cpu().numpy() def load_state_dict( - self: TissueDetectionModel, + self: GrandQCModel, state_dict: Mapping[str, Any], **kwargs: bool, ) -> torch.nn.modules.module._IncompatibleKeys: - """Load state dict for the TissueDetectionModel.""" + """Load state dict for the GrandQCModel.""" return self.tissue_detection_model.load_state_dict(state_dict, **kwargs) diff --git a/tiatoolbox/models/architecture/timm_efficientnet.py b/tiatoolbox/models/architecture/timm_efficientnet.py new file mode 100644 index 000000000..46a249387 --- /dev/null +++ b/tiatoolbox/models/architecture/timm_efficientnet.py @@ -0,0 +1,416 @@ +"""Defines EfficientNet encoder using timm library.""" + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Mapping, Sequence + +import torch +from timm.layers.activations import Swish +from timm.models._efficientnet_builder import decode_arch_def, round_channels +from timm.models.efficientnet import EfficientNet +from torch import nn + +MAX_DEPTH = 5 +MIN_DEPTH = 1 +DEFAULT_IN_CHANNELS = 3 + + +def patch_first_conv( + model: nn.Module, + new_in_channels: int, + default_in_channels: int = 3, + *, + pretrained: bool = True, +) -> None: + """Change first convolution layer input channels. + + Args: + model: The neural network model to patch. + new_in_channels: Number of input channels for the new first layer. + default_in_channels: Original number of input channels. Defaults to 3. + pretrained: Whether to reuse pretrained weights. Defaults to True. + + Note: + In case: + - in_channels == 1 or in_channels == 2 -> reuse original weights + - in_channels > 3 -> make random kaiming normal initialization + """ + # get first conv + conv_module: nn.Conv2d | None = None + for module in model.modules(): + if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: + conv_module = module + break + + if conv_module is None: + return + + weight = conv_module.weight.detach() + conv_module.in_channels = new_in_channels + + if not pretrained: + conv_module.weight = nn.parameter.Parameter( + torch.Tensor( + conv_module.out_channels, + new_in_channels // conv_module.groups, + *conv_module.kernel_size, + ) + ) + conv_module.reset_parameters() + + elif new_in_channels == 1: + new_weight = weight.sum(1, keepdim=True) + conv_module.weight = nn.parameter.Parameter(new_weight) + + else: + new_weight = torch.Tensor( + conv_module.out_channels, + new_in_channels // conv_module.groups, + *conv_module.kernel_size, + ) + + for i in range(new_in_channels): + new_weight[:, i] = weight[:, i % default_in_channels] + + new_weight = new_weight * (default_in_channels / new_in_channels) + conv_module.weight = nn.parameter.Parameter(new_weight) + + +def replace_strides_with_dilation(module: nn.Module, dilation_rate: int) -> None: + """Patch Conv2d modules replacing strides with dilation. + + Args: + module: The module containing Conv2d layers to patch. + dilation_rate: The dilation rate to apply. + """ + for mod in module.modules(): + if isinstance(mod, nn.Conv2d): + mod.stride = (1, 1) + mod.dilation = (dilation_rate, dilation_rate) + kh, _ = mod.kernel_size + mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) + + # Workaround for EfficientNet + if hasattr(mod, "static_padding"): + mod.static_padding = nn.Identity() # type: ignore[attr-defined] + + +class EncoderMixin: + """Add encoder functionality. + + Such as: + - output channels specification of feature tensors (produced by encoder) + - patching first convolution for arbitrary input channels + """ + + _is_torch_scriptable = True + _is_torch_exportable = True + _is_torch_compilable = True + + def __init__(self) -> None: + """Initialize EncoderMixin with default parameters.""" + self._depth = 5 + self._in_channels = 3 + self._output_stride = 32 + self._out_channels: list[int] = [] + + @property + def out_channels(self) -> list[int]: + """Return channels dimensions for each tensor of forward output of encoder. + + Returns: + List of output channel dimensions for each depth level. + """ + return self._out_channels[: self._depth + 1] + + @property + def output_stride(self) -> int: + """Return the output stride of the encoder. + + Returns: + The minimum of configured output stride and 2^depth. + """ + return min(self._output_stride, 2**self._depth) + + def set_in_channels(self, in_channels: int, *, pretrained: bool = True) -> None: + """Change first convolution channels. + + Args: + in_channels: Number of input channels. + pretrained: Whether to use pretrained weights. Defaults to True. + """ + if in_channels == DEFAULT_IN_CHANNELS: + return + + self._in_channels = in_channels + if self._out_channels[0] == DEFAULT_IN_CHANNELS: + self._out_channels = [in_channels, *self._out_channels[1:]] + + # Type ignore needed because self is a mixin that will be used with nn.Module + patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) # type: ignore[arg-type] + + def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: + """Get stages for dilation modification. + + Override this method in your implementation. + + Returns: + Dictionary with keys as output stride and values as list of modules. + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError + + def make_dilated(self, output_stride: int) -> None: + """Convert encoder to dilated version. + + Args: + output_stride: Target output stride (8 or 16). + + Raises: + ValueError: If output_stride is not 8 or 16. + """ + if output_stride not in [8, 16]: + msg = f"Output stride should be 16 or 8, got {output_stride}." + raise ValueError(msg) + + stages = self.get_stages() + for stage_stride, stage_modules in stages.items(): + if stage_stride <= output_stride: + continue + + dilation_rate = stage_stride // output_stride + for module in stage_modules: + replace_strides_with_dilation(module, dilation_rate) + + +def get_efficientnet_kwargs( + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + drop_rate: float = 0.2, +) -> dict[str, Any]: + """Create EfficientNet model kwargs. + + Reference implementation: + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet parameters: + - 'efficientnet-b0': (1.0, 1.0, 224, 0.2) + - 'efficientnet-b1': (1.0, 1.1, 240, 0.2) + - 'efficientnet-b2': (1.1, 1.2, 260, 0.3) + - 'efficientnet-b3': (1.2, 1.4, 300, 0.3) + - 'efficientnet-b4': (1.4, 1.8, 380, 0.4) + - 'efficientnet-b5': (1.6, 2.2, 456, 0.4) + - 'efficientnet-b6': (1.8, 2.6, 528, 0.5) + - 'efficientnet-b7': (2.0, 3.1, 600, 0.5) + - 'efficientnet-b8': (2.2, 3.6, 672, 0.5) + - 'efficientnet-l2': (4.3, 5.3, 800, 0.5) + + Args: + channel_multiplier: Multiplier to number of channels per layer. Defaults to 1.0. + depth_multiplier: Multiplier to number of repeats per stage. Defaults to 1.0. + drop_rate: Dropout rate. Defaults to 0.2. + + Returns: + Dictionary containing model configuration parameters. + """ + arch_def = [ + ["ds_r1_k3_s1_e1_c16_se0.25"], + ["ir_r2_k3_s2_e6_c24_se0.25"], + ["ir_r2_k5_s2_e6_c40_se0.25"], + ["ir_r3_k3_s2_e6_c80_se0.25"], + ["ir_r3_k5_s1_e6_c112_se0.25"], + ["ir_r4_k5_s2_e6_c192_se0.25"], + ["ir_r1_k3_s1_e6_c320_se0.25"], + ] + return { + "block_args": decode_arch_def(arch_def, depth_multiplier), + "num_features": round_channels(1280, channel_multiplier, 8, None), + "stem_size": 32, + "round_chs_fn": partial(round_channels, multiplier=channel_multiplier), + "act_layer": Swish, + "drop_rate": drop_rate, + "drop_path_rate": 0.2, + } + + +class EfficientNetBaseEncoder(EfficientNet, EncoderMixin): + """EfficientNet encoder base class. + + Combines EfficientNet architecture with encoder functionality. + """ + + def __init__( + self, + stage_idxs: list[int], + out_channels: list[int], + depth: int = 5, + output_stride: int = 32, + **kwargs: dict[str, Any], + ) -> None: + """Initialize EfficientNetBaseEncoder. + + Args: + stage_idxs: Indices of stages for feature extraction. + out_channels: Output channels for each depth level. + depth: Encoder depth (1-5). Defaults to 5. + output_stride: Output stride of encoder. Defaults to 32. + **kwargs: Additional keyword arguments for EfficientNet. + + Raises: + ValueError: If depth is not in range [1, 5]. + """ + if depth > MAX_DEPTH or depth < MIN_DEPTH: + msg = f"{self.__class__.__name__} depth should be in range \ + [1, 5], got {depth}" + raise ValueError(msg) + super().__init__(**kwargs) + + self._stage_idxs = stage_idxs + self._depth = depth + self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride + + del self.classifier + + def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: + """Get stages for dilation modification. + + Returns: + Dictionary mapping output strides to corresponding module sequences. + """ + return { + 16: [self.blocks[self._stage_idxs[1] : self._stage_idxs[2]]], # type: ignore[attr-defined] + 32: [self.blocks[self._stage_idxs[2] :]], # type: ignore[attr-defined] + } + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: # type: ignore[override] + """Forward pass through encoder. + + Args: + x: Input tensor. + + Returns: + List of feature tensors from different encoder depths. + """ + features = [x] + + if self._depth >= 1: + x = self.conv_stem(x) # type: ignore[attr-defined] + x = self.bn1(x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 2: # noqa: PLR2004 + x = self.blocks[0](x) # type: ignore[attr-defined] + x = self.blocks[1](x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 3: # noqa: PLR2004 + x = self.blocks[2](x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 4: # noqa: PLR2004 + x = self.blocks[3](x) # type: ignore[attr-defined] + x = self.blocks[4](x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 5: # noqa: PLR2004 + x = self.blocks[5](x) # type: ignore[attr-defined] + x = self.blocks[6](x) # type: ignore[attr-defined] + features.append(x) + + return features + + def load_state_dict( + self, state_dict: Mapping[str, Any], **kwargs: bool + ) -> torch.nn.modules.module._IncompatibleKeys: + """Load state dictionary, excluding classifier weights. + + Args: + state_dict: State dictionary to load. + **kwargs: Additional keyword arguments for load_state_dict. + + Returns: + Result of parent class load_state_dict method. + """ + # Create a mutable copy of the state dict to modify + state_dict_copy = dict(state_dict) + state_dict_copy.pop("classifier.bias", None) + state_dict_copy.pop("classifier.weight", None) + return super().load_state_dict(state_dict_copy, **kwargs) + + +class EfficientNetEncoder(EfficientNetBaseEncoder): + """EfficientNet encoder with configurable scaling parameters. + + Provides a configurable EfficientNet encoder that can be scaled + in terms of depth and channel multipliers. + """ + + def __init__( + self, + stage_idxs: list[int], + out_channels: list[int], + depth: int = 5, + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + drop_rate: float = 0.2, + output_stride: int = 32, + ) -> None: + """Initialize EfficientNetEncoder. + + Args: + stage_idxs: Indices of stages for feature extraction. + out_channels: Output channels for each depth level. + depth: Encoder depth (1-5). Defaults to 5. + channel_multiplier: Channel scaling factor. Defaults to 1.0. + depth_multiplier: Depth scaling factor. Defaults to 1.0. + drop_rate: Dropout rate. Defaults to 0.2. + output_stride: Output stride of encoder. Defaults to 32. + """ + kwargs = get_efficientnet_kwargs( + channel_multiplier, depth_multiplier, drop_rate + ) + super().__init__( + stage_idxs=stage_idxs, + depth=depth, + out_channels=out_channels, + output_stride=output_stride, + **kwargs, + ) + + +timm_efficientnet_encoders = { + "timm-efficientnet-b0": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b0.imagenet", + "revision": "8419e9cc19da0b68dcd7bb12f19b7c92407ad7c4", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b0.advprop", + "revision": "a5870af2d24ce79e0cc7fae2bbd8e0a21fcfa6d8", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b0.noisy-student", + "revision": "bea8b0ff726a50e48774d2d360c5fb1ac4815836", + }, + }, + "params": { + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], + "channel_multiplier": 1.0, + "depth_multiplier": 1.0, + "drop_rate": 0.2, + }, + }, +} diff --git a/tiatoolbox/models/architecture/timm_universal.py b/tiatoolbox/models/architecture/timm_universal.py deleted file mode 100644 index fd8211825..000000000 --- a/tiatoolbox/models/architecture/timm_universal.py +++ /dev/null @@ -1,240 +0,0 @@ -"""TimmUniversalEncoder provides a unified feature extraction interface built on the -`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style -models (e.g., Swin Transformer, ConvNeXt). - -This encoder produces consistent multi-level feature maps for semantic segmentation tasks. -It allows configuring the number of feature extraction stages (`depth`) and adjusting -`output_stride` when supported. - -Key Features: -- Flexible model selection using `timm.create_model`. -- Unified multi-level output across different model hierarchies. -- Automatic alignment for inconsistent feature scales: - - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale. - - VGG-style models (include scale-1 features): Align outputs for compatibility. -- Easy access to feature scale information via the `reduction` property. - -Feature Scale Differences: -- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32. -- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale. -- VGG-style models: Include scale-1 features (input resolution). - -Notes: -- `output_stride` is unsupported in some models, especially transformer-based architectures. -- Special handling for models like TResNet and DLA to ensure correct feature indexing. -- VGG-style models use `_is_vgg_style` to align scale-1 features with standard outputs. -""" - -from typing import Any - -import timm -import torch -from torch import nn - - -class TimmUniversalEncoder(nn.Module): - """A universal encoder leveraging the `timm` library for feature extraction from - various model architectures, including traditional-style and transformer-style models. - - Features: - - Supports configurable depth and output stride. - - Ensures consistent multi-level feature extraction across diverse models. - - Compatible with convolutional and transformer-like backbones. - """ - - _is_torch_scriptable = True - _is_torch_exportable = True - _is_torch_compilable = True - - def __init__( - self, - name: str, - pretrained: bool = True, - in_channels: int = 3, - depth: int = 5, - output_stride: int = 32, - **kwargs: dict[str, Any], - ): - """Initialize the encoder. - - Args: - name (str): Model name to load from `timm`. - pretrained (bool): Load pretrained weights (default: True). - in_channels (int): Number of input channels (default: 3 for RGB). - depth (int): Number of feature stages to extract (default: 5). - output_stride (int): Desired output stride (default: 32). - **kwargs: Additional arguments passed to `timm.create_model`. - """ - # At the moment we do not support models with more than 5 stages, - # but can be reconfigured in the future. - if depth > 5 or depth < 1: - raise ValueError( - f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" - ) - - super().__init__() - self.name = name - - # Default model configuration for feature extraction - common_kwargs = dict( - in_chans=in_channels, - features_only=True, - output_stride=output_stride, - pretrained=pretrained, - out_indices=tuple(range(depth)), - ) - - # Not all models support output stride argument, drop it by default - if output_stride == 32: - common_kwargs.pop("output_stride") - - # Load a temporary model to analyze its feature hierarchy - try: - with torch.device("meta"): - tmp_model = timm.create_model(name, features_only=True) - except Exception: - tmp_model = timm.create_model(name, features_only=True) - - # Check if model output is in channel-last format (NHWC) - self._is_channel_last = getattr(tmp_model, "output_fmt", None) == "NHWC" - - # Determine the model's downsampling pattern and set hierarchy flags - encoder_stage = len(tmp_model.feature_info.reduction()) - reduction_scales = list(tmp_model.feature_info.reduction()) - - if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]: - # Transformer-style downsampling: scales (4, 8, 16, 32) - self._is_transformer_style = True - self._is_vgg_style = False - elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]: - # Traditional-style downsampling: scales (2, 4, 8, 16, 32) - self._is_transformer_style = False - self._is_vgg_style = False - elif reduction_scales == [2**i for i in range(encoder_stage)]: - # Vgg-style models including scale 1: scales (1, 2, 4, 8, 16, 32) - self._is_transformer_style = False - self._is_vgg_style = True - else: - raise ValueError("Unsupported model downsampling pattern.") - - if self._is_transformer_style: - # Transformer-like models (start at scale 4) - if "tresnet" in name: - # 'tresnet' models start feature extraction at stage 1, - # so out_indices=(1, 2, 3, 4) for depth=5. - common_kwargs["out_indices"] = tuple(range(1, depth)) - else: - # Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5. - common_kwargs["out_indices"] = tuple(range(depth - 1)) - - timm_model_kwargs = _merge_kwargs_no_duplicates(common_kwargs, kwargs) - self.model = timm.create_model(name, **timm_model_kwargs) - - # Add a dummy output channel (0) to align with traditional encoder structures. - self._out_channels = ( - [in_channels] + [0] + self.model.feature_info.channels() - ) - else: - if "dla" in name: - # For 'dla' models, out_indices starts at 0 and matches the input size. - common_kwargs["out_indices"] = tuple(range(1, depth + 1)) - if self._is_vgg_style: - common_kwargs["out_indices"] = tuple(range(depth + 1)) - - self.model = timm.create_model( - name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) - ) - - if self._is_vgg_style: - self._out_channels = self.model.feature_info.channels() - else: - self._out_channels = [in_channels] + self.model.feature_info.channels() - - self._in_channels = in_channels - self._depth = depth - self._output_stride = output_stride - - def forward(self, x: torch.Tensor) -> list[torch.Tensor]: - """Forward pass to extract multi-stage features. - - Args: - x (torch.Tensor): Input tensor of shape (B, C, H, W). - - Returns: - list[torch.Tensor]: List of feature maps at different scales. - """ - features = self.model(x) - - # Convert NHWC to NCHW if needed - if self._is_channel_last: - features = [ - feature.permute(0, 3, 1, 2).contiguous() for feature in features - ] - - # Add dummy feature for scale 1/2 if missing (transformer-style models) - if self._is_transformer_style: - B, _, H, W = x.shape - dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device) - features = [dummy] + features - - # Add input tensor as scale 1 feature if `self._is_vgg_style` is False - if not self._is_vgg_style: - features = [x] + features - - return features - - @property - def out_channels(self) -> list[int]: - """Returns the number of output channels for each feature stage. - - Returns: - list[int]: A list of channel dimensions at each scale. - """ - return self._out_channels - - @property - def output_stride(self) -> int: - """Returns the effective output stride based on the model depth. - - Returns: - int: The effective output stride. - """ - return int(min(self._output_stride, 2**self._depth)) - - def load_state_dict(self, state_dict, **kwargs): - # for compatibility of weights for - # timm- ported encoders with TimmUniversalEncoder - patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] - - is_deprecated_encoder = any( - self.name.startswith(pattern) for pattern in patterns - ) - - if is_deprecated_encoder: - keys = list(state_dict.keys()) - for key in keys: - new_key = key - if not key.startswith("model."): - new_key = "model." + key - if "gernet" in self.name: - new_key = new_key.replace(".stages.", ".stages_") - state_dict[new_key] = state_dict.pop(key) - - return super().load_state_dict(state_dict, **kwargs) - - -def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: - """Merge two dictionaries, ensuring no duplicate keys exist. - - Args: - a (dict): Base dictionary. - b (dict): Additional parameters to merge. - - Returns: - dict: A merged dictionary. - """ - duplicates = a.keys() & b.keys() - if duplicates: - raise ValueError(f"'{duplicates}' already specified internally") - - return a | b diff --git a/tiatoolbox/models/architecture/unetplusplus.py b/tiatoolbox/models/architecture/unetplusplus.py index 4611ce56f..e18309b9e 100644 --- a/tiatoolbox/models/architecture/unetplusplus.py +++ b/tiatoolbox/models/architecture/unetplusplus.py @@ -2,73 +2,39 @@ from __future__ import annotations -import warnings -from collections.abc import Callable, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Sequence + +import numpy as np import torch -import torch.nn.functional as F from torch import nn -from tiatoolbox.models.architecture.timm_universal import TimmUniversalEncoder +from tiatoolbox.models.architecture.timm_efficientnet import EfficientNetEncoder from tiatoolbox.models.models_abc import ModelABC -class ArgMax(nn.Module): - def __init__(self, dim=None): - super().__init__() - self.dim = dim - - def forward(self, x): - return torch.argmax(x, dim=self.dim) - - -class Clamp(nn.Module): - def __init__(self, min=0, max=1): - super().__init__() - self.min, self.max = min, max - - def forward(self, x): - return torch.clamp(x, self.min, self.max) - - -class Activation(nn.Module): - def __init__(self, name, **params): - super().__init__() - self.activation: nn.Module - if name is None or name == "identity": - self.activation = nn.Identity(**params) - elif name == "sigmoid": - self.activation = nn.Sigmoid() - elif name == "softmax2d": - self.activation = nn.Softmax(dim=1, **params) - elif name == "softmax": - self.activation = nn.Softmax(**params) - elif name == "logsoftmax": - self.activation = nn.LogSoftmax(**params) - elif name == "tanh": - self.activation = nn.Tanh() - elif name == "argmax": - self.activation = ArgMax(**params) - elif name == "argmax2d": - self.activation = ArgMax(dim=1, **params) - elif name == "clamp": - self.activation = Clamp(**params) - else: - self.activation = nn.Identity(**params) - raise ValueError( - f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/" - f"argmax/argmax2d/clamp/None; got {name}" - ) - - def forward(self, x): - return self.activation(x) - - class SegmentationHead(nn.Sequential): + """Segmentation head for UNet++ model.""" + def __init__( - self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1 - ): + self: SegmentationHead, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + activation: nn.Module | None = None, + upsampling: int = 1, + ) -> None: + """Initialize SegmentationHead. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Convolution kernel size. Defaults to 3. + activation: Activation function. Defaults to None. + upsampling: Upsampling factor. Defaults to 1. + """ conv2d = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 ) @@ -77,103 +43,32 @@ def __init__( if upsampling > 1 else nn.Identity() ) - activation = Activation(activation) + if activation is None: + activation = nn.Identity() super().__init__(conv2d, upsampling, activation) -class ClassificationHead(nn.Sequential): - def __init__( - self, in_channels, classes, pooling="avg", dropout=0.2, activation=None - ): - if pooling not in ("max", "avg"): - raise ValueError(f"Pooling should be one of ('max', 'avg'), got {pooling}.") - pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1) - flatten = nn.Flatten() - dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity() - linear = nn.Linear(in_channels, classes, bias=True) - activation = Activation(activation) - super().__init__(pool, flatten, dropout, linear, activation) - - -def get_norm_layer( - use_norm: bool | str | dict[str, Any], out_channels: int -) -> nn.Module: - supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm") - - # Step 1. Convert tot dict representation - - ## Check boolean - if use_norm is True: - norm_params = {"type": "batchnorm"} - elif use_norm is False: - norm_params = {"type": "identity"} - - ## Check string - elif isinstance(use_norm, str): - norm_str = use_norm.lower() - if norm_str == "inplace": - norm_params = { - "type": "inplace", - "activation": "leaky_relu", - "activation_param": 0.0, - } - elif norm_str in supported_norms: - norm_params = {"type": norm_str} - else: - raise ValueError( - f"Unrecognized normalization type string provided: {use_norm}. Should be in " - f"{supported_norms}" - ) - - ## Check dict - elif isinstance(use_norm, dict): - norm_params = use_norm - - else: - raise ValueError( - f"Invalid type for use_norm should either be a bool (batchnorm/identity), " - f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}" - ) - - # Step 2. Check if the dict is valid - if "type" not in norm_params: - raise ValueError( - f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'." - ) - if norm_params["type"] not in supported_norms: - raise ValueError( - f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}" - ) - - # Step 3. Initialize the norm layer - norm_type = norm_params["type"] - norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"} - - if norm_type == "batchnorm": - norm = nn.BatchNorm2d(out_channels, **norm_kwargs) - elif norm_type == "identity": - norm = nn.Identity() - elif norm_type == "layernorm": - norm = nn.LayerNorm(out_channels, **norm_kwargs) - elif norm_type == "instancenorm": - norm = nn.InstanceNorm2d(out_channels, **norm_kwargs) - else: - raise ValueError(f"Unrecognized normalization type: {norm_type}") - - return norm - - class Conv2dReLU(nn.Sequential): + """Conv2d + BatchNorm + ReLU block.""" + def __init__( - self, + self: Conv2dReLU, in_channels: int, out_channels: int, kernel_size: int, padding: int = 0, stride: int = 1, - use_norm: bool | str | dict[str, Any] = "batchnorm", - ): - norm = get_norm_layer(use_norm, out_channels) + ) -> None: + """Initialize Conv2dReLU block. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Convolution kernel size. + padding: Padding size. Defaults to 0. + stride: Stride size. Defaults to 1. + """ + norm = nn.BatchNorm2d(out_channels) is_identity = isinstance(norm, nn.Identity) conv = nn.Conv2d( @@ -185,128 +80,118 @@ def __init__( bias=is_identity, ) - is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN) - activation = nn.Identity() if is_inplaceabn else nn.ReLU(inplace=True) + activation = nn.ReLU(inplace=True) - super(Conv2dReLU, self).__init__(conv, norm, activation) - - -class SCSEModule(nn.Module): - def __init__(self, in_channels, reduction=16): - super().__init__() - self.cSE = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(in_channels, in_channels // reduction, 1), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels // reduction, in_channels, 1), - nn.Sigmoid(), - ) - self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) - - def forward(self, x): - return x * self.cSE(x) + x * self.sSE(x) - - -class Attention(nn.Module): - def __init__(self, name, **params): - super().__init__() - - if name is None: - self.attention = nn.Identity(**params) - elif name == "scse": - self.attention = SCSEModule(**params) - else: - raise ValueError(f"Attention {name} is not implemented") - - def forward(self, x): - return self.attention(x) + super().__init__(conv, norm, activation) class DecoderBlock(nn.Module): + """Decoder block for UNet++ architecture.""" + def __init__( - self, + self: DecoderBlock, in_channels: int, skip_channels: int, out_channels: int, - use_norm: bool | str | dict[str, Any] = "batchnorm", - attention_type: str | None = None, - interpolation_mode: str = "nearest", - ): + ) -> None: + """Initialize DecoderBlock. + + Args: + in_channels: Number of input channels. + skip_channels: Number of skip connection channels. + out_channels: Number of output channels. + """ super().__init__() self.conv1 = Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, - use_norm=use_norm, - ) - self.attention1 = Attention( - attention_type, in_channels=in_channels + skip_channels ) + self.attention1 = nn.Identity() self.conv2 = Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, - use_norm=use_norm, ) - self.attention2 = Attention(attention_type, in_channels=out_channels) - self.interpolation_mode = interpolation_mode + self.attention2 = nn.Identity() def forward( self, x: torch.Tensor, skip: torch.Tensor | None = None ) -> torch.Tensor: - x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode) + """Forward pass through decoder block. + + Args: + x: Input tensor. + skip: Skip connection tensor. Defaults to None. + + Returns: + torch.Tensor: Output tensor after decoding. + """ + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if skip is not None: x = torch.cat([x, skip], dim=1) x = self.attention1(x) x = self.conv1(x) x = self.conv2(x) - x = self.attention2(x) - return x + return self.attention2(x) class CenterBlock(nn.Sequential): + """Center block for UNet++ architecture.""" + def __init__( - self, + self: CenterBlock, in_channels: int, out_channels: int, - use_norm: bool | str | dict[str, Any] = "batchnorm", - ): + ) -> None: + """Initialize CenterBlock. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + """ conv1 = Conv2dReLU( in_channels, out_channels, kernel_size=3, padding=1, - use_norm=use_norm, ) conv2 = Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, - use_norm=use_norm, ) super().__init__(conv1, conv2) class UnetPlusPlusDecoder(nn.Module): + """UNet++ decoder with dense connections.""" + def __init__( self, encoder_channels: Sequence[int], decoder_channels: Sequence[int], n_blocks: int = 5, - use_norm: bool | str | dict[str, Any] = "batchnorm", - attention_type: str | None = None, - interpolation_mode: str = "nearest", - center: bool = False, - ): + ) -> None: + """Initialize UnetPlusPlusDecoder. + + Args: + encoder_channels: List of encoder output channels. + decoder_channels: List of decoder output channels. + n_blocks: Number of decoder blocks. Defaults to 5. + + Raises: + ValueError: If model depth doesn't match decoder_channels length. + """ super().__init__() if n_blocks != len(decoder_channels): - raise ValueError( - f"Model depth is {n_blocks}, but you provide `decoder_channels` for {len(decoder_channels)} blocks." - ) + msg = f"Model depth is {n_blocks}, but you provide \ + `decoder_channels` for {len(decoder_channels)} blocks." + raise ValueError(msg) # remove first skip with same spatial resolution encoder_channels = encoder_channels[1:] @@ -315,24 +200,11 @@ def __init__( # computing blocks input and output channels head_channels = encoder_channels[0] - self.in_channels = [head_channels] + list(decoder_channels[:-1]) - self.skip_channels = list(encoder_channels[1:]) + [0] + self.in_channels = [head_channels, *list(decoder_channels[:-1])] + self.skip_channels = [*list(encoder_channels[1:]), 0] self.out_channels = decoder_channels - if center: - self.center = CenterBlock( - head_channels, - head_channels, - use_norm=use_norm, - ) - else: - self.center = nn.Identity() - - # combine decoder keyword arguments - kwargs = dict( - use_norm=use_norm, - attention_type=attention_type, - interpolation_mode=interpolation_mode, - ) + + self.center = nn.Identity() blocks = {} for layer_idx in range(len(self.in_channels) - 1): @@ -348,15 +220,23 @@ def __init__( ) in_ch = self.skip_channels[layer_idx - 1] blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock( - in_ch, skip_ch, out_ch, **kwargs + in_ch, skip_ch, out_ch ) blocks[f"x_{0}_{len(self.in_channels) - 1}"] = DecoderBlock( - self.in_channels[-1], 0, self.out_channels[-1], **kwargs + self.in_channels[-1], 0, self.out_channels[-1] ) self.blocks = nn.ModuleDict(blocks) self.depth = len(self.in_channels) - 1 def forward(self, features: list[torch.Tensor]) -> torch.Tensor: + """Forward pass through UNet++ decoder. + + Args: + features: List of encoder feature maps. + + Returns: + torch.Tensor: Decoded output tensor. + """ features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder @@ -376,7 +256,7 @@ def forward(self, features: list[torch.Tensor]) -> torch.Tensor: for idx in range(depth_idx + 1, dense_l_i + 1) ] cat_features = torch.cat( - cat_features + [features[dense_l_i + 1]], dim=1 + [*cat_features, features[dense_l_i + 1]], dim=1 ) dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[ f"x_{depth_idx}_{dense_l_i}" @@ -391,77 +271,99 @@ class UNetPlusPlusModel(ModelABC): """UNet++ Model.""" def __init__( - self, - encoder_name: str = "resnet34", + self: UNetPlusPlusModel, encoder_depth: int = 5, - encoder_weights: str | None = "imagenet", - decoder_use_norm: bool | str | dict[str, Any] = "batchnorm", decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), - decoder_attention_type: str | None = None, - decoder_interpolation: str = "nearest", - in_channels: int = 3, classes: int = 1, - activation: str | Callable | None = None, - aux_params: dict | None = None, - **kwargs: dict[str, Any], - ): + ) -> None: + """Initialize UNet++ model. + + Args: + encoder_depth: Depth of the encoder. Defaults to 5. + decoder_channels: Number of channels in decoder layers. + Defaults to (256, 128, 64, 32, 16). + classes: Number of output classes. Defaults to 1. + """ super().__init__() - if encoder_name.startswith("mit_b"): - raise ValueError(f"UnetPlusPlus is not support encoder_name={encoder_name}") - - decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) - if decoder_use_batchnorm is not None: - warnings.warn( - "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", - DeprecationWarning, - stacklevel=2, - ) - decoder_use_norm = decoder_use_batchnorm - - self.encoder = TimmUniversalEncoder( - name=encoder_name, - in_channels=in_channels, - depth=encoder_depth, - weights=encoder_weights, - **kwargs, + self.encoder = EfficientNetEncoder( + out_channels=[3, 32, 24, 40, 112, 320], + stage_idxs=[2, 3, 5], + channel_multiplier=1.0, + depth_multiplier=1.0, + drop_rate=0.2, ) - self.decoder = UnetPlusPlusDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, - use_norm=decoder_use_norm, - center=True if encoder_name.startswith("vgg") else False, - attention_type=decoder_attention_type, - interpolation_mode=decoder_interpolation, ) self.segmentation_head = SegmentationHead( in_channels=decoder_channels[-1], out_channels=classes, - activation=activation, kernel_size=3, ) - if aux_params is not None: - self.classification_head = ClassificationHead( - in_channels=self.encoder.out_channels[-1], **aux_params - ) - else: - self.classification_head = None + self.name = "unetplusplus-efficientnetb0" - self.name = f"unetplusplus-{encoder_name}" + def forward( + self: UNetPlusPlusModel, + x: torch.Tensor, + *args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002 + **kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002 + ) -> torch.Tensor: + """Sequentially pass `x` through model's encoder, decoder and heads. - def forward(self, x): - """Sequentially pass `x` trough model`s encoder, decoder and heads""" + Args: + x: Input tensor. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + torch.Tensor: Segmentation output. + """ features = self.encoder(x) decoder_output = self.decoder(features) - masks = self.segmentation_head(decoder_output) - - if self.classification_head is not None: - labels = self.classification_head(features[-1]) - return masks, labels - - return masks + return self.segmentation_head(decoder_output) + + @staticmethod + def infer_batch( + model: torch.nn.Module, + batch_data: torch.Tensor | np.ndarray, + *, + device: str, + ) -> np.ndarray: + """Run inference on an input batch. + + This contains logic for forward operation as well as i/o + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (:class:`torch.Tensor`): + A batch of data generated by + `torch.utils.data.DataLoader`. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + np.ndarray: + The inference results as a numpy array. + + """ + model.eval() + + imgs = batch_data + if isinstance(imgs, np.ndarray): + imgs = torch.from_numpy(imgs) + imgs = imgs.to(device).type(torch.float32) + imgs = imgs.permute(0, 3, 1, 2) # to NCHW + + with torch.inference_mode(): + logits = model(imgs) + probs = torch.nn.functional.softmax(logits, 1) + probs = probs.permute(0, 2, 3, 1) # to NHWC + + return probs.cpu().numpy() From d47fa0ac92033088d69870e75d6d6de41fcc5663 Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Tue, 18 Nov 2025 18:48:30 +0000 Subject: [PATCH 10/12] refactor code --- tests/models/test_arch_grandqc.py | 6 +- tiatoolbox/data/pretrained_model.yaml | 1 - tiatoolbox/models/architecture/grandqc.py | 332 ++++++++++++++-- .../models/architecture/unetplusplus.py | 369 ------------------ 4 files changed, 307 insertions(+), 401 deletions(-) delete mode 100644 tiatoolbox/models/architecture/unetplusplus.py diff --git a/tests/models/test_arch_grandqc.py b/tests/models/test_arch_grandqc.py index 27f6afdd4..609abad5d 100644 --- a/tests/models/test_arch_grandqc.py +++ b/tests/models/test_arch_grandqc.py @@ -22,7 +22,7 @@ def test_functional_grandqc() -> None: assert pretrained_weights is not None # test creation - model = GrandQCModel(num_input_channels=3, num_output_channels=2) + model = GrandQCModel(num_output_channels=2) assert model is not None # load pretrained weights @@ -33,8 +33,8 @@ def test_functional_grandqc() -> None: model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10") assert isinstance(model, GrandQCModel) assert isinstance(ioconfig, IOSegmentorConfig) - assert model.num_input_channels == 3 assert model.num_output_channels == 2 + assert model.decoder_channels == (256, 128, 64, 32, 16) # test inference generator = np.random.default_rng(1337) @@ -54,7 +54,7 @@ def test_functional_grandqc() -> None: def test_grandqc_preproc_postproc() -> None: """Test GrandQC preproc and postproc functions.""" - model = GrandQCModel(num_input_channels=3, num_output_channels=2) + model = GrandQCModel(num_output_channels=2) generator = np.random.default_rng(1337) # test preproc diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index f0916dbfc..dbddd60ef 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -940,7 +940,6 @@ grandqc_tissue_detection_mpp10: architecture: class: grandqc.GrandQCModel kwargs: - num_input_channels: 3 num_output_channels: 2 ioconfig: class: io_config.IOSegmentorConfig diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py index e7c3b12ce..8e9864a81 100644 --- a/tiatoolbox/models/architecture/grandqc.py +++ b/tiatoolbox/models/architecture/grandqc.py @@ -5,16 +5,269 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: # pragma: no cover - from collections.abc import Mapping + from collections.abc import Sequence import cv2 import numpy as np import torch +from torch import nn -from tiatoolbox.models.architecture.unetplusplus import UNetPlusPlusModel +from tiatoolbox.models.architecture.timm_efficientnet import EfficientNetEncoder from tiatoolbox.models.models_abc import ModelABC +class SegmentationHead(nn.Sequential): + """Segmentation head for UNet++ model.""" + + def __init__( + self: SegmentationHead, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + activation: nn.Module | None = None, + upsampling: int = 1, + ) -> None: + """Initialize SegmentationHead. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Convolution kernel size. Defaults to 3. + activation: Activation function. Defaults to None. + upsampling: Upsampling factor. Defaults to 1. + """ + conv2d = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + upsampling = ( + nn.UpsamplingBilinear2d(scale_factor=upsampling) + if upsampling > 1 + else nn.Identity() + ) + if activation is None: + activation = nn.Identity() + super().__init__(conv2d, upsampling, activation) + + +class Conv2dReLU(nn.Sequential): + """Conv2d + BatchNorm + ReLU block.""" + + def __init__( + self: Conv2dReLU, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int = 0, + stride: int = 1, + ) -> None: + """Initialize Conv2dReLU block. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Convolution kernel size. + padding: Padding size. Defaults to 0. + stride: Stride size. Defaults to 1. + """ + norm = nn.BatchNorm2d(out_channels) + + is_identity = isinstance(norm, nn.Identity) + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=is_identity, + ) + + activation = nn.ReLU(inplace=True) + + super().__init__(conv, norm, activation) + + +class DecoderBlock(nn.Module): + """Decoder block for UNet++ architecture.""" + + def __init__( + self: DecoderBlock, + in_channels: int, + skip_channels: int, + out_channels: int, + ) -> None: + """Initialize DecoderBlock. + + Args: + in_channels: Number of input channels. + skip_channels: Number of skip connection channels. + out_channels: Number of output channels. + """ + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + ) + self.attention1 = nn.Identity() + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + ) + self.attention2 = nn.Identity() + + def forward( + self, x: torch.Tensor, skip: torch.Tensor | None = None + ) -> torch.Tensor: + """Forward pass through decoder block. + + Args: + x: Input tensor. + skip: Skip connection tensor. Defaults to None. + + Returns: + torch.Tensor: Output tensor after decoding. + """ + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.attention1(x) + x = self.conv1(x) + x = self.conv2(x) + return self.attention2(x) + + +class CenterBlock(nn.Sequential): + """Center block for UNet++ architecture.""" + + def __init__( + self: CenterBlock, + in_channels: int, + out_channels: int, + ) -> None: + """Initialize CenterBlock. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + """ + conv1 = Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + ) + conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + ) + super().__init__(conv1, conv2) + + +class UnetPlusPlusDecoder(nn.Module): + """UNet++ decoder with dense connections.""" + + def __init__( + self, + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], + n_blocks: int = 5, + ) -> None: + """Initialize UnetPlusPlusDecoder. + + Args: + encoder_channels: List of encoder output channels. + decoder_channels: List of decoder output channels. + n_blocks: Number of decoder blocks. Defaults to 5. + + Raises: + ValueError: If model depth doesn't match decoder_channels length. + """ + super().__init__() + + if n_blocks != len(decoder_channels): + msg = f"Model depth is {n_blocks}, but you provide \ + `decoder_channels` for {len(decoder_channels)} blocks." + raise ValueError(msg) + + # remove first skip with same spatial resolution + encoder_channels = encoder_channels[1:] + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + self.in_channels = [head_channels, *list(decoder_channels[:-1])] + self.skip_channels = [*list(encoder_channels[1:]), 0] + self.out_channels = decoder_channels + + self.center = nn.Identity() + + blocks = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(layer_idx + 1): + if depth_idx == 0: + in_ch = self.in_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1) + out_ch = self.out_channels[layer_idx] + else: + out_ch = self.skip_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * ( + layer_idx + 1 - depth_idx + ) + in_ch = self.skip_channels[layer_idx - 1] + blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock( + in_ch, skip_ch, out_ch + ) + blocks[f"x_{0}_{len(self.in_channels) - 1}"] = DecoderBlock( + self.in_channels[-1], 0, self.out_channels[-1] + ) + self.blocks = nn.ModuleDict(blocks) + self.depth = len(self.in_channels) - 1 + + def forward(self, features: list[torch.Tensor]) -> torch.Tensor: + """Forward pass through UNet++ decoder. + + Args: + features: List of encoder feature maps. + + Returns: + torch.Tensor: Decoded output tensor. + """ + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + # start building dense connections + dense_x = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(self.depth - layer_idx): + if layer_idx == 0: + output = self.blocks[f"x_{depth_idx}_{depth_idx}"]( + features[depth_idx], features[depth_idx + 1] + ) + dense_x[f"x_{depth_idx}_{depth_idx}"] = output + else: + dense_l_i = depth_idx + layer_idx + cat_features = [ + dense_x[f"x_{idx}_{dense_l_i}"] + for idx in range(depth_idx + 1, dense_l_i + 1) + ] + cat_features = torch.cat( + [*cat_features, features[dense_l_i + 1]], dim=1 + ) + dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[ + f"x_{depth_idx}_{dense_l_i}" + ](dense_x[f"x_{depth_idx}_{dense_l_i - 1}"], cat_features) + dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"]( + dense_x[f"x_{0}_{self.depth - 1}"] + ) + return dense_x[f"x_{0}_{self.depth}"] + + class GrandQCModel(ModelABC): """GrandQC Tissue Detection Model [1]. @@ -39,18 +292,58 @@ class GrandQCModel(ModelABC): """ - def __init__( - self: GrandQCModel, num_input_channels: int, num_output_channels: int - ) -> None: - """Initialize TissueDetectionModel.""" + def __init__(self: GrandQCModel, num_output_channels: int = 2) -> None: + """Initialize UNet++ model. + + Args: + encoder_depth: Depth of the encoder. Defaults to 5. + num_output_channels: Number of output classes. Defaults to 2. + """ super().__init__() - self.num_input_channels = num_input_channels self.num_output_channels = num_output_channels - self._postproc = self.postproc - self._preproc = self.preproc - self.tissue_detection_model = UNetPlusPlusModel( - classes=self.num_output_channels, + self.decoder_channels = (256, 128, 64, 32, 16) + + self.encoder = EfficientNetEncoder( + out_channels=[3, 32, 24, 40, 112, 320], + stage_idxs=[2, 3, 5], + channel_multiplier=1.0, + depth_multiplier=1.0, + drop_rate=0.2, ) + self.decoder = UnetPlusPlusDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=self.decoder_channels, + n_blocks=5, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder_channels[-1], + out_channels=num_output_channels, + kernel_size=3, + ) + + self.name = "unetplusplus-efficientnetb0" + + def forward( + self: GrandQCModel, + x: torch.Tensor, + *args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002 + **kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002 + ) -> torch.Tensor: + """Sequentially pass `x` through model's encoder, decoder and heads. + + Args: + x: Input tensor. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + torch.Tensor: Segmentation output. + """ + features = self.encoder(x) + decoder_output = self.decoder(features) + + return self.segmentation_head(decoder_output) @staticmethod def preproc(image: np.ndarray) -> np.ndarray: @@ -91,15 +384,6 @@ def postproc(image: np.ndarray) -> np.ndarray: """ return image.argmin(axis=-1) - def forward( - self: GrandQCModel, - imgs: torch.Tensor, - *args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002 - **kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002 - ) -> torch.Tensor: - """Forward function for model.""" - return self.tissue_detection_model(imgs) - @staticmethod def infer_batch( model: torch.nn.Module, @@ -137,11 +421,3 @@ def infer_batch( probs = probs.permute(0, 2, 3, 1) # to NHWC return probs.cpu().numpy() - - def load_state_dict( - self: GrandQCModel, - state_dict: Mapping[str, Any], - **kwargs: bool, - ) -> torch.nn.modules.module._IncompatibleKeys: - """Load state dict for the GrandQCModel.""" - return self.tissue_detection_model.load_state_dict(state_dict, **kwargs) diff --git a/tiatoolbox/models/architecture/unetplusplus.py b/tiatoolbox/models/architecture/unetplusplus.py deleted file mode 100644 index e18309b9e..000000000 --- a/tiatoolbox/models/architecture/unetplusplus.py +++ /dev/null @@ -1,369 +0,0 @@ -"""Define Unet++ architecture from Segmentation Models Pytorch.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: # pragma: no cover - from collections.abc import Sequence - -import numpy as np -import torch -from torch import nn - -from tiatoolbox.models.architecture.timm_efficientnet import EfficientNetEncoder -from tiatoolbox.models.models_abc import ModelABC - - -class SegmentationHead(nn.Sequential): - """Segmentation head for UNet++ model.""" - - def __init__( - self: SegmentationHead, - in_channels: int, - out_channels: int, - kernel_size: int = 3, - activation: nn.Module | None = None, - upsampling: int = 1, - ) -> None: - """Initialize SegmentationHead. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - kernel_size: Convolution kernel size. Defaults to 3. - activation: Activation function. Defaults to None. - upsampling: Upsampling factor. Defaults to 1. - """ - conv2d = nn.Conv2d( - in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 - ) - upsampling = ( - nn.UpsamplingBilinear2d(scale_factor=upsampling) - if upsampling > 1 - else nn.Identity() - ) - if activation is None: - activation = nn.Identity() - super().__init__(conv2d, upsampling, activation) - - -class Conv2dReLU(nn.Sequential): - """Conv2d + BatchNorm + ReLU block.""" - - def __init__( - self: Conv2dReLU, - in_channels: int, - out_channels: int, - kernel_size: int, - padding: int = 0, - stride: int = 1, - ) -> None: - """Initialize Conv2dReLU block. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - kernel_size: Convolution kernel size. - padding: Padding size. Defaults to 0. - stride: Stride size. Defaults to 1. - """ - norm = nn.BatchNorm2d(out_channels) - - is_identity = isinstance(norm, nn.Identity) - conv = nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - bias=is_identity, - ) - - activation = nn.ReLU(inplace=True) - - super().__init__(conv, norm, activation) - - -class DecoderBlock(nn.Module): - """Decoder block for UNet++ architecture.""" - - def __init__( - self: DecoderBlock, - in_channels: int, - skip_channels: int, - out_channels: int, - ) -> None: - """Initialize DecoderBlock. - - Args: - in_channels: Number of input channels. - skip_channels: Number of skip connection channels. - out_channels: Number of output channels. - """ - super().__init__() - self.conv1 = Conv2dReLU( - in_channels + skip_channels, - out_channels, - kernel_size=3, - padding=1, - ) - self.attention1 = nn.Identity() - self.conv2 = Conv2dReLU( - out_channels, - out_channels, - kernel_size=3, - padding=1, - ) - self.attention2 = nn.Identity() - - def forward( - self, x: torch.Tensor, skip: torch.Tensor | None = None - ) -> torch.Tensor: - """Forward pass through decoder block. - - Args: - x: Input tensor. - skip: Skip connection tensor. Defaults to None. - - Returns: - torch.Tensor: Output tensor after decoding. - """ - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if skip is not None: - x = torch.cat([x, skip], dim=1) - x = self.attention1(x) - x = self.conv1(x) - x = self.conv2(x) - return self.attention2(x) - - -class CenterBlock(nn.Sequential): - """Center block for UNet++ architecture.""" - - def __init__( - self: CenterBlock, - in_channels: int, - out_channels: int, - ) -> None: - """Initialize CenterBlock. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - """ - conv1 = Conv2dReLU( - in_channels, - out_channels, - kernel_size=3, - padding=1, - ) - conv2 = Conv2dReLU( - out_channels, - out_channels, - kernel_size=3, - padding=1, - ) - super().__init__(conv1, conv2) - - -class UnetPlusPlusDecoder(nn.Module): - """UNet++ decoder with dense connections.""" - - def __init__( - self, - encoder_channels: Sequence[int], - decoder_channels: Sequence[int], - n_blocks: int = 5, - ) -> None: - """Initialize UnetPlusPlusDecoder. - - Args: - encoder_channels: List of encoder output channels. - decoder_channels: List of decoder output channels. - n_blocks: Number of decoder blocks. Defaults to 5. - - Raises: - ValueError: If model depth doesn't match decoder_channels length. - """ - super().__init__() - - if n_blocks != len(decoder_channels): - msg = f"Model depth is {n_blocks}, but you provide \ - `decoder_channels` for {len(decoder_channels)} blocks." - raise ValueError(msg) - - # remove first skip with same spatial resolution - encoder_channels = encoder_channels[1:] - # reverse channels to start from head of encoder - encoder_channels = encoder_channels[::-1] - - # computing blocks input and output channels - head_channels = encoder_channels[0] - self.in_channels = [head_channels, *list(decoder_channels[:-1])] - self.skip_channels = [*list(encoder_channels[1:]), 0] - self.out_channels = decoder_channels - - self.center = nn.Identity() - - blocks = {} - for layer_idx in range(len(self.in_channels) - 1): - for depth_idx in range(layer_idx + 1): - if depth_idx == 0: - in_ch = self.in_channels[layer_idx] - skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1) - out_ch = self.out_channels[layer_idx] - else: - out_ch = self.skip_channels[layer_idx] - skip_ch = self.skip_channels[layer_idx] * ( - layer_idx + 1 - depth_idx - ) - in_ch = self.skip_channels[layer_idx - 1] - blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock( - in_ch, skip_ch, out_ch - ) - blocks[f"x_{0}_{len(self.in_channels) - 1}"] = DecoderBlock( - self.in_channels[-1], 0, self.out_channels[-1] - ) - self.blocks = nn.ModuleDict(blocks) - self.depth = len(self.in_channels) - 1 - - def forward(self, features: list[torch.Tensor]) -> torch.Tensor: - """Forward pass through UNet++ decoder. - - Args: - features: List of encoder feature maps. - - Returns: - torch.Tensor: Decoded output tensor. - """ - features = features[1:] # remove first skip with same spatial resolution - features = features[::-1] # reverse channels to start from head of encoder - - # start building dense connections - dense_x = {} - for layer_idx in range(len(self.in_channels) - 1): - for depth_idx in range(self.depth - layer_idx): - if layer_idx == 0: - output = self.blocks[f"x_{depth_idx}_{depth_idx}"]( - features[depth_idx], features[depth_idx + 1] - ) - dense_x[f"x_{depth_idx}_{depth_idx}"] = output - else: - dense_l_i = depth_idx + layer_idx - cat_features = [ - dense_x[f"x_{idx}_{dense_l_i}"] - for idx in range(depth_idx + 1, dense_l_i + 1) - ] - cat_features = torch.cat( - [*cat_features, features[dense_l_i + 1]], dim=1 - ) - dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[ - f"x_{depth_idx}_{dense_l_i}" - ](dense_x[f"x_{depth_idx}_{dense_l_i - 1}"], cat_features) - dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"]( - dense_x[f"x_{0}_{self.depth - 1}"] - ) - return dense_x[f"x_{0}_{self.depth}"] - - -class UNetPlusPlusModel(ModelABC): - """UNet++ Model.""" - - def __init__( - self: UNetPlusPlusModel, - encoder_depth: int = 5, - decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), - classes: int = 1, - ) -> None: - """Initialize UNet++ model. - - Args: - encoder_depth: Depth of the encoder. Defaults to 5. - decoder_channels: Number of channels in decoder layers. - Defaults to (256, 128, 64, 32, 16). - classes: Number of output classes. Defaults to 1. - """ - super().__init__() - - self.encoder = EfficientNetEncoder( - out_channels=[3, 32, 24, 40, 112, 320], - stage_idxs=[2, 3, 5], - channel_multiplier=1.0, - depth_multiplier=1.0, - drop_rate=0.2, - ) - self.decoder = UnetPlusPlusDecoder( - encoder_channels=self.encoder.out_channels, - decoder_channels=decoder_channels, - n_blocks=encoder_depth, - ) - - self.segmentation_head = SegmentationHead( - in_channels=decoder_channels[-1], - out_channels=classes, - kernel_size=3, - ) - - self.name = "unetplusplus-efficientnetb0" - - def forward( - self: UNetPlusPlusModel, - x: torch.Tensor, - *args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002 - **kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002 - ) -> torch.Tensor: - """Sequentially pass `x` through model's encoder, decoder and heads. - - Args: - x: Input tensor. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - torch.Tensor: Segmentation output. - """ - features = self.encoder(x) - decoder_output = self.decoder(features) - - return self.segmentation_head(decoder_output) - - @staticmethod - def infer_batch( - model: torch.nn.Module, - batch_data: torch.Tensor | np.ndarray, - *, - device: str, - ) -> np.ndarray: - """Run inference on an input batch. - - This contains logic for forward operation as well as i/o - - Args: - model (nn.Module): - PyTorch defined model. - batch_data (:class:`torch.Tensor`): - A batch of data generated by - `torch.utils.data.DataLoader`. - device (str): - Transfers model to the specified device. Default is "cpu". - - Returns: - np.ndarray: - The inference results as a numpy array. - - """ - model.eval() - - imgs = batch_data - if isinstance(imgs, np.ndarray): - imgs = torch.from_numpy(imgs) - imgs = imgs.to(device).type(torch.float32) - imgs = imgs.permute(0, 3, 1, 2) # to NCHW - - with torch.inference_mode(): - logits = model(imgs) - probs = torch.nn.functional.softmax(logits, 1) - probs = probs.permute(0, 2, 3, 1) # to NHWC - - return probs.cpu().numpy() From d2a66ca79c0242488876061242821ecc8737bdfe Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 21 Nov 2025 17:43:18 +0000 Subject: [PATCH 11/12] add tests --- tests/models/test_arch_grandqc.py | 54 ++++++- tests/models/test_arch_timm_effcientnet.py | 177 +++++++++++++++++++++ 2 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 tests/models/test_arch_timm_effcientnet.py diff --git a/tests/models/test_arch_grandqc.py b/tests/models/test_arch_grandqc.py index 609abad5d..0179b5020 100644 --- a/tests/models/test_arch_grandqc.py +++ b/tests/models/test_arch_grandqc.py @@ -2,12 +2,18 @@ import numpy as np import torch +from torch import nn from tiatoolbox.models.architecture import ( fetch_pretrained_weights, get_pretrained_model, ) -from tiatoolbox.models.architecture.grandqc import GrandQCModel +from tiatoolbox.models.architecture.grandqc import ( + CenterBlock, + GrandQCModel, + SegmentationHead, + UnetPlusPlusDecoder, +) from tiatoolbox.models.engine.io_config import IOSegmentorConfig from tiatoolbox.utils.misc import select_device from tiatoolbox.wsicore.wsireader import VirtualWSIReader @@ -68,3 +74,49 @@ def test_grandqc_preproc_postproc() -> None: postproc_image = model.postproc(dummy_output) assert postproc_image.shape == (512, 512) assert postproc_image.dtype == np.int64 + + +def test_segmentation_head_behaviour() -> None: + """Verify SegmentationHead defaults and upsampling.""" + head = SegmentationHead(3, 5, activation=None, upsampling=1) + assert isinstance(head[1], nn.Identity) + assert isinstance(head[2], nn.Identity) + + x = torch.randn(1, 3, 6, 8) + out = head(x) + assert out.shape == (1, 5, 6, 8) + + head = SegmentationHead(3, 2, activation=nn.Sigmoid(), upsampling=2) + x = torch.ones(1, 3, 4, 4) + out = head(x) + assert out.shape == (1, 2, 8, 8) + assert torch.all(out >= 0) + assert torch.all(out <= 1) + + +def test_unetplusplus_decoder_forward_shapes() -> None: + """Ensure UnetPlusPlusDecoder handles dense connections.""" + decoder = UnetPlusPlusDecoder( + encoder_channels=[1, 2, 4, 8], + decoder_channels=[8, 4, 2], + n_blocks=3, + ) + + features = [ + torch.randn(1, 1, 32, 32), + torch.randn(1, 2, 16, 16), + torch.randn(1, 4, 8, 8), + torch.randn(1, 8, 4, 4), + ] + + output = decoder(features) + assert output.shape == (1, 2, 32, 32) + + +def test_center_block_behavior() -> None: + """Test CenterBlock behavior in UnetPlusPlusDecoder.""" + center_block = CenterBlock(in_channels=8, out_channels=8) + + x = torch.randn(1, 8, 4, 4) + out = center_block(x) + assert out.shape == (1, 8, 4, 4) diff --git a/tests/models/test_arch_timm_effcientnet.py b/tests/models/test_arch_timm_effcientnet.py new file mode 100644 index 000000000..9c62bda26 --- /dev/null +++ b/tests/models/test_arch_timm_effcientnet.py @@ -0,0 +1,177 @@ +"""Unit tests for timm EfficientNet encoder helpers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Sequence + +import pytest +import torch +from torch import nn + +from tiatoolbox.models.architecture import timm_efficientnet as effnet_mod +from tiatoolbox.models.architecture.timm_efficientnet import ( + DEFAULT_IN_CHANNELS, + EfficientNetEncoder, + EncoderMixin, + replace_strides_with_dilation, +) + + +class DummyEncoder(nn.Module, EncoderMixin): + """Lightweight encoder for testing mixin behavior.""" + + def __init__(self) -> None: + """Initialize EncoderMixin for testing.""" + nn.Module.__init__(self) + EncoderMixin.__init__(self) + self.conv = nn.Conv2d(3, 4, kernel_size=3, padding=1) + self.conv32 = nn.Conv2d(4, 4, 3) + self._out_channels = [DEFAULT_IN_CHANNELS, 4, 8] + self._depth = 2 + + def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: + """Get stages for dilation modification. + + Returns: + Dictionary with keys as output stride and values as list of modules. + """ + return {16: [self.conv], 32: [self.conv32]} + + +def test_patch_first_conv() -> None: + """patch_first_conv should reduce or expand correctly.""" + # create simple conv + model = nn.Sequential(nn.Conv2d(3, 2, kernel_size=1, bias=False)) + conv = model[0] + + # collapsing 3 channels into 1 + effnet_mod.patch_first_conv(model, new_in_channels=1, pretrained=True) + assert conv.in_channels == 1 + + # expanding to 5 channels + model = nn.Sequential(nn.Conv2d(3, 2, kernel_size=1, bias=False)) + conv = model[0] + + effnet_mod.patch_first_conv(model, new_in_channels=5, pretrained=True) + assert conv.in_channels == 5 + + +def test_patch_first_conv_reset_weights_when_not_pretrained() -> None: + """Ensure random reinit happens when pretrained flag is False.""" + # start from known weights + model = nn.Sequential(nn.Conv2d(3, 1, kernel_size=1, bias=False)) + original = model[0].weight.clone() + # changing channel count without pretrained should reinit parameters + effnet_mod.patch_first_conv(model, new_in_channels=4, pretrained=False) + assert model[0].in_channels == 4 + assert model[0].weight.shape[1] == 4 + # Almost surely changed due to reset_parameters + assert not torch.equal(original, model[0].weight[:1, :3]) + + +def test_patch_first_conv_no_matching_layer_is_safe() -> None: + """The function should silently exit when no suitable conv exists.""" + model = nn.Sequential(nn.Conv2d(5, 1, kernel_size=1)) + original = model[0].weight.clone() + # no conv with default channel count, so weights stay unchanged + effnet_mod.patch_first_conv(model, new_in_channels=3, pretrained=True) + assert torch.equal(original, model[0].weight) + + +def test_replace_strides_with_dilation_applies_to_nested_convs() -> None: + """Strides become dilation and static padding gets removed.""" + module = nn.Sequential( + nn.Conv2d(1, 1, kernel_size=3, stride=2, padding=1), + ) + # attach static_padding to mirror EfficientNet convs + module[0].static_padding = nn.Conv2d(1, 1, 1) + + # applying dilation should also strip static padding + replace_strides_with_dilation(module, dilation_rate=3) + conv = module[0] + assert conv.stride == (1, 1) + assert conv.dilation == (3, 3) + assert conv.padding == (3, 3) + assert isinstance(conv.static_padding, nn.Identity) + + +def test_encoder_mixin_properties_and_set_in_channels() -> None: + """EncoderMixin should expose out_channels/output_stride and patch convs.""" + # use dummy encoder to check property logic + encoder = DummyEncoder() + assert encoder.out_channels == [3, 4, 8] + # adjust internals to check min logic in output_stride + encoder._output_stride = 4 + encoder._depth = 3 + assert encoder.output_stride == 4 # min(output_stride, 2**depth) + + # calling set_in_channels should patch first conv and update bookkeeping + encoder.set_in_channels(5, pretrained=False) + assert encoder._in_channels == 5 + assert encoder.out_channels[0] == 5 + assert encoder.conv.in_channels == 5 + + +def test_encoder_mixin_make_dilated_and_validation() -> None: + """make_dilated should error on invalid stride and patch convs otherwise.""" + encoder = DummyEncoder() + + # invalid stride raises + with pytest.raises(ValueError, match="Output stride should be 16 or 8"): + encoder.make_dilated(output_stride=4) + + # valid stride should touch both stage groups + encoder.make_dilated(output_stride=8) + conv16, conv32 = encoder.get_stages()[16][0], encoder.get_stages()[32][0] + assert conv16.stride == (1, 1) + assert conv16.dilation == (2, 2) + assert conv32.stride == (1, 1) + assert conv32.dilation == (4, 4) + + +def test_get_efficientnet_kwargs_shapes_and_values() -> None: + """get_efficientnet_kwargs should produce expected keys and scaling.""" + # confirm output contains decoded blocks and scaled channels + kwargs = effnet_mod.get_efficientnet_kwargs( + channel_multiplier=1.2, depth_multiplier=1.4, drop_rate=0.3 + ) + assert kwargs.get("block_args") + assert kwargs["num_features"] == effnet_mod.round_channels(1280, 1.2, 8, None) + assert kwargs["drop_rate"] == 0.3 + + +def test_efficientnet_encoder_depth_validation_and_forward() -> None: + """EfficientNetEncoder should validate depth and run forward returning features.""" + # invalid depth should fail fast + with pytest.raises( + ValueError, match=r"EfficientNetEncoder depth should be in range\s+\[1, 5\]" + ): + EfficientNetEncoder( + stage_idxs=[2, 3, 5], + out_channels=[3, 32, 24, 40, 112, 320], + depth=6, + ) + + # build shallow encoder and run a forward pass + encoder = EfficientNetEncoder( + stage_idxs=[2, 3, 5], + out_channels=[3, 32, 24, 40, 112, 320], + depth=3, + channel_multiplier=0.5, + depth_multiplier=0.5, + ) + x = torch.randn(1, 3, 32, 32) + features = encoder(x) + assert len(features) == encoder._depth + 1 + assert torch.equal(features[0], x) + + # ensure classifier keys are dropped before loading into the model + extended_state = dict(encoder.state_dict()) + extended_state["classifier.bias"] = torch.tensor([1.0]) + extended_state["classifier.weight"] = torch.tensor([[1.0]]) + load_result = encoder.load_state_dict(extended_state, strict=True) + assert not load_result.missing_keys + assert not load_result.unexpected_keys From 19cca904a7ed82c99f886598b0dfc31066548d6a Mon Sep 17 00:00:00 2001 From: Jiaqi Lv Date: Fri, 21 Nov 2025 18:19:34 +0000 Subject: [PATCH 12/12] address comments --- ...imm_effcientnet.py => test_arch_timm_efficientnet.py} | 0 tiatoolbox/models/architecture/grandqc.py | 9 ++++----- 2 files changed, 4 insertions(+), 5 deletions(-) rename tests/models/{test_arch_timm_effcientnet.py => test_arch_timm_efficientnet.py} (100%) diff --git a/tests/models/test_arch_timm_effcientnet.py b/tests/models/test_arch_timm_efficientnet.py similarity index 100% rename from tests/models/test_arch_timm_effcientnet.py rename to tests/models/test_arch_timm_efficientnet.py diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py index 8e9864a81..e81f2a4b3 100644 --- a/tiatoolbox/models/architecture/grandqc.py +++ b/tiatoolbox/models/architecture/grandqc.py @@ -39,14 +39,14 @@ def __init__( conv2d = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 ) - upsampling = ( + upsampling_layer = ( nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() ) if activation is None: activation = nn.Identity() - super().__init__(conv2d, upsampling, activation) + super().__init__(conv2d, upsampling_layer, activation) class Conv2dReLU(nn.Sequential): @@ -190,7 +190,7 @@ def __init__( super().__init__() if n_blocks != len(decoder_channels): - msg = f"Model depth is {n_blocks}, but you provide \ + msg = f"Model depth is {n_blocks}, but you provide \ `decoder_channels` for {len(decoder_channels)} blocks." raise ValueError(msg) @@ -293,10 +293,9 @@ class GrandQCModel(ModelABC): """ def __init__(self: GrandQCModel, num_output_channels: int = 2) -> None: - """Initialize UNet++ model. + """Initialize GrandQC model. Args: - encoder_depth: Depth of the encoder. Defaults to 5. num_output_channels: Number of output classes. Defaults to 2. """ super().__init__()