diff --git a/tests/models/test_arch_grandqc.py b/tests/models/test_arch_grandqc.py new file mode 100644 index 000000000..609abad5d --- /dev/null +++ b/tests/models/test_arch_grandqc.py @@ -0,0 +1,70 @@ +"""Unit test package for GrandQC Tissue Model.""" + +import numpy as np +import torch + +from tiatoolbox.models.architecture import ( + fetch_pretrained_weights, + get_pretrained_model, +) +from tiatoolbox.models.architecture.grandqc import GrandQCModel +from tiatoolbox.models.engine.io_config import IOSegmentorConfig +from tiatoolbox.utils.misc import select_device +from tiatoolbox.wsicore.wsireader import VirtualWSIReader + +ON_GPU = False + + +def test_functional_grandqc() -> None: + """Test for GrandQC model.""" + # test fetch pretrained weights + pretrained_weights = fetch_pretrained_weights("grandqc_tissue_detection_mpp10") + assert pretrained_weights is not None + + # test creation + model = GrandQCModel(num_output_channels=2) + assert model is not None + + # load pretrained weights + pretrained = torch.load(pretrained_weights, map_location="cpu") + model.load_state_dict(pretrained) + + # test get pretrained model + model, ioconfig = get_pretrained_model("grandqc_tissue_detection_mpp10") + assert isinstance(model, GrandQCModel) + assert isinstance(ioconfig, IOSegmentorConfig) + assert model.num_output_channels == 2 + assert model.decoder_channels == (256, 128, 64, 32, 16) + + # test inference + generator = np.random.default_rng(1337) + test_image = generator.integers(0, 256, size=(2048, 2048, 3), dtype=np.uint8) + reader = VirtualWSIReader.open(test_image) + read_kwargs = {"resolution": 0, "units": "level", "coord_space": "resolution"} + batch = np.array( + [ + reader.read_bounds((0, 0, 512, 512), **read_kwargs), + reader.read_bounds((512, 512, 1024, 1024), **read_kwargs), + ], + ) + batch = torch.from_numpy(batch) + output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) + assert output.shape == (2, 512, 512, 2) + + +def test_grandqc_preproc_postproc() -> None: + """Test GrandQC preproc and postproc functions.""" + model = GrandQCModel(num_output_channels=2) + + generator = np.random.default_rng(1337) + # test preproc + dummy_image = generator.integers(0, 256, size=(512, 512, 3), dtype=np.uint8) + preproc_image = model.preproc(dummy_image) + assert preproc_image.shape == dummy_image.shape + assert preproc_image.dtype == np.float64 + + # test postproc + dummy_output = generator.random(size=(512, 512, 2), dtype=np.float32) + postproc_image = model.postproc(dummy_output) + assert postproc_image.shape == (512, 512) + assert postproc_image.dtype == np.int64 diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 8ab9a998f..dbddd60ef 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -815,7 +815,7 @@ mapde-crchisto: threshold_abs: 250 num_classes: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -837,7 +837,7 @@ mapde-conic: threshold_abs: 205 num_classes: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.5 } @@ -860,7 +860,7 @@ sccnn-crchisto: threshold_abs: 0.20 patch_output_shape: [ 13, 13 ] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.25 } @@ -883,7 +883,7 @@ sccnn-conic: threshold_abs: 0.05 patch_output_shape: [ 13, 13 ] ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - { "units": "mpp", "resolution": 0.25 } @@ -903,7 +903,7 @@ nuclick_original-pannuke: num_input_channels: 5 num_output_channels: 1 ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 0.25} @@ -925,7 +925,7 @@ nuclick_light-pannuke: decoder_block: [3,3] skip_type: "add" ioconfig: - class: semantic_segmentor.IOSegmentorConfig + class: io_config.IOSegmentorConfig kwargs: input_resolutions: - {'units': 'baseline', 'resolution': 0.25} @@ -934,3 +934,21 @@ nuclick_light-pannuke: patch_input_shape: [128, 128] patch_output_shape: [128, 128] save_resolution: {'units': 'baseline', 'resolution': 1.0} + +grandqc_tissue_detection_mpp10: + hf_repo_id: TIACentre/GrandQC_Tissue_Detection + architecture: + class: grandqc.GrandQCModel + kwargs: + num_output_channels: 2 + ioconfig: + class: io_config.IOSegmentorConfig + kwargs: + input_resolutions: + - {'units': 'mpp', 'resolution': 10.0} + output_resolutions: + - {'units': 'mpp', 'resolution': 10.0} + patch_input_shape: [512, 512] + patch_output_shape: [512, 512] + stride_shape: [256, 256] + save_resolution: {'units': 'mpp', 'resolution': 10.0} diff --git a/tiatoolbox/models/architecture/grandqc.py b/tiatoolbox/models/architecture/grandqc.py new file mode 100644 index 000000000..8e9864a81 --- /dev/null +++ b/tiatoolbox/models/architecture/grandqc.py @@ -0,0 +1,423 @@ +"""Define GrandQC Tissue Detection Model architecture.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Sequence + +import cv2 +import numpy as np +import torch +from torch import nn + +from tiatoolbox.models.architecture.timm_efficientnet import EfficientNetEncoder +from tiatoolbox.models.models_abc import ModelABC + + +class SegmentationHead(nn.Sequential): + """Segmentation head for UNet++ model.""" + + def __init__( + self: SegmentationHead, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + activation: nn.Module | None = None, + upsampling: int = 1, + ) -> None: + """Initialize SegmentationHead. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Convolution kernel size. Defaults to 3. + activation: Activation function. Defaults to None. + upsampling: Upsampling factor. Defaults to 1. + """ + conv2d = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + upsampling = ( + nn.UpsamplingBilinear2d(scale_factor=upsampling) + if upsampling > 1 + else nn.Identity() + ) + if activation is None: + activation = nn.Identity() + super().__init__(conv2d, upsampling, activation) + + +class Conv2dReLU(nn.Sequential): + """Conv2d + BatchNorm + ReLU block.""" + + def __init__( + self: Conv2dReLU, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int = 0, + stride: int = 1, + ) -> None: + """Initialize Conv2dReLU block. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Convolution kernel size. + padding: Padding size. Defaults to 0. + stride: Stride size. Defaults to 1. + """ + norm = nn.BatchNorm2d(out_channels) + + is_identity = isinstance(norm, nn.Identity) + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=is_identity, + ) + + activation = nn.ReLU(inplace=True) + + super().__init__(conv, norm, activation) + + +class DecoderBlock(nn.Module): + """Decoder block for UNet++ architecture.""" + + def __init__( + self: DecoderBlock, + in_channels: int, + skip_channels: int, + out_channels: int, + ) -> None: + """Initialize DecoderBlock. + + Args: + in_channels: Number of input channels. + skip_channels: Number of skip connection channels. + out_channels: Number of output channels. + """ + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + ) + self.attention1 = nn.Identity() + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + ) + self.attention2 = nn.Identity() + + def forward( + self, x: torch.Tensor, skip: torch.Tensor | None = None + ) -> torch.Tensor: + """Forward pass through decoder block. + + Args: + x: Input tensor. + skip: Skip connection tensor. Defaults to None. + + Returns: + torch.Tensor: Output tensor after decoding. + """ + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.attention1(x) + x = self.conv1(x) + x = self.conv2(x) + return self.attention2(x) + + +class CenterBlock(nn.Sequential): + """Center block for UNet++ architecture.""" + + def __init__( + self: CenterBlock, + in_channels: int, + out_channels: int, + ) -> None: + """Initialize CenterBlock. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + """ + conv1 = Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + ) + conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + ) + super().__init__(conv1, conv2) + + +class UnetPlusPlusDecoder(nn.Module): + """UNet++ decoder with dense connections.""" + + def __init__( + self, + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], + n_blocks: int = 5, + ) -> None: + """Initialize UnetPlusPlusDecoder. + + Args: + encoder_channels: List of encoder output channels. + decoder_channels: List of decoder output channels. + n_blocks: Number of decoder blocks. Defaults to 5. + + Raises: + ValueError: If model depth doesn't match decoder_channels length. + """ + super().__init__() + + if n_blocks != len(decoder_channels): + msg = f"Model depth is {n_blocks}, but you provide \ + `decoder_channels` for {len(decoder_channels)} blocks." + raise ValueError(msg) + + # remove first skip with same spatial resolution + encoder_channels = encoder_channels[1:] + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + self.in_channels = [head_channels, *list(decoder_channels[:-1])] + self.skip_channels = [*list(encoder_channels[1:]), 0] + self.out_channels = decoder_channels + + self.center = nn.Identity() + + blocks = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(layer_idx + 1): + if depth_idx == 0: + in_ch = self.in_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1) + out_ch = self.out_channels[layer_idx] + else: + out_ch = self.skip_channels[layer_idx] + skip_ch = self.skip_channels[layer_idx] * ( + layer_idx + 1 - depth_idx + ) + in_ch = self.skip_channels[layer_idx - 1] + blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock( + in_ch, skip_ch, out_ch + ) + blocks[f"x_{0}_{len(self.in_channels) - 1}"] = DecoderBlock( + self.in_channels[-1], 0, self.out_channels[-1] + ) + self.blocks = nn.ModuleDict(blocks) + self.depth = len(self.in_channels) - 1 + + def forward(self, features: list[torch.Tensor]) -> torch.Tensor: + """Forward pass through UNet++ decoder. + + Args: + features: List of encoder feature maps. + + Returns: + torch.Tensor: Decoded output tensor. + """ + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + # start building dense connections + dense_x = {} + for layer_idx in range(len(self.in_channels) - 1): + for depth_idx in range(self.depth - layer_idx): + if layer_idx == 0: + output = self.blocks[f"x_{depth_idx}_{depth_idx}"]( + features[depth_idx], features[depth_idx + 1] + ) + dense_x[f"x_{depth_idx}_{depth_idx}"] = output + else: + dense_l_i = depth_idx + layer_idx + cat_features = [ + dense_x[f"x_{idx}_{dense_l_i}"] + for idx in range(depth_idx + 1, dense_l_i + 1) + ] + cat_features = torch.cat( + [*cat_features, features[dense_l_i + 1]], dim=1 + ) + dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[ + f"x_{depth_idx}_{dense_l_i}" + ](dense_x[f"x_{depth_idx}_{dense_l_i - 1}"], cat_features) + dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"]( + dense_x[f"x_{0}_{self.depth - 1}"] + ) + return dense_x[f"x_{0}_{self.depth}"] + + +class GrandQCModel(ModelABC): + """GrandQC Tissue Detection Model [1]. + + Example: + >>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor + >>> semantic_segmentor = SemanticSegmentor( + ... model="grandqc_tissue_detection_mpp10", + ... ) + >>> results = semantic_segmentor.run( + ... ["/example_wsi.svs"], + ... masks=None, + ... auto_get_mask=False, + ... patch_mode=False, + ... save_dir=Path("/tissue_mask/"), + ... output_type="annotationstore", + ... ) + + References: + [1] Weng Z. et al. "GrandQC: a comprehensive solution to quality control problem + in digital pathology". + Nature Communications 2024 + + """ + + def __init__(self: GrandQCModel, num_output_channels: int = 2) -> None: + """Initialize UNet++ model. + + Args: + encoder_depth: Depth of the encoder. Defaults to 5. + num_output_channels: Number of output classes. Defaults to 2. + """ + super().__init__() + self.num_output_channels = num_output_channels + self.decoder_channels = (256, 128, 64, 32, 16) + + self.encoder = EfficientNetEncoder( + out_channels=[3, 32, 24, 40, 112, 320], + stage_idxs=[2, 3, 5], + channel_multiplier=1.0, + depth_multiplier=1.0, + drop_rate=0.2, + ) + self.decoder = UnetPlusPlusDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=self.decoder_channels, + n_blocks=5, + ) + + self.segmentation_head = SegmentationHead( + in_channels=self.decoder_channels[-1], + out_channels=num_output_channels, + kernel_size=3, + ) + + self.name = "unetplusplus-efficientnetb0" + + def forward( + self: GrandQCModel, + x: torch.Tensor, + *args: tuple[Any, ...], # skipcq: PYL-W0613 # noqa: ARG002 + **kwargs: dict, # skipcq: PYL-W0613 # noqa: ARG002 + ) -> torch.Tensor: + """Sequentially pass `x` through model's encoder, decoder and heads. + + Args: + x: Input tensor. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + torch.Tensor: Segmentation output. + """ + features = self.encoder(x) + decoder_output = self.decoder(features) + + return self.segmentation_head(decoder_output) + + @staticmethod + def preproc(image: np.ndarray) -> np.ndarray: + """Apply JPEG compression and ImageNet normalization to the input image. + + Args: + image (np.ndarray): + Input image as a NumPy array (H, W, C) in uint8 format. + + Returns: + np.ndarray: + The preprocessed image. + + """ + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80] + _, compressed_image = cv2.imencode(".jpg", image, encode_param) + compressed_image = np.array(cv2.imdecode(compressed_image, 1)) + + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + return (compressed_image / 255.0 - mean) / std + + @staticmethod + def postproc(image: np.ndarray) -> np.ndarray: + """Define post-processing for this model. + + This simply applies argmin to obtain tissue class. + (Tissue = 0, Background = 1) + + Args: + image (np.ndarray): + Input probability map as a NumPy array (H, W, C). + + Returns: + np.ndarray: + Tissue mask + + """ + return image.argmin(axis=-1) + + @staticmethod + def infer_batch( + model: torch.nn.Module, + batch_data: torch.Tensor, + *, + device: str, + ) -> np.ndarray: + """Run inference on an input batch. + + This contains logic for forward operation as well as i/o + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (:class:`torch.Tensor`): + A batch of data generated by + `torch.utils.data.DataLoader`. + device (str): + Transfers model to the specified device. Default is "cpu". + + Returns: + np.ndarray: + The inference results as a numpy array. + + """ + model.eval() + + imgs = batch_data + imgs = imgs.to(device).type(torch.float32) + imgs = imgs.permute(0, 3, 1, 2) # to NCHW + + with torch.inference_mode(): + logits = model(imgs) + probs = torch.nn.functional.softmax(logits, 1) + probs = probs.permute(0, 2, 3, 1) # to NHWC + + return probs.cpu().numpy() diff --git a/tiatoolbox/models/architecture/timm_efficientnet.py b/tiatoolbox/models/architecture/timm_efficientnet.py new file mode 100644 index 000000000..46a249387 --- /dev/null +++ b/tiatoolbox/models/architecture/timm_efficientnet.py @@ -0,0 +1,416 @@ +"""Defines EfficientNet encoder using timm library.""" + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Mapping, Sequence + +import torch +from timm.layers.activations import Swish +from timm.models._efficientnet_builder import decode_arch_def, round_channels +from timm.models.efficientnet import EfficientNet +from torch import nn + +MAX_DEPTH = 5 +MIN_DEPTH = 1 +DEFAULT_IN_CHANNELS = 3 + + +def patch_first_conv( + model: nn.Module, + new_in_channels: int, + default_in_channels: int = 3, + *, + pretrained: bool = True, +) -> None: + """Change first convolution layer input channels. + + Args: + model: The neural network model to patch. + new_in_channels: Number of input channels for the new first layer. + default_in_channels: Original number of input channels. Defaults to 3. + pretrained: Whether to reuse pretrained weights. Defaults to True. + + Note: + In case: + - in_channels == 1 or in_channels == 2 -> reuse original weights + - in_channels > 3 -> make random kaiming normal initialization + """ + # get first conv + conv_module: nn.Conv2d | None = None + for module in model.modules(): + if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: + conv_module = module + break + + if conv_module is None: + return + + weight = conv_module.weight.detach() + conv_module.in_channels = new_in_channels + + if not pretrained: + conv_module.weight = nn.parameter.Parameter( + torch.Tensor( + conv_module.out_channels, + new_in_channels // conv_module.groups, + *conv_module.kernel_size, + ) + ) + conv_module.reset_parameters() + + elif new_in_channels == 1: + new_weight = weight.sum(1, keepdim=True) + conv_module.weight = nn.parameter.Parameter(new_weight) + + else: + new_weight = torch.Tensor( + conv_module.out_channels, + new_in_channels // conv_module.groups, + *conv_module.kernel_size, + ) + + for i in range(new_in_channels): + new_weight[:, i] = weight[:, i % default_in_channels] + + new_weight = new_weight * (default_in_channels / new_in_channels) + conv_module.weight = nn.parameter.Parameter(new_weight) + + +def replace_strides_with_dilation(module: nn.Module, dilation_rate: int) -> None: + """Patch Conv2d modules replacing strides with dilation. + + Args: + module: The module containing Conv2d layers to patch. + dilation_rate: The dilation rate to apply. + """ + for mod in module.modules(): + if isinstance(mod, nn.Conv2d): + mod.stride = (1, 1) + mod.dilation = (dilation_rate, dilation_rate) + kh, _ = mod.kernel_size + mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) + + # Workaround for EfficientNet + if hasattr(mod, "static_padding"): + mod.static_padding = nn.Identity() # type: ignore[attr-defined] + + +class EncoderMixin: + """Add encoder functionality. + + Such as: + - output channels specification of feature tensors (produced by encoder) + - patching first convolution for arbitrary input channels + """ + + _is_torch_scriptable = True + _is_torch_exportable = True + _is_torch_compilable = True + + def __init__(self) -> None: + """Initialize EncoderMixin with default parameters.""" + self._depth = 5 + self._in_channels = 3 + self._output_stride = 32 + self._out_channels: list[int] = [] + + @property + def out_channels(self) -> list[int]: + """Return channels dimensions for each tensor of forward output of encoder. + + Returns: + List of output channel dimensions for each depth level. + """ + return self._out_channels[: self._depth + 1] + + @property + def output_stride(self) -> int: + """Return the output stride of the encoder. + + Returns: + The minimum of configured output stride and 2^depth. + """ + return min(self._output_stride, 2**self._depth) + + def set_in_channels(self, in_channels: int, *, pretrained: bool = True) -> None: + """Change first convolution channels. + + Args: + in_channels: Number of input channels. + pretrained: Whether to use pretrained weights. Defaults to True. + """ + if in_channels == DEFAULT_IN_CHANNELS: + return + + self._in_channels = in_channels + if self._out_channels[0] == DEFAULT_IN_CHANNELS: + self._out_channels = [in_channels, *self._out_channels[1:]] + + # Type ignore needed because self is a mixin that will be used with nn.Module + patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) # type: ignore[arg-type] + + def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: + """Get stages for dilation modification. + + Override this method in your implementation. + + Returns: + Dictionary with keys as output stride and values as list of modules. + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError + + def make_dilated(self, output_stride: int) -> None: + """Convert encoder to dilated version. + + Args: + output_stride: Target output stride (8 or 16). + + Raises: + ValueError: If output_stride is not 8 or 16. + """ + if output_stride not in [8, 16]: + msg = f"Output stride should be 16 or 8, got {output_stride}." + raise ValueError(msg) + + stages = self.get_stages() + for stage_stride, stage_modules in stages.items(): + if stage_stride <= output_stride: + continue + + dilation_rate = stage_stride // output_stride + for module in stage_modules: + replace_strides_with_dilation(module, dilation_rate) + + +def get_efficientnet_kwargs( + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + drop_rate: float = 0.2, +) -> dict[str, Any]: + """Create EfficientNet model kwargs. + + Reference implementation: + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet parameters: + - 'efficientnet-b0': (1.0, 1.0, 224, 0.2) + - 'efficientnet-b1': (1.0, 1.1, 240, 0.2) + - 'efficientnet-b2': (1.1, 1.2, 260, 0.3) + - 'efficientnet-b3': (1.2, 1.4, 300, 0.3) + - 'efficientnet-b4': (1.4, 1.8, 380, 0.4) + - 'efficientnet-b5': (1.6, 2.2, 456, 0.4) + - 'efficientnet-b6': (1.8, 2.6, 528, 0.5) + - 'efficientnet-b7': (2.0, 3.1, 600, 0.5) + - 'efficientnet-b8': (2.2, 3.6, 672, 0.5) + - 'efficientnet-l2': (4.3, 5.3, 800, 0.5) + + Args: + channel_multiplier: Multiplier to number of channels per layer. Defaults to 1.0. + depth_multiplier: Multiplier to number of repeats per stage. Defaults to 1.0. + drop_rate: Dropout rate. Defaults to 0.2. + + Returns: + Dictionary containing model configuration parameters. + """ + arch_def = [ + ["ds_r1_k3_s1_e1_c16_se0.25"], + ["ir_r2_k3_s2_e6_c24_se0.25"], + ["ir_r2_k5_s2_e6_c40_se0.25"], + ["ir_r3_k3_s2_e6_c80_se0.25"], + ["ir_r3_k5_s1_e6_c112_se0.25"], + ["ir_r4_k5_s2_e6_c192_se0.25"], + ["ir_r1_k3_s1_e6_c320_se0.25"], + ] + return { + "block_args": decode_arch_def(arch_def, depth_multiplier), + "num_features": round_channels(1280, channel_multiplier, 8, None), + "stem_size": 32, + "round_chs_fn": partial(round_channels, multiplier=channel_multiplier), + "act_layer": Swish, + "drop_rate": drop_rate, + "drop_path_rate": 0.2, + } + + +class EfficientNetBaseEncoder(EfficientNet, EncoderMixin): + """EfficientNet encoder base class. + + Combines EfficientNet architecture with encoder functionality. + """ + + def __init__( + self, + stage_idxs: list[int], + out_channels: list[int], + depth: int = 5, + output_stride: int = 32, + **kwargs: dict[str, Any], + ) -> None: + """Initialize EfficientNetBaseEncoder. + + Args: + stage_idxs: Indices of stages for feature extraction. + out_channels: Output channels for each depth level. + depth: Encoder depth (1-5). Defaults to 5. + output_stride: Output stride of encoder. Defaults to 32. + **kwargs: Additional keyword arguments for EfficientNet. + + Raises: + ValueError: If depth is not in range [1, 5]. + """ + if depth > MAX_DEPTH or depth < MIN_DEPTH: + msg = f"{self.__class__.__name__} depth should be in range \ + [1, 5], got {depth}" + raise ValueError(msg) + super().__init__(**kwargs) + + self._stage_idxs = stage_idxs + self._depth = depth + self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride + + del self.classifier + + def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: + """Get stages for dilation modification. + + Returns: + Dictionary mapping output strides to corresponding module sequences. + """ + return { + 16: [self.blocks[self._stage_idxs[1] : self._stage_idxs[2]]], # type: ignore[attr-defined] + 32: [self.blocks[self._stage_idxs[2] :]], # type: ignore[attr-defined] + } + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: # type: ignore[override] + """Forward pass through encoder. + + Args: + x: Input tensor. + + Returns: + List of feature tensors from different encoder depths. + """ + features = [x] + + if self._depth >= 1: + x = self.conv_stem(x) # type: ignore[attr-defined] + x = self.bn1(x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 2: # noqa: PLR2004 + x = self.blocks[0](x) # type: ignore[attr-defined] + x = self.blocks[1](x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 3: # noqa: PLR2004 + x = self.blocks[2](x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 4: # noqa: PLR2004 + x = self.blocks[3](x) # type: ignore[attr-defined] + x = self.blocks[4](x) # type: ignore[attr-defined] + features.append(x) + + if self._depth >= 5: # noqa: PLR2004 + x = self.blocks[5](x) # type: ignore[attr-defined] + x = self.blocks[6](x) # type: ignore[attr-defined] + features.append(x) + + return features + + def load_state_dict( + self, state_dict: Mapping[str, Any], **kwargs: bool + ) -> torch.nn.modules.module._IncompatibleKeys: + """Load state dictionary, excluding classifier weights. + + Args: + state_dict: State dictionary to load. + **kwargs: Additional keyword arguments for load_state_dict. + + Returns: + Result of parent class load_state_dict method. + """ + # Create a mutable copy of the state dict to modify + state_dict_copy = dict(state_dict) + state_dict_copy.pop("classifier.bias", None) + state_dict_copy.pop("classifier.weight", None) + return super().load_state_dict(state_dict_copy, **kwargs) + + +class EfficientNetEncoder(EfficientNetBaseEncoder): + """EfficientNet encoder with configurable scaling parameters. + + Provides a configurable EfficientNet encoder that can be scaled + in terms of depth and channel multipliers. + """ + + def __init__( + self, + stage_idxs: list[int], + out_channels: list[int], + depth: int = 5, + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + drop_rate: float = 0.2, + output_stride: int = 32, + ) -> None: + """Initialize EfficientNetEncoder. + + Args: + stage_idxs: Indices of stages for feature extraction. + out_channels: Output channels for each depth level. + depth: Encoder depth (1-5). Defaults to 5. + channel_multiplier: Channel scaling factor. Defaults to 1.0. + depth_multiplier: Depth scaling factor. Defaults to 1.0. + drop_rate: Dropout rate. Defaults to 0.2. + output_stride: Output stride of encoder. Defaults to 32. + """ + kwargs = get_efficientnet_kwargs( + channel_multiplier, depth_multiplier, drop_rate + ) + super().__init__( + stage_idxs=stage_idxs, + depth=depth, + out_channels=out_channels, + output_stride=output_stride, + **kwargs, + ) + + +timm_efficientnet_encoders = { + "timm-efficientnet-b0": { + "encoder": EfficientNetEncoder, + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b0.imagenet", + "revision": "8419e9cc19da0b68dcd7bb12f19b7c92407ad7c4", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b0.advprop", + "revision": "a5870af2d24ce79e0cc7fae2bbd8e0a21fcfa6d8", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b0.noisy-student", + "revision": "bea8b0ff726a50e48774d2d360c5fb1ac4815836", + }, + }, + "params": { + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], + "channel_multiplier": 1.0, + "depth_multiplier": 1.0, + "drop_rate": 0.2, + }, + }, +}