Neural Style Transfer#

Summary#

  • Introduction

  • Principle

  • Example: apply style transfer to an image

Introduction#

Neural style transfer in a nutshell#

  • Reproduce an image with a new artistic style provided by another image.

  • Blend a content image and a style reference image in a stylized output image.

  • First described in A Neural Algorithm of Artistic Style by Gatys et al (2015). Many refinements and variations since.

Example#

Prisma style transfer example

Principle#

Loss definition#

The resulting image \(\pmb{x}^*\) minimizes a weighted sum of three losses: the content loss \(\mathcal{L}_{content}(\pmb{c}, \pmb{x})\), the style loss \(\mathcal{L}_{style}(\pmb{s}, \pmb{x})\) and the total variation loss \(\mathcal{L}_{total\_variation}(\pmb{x})\).

\[\pmb{x}^* = \underset{\pmb{x}}{\mathrm{argmin}}(\alpha\mathcal{L}_{content}(\pmb{c}, \pmb{x}) + \beta\mathcal{L}_{style}(\pmb{s}, \pmb{x}) + \gamma\mathcal{L}_{total\_variation}(\pmb{x}))\]
  • \(\pmb{x}\): generated image.

  • \(\pmb{c}\): content image.

  • \(\pmb{s}\): reference style image.

  • \(\alpha\), \(\beta\) and \(\gamma\): weights.

The content loss#

  • Content = high-level structure of an image.

  • Can be captured by the upper layer of a convolutional neural network.

  • Content loss for a layer = distance between the feature maps of the content and generated images.

The style loss#

  • Style = low-level features of an image (textures, colors, visual patterns).

  • Can be captured by using correlations across the different feature maps (filter responses) of a convnet.

  • Feature correlations are computed via a Gram matrix (outer product of the feature maps for a given layer).

  • Style loss for a layer = distance between the Gram matrices of the feature maps for the style and generated images.

The total variation loss#

  • Sum of the absolute differences for neighboring pixel-values in an image. Measures how much noise is in the image.

  • Encourage local spatial continuity in the generated image (denoising).

  • Act as a regularization loss.

Gradient descent#

  • Objective: minimize the total loss.

  • Optimizer: L-BFGS (original choice made by Gatys et al.) or Adam.

Animation of style transfer

Example: apply style transfer to an image#

(Heavily inspired by this Keras example)

Environment setup#

import platform

print(f"Python version: {platform.python_version()}")
assert platform.python_version_tuple() >= ("3", "6")

from IPython.display import Image, display

import numpy as np
Python version: 3.8.2
import tensorflow as tf

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {tf.keras.__version__}")
print('GPU found :)' if tf.config.list_physical_devices("GPU") else 'No GPU :(')

from tensorflow.keras import Model
from tensorflow.keras.applications import vgg19
from tensorflow.keras.utils import get_file
from tensorflow.keras.preprocessing.image import load_img, save_img, img_to_array
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.optimizers.schedules import ExponentialDecay
TensorFlow version: 2.2.0
Keras version: 2.3.0-tf
GPU found :)

Images loading#

# Download content and style images
base_image_path = get_file("paris.jpg", "https://i.imgur.com/F28w3Ac.jpg")
style_reference_image_path = get_file(
    "starry_night.jpg", "https://i.imgur.com/9ooB60I.jpg"
)

width, height = load_img(base_image_path).size
print(f"Input image dimensions: {width, height}")
Input image dimensions: (910, 607)
# Show content images
display(Image(base_image_path))
../_images/eeec368e92239026a6b7f2dada212daf74d8ae0bb191e99e4ae6c5d41ad58fc9.jpg
# Show style reference image
display(Image(style_reference_image_path))
../_images/0f0a513035709bfd1da42901a889160c4c864f09f32c9c35c17d641cb8c181bb.jpg

Utility functions#

def preprocess_image(image_path, height, width):
    """Open, resize and format a picture into appropriate tensors"""
    img = load_img(
        image_path, target_size=(height, width)
    )
    img = img_to_array(img)
    img = np.expand_dims(img, axis=0)
    # Convert image from RGB to BGR and zero-center each color channel w.r.t. the ImageNet dataset
    img = vgg19.preprocess_input(img)
    return tf.convert_to_tensor(img)


