Skip to content

Commit d2c3b12

Browse files
authored
Update ssim.py
- Add the support of MS-SSIM. - Provides a rich interface to align the two widely used libraries: - https://github.com/VainF/pytorch-msssim - https://github.com/Po-Hsun-Su/pytorch-ssim
1 parent 92bf0ec commit d2c3b12

File tree

1 file changed

+132
-32
lines changed

1 file changed

+132
-32
lines changed

ssim.py

Lines changed: 132 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,78 @@
1+
import math
2+
import warnings
3+
14
import torch
25
import torch.nn as nn
36
import torch.nn.functional as F
47

58

69
class GaussianFilter2D(nn.Module):
7-
def __init__(self, window_size=11, in_channels=1, sigma=1.5) -> None:
10+
def __init__(self, window_size=11, in_channels=1, sigma=1.5, padding=None, ensemble_kernel=True):
811
"""2D Gaussian Filer
912
1013
Args:
1114
window_size (int, optional): The window size of the gaussian filter. Defaults to 11.
1215
in_channels (int, optional): The number of channels of the 4d tensor. Defaults to False.
1316
sigma (float, optional): The sigma of the gaussian filter. Defaults to 1.5.
17+
padding (int, optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0.
18+
ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True.
1419
"""
1520
super().__init__()
1621
self.window_size = window_size
1722
if not (window_size % 2 == 1):
1823
raise ValueError("Window size must be odd.")
19-
self.in_channels = in_channels
20-
self.padding = window_size // 2
24+
self.padding = padding if padding is not None else window_size // 2
2125
self.sigma = sigma
22-
self.register_buffer(name="gaussian_window2d", tensor=self._get_gaussian_window2d())
26+
self.ensemble_kernel = ensemble_kernel
27+
28+
kernel = self._get_gaussian_window1d()
29+
if ensemble_kernel:
30+
kernel = self._get_gaussian_window2d(kernel)
31+
self.register_buffer(name="gaussian_window", tensor=kernel.repeat(in_channels, 1, 1, 1))
2332

