Machine Learning Notebook
  • Introduction
  • Supervised Learning
    • Basic Overview
      • Numpy Basics
      • Loss Functions
      • Evaluation Metrics
    • Convolutional Neural Network
      • Convolution Operation
      • Transpose Convolution Operation
      • Batch Normalization
      • Weight Initialization
      • Segmentation
    • Diffusion
      • KL Divergence
      • Variational Inference
      • Variational Autoencoder
      • Stable Diffusion Overview
      • Stable Diffusion Deep Dive
    • Naive Bayes
    • Decision Tree
      • Random Forest
      • Gradient Boosting
    • Natural Language Processing
      • Word2Vec
    • Search
      • Nearest Neighbor Search
    • Recommender
      • Singular Value Decomposition
      • Low Rank Matrix Factorization
      • Neural Collaborative Filtering
      • Sampling Bias Corrected Neural Modeling for Large Corpus Item Recommendations
      • Real-time Personalization using Embeddings for Search Ranking
      • Wide and Deep Learning for Recommender Systems
    • Recurrent Neural Network
      • Vanilla Recurrent Neural Network
      • LSTM Recurrent Neural Network
  • Unsupervised Learning
    • Clustering
      • Spectral Clustering
    • Reinforcement Learning
      • Deep Q Learning
      • Policy Gradients
  • SageMaker
    • Population Segmentation with PCA and KMeans
    • Fraud Detection with Linear Learner
    • Time Series Forecast with DeepAR
    • PyTorch Non-linear Classifier
Powered by GitBook
On this page
  • Inference as Optimization
  • ELBO
  • TensorFlow Example
  • Loss Explained
  1. Supervised Learning
  2. Diffusion

Variational Inference

PreviousKL DivergenceNextVariational Autoencoder

Last updated 2 years ago

Many probablistic models are difficult to train because it is difficult to perform inference in them. In the context of deep learning, we usually have a set of visible variables xxx and a set of latent variables zzz. The challenge of inference refers to the difficult problem of computing posterior p(z∣x)p(z \mid x)p(z∣x) or taking expectations with respect to it. Such operations are often necessary for tasks like maximum likelihood learning.

In deep learning, the posterior in general means given a visible variable, e.g. an input, what's the probability distribution of the latent variables e.g. hidden layer activation. So we have posterior of the following form.

p(z∣x)=p(x∣z)p(z)p(x)=pjoint(z,x)p(x)p(z\mid x) = \frac{p(x \mid z) p(z)}{p(x)} = \frac{p_{joint}(z, x)}{p(x)}p(z∣x)=p(x)p(x∣z)p(z)​=p(x)pjoint​(z,x)​

However, the problem is we can't compute the denominator a.k.a. the marginal probability p(x).

p(x)=∫−∞∞p(z,x)dzp(x) = \int_{-\infty}^{\infty} p(z, x) dzp(x)=∫−∞∞​p(z,x)dz

This integral requires us to sum over all possible values of z. There is no closed form solution of this integral over a joint distribution. We have to iterate through all possible values of z. It becomes unfeasible if z is high dimensional vector. Thus, we need a way to approximate this posterior p(z∣x)p(z \mid x)p(z∣x).

Inference as Optimization

Many approaches to confronting the problem of difficult inference make use of the observation that exact inference can be described as an optimization problem.

ELBO

Resources

Assume that we have a probabilistic model with parameters θ\thetaθ. It takes observed input xxx and generates latent output zzz. We want to update θ\thetaθ such that it maximizes the likelihood of our model producing the observed data distribution. We will never know the true distribution of our inputs because we are only given the snapshots. For example, if xs are pixels of images, it's impossible for us to know the true distribution of pixels across all images. At best, we can only come up with a modeled distribution that aims to maximizes the likelihood of observed data.

argmaxθΣi=1Nlog  pθ(xi)\text{argmax}_\theta \Sigma_{i=1}^{N} log\;p_\theta(x_i)argmaxθ​Σi=1N​logpθ​(xi​)

It is too difficult to calculate the distribution of xxx because we need to marginalize out zzz. We can get away by computing a lower bound on log  pθ(x)log\;p_\theta(x)logpθ​(x). This bound is called the evidence lower bound(ELBO).

L(x,θ,q)=log  pθ(x)−DKL[  q(z∣x)  ∥  pθ(z∣x)]where q is an arbitrary probability distribution over z\mathbb{L}(x, \theta, q) = log\;p_\theta(x) - D_{KL}\left[\; q(z \mid x) \;\|\; p_\theta(z \mid x) \right] \text{where q is an arbitrary probability distribution over z}L(x,θ,q)=logpθ​(x)−DKL​[q(z∣x)∥pθ​(z∣x)]where q is an arbitrary probability distribution over z

