-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVQVisualiser.py
More file actions
99 lines (72 loc) · 3.17 KB
/
VQVisualiser.py
File metadata and controls
99 lines (72 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
def show_subplot(original, reconstructed):
plt.subplot(1, 2, 1)
plt.imshow(tf.squeeze(original).numpy())
plt.title("Original")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(tf.squeeze(reconstructed).numpy())
plt.title("Reconstructed")
plt.axis("off")
plt.show()
def plotImageInputOutput(original, reconstructed, n_samples=5):
""" Function to plot all five frames of input and reconstructed output. """
# Loop over the number of samples we want to plot
for sample_idx in range(n_samples):
fig, axes = plt.subplots(2, 1, figsize=(15, 6)) # 2 rows (input/output), 5 columns (frames)
for frame_idx in range(1): # Loop through the 5 frames
# Plot original image (input)
ax_input = axes[0]
ax_input.imshow(original[sample_idx]) # Plot the frame
ax_input.set_title(f"Input Frame {1}")
ax_input.axis('off')
# Plot reconstructed image (output)
ax_output = axes[1]
ax_output.imshow(reconstructed[sample_idx]) # Plot the frame
ax_output.set_title(f"Output Frame {1}")
ax_output.axis('off')
plt.tight_layout()
plt.show()
def outputVisualiser(vqnvaeTrainer,dataset):
#Decoder Visualisation
predictions = 15
valData = []
for x in dataset.take(predictions):
valData.append(x.numpy())
valData = np.concatenate(valData, axis=0)
trained_vqnvae_model = vqnvaeTrainer.vqnvae
valPreds = trained_vqnvae_model.predict(valData)
plotImageInputOutput(valData, valPreds, n_samples=10)
trained_vqnvae_model = vqnvaeTrainer.vqnvae.get_layer(name="VQNVAE-Encoder")
valPreds,_ = trained_vqnvae_model.predict(valData)
plotImageInputOutput(valData, valPreds, n_samples=10)
def plotImageInputOutput(original, reconstructed, n_samples=5):
# Limit number of samples to the size of the arrays
n_samples = min(n_samples, len(original), len(reconstructed))
fig, axes = plt.subplots(2, n_samples, figsize=(3 * n_samples, 6))
for i in range(n_samples):
# Input image
ax_input = axes[0, i] if n_samples > 1 else axes[0]
ax_input.imshow(original[i], cmap='gray')
ax_input.set_title(f"Input {i+1}")
ax_input.axis('off')
# Reconstructed image
ax_output = axes[1, i] if n_samples > 1 else axes[1]
ax_output.imshow(reconstructed[i], cmap='gray')
ax_output.set_title(f"Output {i+1}")
ax_output.axis('off')
plt.tight_layout()
plt.show()
def outputVisualiser(vqnvaeTrainer, dataset, n_samples=10):
# Collect dataset samples
valData = []
for x in dataset.take(n_samples):
valData.append(x.numpy())
valData = np.concatenate(valData, axis=0)
trained_vqnvae_model = vqnvaeTrainer.vqnvae
# Predictions from the full VQ-NVAE
valPreds = trained_vqnvae_model.predict(valData)
# Plot input vs reconstructed outputs in one figure
plotImageInputOutput(valData, valPreds, n_samples=n_samples)