Generative Adversarial Networks

Summary

  • Introduction and basic example

  • Deep Convolutional GANs

  • The GAN ecosystem

Introduction and basic example

GAN in a nutshell

  • First described in Generative Adversarial Nets by Goodfellow et al. (2014).

  • Simultaneously train two models:

    • One tries to generate realistic data.

    • The other tries to discriminate between real and generated data.

  • Each model is trained to best the other.

  • NIPS 2016 Tutorial.

GAN overview

GAN process

Training process

Multiple iterations, each consisting of:

  • Phase 1

    • The generator creates images from random noise (typically a Gaussian distribution).

    • Generated images are mixed with real ones.

    • The discriminator is trained to distinguish fake vs. real images.

  • Phase 2

    • Another batch of images is produced by the generator.

    • The generator is trained so that the discriminator classifies these generated images as real.

Example: generating fashion images with a simple GAN

(Heavily inspired by Chapter 17 of Hands-On Machine Learning by Aurélien Géron)

Environment setup

import platform

print(f"Python version: {platform.python_version()}")
assert platform.python_version_tuple() >= ("3", "6")

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
Python version: 3.7.5
# Setup plots
%matplotlib inline
plt.rcParams["figure.figsize"] = 10, 8
%config InlineBackend.figure_format = 'retina'
import tensorflow as tf

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {tf.keras.__version__}")

from tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    Dense,
    Flatten,
    Reshape,
    BatchNormalization,
    Conv2D,
    Conv2DTranspose,
    LeakyReLU,
    Dropout
)
from tensorflow.keras.datasets import fashion_mnist
TensorFlow version: 2.3.1
Keras version: 2.4.0

Data loading

# Load training inputs from the Fashion-MNIST dataset
(train_images, _), (_, _) = fashion_mnist.load_data()

# Change pixel values from (0, 255) to (0, 1)
x_train = train_images.astype("float32") / 255

print(f"x_train: {x_train.shape}")
x_train: (60000, 28, 28)
def plot_multiple_images(images, n_cols=None):
    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1
    if images.shape[-1] == 1:
        images = np.squeeze(images, axis=-1)
    plt.figure(figsize=(n_cols * 1.5, n_rows * 1.5))
    for index, image in enumerate(images):
        plt.subplot(n_rows, n_cols, index + 1)
        plt.imshow(image, cmap="binary")
        plt.axis("off")

Model definition

# Input dimensionality for generator
codings_size = 30

generator = Sequential(
    [
        Dense(100, activation="selu", input_shape=(codings_size,)),
        Dense(150, activation="selu"),
        Dense(28 * 28, activation="sigmoid"),
        Reshape((28, 28)),
    ],
    name="generator"
)
discriminator = Sequential(
    [
        Flatten(input_shape=(28, 28)),
        Dense(150, activation="selu"),
        Dense(100, activation="selu"),
        Dense(1, activation="sigmoid"),
    ],
    name="discriminator"
)
gan = Sequential([generator, discriminator])

# Print GAN model summary
gan.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
generator (Sequential)       (None, 28, 28)            136634    
_________________________________________________________________
discriminator (Sequential)   (None, 1)                 132951    
=================================================================
Total params: 269,585
Trainable params: 269,585
Non-trainable params: 0
_________________________________________________________________

Model training

# The generator is trained through the GAN model: no need to compile it

discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")

# The trainable attribute is taken into account only when compiling a model
# Discriminator weights will be updated only when it will be trained on its own
# They will be frozen when the whole GAN model will be trained
discriminator.trainable = False

gan.compile(loss="binary_crossentropy", optimizer="rmsprop")
def train_gan(gan, dataset, batch_size, codings_size, n_epochs=50):
    generator, discriminator = gan.layers
    for epoch in range(n_epochs):
        print(f"Epoch [{epoch+1}/{n_epochs}]...")
        for x_batch in dataset:
            # Phase 1 - training the discriminator
            noise = tf.random.normal(shape=(batch_size, codings_size))
            generated_images = generator(noise)
            # Gather an equal number of generated (y=0) and real (y=1) images
            x_discr = tf.concat([generated_images, x_batch], axis=0)
            y_discr = tf.constant([[0.0]] * batch_size + [[1.0]] * batch_size)
            # https://stackoverflow.com/a/49100617
            discriminator.train_on_batch(x_discr, y_discr)
            
            # Phase 2 - training the generator
            noise = tf.random.normal(shape=(batch_size, codings_size))
            # Generated images should be labeled "real" by the discriminator
            y_gen = tf.constant([[1.0]] * batch_size)
            # Update only the generator weights (see above)
            gan.train_on_batch(noise, y_gen)
    print("Training complete!")
batch_size = 32

# Load images in batches
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)

# Train the GAN model
train_gan(gan, dataset, batch_size, codings_size, n_epochs=1)
Epoch [1/1]...
Training complete!

Generating new images

noise = tf.random.normal(shape=(batch_size, codings_size))
generated_images = generator(noise)
plot_multiple_images(generated_images, 8)
../_images/generative_adversarial_networks_23_0.png

GAN specificities and gotchas

  • A GAN is a dynamic system that evolves at each training step.

  • Interestingly, the generator never sees images froms the training set directly: all it gets are the gradients coming back from the discriminator.

  • Training can be tricky with several pitfalls: noisy generated data, vanishing gradients, lack of diversity in generated data, domination of one side…

  • GAN convergence theory is an active area of research.

  • GAN Open Questions.

