Autoencoders#

Summary#

  • Introduction

  • Stacked autoencoders

  • Variational autoencoders

Introduction#

Autoencoders in a nutshell#

Autoencoders are a type of network that aims to encode an input in a latent space and then decode it back.

Autoencoder principle

Autoencoder architecture#

An autoencoder is composed of an encoding function \(E(x)\) outputting a latent representation \(s\), a decoding function \(D(s)\) computing the reconstructed output \(o\) and a loss function \(\mathcal{L}\) measuring the distance between original and reconstructed data.

Autoencoder architecture

What’s the point?#

An autoencoder learns to copy its inputs to its outputs under some constraints: for example, limiting the dimensionality of the latent space, or adding noise to the inputs.

To do its job, it must find efficient ways of representing the data: for example, learning the most relevant features and dropping the others.

Latent space properties#

The latent space learned by an autoencoder may have interesting properties.

For example, in a latent space of images of faces, there may be a smile vector \(s\), such that if latent point \(z\) is the representation of a certain face, then latent point \(z + s\) is the representation of the same face, smiling. It becomes possible to add a smile to existing images.

Smile vector

Autoencoders applications#

  • Dimensionality reduction

  • Denoising

  • Data generation

  • Anomaly detection

    • Network is trained on normal samples only.

    • Outliers will induce a high reconstruction loss and will be flagged as anomalies.

Example: performing dimensionality reduction with a linear autoencoder#

(Heavily inspired by Chapter 17 of Hands-On Machine Learning by Aurélien Géron)

Environment setup#

import platform

print(f"Python version: {platform.python_version()}")
assert platform.python_version_tuple() >= ("3", "6")

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
Python version: 3.7.5
# Setup plots
%matplotlib inline
plt.rcParams["figure.figsize"] = 10, 8
%config InlineBackend.figure_format = 'retina'
sns.set()
import tensorflow as tf

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {tf.keras.__version__}")

from tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape, Layer, Input
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.metrics import binary_accuracy
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
TensorFlow version: 2.3.1
Keras version: 2.4.0
Hide code cell source
def plot_loss(history):
    """Plot training loss
    Takes a Keras History object as parameter"""

    loss = history.history["loss"]
    epochs = range(1, len(loss) + 1)

    plt.figure(figsize=(10, 10))

    plt.subplot(2, 1, 1)
    plt.plot(epochs, loss, ".--", label="Training loss")
    final_loss = loss[-1]
    title = "Training loss: {:.4f}".format(final_loss)
    plt.ylabel("Loss")
    if "val_loss" in history.history:
        val_loss = history.history["val_loss"]
        plt.plot(epochs, val_loss, "o-", label="Validation loss")
        final_val_loss = val_loss[-1]
        title += ", Validation loss: {:.4f}".format(final_val_loss)
    plt.title(title)
    plt.legend()

3D data generation#

np.random.seed(4)

def generate_3d_data(m, w1=0.1, w2=0.3, noise=0.1):
    angles = np.random.rand(m) * 3 * np.pi / 2 - 0.5
    data = np.empty((m, 3))
    data[:, 0] = np.cos(angles) + np.sin(angles)/2 + noise * np.random.randn(m) / 2
    data[:, 1] = np.sin(angles) * 0.7 + noise * np.random.randn(m) / 2
    data[:, 2] = data[:, 0] * w1 + data[:, 1] * w2 + noise * np.random.randn(m)
    return data

x_train_3d = generate_3d_data(60)
x_train_3d = x_train_3d - x_train_3d.mean(axis=0, keepdims=0)
print(f"x_train: {x_train_3d.shape}")
x_train: (60, 3)
# Plot 3D data
fig = px.scatter_3d(x_train_3d, x=0, y=1, z=2, labels={"0": "x1", "1": "x2", "2": "x3"})
fig.show()