Many probablistic models are difficult to train because it is difficult to perform inference in them. In the context of deep learning, we usually have a set of visible variables x and a set of latent variables z. The challenge of inference refers to the difficult problem of computing posterior p(z∣x) or taking expectations with respect to it. Such operations are often necessary for tasks like maximum likelihood learning.
In deep learning, the posterior in general means given a visible variable, e.g. an input, what's the probability distribution of the latent variables e.g. hidden layer activation. So we have posterior of the following form.
p(z∣x)=p(x)p(x∣z)p(z)=p(x)pjoint(z,x)
However, the problem is we can't compute the denominator a.k.a. the marginal probability p(x).
p(x)=∫−∞∞p(z,x)dz
This integral requires us to sum over all possible values of z. There is no closed form solution of this integral over a joint distribution. We have to iterate through all possible values of z. It becomes unfeasible if z is high dimensional vector. Thus, we need a way to approximate this posterior p(z∣x).
Inference as Optimization
Many approaches to confronting the problem of difficult inference make use of the observation that exact inference can be described as an optimization problem.
Assume that we have a probabilistic model with parameters θ. It takes observed input x and generates latent output z. We want to update θ such that it maximizes the likelihood of our model producing the observed data distribution. We will never know the true distribution of our inputs because we are only given the snapshots. For example, if xs are pixels of images, it's impossible for us to know the true distribution of pixels across all images. At best, we can only come up with a modeled distribution that aims to maximizes the likelihood of observed data.
argmaxθΣi=1Nlogpθ(xi)
It is too difficult to calculate the distribution of x because we need to marginalize out z. We can get away by computing a lower bound on logpθ(x). This bound is called the evidence lower bound(ELBO).
L(x,θ,q)=logpθ(x)−DKL[q(z∣x)∥pθ(z∣x)]where q is an arbitrary probability distribution over z
The difference between the logpθ(x) and L is the KL divergence term. KL divergence is always non-negative. We can see that lower bound becomes the true distribution if we can minimize the KL divergence term to 0. It goes to zero when q is the same distribution as p(z∣x).
Since we have expected value of logpθ(x) with respect to h, we can say this is taking the expected value of a constant. This will cancel out the first term. Here we have the final form.
L(x,θ,q)=−Ez∼q[logq(z∣x)−logpθ(x,z)]
The inference can be thought of as the procedure for finding the q that maximizes the lower bound L. Whether the lower bound is tight (close approximation to p(x)) or loose, it's dependent on the choice of q we pick. L is significantly easier to compute when we choose easy distribution q, e.g. a Gaussian distribution with mean and variance as the only parameters.
This is because joint distribution can be written as conditional probability.
pθ(x,z)=pθ(x∣z)pθ(z)
In practice, we cannot compute the exact expected value of the expression. We will rely on a single sample per forward propagation to perform a Monte Carlo estimate.
L=logp(x∣z)+logp(z)−logq(z∣x)
This is quite similar to the idea in cross entropy where we try to learn the distribution of labels by measuring the difference target and predicted distribution. We cannot know the exact expected value of target distribution so we approximate it using the same technique.
The VAE has encoder and decoder. The encoder consumes an observable variable x vector and produces a latent variable z via reparametrization. We choose Gaussian distribution to be our q(z∣x) distribution.
Why Reparametrization?
z is supposed to be sampled from a Gaussian distribution but gradient cannot flow through a tf.random.normal function. We need to reparametrize z such that the gradient is not dependent on tf.random.normal.
We can generate a unit Gaussian from tf.random.normal and redefine z as follows.
z=μ+σ⋅ϵ
where ϵ is sampled from a unit Gaussian distribution.
logp(x∣z) is equivalent to asking how far apart is the distribution of output x^ given that we have z (which comes from input x) away from the distribution of x. This is known as the reconstruction loss.
When we set mean=0 and logvar=0 for p(z)=log_normal_pdf(z, 0, 0). We will obtain mean=0 and var=1 which is saying that z is sampled from a standard Gaussian probability density function.
If the loss is minimized, model p(z) will match the enforced Gaussian distribution q(z∣x). We selectively chose a distribution for $q$. This loss minimization is encouraging the model to learn the selected distribution q.
Another way to write the KL divergence loss without the log normal probability density function is