Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,29 @@ def get_vgg_features(input, layers, input_shape):
vgg = VGG19(input, input_shape)
outputs = [layer.output for layer in vgg.layers if layer.name in layers]
return outputs


def calculate_content_loss(content_image, reconstructed_image,
content_weight, image_shape, batch_size):
content_features = get_vgg_features(
content_image, CONTENT_LAYERS, image_shape)[0]
reconstructed_content_features = get_vgg_features(
reconstructed_image, CONTENT_LAYERS, image_shape)[0]

content_size = tensor_size(content_features) * batch_size
content_loss = content_weight * (2 * l2_loss(
reconstructed_content_features - content_features) / content_size)

return content_loss
def calculate_style_loss(style_image, reconstructed_image,

def calculate_style_loss(style_vgg_features, reconstructed_image,
style_weight, style_image_shape, content_image_shape,
batch_size):
# Get outputs of style and content images at VGG layers
style_vgg_features = get_vgg_features(
style_image, STYLE_LAYERS, style_image_shape)

reconstructed_style_vgg_features = get_vgg_features(
reconstructed_image, STYLE_LAYERS, content_image_shape)

# Calculate the style features of the style image and output image
# Style features are the gram matrices of the VGG feature maps
style_grams = []
Expand All @@ -70,7 +69,7 @@ def calculate_style_loss(style_image, reconstructed_image,
features_T = tf.transpose(features, perm=[0,2,1])
gram = tf.matmul(features_T, features) / features_size
style_grams.append(gram)

# Output image style features
for features in reconstructed_style_vgg_features:
_, h, w, filters = K.int_shape(features)
Expand All @@ -80,20 +79,20 @@ def calculate_style_loss(style_image, reconstructed_image,
features = K.reshape(features, np.array((batch_size, h * w, filters)))
features_T = tf.transpose(features, perm=[0,2,1])
gram = tf.matmul(features_T, features) / size
style_rec_grams.append(gram)
style_rec_grams.append(gram)

# Calculate style loss
style_losses = []
for style_gram, style_rec_gram in zip(style_grams, style_rec_grams):
style_gram_size = tensor_size(style_gram)
l2 = l2_loss(style_rec_gram - style_gram)
style_losses.append(2 * l2 / style_gram_size)

style_loss = style_weight * reduce(tf.add, style_losses) / batch_size

return style_loss


def calculate_tv_loss(x, tv_weight, batch_size):
tv_y_size = tensor_size(x[:,1:,:,:])
tv_x_size = tensor_size(x[:,:,1:,:])
Expand All @@ -106,6 +105,8 @@ def calculate_tv_loss(x, tv_weight, batch_size):
def create_loss_fn(style_image, content_weight,
style_weight, tv_weight, batch_size):
style_image = tf.convert_to_tensor(style_image)
style_vgg_features = get_vgg_features(
style_image, STYLE_LAYERS, K.int_shape(style_image))

def style_transfer_loss(y_true, y_pred):
"""
Expand All @@ -115,13 +116,13 @@ def style_transfer_loss(y_true, y_pred):

content_image = y_true
reconstructed_image = y_pred

content_loss = calculate_content_loss(content_image,
reconstructed_image,
content_weight,
CONTENT_TRAINING_SIZE,
batch_size)
style_loss = calculate_style_loss(style_image,
style_loss = calculate_style_loss(style_vgg_features,
reconstructed_image,
style_weight,
K.int_shape(style_image),
Expand Down
17 changes: 11 additions & 6 deletions vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,20 @@ def vgg_layers(img_input, input_shape):

return x

weights_file = None
def cached_file_load():
global weights_file
if(weights_file is None):
weights_path = get_file('vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
file_hash='253f8cb515780f3b799900260a226db6')
weights_file = h5py.File(weights_path)
return weights_file

def load_weights(model):
weights_path = get_file('vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
file_hash='253f8cb515780f3b799900260a226db6')
f = h5py.File(weights_path)
f = cached_file_load()
layer_names = [name for name in f.attrs['layer_names']]

for layer in model.layers:
if layer.name in layer_names:
g = f[layer.name]
Expand Down