The difference between the log  pθ(x)log\;p_\theta(x)logpθ​(x) and L\mathbb{L}L is the KL divergence term. KL divergence is always non-negative. We can see that lower bound becomes the true distribution if we can minimize the KL divergence term to 0. It goes to zero when qqq is the same distribution as p(z∣x)p(z \mid x)p(z∣x).

We can re-arrange L\mathbb{L}L algebraically.

L(x,θ,q)=log  pθ(x)−DKL[  q(z∣x)  ∥  pθ(z∣x)]=log  pθ(x)−Ez∼qlogq(z∣x)p(z∣x)=log  pθ(x)−Ez∼qlog(q(z∣x)pθ(x)pθ(x,z))=log  pθ(x)−Ez∼q[log  q(z∣x)−log  pθ(x,z)+log  pθ(x)]\begin{align} \mathbb{L}(x, \theta, q) &= log\;p_\theta(x) - D_{KL}\left[\; q(z \mid x) \;\|\; p_\theta(z \mid x) \right] \\ &= log\;p_\theta(x) - \mathbb{E}_{z\sim q} log\frac{q(z \mid x)}{p(z \mid x)} \\ &= log\;p_\theta(x) - \mathbb{E}_{z\sim q} log\left( q(z\mid x) \frac{p_\theta(x)}{ p_\theta(x, z)} \right) \\ &= log\;p_\theta(x) - \mathbb{E}_{z\sim q} \left[ log\;q(z \mid x) - log\;p_\theta(x, z) + log\;p_\theta(x) \right] \end{align}L(x,θ,q)​=logpθ​(x)−DKL​[q(z∣x)∥pθ​(z∣x)]=logpθ​(x)−Ez∼q​logp(z∣x)q(z∣x)​=logpθ​(x)−Ez∼q​log(q(z∣x)pθ​(x,z)pθ​(x)​)=logpθ​(x)−Ez∼q​[logq(z∣x)−logpθ​(x,z)+logpθ​(x)]​​

Since we have expected value of log  pθ(x)log\;p_\theta(x)logpθ​(x) with respect to hhh, we can say this is taking the expected value of a constant. This will cancel out the first term. Here we have the final form.

L(x,θ,q)=−Ez∼q[log  q(z∣x)−log  pθ(x,z)]\mathbb{L}(x, \theta, q) = - \mathbb{E}_{z\sim q} \left[ log\;q(z \mid x) - log\;p_\theta(x, z)\right]L(x,θ,q)=−Ez∼q​[logq(z∣x)−logpθ​(x,z)]

The inference can be thought of as the procedure for finding the qqq that maximizes the lower bound L\mathbb{L}L. Whether the lower bound is tight (close approximation to p(x)p(x)p(x)) or loose, it's dependent on the choice of qqq we pick. L\mathbb{L}L is significantly easier to compute when we choose easy distribution qqq, e.g. a Gaussian distribution with mean and variance as the only parameters.

TensorFlow Example

Given that we have the ELBO definition from section above. We can define a lower bound for our loss as

L=log  pθ(x)≥Ez∼q[log  pθ(x,z)−log  q(z∣x)]≥Ez∼q[log  pθ(x∣z)+log  pθ(z)−log  q(z∣x)]\begin{align} L &= log\;p_\theta(x) \geq \mathbb{E}_{z\sim q}\left[ log\;p_\theta(x, z) - log\;q(z \mid x) \right] \\ &\geq \mathbb{E}_{z\sim q}\left[ log\;p_\theta(x \mid z) + log\;p_\theta(z) - log\;q(z \mid x) \right] \end{align}L​=logpθ​(x)≥Ez∼q​[logpθ​(x,z)−logq(z∣x)]≥Ez∼q​[logpθ​(x∣z)+logpθ​(z)−logq(z∣x)]​​

This is because joint distribution can be written as conditional probability.

pθ(x,z)=pθ(x∣z)pθ(z)p_\theta(x, z) = p_\theta(x \mid z) p_\theta(z)pθ​(x,z)=pθ​(x∣z)pθ​(z)

In practice, we cannot compute the exact expected value of the expression. We will rely on a single sample per forward propagation to perform a Monte Carlo estimate.

L=log  p(x∣z)+log  p(z)−log  q(z∣x)L = log\;p(x \mid z) + log\;p(z) - log\;q(z\mid x)L=logp(x∣z)+logp(z)−logq(z∣x)

This is quite similar to the idea in cross entropy where we try to learn the distribution of labels by measuring the difference target and predicted distribution. We cannot know the exact expected value of target distribution so we approximate it using the same technique.

import tensorflow as tf
import numpy as np

from tensorflow.keras import layers, Sequential


