diff --git a/loss.py b/loss.py index 2bfadbd..9bbca5b 100644 --- a/loss.py +++ b/loss.py @@ -29,7 +29,7 @@ def get_vgg_features(input, layers, input_shape): vgg = VGG19(input, input_shape) outputs = [layer.output for layer in vgg.layers if layer.name in layers] return outputs - + def calculate_content_loss(content_image, reconstructed_image, content_weight, image_shape, batch_size): @@ -37,22 +37,21 @@ def calculate_content_loss(content_image, reconstructed_image, content_image, CONTENT_LAYERS, image_shape)[0] reconstructed_content_features = get_vgg_features( reconstructed_image, CONTENT_LAYERS, image_shape)[0] - + content_size = tensor_size(content_features) * batch_size content_loss = content_weight * (2 * l2_loss( reconstructed_content_features - content_features) / content_size) - + return content_loss - -def calculate_style_loss(style_image, reconstructed_image, + +def calculate_style_loss(style_vgg_features, reconstructed_image, style_weight, style_image_shape, content_image_shape, batch_size): # Get outputs of style and content images at VGG layers - style_vgg_features = get_vgg_features( - style_image, STYLE_LAYERS, style_image_shape) + reconstructed_style_vgg_features = get_vgg_features( reconstructed_image, STYLE_LAYERS, content_image_shape) - + # Calculate the style features of the style image and output image # Style features are the gram matrices of the VGG feature maps style_grams = [] @@ -70,7 +69,7 @@ def calculate_style_loss(style_image, reconstructed_image, features_T = tf.transpose(features, perm=[0,2,1]) gram = tf.matmul(features_T, features) / features_size style_grams.append(gram) - + # Output image style features for features in reconstructed_style_vgg_features: _, h, w, filters = K.int_shape(features) @@ -80,20 +79,20 @@ def calculate_style_loss(style_image, reconstructed_image, features = K.reshape(features, np.array((batch_size, h * w, filters))) features_T = tf.transpose(features, perm=[0,2,1]) gram = tf.matmul(features_T, features) / size - style_rec_grams.append(gram) - + style_rec_grams.append(gram) + # Calculate style loss style_losses = [] for style_gram, style_rec_gram in zip(style_grams, style_rec_grams): style_gram_size = tensor_size(style_gram) l2 = l2_loss(style_rec_gram - style_gram) style_losses.append(2 * l2 / style_gram_size) - + style_loss = style_weight * reduce(tf.add, style_losses) / batch_size - + return style_loss - - + + def calculate_tv_loss(x, tv_weight, batch_size): tv_y_size = tensor_size(x[:,1:,:,:]) tv_x_size = tensor_size(x[:,:,1:,:]) @@ -106,6 +105,8 @@ def calculate_tv_loss(x, tv_weight, batch_size): def create_loss_fn(style_image, content_weight, style_weight, tv_weight, batch_size): style_image = tf.convert_to_tensor(style_image) + style_vgg_features = get_vgg_features( + style_image, STYLE_LAYERS, K.int_shape(style_image)) def style_transfer_loss(y_true, y_pred): """ @@ -115,13 +116,13 @@ def style_transfer_loss(y_true, y_pred): content_image = y_true reconstructed_image = y_pred - + content_loss = calculate_content_loss(content_image, reconstructed_image, content_weight, CONTENT_TRAINING_SIZE, batch_size) - style_loss = calculate_style_loss(style_image, + style_loss = calculate_style_loss(style_vgg_features, reconstructed_image, style_weight, K.int_shape(style_image), diff --git a/vgg.py b/vgg.py index f06a356..87608ba 100644 --- a/vgg.py +++ b/vgg.py @@ -46,15 +46,20 @@ def vgg_layers(img_input, input_shape): return x +weights_file = None +def cached_file_load(): + global weights_file + if(weights_file is None): + weights_path = get_file('vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5', + WEIGHTS_PATH_NO_TOP, + cache_subdir='models', + file_hash='253f8cb515780f3b799900260a226db6') + weights_file = h5py.File(weights_path) + return weights_file def load_weights(model): - weights_path = get_file('vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5', - WEIGHTS_PATH_NO_TOP, - cache_subdir='models', - file_hash='253f8cb515780f3b799900260a226db6') - f = h5py.File(weights_path) + f = cached_file_load() layer_names = [name for name in f.attrs['layer_names']] - for layer in model.layers: if layer.name in layer_names: g = f[layer.name]