def deprocess_image(x, height, width):
    """Convert a tensor into a valid image"""
    x = x.reshape((height, width, 3))
    # Remove zero-center by mean pixel applied by vgg19.preprocess_input()
    # The following values are the mean pixel values of each color channel for the ImageNet dataset
    x[:, :, 0] += 103.939
    x[:, :, 1] += 116.779
    x[:, :, 2] += 123.68
    # 'BGR'->'RGB'
    x = x[:, :, ::-1]
    x = np.clip(x, 0, 255).astype("uint8")
    return x

Loss functions#

# The gram matrix of a 3D tensor (correlations between the feature maps of a convolutional layer)
def gram_matrix(x):
    # Transpose feature maps tensor tensor so that 3rd dimension becomes 1st
    x = tf.transpose(x, (2, 0, 1))
    # Reshape feature maps tensor into a matrix. First dimension is the number of filters/maps 
    features = tf.reshape(x, (tf.shape(x)[0], -1))
    # Compute the outer product of feature vectors with themselves
    gram = tf.matmul(features, tf.transpose(features))
    return gram


# The style loss is designed to maintain the style of the reference image in the generated image
# It is based on the gram matrices (which capture style) of feature maps from the style reference image
# and from the generated image
def style_loss(style, combination, height, width):
    S = gram_matrix(style)
    C = gram_matrix(combination)
    channels = 3
    size = height * width
    # Compute distance between Gram matrices of style and generated images
    return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels ** 2) * (size ** 2))
# The content loss is designed to maintain the "content" of the base image in the generated image
def content_loss(base, combination):
    return tf.reduce_sum(tf.square(combination - base))

# The total variation loss is designed to keep the generated image locally coherent
def total_variation_loss(x, height, width):
    a = tf.square(
        x[:, : height - 1, : width - 1, :] - x[:, 1:, : width - 1, :]
    )
    b = tf.square(
        x[:, : height - 1, : width - 1, :] - x[:, : height - 1, 1:, :]
    )
    return tf.reduce_sum(tf.pow(a + b, 1.25))

Model definition#

We use a VGG model pretrained on the ImageNet dataset.

# Using the convolutional base of VGG19, loaded with pre-trained ImageNet weights
vgg = vgg19.VGG19(weights="imagenet", include_top=False)

# Get the symbolic outputs of each "key" layer (we gave them unique names)
outputs_dict = dict([(layer.name, layer.output) for layer in vgg.layers])

# Set up a model that returns the activation values for every layer in VGG19 (as a dict)
feature_extractor = Model(inputs=vgg.inputs, outputs=outputs_dict)
vgg.summary()
Model: "vgg19"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, None, None, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, None, None, 64)    36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, None, None, 64)    0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, None, None, 128)   73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, None, None, 128)   147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, None, None, 128)   0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, None, None, 256)   295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_conv4 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, None, None, 256)   0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, None, None, 512)   1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_conv4 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, None, None, 512)   0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv4 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, None, None, 512)   0         
=================================================================
Total params: 20,024,384
Trainable params: 20,024,384
Non-trainable params: 0
_________________________________________________________________

Loss computation#

# List of layers to use for the style loss
style_layer_names = [
    "block1_conv1",
    "block2_conv1",
    "block3_conv1",
    "block4_conv1",
    "block5_conv1",
]
# The layer to use for the content loss
content_layer_name = "block5_conv2"

# Weights of the different loss components
total_variation_weight = 1e-6
style_weight = 1e-6
content_weight = 2.5e-8

def compute_loss(combination_image, base_image, style_reference_image, height, width):
    input_tensor = tf.concat(
        [base_image, style_reference_image, combination_image], axis=0
    )
    features = feature_extractor(input_tensor)

    # Initialize the loss
    loss = tf.zeros(shape=())

    # Add content loss
    layer_features = features[content_layer_name]
    base_image_features = layer_features[0, :, :, :]
    combination_features = layer_features[2, :, :, :]
    loss = loss + content_weight * content_loss(
        base_image_features, combination_features
    )
    # Add style loss
    for layer_name in style_layer_names:
        layer_features = features[layer_name]
        style_reference_features = layer_features[1, :, :, :]
        combination_features = layer_features[2, :, :, :]
        sl = style_loss(style_reference_features, combination_features, height, width)
        loss += (style_weight / len(style_layer_names)) * sl

    # Add total variation loss
    loss += total_variation_weight * total_variation_loss(combination_image, height, width)
    return loss
