VAEs from a generative perspective
This is my attempt to tell a story1 about how you might invent variational autoencoders (VAEs). There are already great introductory posts doing this and if you haven’t seen VAEs before, I would strongly recommend you start with one of those. These introductions often start with autoencoders and then extend them to VAEs. In contrast, we will start by asking ourselves how to generate new data that matches a training distribution and then motivate VAEs from there. We won’t assume an autoencoder-like architecture a priori, instead it will arise naturally from this motivation.
Of course just this motivation of generating new samples given a training distribution won’t uniquely lead to VAEs — after all, there are other good options for generative models. So at some points we will need to make design decisions but hopefully they won’t come out of the blue.
One pedagogical note before we start: if this derivation of VAEs seems unnecessarily long and convoluted, that’s because it is. The goal is not to arrive at the VAE framework as quickly as possible, but rather to make each step seem natural and to avoid any unmotivated “magical” jumps. It’s probably best if you forget for a moment what you know about VAEs, in particular that they consist of an encoder and a decoder. We will get there at the very end but initially this preconception might just be confusing.
Generative models
The goal in generative modeling is the following: we have some family of probability distributions . Given a set of training examples (assumed to be i.i.d.), we now want to pick the distribution from our family that maximizes the likelihood . Equivalently, we can maximize the log-likelihood:
For now, we will consider the simper special case where we only have a single datapoint and want to maximize (we will get back to the general case at the end).
Optimizing over a family of probability distributions is very abstract. To turn this into a problem we can actually solve numerically, we will use a parameterized family and optimize over the parameter . should be differentiable with respect to , then we can at least find a local optimum for our problem using gradient ascent.
This still leaves the question which parameterized family we should use. This is the largest crossroads we’ll face in this post: there are many good options to choose from. The challenge we face is to find a good trade-off between having a flexible family of distributions and keeping the number of parameters manageable. For example, if takes on discrete values, we could in principle use the full categorical distribution over all possible values of . This would be as flexible as possible but the number of parameters might be huge. If describes a binary image, there are already possible values that can take, meaning we’d need about that many parameters.
The way we will deal with this problem is to use a continuous mixture of simple distributions. We will introduce a new latent variable on which we define a very simple distribution , for example a unit normal, . Then we parameterize a distribution , which gives us
The important point is that for a fixed , may be an extremely simple distribution. In the example above, we could use an independent Bernoulli for each of the 784 pixels, which requires only 784 parameters. But because we additionally have a dependency on , the marginal distribution can be much more complex (in particular, the pixels are typically not independent). Of course the dependency on will require some additional parameters but this could just be a reasonably sized neural network, which gives us far fewer parameters than the that a full categorical distribution would require.
This already describes our model. Sampling from this model is easy: we sample and then for this sample . By assumption, both of those distributions are very simple (and we can also choose them to be easy to sample from).
But evaluating the likelihood of a datapoint is intractable for most models and because it requires calculating a complicated integral. Even if we only care about generating samples, this is a problem: to train the model, we want to maximize , but we can’t even evaluate it (nor its gradient, for the same reason).
The cleverness of VAEs lies in using the right approximations to make this optimization problem tractable, and that is what the remainder of this post is about.
Variational inference
First, we expand the log-likelihood a bit. For any value of , we have
(not writing our the dependency on for now). The first term is easy to evaluate. So if we could evaluate the second term, our problem would be solved (sidenote on motivation2).
This is where variational inference comes into play (here is a tutorial if you want to dive a bit deeper but that’s not necessary for this post). The idea of variational inference is that you have some distribution that you care about, but which is intractable to work with. So you define a family of simpler distributions and then find
where is the Kullback-Leibler divergence (which measures “distances” between probability distributions, though it is not a metric in the mathematical sense3). We can then use in place of whenever we need to evaluate it.
This may sound like an enourmous amount of computational overhead: to just evaluate our objective, we will have to solve an entire optimization problem each time! We will later find a way to alleviate this issue but for now, let’s just ignore it and understand how we would solve the problem naively.
To apply this to our problem, we will approximate with a simpler distribution , parameterized by a new parameter . For example, could be a Gaussian and would be its mean and covariance matrix. Note that while does not explicitly depend on , the optimal parameter does depend on because we minimize the Kullback-Leibler divergence between and .
The variational inference problem is now minimizing
This still contains the term that we can’t evaluate. But we can get rid of that by writing
We have reintroduced , which is intractable, but crucially, it doesn’t depend on . So to solve the minimization problem above, we can also minimize the expected value on the right. Usually, we instead maximize the negative of that:
The objective
is called the evidence lower bound (ELBO) because it is a lower bound on the log evidence :
For now this fact isn’t really interesting, but it will become relevant later.
Maximizing the ELBO is finally a tractable problem: we can write
which is something we can easily evaluate. The expectation is over which also doesn’t pose a problem4.
Combining the optimization problems
Let’s briefly recap our progress so far. We originally wanted to find
which we rewrote as
for an arbitrarily chosen . We then used variational inference to approximate the intractable term as
where is the solution to the variational problem:
So we could now in principle plug in this approximation and solve
but there are problems with this. First, note that depends on . If we for example use gradient ascent to optimize over , we would need to find the new optimal after each gradient step. Second, using an arbitrarily chosen is kind of silly: we optimized such that the entire distribution approximates well, we should make use of this entire distribution.
So let’s go back. We know that the solution to
is the same for any . So we can also maximize
instead, for an arbitrary distribution . Plugging in our approximation, we get
The question now is which distribution to use. But note that by using , we again get the ELBO, this time as the objective for our original optimization problem. This is a good choice for two reasons:
- The ELBO is a lower bound on the evidence, . If we used another distribution , we wouldn’t have any guarantee that we’re optimizing for the right thing if the approximation became bad enough. This way, we’re at least optimizing a lower bound on what we really care about.
- We saw above that we need to find the new after each update to , which is very inefficient. But the ELBO is already our objective for , so now we have the same optimization objective for both parameters and can optimize them jointly.
With this choice of , the joint optimization problem becomes
We can use the reparameterization trick / pathwise gradients to optimize this efficiently with gradient ascent.
Using an encoder
For a single datapoint , we now have an efficiently solvable problem. But now we get back to the more interesting setting of an entire dataset . We then want to optimize the likelihood of the entire dataset:
The problem is that is supposed to approximate , so the optimal is different for each datapoint . Full variational inference would mean using a separate parameter for each datapoint. The optimization would then be over , where is the parameter for the -th datapoint. This again gets us into the realm of a huge number of parameters and computational infeasibility.
So instead, we use amortized variational inference. This means that instead of optimizing parameters for for each , we learn a function . This function is trained to approximate the optimal solution . The downside is that we’re introducing yet another approximation, which can only worsen how well we maximize the likelihood . But the big advantage is that evaluating it is much cheaper than solving an entire optimization problem.
In practice, this means we train a neural network to find the best (in terms of the objective above) for a given . We then use in place of . To make the notation a bit nicer, we write this as
Then we finally get the VAE objective:
Connection to VAEs in practice
As you’ve probably guessed by now, is the decoder of a VAE and is the encoder. The ELBO can be rewritten as
which gives us the interpretation as “reconstruction + regularization loss” that you may have encountered elsewhere (to treat this as a loss that is minimized, you would multiply everything by ).
is typically chosen as a normal distribution, because that makes the KL divergence in the ELBO easy to calculate if is chosen as a unit normal. The choice of depends on the type of data. As mentioned, for binary images we might use independent Bernoulli distributions for each pixel. For continuous output, a normal distribution is a common choice.
Conclusion
We saw how to arrive at VAEs starting from a purely generative motivation, without assuming an autoencoder architecture a priori. Interestingly, this gives a very different impression than the “autoencoder perspective”: what we really care about is the decoder, whereas the encoder is just a useful trick to be able to train the decoder efficiently.
This doesn’t mean that the autoencoder perspective is wrong of course. Having an encoder can be intrinsically useful for some applications, and this is something which is missing in this post. But I think the perspective we took here demonstrates that the VAE architecture is far less arbitrary than it may seem when starting from autoencoders.
Further reading
- The CS 228 lecture notes on VAEs take a somewhat similar approach to this post in terms of emphasizing the variational inference perspective. They Also contain details on some points that I basically ignored, for example on the reparameterization trick
- Carl Doersch’s tutorial on VAEs contains much more detail and also has a different motivation for why we want to approximate (namely to use that to estimate the integral by sampling values of that contribute the most)
- There is also a tutorial by Kingma and Welling, the authors who introduced VAEs. You could also look at their original paper but that’s a lot terser
- Michael Nielsen calls this “discovery fiction”, mentioned for example here↩
- If you intrinsically care about , for example because you hope the latent variables will have an interesting meaning, this and parts of the remaining post are unnecessary, you’ll get to VAEs more directly. But my point is that you don’t need that motivation — VAEs arise pretty naturally even if you only care about finding a good model of the training data.↩
- In particular, the Kullback-Leibler divergence is not symmetric, which raises the question why we use and not . The reason is that the latter would itself lead to an intractable optimization problem and so we wouldn’t have made any progress.↩
- If you’ve read the previous footnote: this is the point where using the other KL divergence would mean we’re stuck because we have an expectation with respect to ↩