@@ -102,6 +102,61 @@ assert torch.allclose(ssim_score_0, ssim_score_1)
102102print (ssim_score_0.shape, ssim_score_1.shape)
103103```
104104
105+ ## As A Loss
106+
107+ ![ prediction] ( https://user-images.githubusercontent.com/26847524/174814849-f80ec67c-5397-4ce6-bf4e-8b0aa568ed6f.png )
108+
109+ ``` python
110+ import matplotlib.pyplot as plt
111+ import torch
112+ from pytorch_ssim import SSIM
113+ from skimage import data
114+ from torch.optim import Adam
115+
116+
117+ original_image = data.camera() / 255
118+ target_image = torch.from_numpy(original_image).unsqueeze(0 ).unsqueeze(0 ).float().cuda()
119+ predicted_image = torch.rand_like(
120+ target_image, device = target_image.device, dtype = target_image.dtype, requires_grad = True
121+ )
122+ initial_image = predicted_image.clone()
123+
124+ ssim = SSIM().cuda()
125+ initial_ssim_value = ssim(predicted_image, target_image)
126+ print (f " Initial ssim: { initial_ssim_value.item():.4f } " )
127+ ssim_value = initial_ssim_value
128+
129+ optimizer = Adam([predicted_image], lr = 0.01 )
130+ loss_curves = []
131+ while ssim_value < 0.95 :
132+ ssim_out = 1 - ssim(predicted_image, target_image)
133+ loss_curves.append(ssim_out.item())
134+ ssim_value = 1 - ssim_out.item()
135+ ssim_out.backward()
136+ optimizer.step()
137+ optimizer.zero_grad()
138+
139+ fig, axes = plt.subplots(nrows = 1 , ncols = 4 , figsize = (8 , 2 ))
140+ ax = axes.ravel()
141+
142+ ax[0 ].imshow(original_image, cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
143+ ax[0 ].set_title(" Original Image" )
144+
145+ ax[1 ].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
146+ ax[1 ].set_xlabel(f " SSIM: { initial_ssim_value:.4f } " )
147+ ax[1 ].set_title(" Initial Image" )
148+
149+ ax[2 ].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap = plt.cm.gray, vmin = 0 , vmax = 1 )
150+ ax[2 ].set_xlabel(f " SSIM: { ssim_value:.4f } " )
151+ ax[2 ].set_title(" Predicted Image" )
152+
153+ ax[3 ].plot(loss_curves)
154+ ax[3 ].set_title(" SSIM Loss Curve" )
155+
156+ plt.tight_layout()
157+ plt.savefig(" prediction.png" )
158+ ```
159+
105160## Reference
106161
107162- https://github.com/Po-Hsun-Su/pytorch-ssim
0 commit comments