# From expectation maximization to stochastic variational inference

April 3, 2018

## Introduction

Given a probabilistic model $p(\mathbf{x};\boldsymbol\theta)$ and some observations $\mathbf{x}$, we often want to estimate optimal parameter values $\boldsymbol{\hat{\theta}}$ that maximize the data likelihood. This can be done via maximum likelihood (ML) estimation or maximum a posteriori (MAP) estimation if point estimates of $\boldsymbol\theta$ are sufficient:

In many cases, direct computation and optimization of the likelihood function $p(\mathbf{x};\boldsymbol\theta)$ is complex or impossible. One option to ease computation is the introduction of latent variables $\mathbf{t}$ so that we have a complete data likelihood $p(\mathbf{x},\mathbf{t};\boldsymbol\theta)$ which can be decomposed into a conditional likelihood $p(\mathbf{x} \lvert \mathbf{t};\boldsymbol\theta)$ and a prior $p(\mathbf{t})$.

Latent variables are not observed directly but assumed to cause observations $\mathbf{x}$. Their choice is problem-dependent. To obtain the the marginal likelihood $p(\mathbf{x};\boldsymbol\theta)$, we have to integrate i.e. marginalize out the latent variables.

Usually, we choose a latent variable model such that parameter estimation for the conditional likelihood $p(\mathbf{x} \lvert \mathbf{t};\boldsymbol\theta)$ is easier than for the marginal likelihood $p(\mathbf{x};\boldsymbol\theta)$. For example, the conditional likelihood of a Gaussian mixture model (GMM) is a single Gaussian for which parameter estimation is easier than for the marginal likelihood which is a mixture of Gaussians. The latent variable $\mathbf{t}$ in a GMM determines the assignment to mixture components and follows a categorical distribution. If we can solve the integral in Eq. 3 we can also compute the posterior distribution of the latent variables by using Bayes’ theorem:

With the posterior, inference for the latent variables becomes possible. Note that in this article the term estimation is used to refer to (point) estimation of parameters via ML or MAP and inference to refer to Bayesian inference of random variables by computing the posterior.

A major challenge in Bayesian inference is that the integral in Eq. 3 is often impossible or very difficult to compute in closed form. Therefore, many techniques exist to approximate the posterior in Eq. 4. They can be classified into numerical approximations (Monte Carlo techniques) and deterministic approximations. This article is about deterministic approximations only, and their stochastic variants.

## Expectation maximization (EM)

Basis for many inference methods is the expectation-maximization (EM) algorithm. It is an iterative algorithm for estimating the parameters of latent variable models, often with closed-form updates at each step. We start with a rather general view of the EM algorithm that also serves as a basis for discussing variational inference methods later. It is straightforward to show[2] that the marginal log likelihood can be written as

with

and

where $q(\mathbf{t})$ is any probability density function. $\mathrm{KL}(q \mid\mid p)$ is the Kullback-Leibler divergence between $q(\mathbf{t})$ and $p(\mathbf{t} \lvert \mathbf{x};\boldsymbol\theta)$ that measures how much $q$ diverges from $p$. The Kullback-Leibler divergence is zero for identical distributions and greater than zero otherwise. Thus, $\mathcal{L}(q, \boldsymbol\theta)$ is a lower bound of the log likelihood. It is equal to the log likelihood if $q(\mathbf{t}) = p(\mathbf{t} \lvert \mathbf{x};\boldsymbol\theta)$. In the E-step of the EM algorithm, $q(\mathbf{t})$ is therefore set to $p(\mathbf{t} \lvert \mathbf{x};\boldsymbol\theta)$ using the parameter values of the previous iteration $l-1$.

Note that this requires that $p(\mathbf{t} \lvert \mathbf{x};\boldsymbol\theta)$ is known, like in the GMM case where the posterior is a categorical distribution, as mentioned above. In the M-step, $\mathcal{L}(q, \boldsymbol\theta)$ is optimized w.r.t. $\boldsymbol\theta$ using $q(\mathbf{t})$ from the E-step:

In general, this is much simpler than optimizing $p(\mathbf{x};\boldsymbol\theta)$ directly. E and M steps are repeated until convergence. However, the requirement that the posterior $p(\mathbf{t} \lvert \mathbf{x};\boldsymbol\theta)$ must be known is rather restrictive and there are many cases where the posterior is intractable. In these cases, further approximations must be made.

## Variational EM

If the posterior is unknown, we have to assume specific forms of $q(\mathbf{t})$ and maximize the lower bound $\mathcal{L}(q, \boldsymbol\theta)$ w.r.t. these functions. The area of mathematics related to these optimization problems is called calculus of variations[3], hence the name variational EM, or variationial inference in general. A widely used approximation for the unknown posterior is the mean-field approximation[2][3] which factorizes $q(\mathbf{t})$ into $M$ partitions:

For example, if $\mathbf{t}$ is 10-dimensional, we can factorize $q(\mathbf{t})$ into a product of 10 $q_i(\mathbf{t}_i)$, one for each dimension, assuming independence between dimensions. The approximate posterior $q(\mathbf{t})$ can be obtained by minimizing $\mathrm{KL}(q \mid\mid \tilde{p})$ which is the KL divergence between the factorized distribution $q(\mathbf{t})$ and the unnormalized posterior $\tilde{p}(\mathbf{t};\boldsymbol\theta) = p(\mathbf{x},\mathbf{t};\boldsymbol\theta)$. This leads to the following update formula for $q_i(\mathbf{t}_i)$:

where $\mathbb{E}_{-q_i}$ denotes the expectation with respect to all variables of $\mathbf{t}$ except $t_i$. $\boldsymbol\theta^{l-1}$ are the parameters from the previous iteration. This is repeated for all $q_i$ until convergence. The E-step of the variational EM algorithm is therefore

and the M-step uses the posterior approximation $q^l(\mathbf{t})$ from the E-step to estimate parameters $\boldsymbol\theta^l$:

The mean field approach allows inference for many interesting latent variable models but it requires analytical solutions for the approximate posterior which is not always possible. Especially when used in context of deep learning where the approximate posterior $q(\mathbf{t})$ and the conditional likelihood $p(\mathbf{x} \lvert \mathbf{t};\boldsymbol\theta)$ are neural networks with at least one non-linear hidden layer, the mean field approach is not applicable any more[4]. Further approximations are required.

## Stochastic variational inference

Let’s assume we have a latent variable model with one latent variable $\mathbf{t}^{(i)}$ for each observation $\mathbf{x}^{(i)}$. Observations $\mathbf{x}^{(i)}$ come from an i.i.d. dataset. To make the following more concrete let’s say that $\mathbf{x}^{(i)}$ are images and $\mathbf{t}^{(i)}$ are $D$-dimensional latent vectors that cause the generation of $\mathbf{x}^{(i)}$ under the generative model $p(\mathbf{x},\mathbf{t};\boldsymbol\theta) = p(\mathbf{x} \lvert \mathbf{t};\boldsymbol\theta)p(\mathbf{t})$.

Our goal is to find optimal parameter values for the marginal likelihood $p(\mathbf{x};\boldsymbol\theta)$ by maximizing its variational lower bound. Here, we neither know the true posterior $p(\mathbf{t} \lvert \mathbf{x};\boldsymbol\theta)$ nor can we apply the mean field approximation[4], so we have to make further approximations. We start by assuming that $q(\mathbf{t})$ is a factorized Gaussian i.e. a Gaussian with a diagonal covariance matrix and that we have a separate distribution $q^{(i)}$ for each latent variable $\mathbf{t}^{(i)}$:

The problem here is that we have to estimate too many parameters. For example, if the latent space is 50-dimensional we have to estimate about 100 parameters per training object! This is not what we want. Another option is that all $q^{(i)}$ share their parameters $\mathbf{m}$ and $\mathbf{s}^2$ i.e. all $q^{(i)}$ are identical. This would keep the number of parameters constant but would be too restrictive though. If we want to support different $q^{(i)}$ for different $\mathbf{t}^{(i)}$ but with a limited number of parameters we should consider using parameters for $q$ that are functions of $\mathbf{x}^{(i)}$. These functions are themselves parametric functions that share a set of parameters $\boldsymbol\phi$:

So we finally have a variational distribution $q(\mathbf{t} \lvert \mathbf{x};\boldsymbol\phi)$ with a fixed number of parameters $\boldsymbol\phi$ as approximation for the true but unknown posterior $p(\mathbf{t} \lvert \mathbf{x};\boldsymbol\theta)$. To implement the (complex) functions $m$ and $s$ that map from an input image to the mean and the variance of that distribution we can use a convolutional neural network (CNN) that is parameterized by $\boldsymbol\phi$. Similarly, for implementing $p(\mathbf{x} \lvert \mathbf{t};\boldsymbol\theta)$ we can use another neural network, parameterized by $\boldsymbol\theta$, that maps a latent vector $\mathbf{t}$ to the sufficient statistics of that probability distribution. Since $\mathbf{t}$ is often a lower-dimensional embedding or code of image $\mathbf{x}$, $q(\mathbf{t} \lvert \mathbf{x};\boldsymbol\phi)$ is referred to as probabilistic encoder and $p(\mathbf{x} \lvert \mathbf{t};\boldsymbol\theta)$ as probabilistic decoder.

### Variational auto-encoder

