1212from __future__ import annotations
1313
1414from collections .abc import Sequence
15+ from typing import Optional
1516
1617import numpy as np
1718import torch
1819import torch .nn as nn
1920import torch .nn .functional as F
2021from torch .nn import LayerNorm
2122
22- from monai .networks .blocks .pos_embed_utils import build_sincos_position_embedding
23+ from monai .networks .blocks .pos_embed_utils import build_fourier_position_embedding , build_sincos_position_embedding
2324from monai .networks .layers import Conv , trunc_normal_
2425from monai .utils import ensure_tuple_rep , optional_import
2526from monai .utils .module import look_up_option
2627
2728Rearrange , _ = optional_import ("einops.layers.torch" , name = "Rearrange" )
2829SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv" , "perceptron" }
29- SUPPORTED_POS_EMBEDDING_TYPES = {"none" , "learnable" , "sincos" }
30+ SUPPORTED_POS_EMBEDDING_TYPES = {"none" , "learnable" , "sincos" , "fourier" }
3031
3132
3233class PatchEmbeddingBlock (nn .Module ):
@@ -53,6 +54,7 @@ def __init__(
5354 pos_embed_type : str = "learnable" ,
5455 dropout_rate : float = 0.0 ,
5556 spatial_dims : int = 3 ,
57+ pos_embed_kwargs : Optional [dict ] = None ,
5658 ) -> None :
5759 """
5860 Args:
@@ -65,6 +67,8 @@ def __init__(
6567 pos_embed_type: position embedding layer type.
6668 dropout_rate: fraction of the input units to drop.
6769 spatial_dims: number of spatial dimensions.
70+ pos_embed_kwargs: additional arguments for position embedding. For `sincos`, it can contain
71+ `temperature` and for fourier it can contain `scales`.
6872 """
6973
7074 super ().__init__ ()
@@ -105,6 +109,8 @@ def __init__(
105109 self .position_embeddings = nn .Parameter (torch .zeros (1 , self .n_patches , hidden_size ))
106110 self .dropout = nn .Dropout (dropout_rate )
107111
112+ pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs
113+
108114 if self .pos_embed_type == "none" :
109115 pass
110116 elif self .pos_embed_type == "learnable" :
@@ -114,7 +120,17 @@ def __init__(
114120 for in_size , pa_size in zip (img_size , patch_size ):
115121 grid_size .append (in_size // pa_size )
116122
117- self .position_embeddings = build_sincos_position_embedding (grid_size , hidden_size , spatial_dims )
123+ self .position_embeddings = build_sincos_position_embedding (
124+ grid_size , hidden_size , spatial_dims , ** pos_embed_kwargs
125+ )
126+ elif self .pos_embed_type == "fourier" :
127+ grid_size = []
128+ for in_size , pa_size in zip (img_size , patch_size ):
129+ grid_size .append (in_size // pa_size )
130+
131+ self .position_embeddings = build_fourier_position_embedding (
132+ grid_size , hidden_size , spatial_dims , ** pos_embed_kwargs
133+ )
118134 else :
119135 raise ValueError (f"pos_embed_type { self .pos_embed_type } not supported." )
120136
0 commit comments