Variational Inference
Last updated
Last updated
Many probablistic models are difficult to train because it is difficult to perform inference in them. In the context of deep learning, we usually have a set of visible variables and a set of latent variables . The challenge of inference refers to the difficult problem of computing posterior or taking expectations with respect to it. Such operations are often necessary for tasks like maximum likelihood learning.
In deep learning, the posterior in general means given a visible variable, e.g. an input, what's the probability distribution of the latent variables e.g. hidden layer activation. So we have posterior of the following form.
However, the problem is we can't compute the denominator a.k.a. the marginal probability p(x)
.
This integral requires us to sum over all possible values of z
. There is no closed form solution of this integral over a joint distribution. We have to iterate through all possible values of z
. It becomes unfeasible if z
is high dimensional vector. Thus, we need a way to approximate this posterior .
Many approaches to confronting the problem of difficult inference make use of the observation that exact inference can be described as an optimization problem.
Resources
This example is taken from Convolutional Variational Autoencoder.
Given that we have the ELBO definition from section above. We can define a lower bound for our loss as
This is because joint distribution can be written as conditional probability.
In practice, we cannot compute the exact expected value of the expression. We will rely on a single sample per forward propagation to perform a Monte Carlo estimate.
This is quite similar to the idea in cross entropy where we try to learn the distribution of labels by measuring the difference target and predicted distribution. We cannot know the exact expected value of target distribution so we approximate it using the same technique.
Why Reparametrization?
z
is supposed to be sampled from a Gaussian distribution but gradient cannot flow through a tf.random.normal
function. We need to reparametrize z
such that the gradient is not dependent on tf.random.normal
.
We can generate a unit Gaussian from tf.random.normal
and redefine z
as follows.
This example uses logvar
for numerical stability. We can get rid of the log
from tf.exp(logvar / 2)
because
These two terms reprsent the KL divergence. It asks how far apart is model's encoder output z
distribution away from the expected Gaussian distribution of z
. This divergence is expressed in terms of log density ratio which derivation can be found on Density Ratio Estimation for KL Divergence Minimzation Between Implicit Distributions.
Another way to write the KL divergence loss without the log normal probability density function is
Assume that we have a probabilistic model with parameters . It takes observed input and generates latent output . We want to update such that it maximizes the likelihood of our model producing the observed data distribution. We will never know the true distribution of our inputs because we are only given the snapshots. For example, if x
s are pixels of images, it's impossible for us to know the true distribution of pixels across all images. At best, we can only come up with a modeled distribution that aims to maximizes the likelihood of observed data.
It is too difficult to calculate the distribution of because we need to marginalize out . We can get away by computing a lower bound on . This bound is called the evidence lower bound(ELBO).
The difference between the and is the KL divergence term. KL divergence is always non-negative. We can see that lower bound becomes the true distribution if we can minimize the KL divergence term to 0. It goes to zero when is the same distribution as .
We can re-arrange algebraically.
Since we have expected value of with respect to , we can say this is taking the expected value of a constant. This will cancel out the first term. Here we have the final form.
The inference can be thought of as the procedure for finding the that maximizes the lower bound . Whether the lower bound is tight (close approximation to ) or loose, it's dependent on the choice of we pick. is significantly easier to compute when we choose easy distribution , e.g. a Gaussian distribution with mean and variance as the only parameters.
The VAE has encoder and decoder. The encoder consumes an observable variable x
vector and produces a latent variable z
via reparametrization. We choose Gaussian distribution to be our distribution.
where is sampled from a unit Gaussian distribution.
Let be our input and be our output. Our objective is to set
is equivalent to asking how far apart is the distribution of output given that we have z
(which comes from input ) away from the distribution of . This is known as the reconstruction loss.
When we set mean=0
and logvar=0
for log_normal_pdf(z, 0, 0)
. We will obtain mean=0
and var=1
which is saying that z
is sampled from a standard Gaussian probability density function.
If the loss is minimized, model will match the enforced Gaussian distribution . We selectively chose a distribution for $q$. This loss minimization is encouraging the model to learn the selected distribution .