Both, encoder and decoder, can be combined to a variational auto-encoder[4] that is trained with the variational lower bound $\mathcal{L}$ as optimization objective using standard stochastic gradient ascent methods. For our model, the variational lower bound for a single training object $\mathbf{x}^{(i)}$ can also be formulated as:

The first term is the expected negative reconstruction error of an image $\mathbf{x}^{(i)}$. This term is maximized when the reconstructed image is as close as possible to the original image. It is computed by first feeding an input image $\mathbf{x}^{(i)}$ through the encoder to compute the mean and the variance of the variational distribution $q(\mathbf{t} \lvert \mathbf{x};\boldsymbol\phi)$. To compute an approximate value of the expected negative reconstruction error, we sample from the variational distribution. Since this is a Gaussian distribution, sampling is very efficient. To compute $p(\mathbf{x} \lvert \mathbf{t};\boldsymbol\theta)$ we feed the samples through the decoder. A single sample per training object is usually sufficient[4] if the mini-batch size during training is large enough e.g. > 100.

The second term in Eq. 16, the negative KL divergence, is maximized when the approximate posterior $q(\mathbf{t} \lvert \mathbf{x};\boldsymbol\phi)$ is equal to the prior $p(\mathbf{t})$. The prior is usually chosen to be the standard normal distribution $\mathcal{N}(\mathbf{0},\mathbf{I})$. This term therefore acts as a regularization term to avoid that the variance of $q(\mathbf{t} \lvert \mathbf{x};\boldsymbol\phi)$ becomes zero, otherwise, $q(\mathbf{t} \lvert \mathbf{x};\boldsymbol\phi)$ would degenerate to a delta function and the variational auto-encoder to a usual auto-encoder. Regularizing $q(\mathbf{t} \lvert \mathbf{x};\boldsymbol\phi)$ to have non-zero variance makes the decoder more robust against small changes in $\mathbf{t}$ and the latent space a continuous space of codes that can be decoded to realistic images.

### Gradient of the variational lower bound

To be able to use the variational lower bound as optimization objective or loss function in tools like Tensorflow, we have to make sure that it is differentiable. This is easy to achieve for the regularization term which can be integrated analytically in the Gaussian case

where $m_j$ and $s_j$ denote the $j$-th elements of the vectors computed with functions $m$ and $s$ (see Eq. 15). $D$ is the dimensionality of these vectors. The computation of the expected negative reconstruction error, on the other hand, involves sampling from $q(\mathbf{t} \lvert \mathbf{x};\boldsymbol\phi)$ which is not differentiable. However, a simple reparameterization trick allows to express the random variable $\mathbf{t}$ as deterministic variable $\mathbf{t} = g(\mathbf{m}, \mathbf{s}, \mathbf\epsilon) = \mathbf{m} + \mathbf{s} \mathbf\epsilon$ plus a random (noise) variable $\mathbf\epsilon \sim \mathcal{N}(\mathbf{0},\mathbf{I})$ that doesn’t depend on any parameters to be optimized. With this trick we can easily compute the gradient of function $g$ and can ignore $\mathbf\epsilon$ i.e. the sampling procedure during back-propagation.

We haven’t defined the functional form of the probabilistic decoder $p(\mathbf{x} \lvert \mathbf{t};\boldsymbol\theta)$ yet. If we train the variational auto-encoder with grey-scale MNIST images, for example, it makes sense to use a multivariate Bernoulli distribution. In this case, the output of the decoder network is the single parameter of this distribution. It defines for each pixel the probability of being white. These probabilities are then simply mapped to values from 0-255 to generate grey-scale images. In the output layer of the decoder network there is one node per pixel with a sigmoid activation function. Hence, we compute the binary cross-entropy between the input image and the decoder output to estimate the expected reconstruction error.

You can find a complete implementation example of a variational auto-encoder in this notebook. It is part of the bayesian-machine-learning repository.

Stochastic variational inference algorithms implemented as variational auto-encoders scale to very large datasets as they can be trained based on mini-batches. Furthermore, they can also be used for data other than image data. For example, Gómez-Bombarelli et. al.[5] use a sequential representation of chemical compounds together with an RNN-based auto-encoder to infer a continuous latent space of chemical compounds that can be used e.g. for generating new chemical compounds with properties that are desirable for drug discovery. I’ll cover that in another blog post.

## References

[1] Dimitris G. Tzikass, Aristidis et. al. The Variational Approximation for Bayesian Inference.
[2] Kevin P. Murphy. Machine Learning, A Probabilistic Perspective, Chapters 11 and 21.
[3] Christopher M. Bishop. Pattern Recognition and Machine Learning, Chapters 9 and 10.
[4] Diederik P Kingma, Max Welling Auto-Encoding Variational Bayes.
[5] Gómez-Bombarelli et. al. Automatic chemical design using a data-driven continuous representation of molecules.