Deep Convolutional GANs

Architecture

A few months after the original GAN paper came out, a team succeeded in building GANs based on deeper convolutional models, or DCGANs (paper).

General rules for implemeting DCGANs are:

  • Replace pooling layers with strided convolutions (in the discriminator) and transposed convolutions (in the generator).

  • Add batch normalization.

  • Use ReLU in the generator except for the output layer, which should use tanh.

  • Use leaky ReLU in the discriminator.

The leaky ReLU activation function

Leaky version of a Rectified Linear Unit. It allows a small gradient when the unit is not active to prevent the “dying ReLU” problem (neurons always outputting zero).

\[LeakyReLU_\alpha(z) = max(\alpha z,z)\]

Leaky ReLU

Example: generating fashion images with a DCGAN

(Heavily inspired by Chapter 17 of Hands-On Machine Learning by Aurélien Géron)

Generator definition

# Input dimensionality for generator
codings_size = 100

dcgan_generator = Sequential(
    [
        Dense(7 * 7 * 128, input_shape=(codings_size,)),
        Reshape((7, 7, 128)),
        BatchNormalization(),
        Conv2DTranspose(
            64, kernel_size=5, strides=2, padding="same", activation="selu"
        ),
        BatchNormalization(),
        Conv2DTranspose(1, kernel_size=5, strides=2, padding="same", activation="tanh"),
    ],
    name="generator",
)
dcgan_generator.summary()
Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_23 (Dense)             (None, 6272)              633472    
_________________________________________________________________
reshape_12 (Reshape)         (None, 7, 7, 128)         0         
_________________________________________________________________
batch_normalization_20 (Batc (None, 7, 7, 128)         512       
_________________________________________________________________
conv2d_transpose_20 (Conv2DT (None, 14, 14, 64)        204864    
_________________________________________________________________
batch_normalization_21 (Batc (None, 14, 14, 64)        256       
_________________________________________________________________
conv2d_transpose_21 (Conv2DT (None, 28, 28, 1)         1601      
=================================================================
Total params: 840,705
Trainable params: 840,321
Non-trainable params: 384
_________________________________________________________________

Discriminator definition

dcgan_discriminator = Sequential(
    [
        Conv2D(
            64,
            kernel_size=5,
            strides=2,
            padding="same",
            activation=LeakyReLU(0.2),
            input_shape=(28, 28, 1),
        ),
        Dropout(0.4),
        Conv2D(
            128, kernel_size=5, strides=2, padding="same", activation=LeakyReLU(0.2)
        ),
        Dropout(0.4),
        Flatten(),
        Dense(1, activation="sigmoid"),
    ],
    name="discriminator",
)
dcgan_discriminator.summary()
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_11 (Conv2D)           (None, 14, 14, 64)        1664      
_________________________________________________________________
dropout_10 (Dropout)         (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 7, 7, 128)         204928    
_________________________________________________________________
dropout_11 (Dropout)         (None, 7, 7, 128)         0         
_________________________________________________________________
flatten_6 (Flatten)          (None, 6272)              0         
_________________________________________________________________
dense_24 (Dense)             (None, 1)                 6273      
=================================================================
Total params: 212,865
Trainable params: 212,865
Non-trainable params: 0
_________________________________________________________________

Model definition

dcgan = Sequential([dcgan_generator, dcgan_discriminator])

dcgan.summary()
Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
generator (Sequential)       (None, 28, 28, 1)         840705    
_________________________________________________________________
discriminator (Sequential)   (None, 1)                 212865    
=================================================================
Total params: 1,053,570
Trainable params: 1,053,186
Non-trainable params: 384
_________________________________________________________________

Model training

# The generator is trained through the GAN model: no need to compile it

dcgan_discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")

# The trainable attribute is taken into account only when compiling a model
# Discriminator weights will be updated only when it will be trained on its own
# They will be frozen when the whole GAN model will be trained
dcgan_discriminator.trainable = False

dcgan.compile(loss="binary_crossentropy", optimizer="rmsprop")
# Reshape and rescale input into a 4D tensor with values between -1 and 1
# Needed because tanh outputs are in this range
x_train_dcgan = x_train.reshape(-1, 28, 28, 1) * 2. - 1.
print(f"x_train_dcgan: {x_train_dcgan.shape}")
x_train_dcgan: (60000, 28, 28, 1)
batch_size = 32

# Load images in batches
dataset = tf.data.Dataset.from_tensor_slices(x_train_dcgan)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)

# Train the DCGAN model
train_gan(dcgan, dataset, batch_size, codings_size, n_epochs=5)
Epoch [1/5]...
Epoch [2/5]...
Epoch [3/5]...
Epoch [4/5]...
Epoch [5/5]...
Training complete!

Generating new images

noise = tf.random.normal(shape=(batch_size, codings_size))
generated_images = dcgan_generator(noise)
plot_multiple_images(generated_images, 8)
../_images/generative_adversarial_networks_42_0.png

The GAN ecosystem

A GAN landscape

GAN flavours

GAN progress on face generation

GAN progress from 2014 to 2018

GAN use cases: not just images!