class CVAE(tf.keras.Model):
    """Convolutional Variational Autoencoder."""
    def __init__(self, latent_dim):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = Sequential([
            layers.InputLayer(input_shape=(28, 28, 1)),
            layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
            layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
            layers.Flatten(),
            # No activation
            layers.Dense(latent_dim + latent_dim),
        ])

        self.decoder = tf.keras.Sequential([
            layers.InputLayer(input_shape=(latent_dim,)),
            layers.Dense(units=7*7*32, activation=tf.nn.relu),
            layers.Reshape(target_shape=(7, 7, 32)),
            layers.Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', activation='relu'),
            layers.Conv2DTranspose(filters=32, kernel_size=3, strides=2, padding='same', activation='relu'),
            # No activation
            layers.Conv2DTranspose(filters=1, kernel_size=3, strides=1, padding='same'),
        ])

    @tf.function
    def sample(self, z=None):
        if z is None:
            z = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(z, apply_sigmoid=True)

    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean

    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

The VAE has encoder and decoder. The encoder consumes an observable variable x vector and produces a latent variable z via reparametrization. We choose Gaussian distribution to be our q(z∣x)q(z \mid x)q(z∣x) distribution.

Why Reparametrization?

z is supposed to be sampled from a Gaussian distribution but gradient cannot flow through a tf.random.normal function. We need to reparametrize z such that the gradient is not dependent on tf.random.normal.

We can generate a unit Gaussian from tf.random.normal and redefine z as follows.

z=μ+σ⋅ϵz = \mu + \sigma \cdot \epsilonz=μ+σ⋅ϵ

where ϵ\epsilonϵ is sampled from a unit Gaussian distribution.

def log_normal_pdf(sample, mean, logvar, raxis=1):
    log_2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(-.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log_2pi), axis=raxis)


def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    log_p_x_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
    log_p_z = log_normal_pdf(z, 0., 0.)
    log_q_z_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(log_p_x_z + log_p_z - log_q_z_x)

This example uses logvar for numerical stability. We can get rid of the log from tf.exp(logvar / 2) because

log(variance)=log(σ2)=2  log(σ)log(\text{variance}) = log(\sigma^2) = 2\;log(\sigma)log(variance)=log(σ2)=2log(σ)
σ=e(log(variance)/2)\sigma = e^{(log(\text{variance}) / 2)}σ=e(log(variance)/2)
latent_dim = 10
model = CVAE(latent_dim)
x = np.random.randn(1, 28, 28, 1)
print('Shape of input x', x.shape)
mean, logvar = model.encode(x)
z = model.reparameterize(mean, logvar)
print('Shape of latent vector z', z.shape)
y = model.decode(z)
print('Shape of output y', y.shape)

print(tf.reduce_sum(z))
tf.exp(log_normal_pdf(z, 0., 0.))
Shape of input x (1, 28, 28, 1)
Shape of latent vector z (1, 10)
Shape of output y (1, 28, 28, 1)
tf.Tensor(-0.26499367, shape=(), dtype=float32)





<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.3094783e-06], dtype=float32)>

Loss Explained

Let xxx be our input and x^\hat{x}x^ be our output. Our objective is to set x≈x^x \approx \hat{x}x≈x^

logp_x_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])

log  p(x∣z)log\;p(x\mid z)logp(x∣z) is equivalent to asking how far apart is the distribution of output x^\hat{x}x^ given that we have z (which comes from input xxx) away from the distribution of xxx. This is known as the reconstruction loss.

log_p_z = log_normal_pdf(z, 0., 0.)
log_q_z_x = log_normal_pdf(z, mean, logvar)

When we set mean=0 and logvar=0 for p(z)=p(z) = p(z)=log_normal_pdf(z, 0, 0). We will obtain mean=0 and var=1 which is saying that z is sampled from a standard Gaussian probability density function.

If the loss is minimized, model p(z)p(z)p(z) will match the enforced Gaussian distribution q(z∣x)q(z \mid x)q(z∣x). We selectively chose a distribution for $q$. This loss minimization is encouraging the model to learn the selected distribution qqq.

Another way to write the KL divergence loss without the log normal probability density function is

kl_loss = -0.5 * (1 + logvar - mean**2 - tf.exp(logvar)

This example is taken from .

These two terms reprsent the KL divergence. It asks how far apart is model's encoder output z distribution away from the expected Gaussian distribution of z. This divergence is expressed in terms of log density ratio which derivation can be found on .

Variational Inference, Part 1
Variational Inference, Part 2
Variational Inference, Part 3
Deep Learning Book
Convolutional Variational Autoencoder
Density Ratio Estimation for KL Divergence Minimzation Between Implicit Distributions