Variational Autoencoder

from IPython import display

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import time
We will model each pixel with a Bernoulli distribution. Each pixel may take a value either 0 or 1.

def preprocess_images(images):
    images = images.reshape((images.shape[0], 28, 28, 1)) / 255.
    return np.where(images > .5, 1.0, 0.0).astype('float32')

(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)

print(train_images.shape, train_images.dtype)
(60000, 28, 28, 1) float32
train_size = 60000
batch_size = 32
test_size = 10000
train_dataset =
test_dataset =
Reparameterization Trick


class ConvVAE(tf.keras.Model):
    """Convolutional Variational Autoencoder"""
    def __init__(self, latent_dim):
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
            tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
            # No activation
            tf.keras.layers.Dense(latent_dim + latent_dim),

        self.decoder = tf.keras.Sequential([
            tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),
            tf.keras.layers.Reshape(target_shape=(7, 7, 32)),
            tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', activation='relu'),
            tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=3, strides=2, padding='same',activation='relu'),
            # No activation
            tf.keras.layers.Conv2DTranspose(filters=1, kernel_size=3, strides=1, padding='same'),
    def sample(self, z=None):
        if z is None:
            z = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(z, apply_sigmoid=True)
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return mean + tf.exp(logvar * 0.5) * eps

    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits


VAE is trained by maximizing the evidence lower bound on the marginal log-likelihood.

In practice, optimize the single sample Monte Carlo estimate of this expectation.

optimizer = tf.keras.optimizers.Adam(1e-4)

def log_normal_pdf(sample, mean, logvar, raxis=1):
    """Use log to get rid of the exponential in Gaussian"""
    log2pi = tf.math.log(2 * np.pi)
    return tf.reduce_sum(-0.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi), axis=raxis)

def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    log_p_x_z = -tf.reduce_sum(cross_entropy, axis=[1, 2, 3])
    log_p_z = log_normal_pdf(z, 0., 0.)
    log_q_z_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(log_p_x_z + log_p_z - log_q_z_x)

def train_step(model, x, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))


epochs = 10
# set the dimensionality of the latent space to a 2D plane for visualization later
latent_dim = 2
num_examples_to_generate = 16

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement.
random_vector_for_generation = tf.random.normal(shape=[num_examples_to_generate, latent_dim])
model = ConvVAE(latent_dim)

Model: "sequential"
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 13, 13, 32)        320       
 conv2d_1 (Conv2D)           (None, 6, 6, 64)          18496     
 flatten (Flatten)           (None, 2304)              0         
 dense (Dense)               (None, 4)                 9220      
Total params: 28,036
Trainable params: 28,036
Non-trainable params: 0
Model: "sequential_1"
 Layer (type)                Output Shape              Param #   
 dense_1 (Dense)             (None, 1568)              4704      
 reshape (Reshape)           (None, 7, 7, 32)          0         
 conv2d_transpose (Conv2DTra  (None, 14, 14, 64)       18496     
 conv2d_transpose_1 (Conv2DT  (None, 28, 28, 32)       18464     
 conv2d_transpose_2 (Conv2DT  (None, 28, 28, 1)        289       
Total params: 41,953
Trainable params: 41,953
Non-trainable params: 0
def generate_and_save_images(model, epoch, test_sample):
    mean, logvar = model.encode(test_sample)
    z = model.reparameterize(mean, logvar)
    predictions = model.sample(z)
    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0], cmap='gray')

    # tight_layout minimizes the overlap between 2 sub-plots
# Pick a sample of the test set for generating output images
assert batch_size >= num_examples_to_generate
for test_batch in test_dataset.take(1):
    test_sample = test_batch[0:num_examples_to_generate, :, :, :]
generate_and_save_images(model, 0, test_sample)
for epoch in range(1, epochs + 1):
    start_time = time.time()
    for train_x in train_dataset:
        train_step(model, train_x, optimizer)
    end_time = time.time()

    loss = tf.keras.metrics.Mean()
    for test_x in test_dataset:
        loss(compute_loss(model, test_x))
    elbo = -loss.result()
    print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'.format(epoch, elbo, end_time - start_time))
    generate_and_save_images(model, epoch, test_sample)
Epoch: 10, Test set ELBO: -156.0026397705078, time elapse for current epoch: 10.945546627044678
def display_image(epoch_no):
plt.axis('off')  # Display images
(-0.5, 399.5, 399.5, -0.5)

Visualize Training Progression

anim_file = 'cvae.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
    filenames = glob.glob('image*.png')
    filenames = sorted(filenames)
    for filename in filenames:
        image = imageio.imread(filename)
    image = imageio.imread(filename)
import tensorflow_docs.vis.embed as embed

Visualize Latent Space

We defined our latent dimension to be 2, which means are we can simply visualize the z vectors as a 2D image. Now we choose a latent space region. We generate a Gaussian and create z vectors from it.

norm = tfp.distributions.Normal(0, 1)
print("Probs of Quantiles", norm.prob(norm.quantile(np.linspace(0., 1., 10))))
print("Probs of Lin Space", norm.prob(np.linspace(0., 1., 10)))
Probs of Quantiles tf.Tensor(
[0.         0.18939508 0.29780126 0.3635998  0.39506775 0.39506775
 0.36359975 0.29780126 0.18939508 0.        ], shape=(10,), dtype=float32)
Probs of Lin Space tf.Tensor(
[0.3989423  0.39648727 0.3892125  0.37738323 0.36142382 0.3418923
 0.31944802 0.29481488 0.26874286 0.24197073], shape=(10,), dtype=float32)

The quantile function returns value of the random variable X such that the probability of the variable being less than or equal to that value equals the given probability.

For example, norm.quantile(0.1) is equivalent to say find me the x value that has probability of 0.1 that my sample from norm will be less than or equal to x value.

x = np.linspace(-5, 5, 100)
plt.plot(x, norm.prob(x))
plt.plot(x, norm.prob(norm.quantile(np.linspace(0., 1., 100))))
[<matplotlib.lines.Line2D at 0x7f500fb32f10>]
def plot_latent_images(model, n, digit_size=28):
    norm = tfp.distributions.Normal(0, 1)
    grid_x = norm.quantile(np.linspace(0.05, 0.95, n))
    grid_y = norm.quantile(np.linspace(0.05, 0.95, n))
    image_width = digit_size*n
    image_height = image_width
    image = np.zeros((image_height, image_width))
    # Construct a giant image that represent the latent space.
    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z = np.array([[xi, yi]])
            x_decoded = model.sample(z)
            digit = tf.reshape(x_decoded[0], (digit_size, digit_size))
            image[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size] = digit.numpy()

    plt.figure(figsize=(10, 10))
    plt.imshow(image, cmap='Greys_r')
plot_latent_images(model, 20)

