1+ import math
2+ import warnings
3+
14import torch
25import torch .nn as nn
36import torch .nn .functional as F
47
58
69class GaussianFilter2D (nn .Module ):
7- def __init__ (self , window_size = 11 , in_channels = 1 , sigma = 1.5 ) -> None :
10+ def __init__ (self , window_size = 11 , in_channels = 1 , sigma = 1.5 , padding = None , ensemble_kernel = True ) :
811 """2D Gaussian Filer
912
1013 Args:
1114 window_size (int, optional): The window size of the gaussian filter. Defaults to 11.
1215 in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False.
1316 sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5.
17+ padding (int, optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0.
18+ ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True.
1419 """
1520 super ().__init__ ()
1621 self .window_size = window_size
1722 if not (window_size % 2 == 1 ):
1823 raise ValueError ("Window size must be odd." )
19- self .in_channels = in_channels
20- self .padding = window_size // 2
24+ self .padding = padding if padding is not None else window_size // 2
2125 self .sigma = sigma
22- self .register_buffer (name = "gaussian_window2d" , tensor = self ._get_gaussian_window2d ())
26+ self .ensemble_kernel = ensemble_kernel
27+
28+ kernel = self ._get_gaussian_window1d ()
29+ if ensemble_kernel :
30+ kernel = self ._get_gaussian_window2d (kernel )
31+ self .register_buffer (name = "gaussian_window" , tensor = kernel .repeat (in_channels , 1 , 1 , 1 ))
2332
2433 def _get_gaussian_window1d (self ):
2534 sigma2 = self .sigma * self .sigma
2635 x = torch .arange (- (self .window_size // 2 ), self .window_size // 2 + 1 )
2736 w = torch .exp (- 0.5 * x ** 2 / sigma2 )
2837 w = w / w .sum ()
29- return w .reshape (1 , 1 , self .window_size , 1 )
38+ return w .reshape (1 , 1 , 1 , self .window_size )
3039
31- def _get_gaussian_window2d (self ):
32- gaussian_window_1d = self ._get_gaussian_window1d ()
33- w = torch .matmul (gaussian_window_1d , gaussian_window_1d .transpose (dim0 = - 1 , dim1 = - 2 ))
34- w .reshape (1 , 1 , self .window_size , self .window_size )
35- return w .repeat (self .in_channels , 1 , 1 , 1 )
40+ def _get_gaussian_window2d (self , gaussian_window_1d ):
41+ w = torch .matmul (gaussian_window_1d .transpose (dim0 = - 1 , dim1 = - 2 ), gaussian_window_1d )
42+ return w
3643
3744 def forward (self , x ):
38- x = F .conv2d (input = x , weight = self .gaussian_window2d , padding = self .padding , groups = x .shape [1 ])
45+ if self .ensemble_kernel :
46+ # ensemble kernel: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/3add4532d3f633316cba235da1c69e90f0dfb952/pytorch_ssim/__init__.py#L11-L15
47+ x = F .conv2d (input = x , weight = self .gaussian_window , stride = 1 , padding = self .padding , groups = x .shape [1 ])
48+ else :
49+ # splitted kernel: https://github.com/VainF/pytorch-msssim/blob/2398f4db0abf44bcd3301cfadc1bf6c94788d416/pytorch_msssim/ssim.py#L48
50+ for i , d in enumerate (x .shape [2 :], start = 2 ):
51+ if d >= self .window_size :
52+ w = self .gaussian_window .transpose (dim0 = - 1 , dim1 = i )
53+ x = F .conv2d (input = x , weight = w , stride = 1 , padding = self .padding , groups = x .shape [1 ])
54+ else :
55+ warnings .warn (
56+ f"Skipping Gaussian Smoothing at dimension { i } for x: { x .shape } and window size: { self .window_size } "
57+ )
3958 return x
4059
4160
4261class SSIM (nn .Module ):
4362 def __init__ (
44- self , window_size = 11 , in_channels = 1 , sigma = 1.5 , K1 = 0.01 , K2 = 0.03 , L = 1 , keep_batch_dim = False , return_log = False
63+ self ,
64+ window_size = 11 ,
65+ in_channels = 1 ,
66+ sigma = 1.5 ,
67+ * ,
68+ K1 = 0.01 ,
69+ K2 = 0.03 ,
70+ L = 1 ,
71+ keep_batch_dim = False ,
72+ return_log = False ,
73+ return_msssim = False ,
74+ padding = None ,
75+ ensemble_kernel = True ,
4576 ):
4677 """Calculate the mean SSIM (MSSIM) between two 4D tensors.
4778
@@ -54,6 +85,9 @@ def __init__(
5485 L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1.
5586 keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False.
5687 return_log (bool, optional): Whether to return the logarithmic form. Defaults to False.
88+ return_msssim (bool, optional): Whether to return the MS-SSIM score. Defaults to False, which will return the original MSSIM score.
89+ padding (int, optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0.
90+ ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True.
5791
5892 ```
5993 # setting 0: for 4d float tensors with the data range [0, 1] and 1 channel
@@ -66,6 +100,8 @@ def __init__(
66100 ssim_caller = SSIM(L=255, in_channels=3, return_log=True).cuda()
67101 # setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim
68102 ssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda()
103+ # setting 5: for 4d float tensors with the data range [0, 1] and 1 channel, padding=0 and the splitted kernels.
104+ ssim_caller = SSIM(return_log=True, keep_batch_dim=True, padding=0, ensemble_kernel=False).cuda()
69105
70106 # two 4d tensors
71107 x = torch.randn(3, 1, 100, 100).cuda()
@@ -76,15 +112,26 @@ def __init__(
76112 ssim_score_1 = ssim_caller(x, y)
77113 assert torch.isclose(ssim_score_0, ssim_score_1)
78114 ```
115+
116+ Reference:
117+ [1] SSIM: Wang, Zhou et al. “Image quality assessment: from error visibility to structural similarity.” IEEE Transactions on Image Processing 13 (2004): 600-612.
118+ [2] MS-SSIM: Wang, Zhou et al. “Multi-scale structural similarity for image quality assessment.” (2003).
79119 """
80120 super ().__init__ ()
81121 self .window_size = window_size
82122 self .C1 = (K1 * L ) ** 2 # equ 7 in ref1
83123 self .C2 = (K2 * L ) ** 2 # equ 7 in ref1
84124 self .keep_batch_dim = keep_batch_dim
85125 self .return_log = return_log
126+ self .return_msssim = return_msssim
86127
87- self .gaussian_filter = GaussianFilter2D (window_size = window_size , in_channels = in_channels , sigma = sigma )
128+ self .gaussian_filter = GaussianFilter2D (
129+ window_size = window_size ,
130+ in_channels = in_channels ,
131+ sigma = sigma ,
132+ padding = padding ,
133+ ensemble_kernel = ensemble_kernel ,
134+ )
88135
89136 @torch .cuda .amp .autocast (enabled = False )
90137 def forward (self , x , y ):
@@ -95,41 +142,89 @@ def forward(self, x, y):
95142 y (Tensor): 4d tensor
96143
97144 Returns:
98- Tensor: MSSIM
145+ Tensor: MSSIM or MS-SSIM
99146 """
100147 assert x .shape == y .shape , f"x: { x .shape } and y: { y .shape } must be the same"
101148 assert x .ndim == y .ndim == 4 , f"x: { x .ndim } and y: { y .ndim } must be 4"
102- if x .type () != self .gaussian_filter .gaussian_window2d .type ():
103- x = x .type_as (self .gaussian_filter .gaussian_window2d )
104- if y .type () != self .gaussian_filter .gaussian_window2d .type ():
105- y = y .type_as (self .gaussian_filter .gaussian_window2d )
149+ if x .type () != self .gaussian_filter .gaussian_window .type ():
150+ x = x .type_as (self .gaussian_filter .gaussian_window )
151+ if y .type () != self .gaussian_filter .gaussian_window .type ():
152+ y = y .type_as (self .gaussian_filter .gaussian_window )
153+
154+ if self .return_msssim :
155+ return self .msssim (x , y )
156+ else :
157+ return self .ssim (x , y )
158+
159+ def ssim (self , x , y ):
160+ ssim , _ = self ._ssim (x , y )
161+ if self .return_log :
162+ # https://github.com/xuebinqin/BASNet/blob/56393818e239fed5a81d06d2a1abfe02af33e461/pytorch_ssim/__init__.py#L81-L83
163+ ssim = ssim - ssim .min ()
164+ ssim = ssim / ssim .max ()
165+ ssim = - torch .log (ssim + 1e-8 )
106166
167+ if self .keep_batch_dim :
168+ return ssim .mean (dim = (1 , 2 , 3 ))
169+ else :
170+ return ssim .mean ()
171+
172+ def msssim (self , x , y ):
173+ ms_components = []
174+ for i , w in enumerate ((0.0448 , 0.2856 , 0.3001 , 0.2363 , 0.1333 )):
175+ ssim , cs = self ._ssim (x , y )
176+
177+ if self .keep_batch_dim :
178+ ssim = ssim .mean (dim = (1 , 2 , 3 ))
179+ cs = cs .mean (dim = (1 , 2 , 3 ))
180+ else :
181+ ssim = ssim .mean ()
182+ cs = cs .mean ()
183+
184+ if i == 4 :
185+ ms_components .append (ssim ** w )
186+ else :
187+ ms_components .append (cs ** w )
188+ padding = [s % 2 for s in x .shape [2 :]] # spatial padding
189+ x = F .avg_pool2d (x , kernel_size = 2 , stride = 2 , padding = padding )
190+ y = F .avg_pool2d (y , kernel_size = 2 , stride = 2 , padding = padding )
191+ msssim = math .prod (ms_components ) # equ 7 in ref2
192+ return msssim
193+
194+ def _ssim (self , x , y ):
107195 mu_x = self .gaussian_filter (x ) # equ 14
108196 mu_y = self .gaussian_filter (y ) # equ 14
109197 sigma2_x = self .gaussian_filter (x * x ) - mu_x * mu_x # equ 15
110198 sigma2_y = self .gaussian_filter (y * y ) - mu_y * mu_y # equ 15
111199 sigma_xy = self .gaussian_filter (x * y ) - mu_x * mu_y # equ 16
112200
113- # equ 13 in ref1
114201 A1 = 2 * mu_x * mu_y + self .C1
115202 A2 = 2 * sigma_xy + self .C2
116203 B1 = mu_x * mu_x + mu_y * mu_y + self .C1
117204 B2 = sigma2_x + sigma2_y + self .C2
118- S = (A1 * A2 ) / (B1 * B2 )
119205
120- if self .return_log :
121- S = S - S .min ()
122- S = S / S .max ()
123- S = - torch .log (S + 1e-8 )
124-
125- if self .keep_batch_dim :
126- return S .mean (dim = (1 , 2 , 3 ))
127- else :
128- return S .mean ()
206+ # equ 12, 13 in ref1
207+ l = A1 / B1
208+ cs = A2 / B2
209+ ssim = l * cs
210+ return ssim , cs
129211
130212
131213def ssim (
132- x , y , * , window_size = 11 , in_channels = 1 , sigma = 1.5 , K1 = 0.01 , K2 = 0.03 , L = 1 , keep_batch_dim = False , return_log = False
214+ x ,
215+ y ,
216+ * ,
217+ window_size = 11 ,
218+ in_channels = 1 ,
219+ sigma = 1.5 ,
220+ K1 = 0.01 ,
221+ K2 = 0.03 ,
222+ L = 1 ,
223+ keep_batch_dim = False ,
224+ return_log = False ,
225+ return_msssim = False ,
226+ padding = None ,
227+ ensemble_kernel = True ,
133228):
134229 """Calculate the mean SSIM (MSSIM) between two 4D tensors.
135230
@@ -144,10 +239,12 @@ def ssim(
144239 L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1.
145240 keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False.
146241 return_log (bool, optional): Whether to return the logarithmic form. Defaults to False.
147-
242+ return_msssim (bool, optional): Whether to return the MS-SSIM score. Defaults to False, which will return the original MSSIM score.
243+ padding (int, optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0.
244+ ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True.
148245
149246 Returns:
150- Tensor: MSSIM
247+ Tensor: MSSIM or MS-SSIM
151248 """
152249 ssim_obj = SSIM (
153250 window_size = window_size ,
@@ -158,5 +255,8 @@ def ssim(
158255 L = L ,
159256 keep_batch_dim = keep_batch_dim ,
160257 return_log = return_log ,
258+ return_msssim = return_msssim ,
259+ padding = padding ,
260+ ensemble_kernel = ensemble_kernel ,
161261 ).to (device = x .device )
162262 return ssim_obj (x , y )
0 commit comments