2433
def _get_gaussian_window1d(self):
2534
sigma2 = self.sigma * self.sigma
2635
x = torch.arange(-(self.window_size // 2), self.window_size // 2 + 1)
2736
w = torch.exp(-0.5 * x ** 2 / sigma2)
2837
w = w / w.sum()
29-
return w.reshape(1, 1, self.window_size, 1)
38+
return w.reshape(1, 1, 1, self.window_size)
3039

31-
def _get_gaussian_window2d(self):
32-
gaussian_window_1d = self._get_gaussian_window1d()
33-
w = torch.matmul(gaussian_window_1d, gaussian_window_1d.transpose(dim0=-1, dim1=-2))
34-
w.reshape(1, 1, self.window_size, self.window_size)
35-
return w.repeat(self.in_channels, 1, 1, 1)
40+
def _get_gaussian_window2d(self, gaussian_window_1d):
41+
w = torch.matmul(gaussian_window_1d.transpose(dim0=-1, dim1=-2), gaussian_window_1d)
42+
return w
3643

3744
def forward(self, x):
38-
x = F.conv2d(input=x, weight=self.gaussian_window2d, padding=self.padding, groups=x.shape[1])
45+
if self.ensemble_kernel:
46+
# ensemble kernel: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/3add4532d3f633316cba235da1c69e90f0dfb952/pytorch_ssim/__init__.py#L11-L15
47+
x = F.conv2d(input=x, weight=self.gaussian_window, stride=1, padding=self.padding, groups=x.shape[1])
48+
else:
49+
# splitted kernel: https://github.com/VainF/pytorch-msssim/blob/2398f4db0abf44bcd3301cfadc1bf6c94788d416/pytorch_msssim/ssim.py#L48
50+
for i, d in enumerate(x.shape[2:], start=2):
51+
if d >= self.window_size:
52+
w = self.gaussian_window.transpose(dim0=-1, dim1=i)
53+
x = F.conv2d(input=x, weight=w, stride=1, padding=self.padding, groups=x.shape[1])
54+
else:
55+
warnings.warn(
56+
f"Skipping Gaussian Smoothing at dimension {i} for x: {x.shape} and window size: {self.window_size}"
57+
)
3958
return x
4059

4160

4261
class SSIM(nn.Module):
4362
def __init__(
44-
self, window_size=11, in_channels=1, sigma=1.5, K1=0.01, K2=0.03, L=1, keep_batch_dim=False, return_log=False
63+
self,
64+
window_size=11,
65+
in_channels=1,
66+
sigma=1.5,
67+
*,
68+
K1=0.01,
69+
K2=0.03,
70+
L=1,
71+
keep_batch_dim=False,
72+
return_log=False,
73+
return_msssim=False,
74+
padding=None,
75+
ensemble_kernel=True,
4576
):
4677
"""Calculate the mean SSIM (MSSIM) between two 4D tensors.
4778
@@ -54,6 +85,9 @@ def __init__(
5485
L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1.
5586
keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False.
5687
return_log (bool, optional): Whether to return the logarithmic form. Defaults to False.
88+
return_msssim (bool, optional): Whether to return the MS-SSIM score. Defaults to False, which will return the original MSSIM score.
89+
padding (int, optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0.
90+
ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True.
5791
5892
```
5993
# setting 0: for 4d float tensors with the data range [0, 1] and 1 channel
@@ -66,6 +100,8 @@ def __init__(
66100
ssim_caller = SSIM(L=255, in_channels=3, return_log=True).cuda()
67101
# setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim
68102
ssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda()
103+
# setting 5: for 4d float tensors with the data range [0, 1] and 1 channel, padding=0 and the splitted kernels.
104+
ssim_caller = SSIM(return_log=True, keep_batch_dim=True, padding=0, ensemble_kernel=False).cuda()
69105
70106
# two 4d tensors
71107
x = torch.randn(3, 1, 100, 100).cuda()
@@ -76,15 +112,26 @@ def __init__(
76112
ssim_score_1 = ssim_caller(x, y)
77113
assert torch.isclose(ssim_score_0, ssim_score_1)
78114
```
115+
116+
Reference:
117+
[1] SSIM: Wang, Zhou et al. “Image quality assessment: from error visibility to structural similarity.” IEEE Transactions on Image Processing 13 (2004): 600-612.
118+
[2] MS-SSIM: Wang, Zhou et al. “Multi-scale structural similarity for image quality assessment.” (2003).
79119
"""
80120
super().__init__()
81121
self.window_size = window_size
82122
self.C1 = (K1 * L) ** 2 # equ 7 in ref1
83123
self.C2 = (K2 * L) ** 2 # equ 7 in ref1
84124
self.keep_batch_dim = keep_batch_dim
85125
self.return_log = return_log
126+
self.return_msssim = return_msssim
86127

87-
self.gaussian_filter = GaussianFilter2D(window_size=window_size, in_channels=in_channels, sigma=sigma)
128+
self.gaussian_filter = GaussianFilter2D(
129+
window_size=window_size,
130+
in_channels=in_channels,
131+
sigma=sigma,
132+
padding=padding,
133+
ensemble_kernel=ensemble_kernel,
134+
)
88135

89136
@torch.cuda.amp.autocast(enabled=False)
90137
def forward(self, x, y):
@@ -95,41 +142,89 @@ def forward(self, x, y):
95142
y (Tensor): 4d tensor
96143
97144
Returns:
98-
Tensor: MSSIM
145+
Tensor: MSSIM or MS-SSIM
99146
"""
100147
assert x.shape == y.shape, f"x: {x.shape} and y: {y.shape} must be the same"
101148
assert x.ndim == y.ndim == 4, f"x: {x.ndim} and y: {y.ndim} must be 4"
102-
if x.type() != self.gaussian_filter.gaussian_window2d.type():
103-
x = x.type_as(self.gaussian_filter.gaussian_window2d)
104-
if y.type() != self.gaussian_filter.gaussian_window2d.type():
105-
y = y.type_as(self.gaussian_filter.gaussian_window2d)
149+
if x.type() != self.gaussian_filter.gaussian_window.type():
150+
x = x.type_as(self.gaussian_filter.gaussian_window)
151+
if y.type() != self.gaussian_filter.gaussian_window.type():
152+
y = y.type_as(self.gaussian_filter.gaussian_window)
153+
154+
if self.return_msssim:
155+
return self.msssim(x, y)
156+
else:
157+
return self.ssim(x, y)
158+
159+
def ssim(self, x, y):
160+
ssim, _ = self._ssim(x, y)
161+
if self.return_log:
162+
# https://github.com/xuebinqin/BASNet/blob/56393818e239fed5a81d06d2a1abfe02af33e461/pytorch_ssim/__init__.py#L81-L83
163+
ssim = ssim - ssim.min()
164+
ssim = ssim / ssim.max()
165+
ssim = -torch.log(ssim + 1e-8)
106166

167+
if self.keep_batch_dim:
168+
return ssim.mean(dim=(1, 2, 3))
169+
else:
170+
return ssim.mean()
171+
172+
def msssim(self, x, y):
173+
ms_components = []
174+
for i, w in enumerate((0.0448, 0.2856, 0.3001, 0.2363, 0.1333)):
175+
ssim, cs = self._ssim(x, y)
176+
177+
if self.keep_batch_dim:
178+
ssim = ssim.mean(dim=(1, 2, 3))
179+
cs = cs.mean(dim=(1, 2, 3))
180+
else:
181+
ssim = ssim.mean()
182+
cs = cs.mean()
183+
184+
if i == 4:
185+
ms_components.append(ssim ** w)
186+
else:
187+
ms_components.append(cs ** w)
188+
padding = [s % 2 for s in x.shape[2:]] # spatial padding
189+
x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=padding)
190+
y = F.avg_pool2d(y, kernel_size=2, stride=2, padding=padding)
191+
msssim = math.prod(ms_components) # equ 7 in ref2
192+
return msssim
193+
194+
def _ssim(self, x, y):
107195
mu_x = self.gaussian_filter(x) # equ 14
108196
mu_y = self.gaussian_filter(y) # equ 14
109197
sigma2_x = self.gaussian_filter(x * x) - mu_x * mu_x # equ 15
110198
sigma2_y = self.gaussian_filter(y * y) - mu_y * mu_y # equ 15
111199
sigma_xy = self.gaussian_filter(x * y) - mu_x * mu_y # equ 16
112200

113-
# equ 13 in ref1
114201
A1 = 2 * mu_x * mu_y + self.C1
115202
A2 = 2 * sigma_xy + self.C2
116203
B1 = mu_x * mu_x + mu_y * mu_y + self.C1
117204
B2 = sigma2_x + sigma2_y + self.C2
118-
S = (A1 * A2) / (B1 * B2)
119205

120-
if self.return_log:
121-
S = S - S.min()
122-
S = S / S.max()
123-
S = -torch.log(S + 1e-8)
124-
125-
if self.keep_batch_dim:
126-
return S.mean(dim=(1, 2, 3))
127-
else:
128-
return S.mean()
206+
# equ 12, 13 in ref1
207+
l = A1 / B1
208+
cs = A2 / B2
209+
ssim = l * cs
210+
return ssim, cs
129211

130212

131213
def ssim(
132-
x, y, *, window_size=11, in_channels=1, sigma=1.5, K1=0.01, K2=0.03, L=1, keep_batch_dim=False, return_log=False
214+
x,
215+
y,
216+
*,
217+
window_size=11,
218+
in_channels=1,
219+
sigma=1.5,
220+
K1=0.01,
221+
K2=0.03,
222+
L=1,
223+
keep_batch_dim=False,
224+
return_log=False,
225+
return_msssim=False,
226+
padding=None,
227+
ensemble_kernel=True,
133228
):
134229
"""Calculate the mean SSIM (MSSIM) between two 4D tensors.
135230
@@ -144,10 +239,12 @@ def ssim(
144239
L (int, optional): The dynamic range of the pixel values (255 for 8-bit grayscale images). Defaults to 1.
145240
keep_batch_dim (bool, optional): Whether to keep the batch dim. Defaults to False.
146241
return_log (bool, optional): Whether to return the logarithmic form. Defaults to False.
147-
242+
return_msssim (bool, optional): Whether to return the MS-SSIM score. Defaults to False, which will return the original MSSIM score.
243+
padding (int, optional): The padding of the gaussian filter. Defaults to None. If it is set to None, the filter will use window_size//2 as the padding. Another common setting is 0.
244+
ensemble_kernel (bool, optional): Whether to fuse the two cascaded 1d kernel into a 2d kernel. Defaults to True.
148245
149246
Returns:
150-
Tensor: MSSIM
247+
Tensor: MSSIM or MS-SSIM
151248
"""
152249
ssim_obj = SSIM(
153250
window_size=window_size,
@@ -158,5 +255,8 @@ def ssim(
158255
L=L,
159256
keep_batch_dim=keep_batch_dim,
160257
return_log=return_log,
258+
return_msssim=return_msssim,
259+
padding=padding,
260+
ensemble_kernel=ensemble_kernel,
161261
).to(device=x.device)
162262
return ssim_obj(x, y)

0 commit comments

Comments
 (0)