Skip to content

Commit 1584aad

Browse files
authored
Update README.md
1 parent 0095ee2 commit 1584aad

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

README.md

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,61 @@ assert torch.allclose(ssim_score_0, ssim_score_1)
102102
print(ssim_score_0.shape, ssim_score_1.shape)
103103
```
104104

105+
## As A Loss
106+
107+
![prediction](https://user-images.githubusercontent.com/26847524/174814849-f80ec67c-5397-4ce6-bf4e-8b0aa568ed6f.png)
108+
109+
```python
110+
import matplotlib.pyplot as plt
111+
import torch
112+
from pytorch_ssim import SSIM
113+
from skimage import data
114+
from torch.optim import Adam
115+
116+
117+
original_image = data.camera() / 255
118+
target_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda()
119+
predicted_image = torch.rand_like(
120+
target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True
121+
)
122+
initial_image = predicted_image.clone()
123+
124+
ssim = SSIM().cuda()
125+
initial_ssim_value = ssim(predicted_image, target_image)
126+
print(f"Initial ssim: {initial_ssim_value.item():.4f}")
127+
ssim_value = initial_ssim_value
128+
129+
optimizer = Adam([predicted_image], lr=0.01)
130+
loss_curves = []
131+
while ssim_value < 0.95:
132+
ssim_out = 1 - ssim(predicted_image, target_image)
133+
loss_curves.append(ssim_out.item())
134+
ssim_value = 1 - ssim_out.item()
135+
ssim_out.backward()
136+
optimizer.step()
137+
optimizer.zero_grad()
138+
139+
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(8, 2))
140+
ax = axes.ravel()
141+
142+
ax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1)
143+
ax[0].set_title("Original Image")
144+
145+
ax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
146+
ax[1].set_xlabel(f"SSIM: {initial_ssim_value:.4f}")
147+
ax[1].set_title("Initial Image")
148+
149+
ax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)
150+
ax[2].set_xlabel(f"SSIM: {ssim_value:.4f}")
151+
ax[2].set_title("Predicted Image")
152+
153+
ax[3].plot(loss_curves)
154+
ax[3].set_title("SSIM Loss Curve")
155+
156+
plt.tight_layout()
157+
plt.savefig("prediction.png")
158+
```
159+
105160
## Reference
106161

107162
- https://github.com/Po-Hsun-Su/pytorch-ssim

0 commit comments

Comments
 (0)