@tf.function
def compute_loss_and_grads(combination_image, base_image, style_reference_image, height, width):
    with tf.GradientTape() as tape:
        loss = compute_loss(combination_image, base_image, style_reference_image, height, width)
    grads = tape.gradient(loss, combination_image)
    return loss, grads

Training loop#

# Generated image height
gen_height = 400
# Compute generated width so that input and generated images have same scale
gen_width = int(width * gen_height / height)
print(f"Generated image dimensions: {gen_width, gen_height}")
Generated image dimensions: (599, 400)
optimizer = SGD(
    ExponentialDecay(
        initial_learning_rate=100.0, decay_steps=100, decay_rate=0.96
    )
)

base_image = preprocess_image(base_image_path, gen_height, gen_width)
style_reference_image = preprocess_image(style_reference_image_path, gen_height, gen_width)
combination_image = tf.Variable(preprocess_image(base_image_path, gen_height, gen_width))

# Training loop
n_epochs = 4000
for epoch in range(1, n_epochs + 1):
    loss, grads = compute_loss_and_grads(
        combination_image, base_image, style_reference_image, gen_height, gen_width
    )
    optimizer.apply_gradients([(grads, combination_image)])
    if epoch % 100 == 0:
        print(f"Epoch [{epoch}/{n_epochs}], loss: {loss:.2f}")

# Save final image
final_img = deprocess_image(combination_image.numpy(), gen_height, gen_width)
result_image_path = f"generated_image_epoch_{n_epochs}.png"
save_img(result_image_path, final_img)
Epoch [100/4000], loss: 11018.34
Epoch [200/4000], loss: 8514.22
Epoch [300/4000], loss: 7571.62
Epoch [400/4000], loss: 7064.01
Epoch [500/4000], loss: 6736.22
Epoch [600/4000], loss: 6501.81
Epoch [700/4000], loss: 6323.29
Epoch [800/4000], loss: 6181.49
Epoch [900/4000], loss: 6065.40
Epoch [1000/4000], loss: 5967.84
Epoch [1100/4000], loss: 5884.72
Epoch [1200/4000], loss: 5812.83
Epoch [1300/4000], loss: 5750.32
Epoch [1400/4000], loss: 5695.63
Epoch [1500/4000], loss: 5647.32
Epoch [1600/4000], loss: 5604.25
Epoch [1700/4000], loss: 5565.57
Epoch [1800/4000], loss: 5530.69
Epoch [1900/4000], loss: 5499.09
Epoch [2000/4000], loss: 5470.43
Epoch [2100/4000], loss: 5444.34
Epoch [2200/4000], loss: 5420.45
Epoch [2300/4000], loss: 5398.52
Epoch [2400/4000], loss: 5378.34
Epoch [2500/4000], loss: 5359.73
Epoch [2600/4000], loss: 5342.55
Epoch [2700/4000], loss: 5326.65
Epoch [2800/4000], loss: 5311.85
Epoch [2900/4000], loss: 5298.08
Epoch [3000/4000], loss: 5285.24
Epoch [3100/4000], loss: 5273.25
Epoch [3200/4000], loss: 5262.03
Epoch [3300/4000], loss: 5251.54
Epoch [3400/4000], loss: 5241.72
Epoch [3500/4000], loss: 5232.52
Epoch [3600/4000], loss: 5223.90
Epoch [3700/4000], loss: 5215.81
Epoch [3800/4000], loss: 5208.19
Epoch [3900/4000], loss: 5201.01
Epoch [4000/4000], loss: 5194.25

Generated image display#

# Show final generated image
display(Image(result_image_path))
../_images/654530bd45667dedaf57fcf392f81ee50fa0b490289f517e9ec9565f24f3e626.png