-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVQAutoEncoder.py
More file actions
129 lines (94 loc) · 4.23 KB
/
VQAutoEncoder.py
File metadata and controls
129 lines (94 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#IMPORTS SECTION
import tensorflow as tf
from tensorflow.keras import layers
import keras
import numpy as np
import random
import pdb
from tensorflow.keras.layers import MultiHeadAttention
import tensorflow_probability as tfp
import tensorflow_addons as tfa
import os
import urllib.request
import json
import yaml
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import LearningRateScheduler
from VQModels import gpuChecker, lr_schedule, modelSave, VectorQuantizer, createEncoderDecoder, VQNVAE, VQNVAETrainer
from VQVisualiser import plotImageInputOutput, outputVisualiser
from VQData import process_path, imagePreprocessorLoader, processPathRandom, imageRandomisedPreprocessorLoader, imagePreprocessorLoader, unisonShuffle, dataCreator
if __name__ == "__main__":
with open("cfg.yml", "r") as f:
config = yaml.safe_load(f)
#Hyperparameters
epochs = config['epochs']
batches = config['batches']
learningRate = config['learningRate']
latentDim = config['latentDim']
imageSize = config['imageSize']
cropSize = config['cropSize']
imageChannels = config['imageChannels']
numEmbeddings = config['numEmbeddings']
strides = config['strides']
filters = config['filters']
encoderString = config['encoderString']
decoderString = config['decoderString']
#COCODATA
baseDir = config['baseDir']
valDataDir = config['valDataDir']
trainDataDir = config['trainDataDir']
testDataDir = config['testDataDir']
annDir = config['annDir']
diffDir = config['diffDir']
valAnnDir = config['valAnnDir']
trainAnnDir = config['trainAnnDir']
#Model Visualising Params
outputView = config['outputView']
#Model Saving/Loading Params
load = config['load']
save = config['save']
modelOutputType = config['modelOutputType']
#GPU Section
gpuChecker()
with open(baseDir + valAnnDir, 'r') as file:
cocoValAnn = json.load(file)
with open(baseDir + trainAnnDir, 'r') as file:
cocoTrainAnn = json.load(file)
# Print the keys of the loaded JSON to verify
print(cocoValAnn.keys())
print(cocoTrainAnn.keys())
valInp = dataCreator(cocoValAnn,baseDir+valDataDir)
print("Entry Inp: ", valInp[0])
trainInp = dataCreator(cocoTrainAnn,baseDir+trainDataDir)
print("Entry Inp: ", trainInp[0])
trainInp = unisonShuffle(trainInp)
#Create Training Data
valDataset = tf.data.Dataset.from_tensor_slices(valInp)
valDataset = valDataset.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)
valDataset = valDataset.batch(1).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
trainDataset = tf.data.Dataset.from_tensor_slices(trainInp)
trainDataset = trainDataset.map(processPathRandom, num_parallel_calls=tf.data.experimental.AUTOTUNE)
trainDataset = trainDataset.batch(batches).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
#Define VQNVAE Model
if load == True:
vqnvaeTrainer = trainerLoad(fileType=modelOutputType, encoderPathString=encoderPathString, decoderPathString=decoderPathString,
latentDim=latentDim, numEmbeddings=numEmbeddings, imageSize=imageSize, channels=imageChannels, strides=strides, filters=filters)
else:
vqnvaeTrainer = VQNVAETrainer(latentDim=latentDim, numEmbeddings=numEmbeddings,imageSize=imageSize,channels=imageChannels,strides=strides,filters=filters)
vqnvaeTrainer.vqnvae.summary()
#Set scheduler and compile
lr_scheduler = LearningRateScheduler(lr_schedule)
vqnvaeTrainer.compile(optimizer=keras.optimizers.Adam(learning_rate=learningRate),metrics=['accuracy'])
#Fit model
vqnvaeTrainer.fit(trainDataset,epochs=epochs,batch_size=batches,callbacks=lr_scheduler)
#Visualiser Code
if outputView == True:
outputVisualiser(vqnvaeTrainer,valDataset)
#Save Models
if modelSave == True:
modelSave(vqvnaeTrainer=vqvnaeTrainer,modelOutputType=modelOutputType,encoderPathString=encoderString,decoderPathString=decoderString)