Generative models: Autoencoder

Chun-Hao Yang

Statistical modeling

  • In general, statistical modeling represents the data by a probability distribution.
  • We can roughly categorize statistical models into two types (non-exclusively):
    • Descriptive models: describe the data by the probability density/mass function.
    • Generative models: describe the data by the generating process of data points.
  • For example, the following hierarchical model is considered a generative model: \[ \begin{aligned} \eta & \sim N(0, 1),\\ \tau^2 & \sim \text{Gamma}(1, 1),\\ \theta \mid \eta, \tau^2 & \sim N(\eta, \tau^2),\\ \sigma^2 & \sim \text{Gamma}(1, 1),\\ x \mid \theta, \sigma^2 & \sim N(\theta, \sigma^2). \end{aligned} \]
  • It is easy to generate \(x\) by following the generative process. Yet the marginal distribution of \(x\) is very complicated.

Descriptive vs. Generative models

  • The main advantage of generative models is that they can very flexible.

  • However, the flexibility comes at a cost: it is computationally expensive to make inference.

  • In contrast, descriptive models are usually simpler and computationally efficient, but require strong assumptions.

  • Most generative models face at least one of the following issues:

    • Strong assumptions: Some model make strong assumptions about the structure in the data, for example, linear assumption in PCA.
    • Approximations: Some models make severe approximations, leading to sub-optimal models.
    • Expensive computation: Some models rely on computationally expensive inference procedures like Markov Chain Monte Carlo.

Sampling-based statistical inference

  • Inference for generative models is usually done by sampling-based methods.
  • If \(x\) follows a complicated generative model, we can use the following procedure to make inference:
    1. Generate a large number of samples \(\{x^{(1)}, \ldots, x^{(N)}\}\) from the generative model.
    2. Compute the empirical distribution of the samples, i.e., \(\hat{p}(x) = \frac{1}{N}\sum_{i=1}^N \delta_{x^{(i)}}(x)\).
    3. Use \(\hat{p}(x)\) as an approximation of the true distribution \(p(x)\).
  • For example, we can use \(\frac{1}{N}\sum_{i=1}^N f(x^{(i)})\) to approximate \(\mathbb{E}[f(x)]\).
  • This type of inference is very common in Bayesian statistics, where the posterior distribution is usually complicated.
  • For DL models, the forward propagation can be viewed as a sampling process.

Supervised vs. Unsupervised learning

  • In supervised learning, we have a set of input-output pairs \(\{(\boldsymbol{x}_1, y_1), \ldots, (\boldsymbol{x}_n, y_n)\}\).
  • The typical model is a conditional model \(p(\boldsymbol{y} \mid \boldsymbol{x})\), for example the Gaussian linear regression model is \(p(y \mid \boldsymbol{x}) = N(\boldsymbol{x}^T\boldsymbol{\beta}, \sigma^2)\).
  • For unsupervised learning, we only have a set of input data points \(\{\boldsymbol{x}_1, \ldots, \boldsymbol{x}_n\}\).
  • We need to learn the distribution of the data points, i.e., \(p(\boldsymbol{x})\), for example, we can use the kernel density estimation to estimate \(p(\boldsymbol{x})\).
  • However, the data contains noise and a more important goal is to remove the noise and extract the underlying structure in the data.
  • As usual, we will start with the linear models.

Outline

  • Linear Factor model
    • Principle Component Analysis (PCA)
    • Probabilistic PCA (PPCA)
  • Autoencoder
    • Linear Autoencoder
    • Variational Autoencoder (VAE)
    • Training the VAE

Linear Factor Model

Linear Factor model

  • A linear factor model assumes that the data points are generated by a linear transformation of a low-dimensional latent variable, i.e., \[ \boldsymbol{x} = \boldsymbol{W}\boldsymbol{z} + \boldsymbol{\epsilon}, \] where
    • \(\boldsymbol{x} \in \mathbb{R}^p\) is the observed data point,
    • \(\boldsymbol{z} \in \mathbb{R}^q\) is the latent variable with \(q \ll p\),
    • \(\boldsymbol{W} \in \mathbb{R}^{p \times q}\) is the transformation matrix, and
    • \(\boldsymbol{\epsilon} \in \mathbb{R}^p\) is the noise term.
  • The latent variable \(\boldsymbol{z}\) captures the underlying structure in the data points.
  • For simplicity, we assume the \(\boldsymbol{x}\) and \(\boldsymbol{z}\) are centered.

Analysis for linear factor model

  • Factor analysis:
    • the latent variable \(\boldsymbol{z}\) represents common factors, typically independent, that describe the shared structure in the observations
    • the matrix \(\boldsymbol{W}\) represents the factor loadings that describe how the factors are combined to generate the observations
  • Principal component analysis (PCA):
    • transform the observations to a new coordinate system where the axes are the directions of maximum variance
    • the transformation matrix \(\boldsymbol{W}\) is assumed to be orthogonal
    • it is mainly used to reduce the dimensionality of observations

Principal Component Analysis (PCA)

  • Let \(\boldsymbol{x}_1, \ldots, \boldsymbol{x}_n \in \mathbb{R}^p\) be the data points and \(\boldsymbol{X} = [\boldsymbol{x}_1, \boldsymbol{x}_2, \ldots, \boldsymbol{x}_n]^T \in \mathbb{R}^{n\times p}\) be the data matrix. Assume that the data points are centered, i.e., \(\boldsymbol{X}^T \boldsymbol{1} = \boldsymbol{0}\).
  • The goal of PCA is to find a set of \(q\) orthogonal vectors \(\boldsymbol{w}_1, \boldsymbol{w}_2, \ldots, \boldsymbol{w}_q \in \mathbb{R}^p\) such that the variance of the projected data points \(\boldsymbol{z}_i = \boldsymbol{W}^T \boldsymbol{x}_i\) is maximized, where \(\boldsymbol{W} = [\boldsymbol{w}_1, \ldots, \boldsymbol{w}_q] \in \mathbb{R}^{p \times q}\).
  • Let \(\boldsymbol{Z} = [\boldsymbol{z}_1, \boldsymbol{z}_2, \ldots, \boldsymbol{z}_n]^T \in \mathbb{R}^{n \times q}\) be the projected data matrix. Then \(\boldsymbol{Z} = \boldsymbol{X} \boldsymbol{W}\).
  • It is easy to see that the projected data points have zero mean, since \(\boldsymbol{Z}^T \boldsymbol{1} = \boldsymbol{W}^T \boldsymbol{X}^T \boldsymbol{1} = \boldsymbol{0}\).
  • The sum of variances of the projected data points is \[ \sum_{i=1}^n \|\boldsymbol{z}_i\|^2 = \sum_{i=1}^n \boldsymbol{z}_i^T \boldsymbol{z}_i = {\rm tr}(\boldsymbol{Z}^T\boldsymbol{Z}) = {\rm tr}(\boldsymbol{W}^T\boldsymbol{X}^T\boldsymbol{X}\boldsymbol{W}). \]
  • The optimization problem is \[ \max_{\boldsymbol{W}}\; {\rm tr}(\boldsymbol{W}^T\boldsymbol{X}^T\boldsymbol{X}\boldsymbol{W}) \quad \text{subject to} \quad \boldsymbol{W}^T\boldsymbol{W} = \boldsymbol{\boldsymbol{I}}. \]

Decomposition of variance

  • Note that the total variance of the data points is \(\sum_{i=1}^n \|\boldsymbol{x}_i\|^2 = {\rm tr}(\boldsymbol{X}^T\boldsymbol{X})\).
  • Let \(P_{\boldsymbol{W}} \in \mathbb{R}^{p \times p}\) be the projection matrix onto the subspace spanned by \(\boldsymbol{W}\), i.e., \(P_{\boldsymbol{W}} = \boldsymbol{W}(\boldsymbol{W}^T\boldsymbol{W})^{-1}\boldsymbol{W}^T\).
  • Since \(\boldsymbol{W}\) is orthogonal, \(P_{\boldsymbol{W}} = \boldsymbol{W}\boldsymbol{W}^T\). A projection matrix is idempotent, i.e., \(P_{\boldsymbol{W}}^2 = P_{\boldsymbol{W}}\).
  • We can decompose the data matrix \(\boldsymbol{X}\) as \(\boldsymbol{X} = \boldsymbol{X}(\boldsymbol{I} - P_{\boldsymbol{W}}) + \boldsymbol{X}P_{\boldsymbol{W}}\).
  • Since \(P_{\boldsymbol{W}}\) is idempotent and symmetric, we have \[ \boldsymbol{X}^T\boldsymbol{X} = (\boldsymbol{I} - P_{\boldsymbol{W}})\boldsymbol{X}^T\boldsymbol{X}(\boldsymbol{I} - P_{\boldsymbol{W}}) + P_{\boldsymbol{W}}\boldsymbol{X}^T\boldsymbol{X}P_{\boldsymbol{W}}. \]
  • Therefore the total variance can be decomposed as \[ {\rm tr}(\boldsymbol{X}^T\boldsymbol{X}) = {\rm tr}((\boldsymbol{I} - P_{\boldsymbol{W}})\boldsymbol{X}^T\boldsymbol{X}(\boldsymbol{I} - P_{\boldsymbol{W}})) + {\rm tr}(P_{\boldsymbol{W}}\boldsymbol{X}^T\boldsymbol{X}P_{\boldsymbol{W}}). \]
  • The second term is the variance of the projected data points \[ {\rm tr}(P_{\boldsymbol{W}}\boldsymbol{X}^T\boldsymbol{X}P_{\boldsymbol{W}}) = {\rm tr}(\boldsymbol{W}\boldsymbol{W}^T\boldsymbol{X}^T\boldsymbol{X}\boldsymbol{W}\boldsymbol{W}^T) = {\rm tr}(\boldsymbol{W}^T\boldsymbol{X}^T\boldsymbol{X}\boldsymbol{W}). \]

Reconstruction error

  • Let \(\tilde{\boldsymbol{x}}_i = \boldsymbol{W}\boldsymbol{z}_i = \boldsymbol{W}\boldsymbol{W}^T\boldsymbol{x}_i \in \mathbb{R}^p\) be the reconstructed data point.
  • The reconstructed data matrix is \(\widetilde{\boldsymbol{X}} = \boldsymbol{Z}\boldsymbol{W}^T = \boldsymbol{X}\boldsymbol{W}\boldsymbol{W}^T \in \mathbb{R}^{n \times p}\).
  • The reconstruction error is \[\begin{align*} \sum_{i=1}^n \|\boldsymbol{x}_i - \tilde{\boldsymbol{x}}_i\|^2 & = \|\boldsymbol{X} - \widetilde{\boldsymbol{X}}\|^2 = {\rm tr}((\boldsymbol{X} - \widetilde{\boldsymbol{X}})^T(\boldsymbol{X} - \widetilde{\boldsymbol{X}}))\\ & = {\rm tr}\left[(\boldsymbol{X} - \boldsymbol{X}\boldsymbol{W}\boldsymbol{W}^T)^T(\boldsymbol{X} - \boldsymbol{X}\boldsymbol{W}\boldsymbol{W}^T)\right]\\ & = {\rm tr}\left[(\boldsymbol{I} - \boldsymbol{W}\boldsymbol{W}^T)^T\boldsymbol{X}^T\boldsymbol{X}(\boldsymbol{I} - \boldsymbol{W}\boldsymbol{W}^T)\right]\\ & = {\rm tr}((\boldsymbol{I} - P_{\boldsymbol{W}})\boldsymbol{X}^T\boldsymbol{X}(\boldsymbol{I} - P_{\boldsymbol{W}})). \end{align*}\]
  • That is, the first term in the decomposition of the total variance is the reconstruction error.

Quick summary of PCA

Given an orthogonal matrix \(\boldsymbol{W} \in \mathbb{R}^{p \times q}\), we can obtain

  • the projected data points \(\boldsymbol{Z} = \boldsymbol{X}\boldsymbol{W}\),
  • the reconstructed data points based on \(\boldsymbol{Z}\), \(\widetilde{\boldsymbol{X}} = \boldsymbol{Z}\boldsymbol{W}^T = \boldsymbol{X}\boldsymbol{W}\boldsymbol{W}^T\),
  • the variance of the projected data points \({\rm tr}(\boldsymbol{Z}^T\boldsymbol{Z}) = {\rm tr}(\boldsymbol{W}^T\boldsymbol{X}^T\boldsymbol{X}\boldsymbol{W})\),
  • the reconstruction error \({\rm tr}((\boldsymbol{I} - P_{\boldsymbol{W}})\boldsymbol{X}^T\boldsymbol{X}(\boldsymbol{I} - P_{\boldsymbol{W}}))\),
  • the total variance \({\rm tr}(\boldsymbol{X}^T\boldsymbol{X})\) can be decomposed as \[ \text{total var.} = \text{reconstruction error} + \text{projected var.}, \]
  • the PCA finds the orthogonal matrix \(\boldsymbol{W}\) that maximizes the projected variance, which is equivalent to minimizing the reconstruction error,
  • the solution is given by the eigenvectors of the covariance matrix \(\boldsymbol{X}^T\boldsymbol{X}\) (using Rayleigh quotients), or equivalently, the right singular vectors of the data matrix \(\boldsymbol{X}\).

Probabilistic PCA

  • Probabilistic PCA (PPCA) is to cast PCA in a probabilistic framework.
  • Recall the linear factor model: \(\boldsymbol{x} = \boldsymbol{W}\boldsymbol{z} + \boldsymbol{\epsilon}\). We now add additional distributional assumptions: \[ \boldsymbol{z} \sim N_q(0, \boldsymbol{I}), \qquad \boldsymbol{\epsilon} \sim N_p(0, \sigma^2 \boldsymbol{I}). \]
  • Under these assumptions, the distribution of \(\boldsymbol{x}\) is given by \(\boldsymbol{x} \sim N_p(0, \boldsymbol{W}\boldsymbol{W}^T + \sigma^2 \boldsymbol{I})\)
  • The MLE for \(\boldsymbol{W}\) and \(\sigma^2\) are given by \[ \widehat{\boldsymbol{W}}_{\text{MLE}} = \boldsymbol{U}_q (\boldsymbol{\Lambda}_q - \hat{\sigma}^2_{\text{MLE}} \boldsymbol{I})^{1/2} \boldsymbol{R}, \qquad \hat{\sigma}^2_{\text{MLE}} = \frac{1}{p-q}\sum_{j=q+1}^p \lambda_j, \] where \(\boldsymbol{U}_q\) and \(\boldsymbol{\Lambda}_q = {\rm diag}(\lambda_1, \ldots, \lambda_q)\) are the principal eigenvectors and eigenvalues of the sample covariance matrix, and \(\boldsymbol{R}\) is an arbitrary \(q \times q\) orthogonal matrix.1

PPCA as a generative model

  • Once we obtain the MLEs \(\widehat{\boldsymbol{W}}_{\text{MLE}}\) and \(\hat{\sigma}^2_{\text{MLE}}\), the PPCA model can be viewed as a generative model:
    • Generate the latent variable \(\boldsymbol{z} \sim N_q(0, \boldsymbol{I})\) and the noise term \(\boldsymbol{e} \sim N_p(0, \boldsymbol{I})\).
    • Compute \(\boldsymbol{x}^{\star} = \widehat{\boldsymbol{W}}_{\text{MLE}}\boldsymbol{z} + \hat{\sigma}^2_{\text{MLE}}\boldsymbol{e}\).
  • By the consistency of MLEs, we have \(\widehat{\boldsymbol{W}}_{\text{MLE}} \stackrel{p}{\to} \boldsymbol{W}\) and \(\hat{\sigma}^2_{\text{MLE}} \stackrel{p}{\to} \sigma^2\) as \(n \to \infty\).
  • By Slutsky’s theorem, we have \(\boldsymbol{x}^{\star} \stackrel{d}{\to} \boldsymbol{x}\) as \(n \to \infty\).
  • That is, when the sample size is large, the generated data points \(\boldsymbol{x}^{\star}\) have the same distribution as the observed data points \(\boldsymbol{x}\).

Variants of PCA

There are many variants of PCA/PPCA:

  • Sparse PCA1: adding the sparsity constraint (\(L_1\) penalty) to the projection matrix \(\boldsymbol{W}\).
  • Robust PCA2: minimizing a weighted combination of the nuclear norm and of the \(\ell_1\) norm
  • kernel PCA3: using the kernel trick to perform nonlinear dimensionality reduction.
  • Contrastive PCA4: performing PCA with reference information.

Autoencoder

Autoencoder

  • An autoencoder is a neural network that is trained to attempt to copy its input to its output.
  • It has a hidden layer \(\boldsymbol{z}\) that describes a code used to represent the input.
  • The network may be viewed as consisting of two parts: an encoder function \(\boldsymbol{z}= f(\boldsymbol{x})\) and a decoder that produces a reconstruction \(\boldsymbol{r} = g(\boldsymbol{z})\).
  • Usually, the input \(\boldsymbol{x} \in \mathbb{R}^p\) is high-dimensional, and the code \(\boldsymbol{z} \in \mathbb{R}^q\) is low-dimensional, \(q \ll p\).
  • The code can be used as a compact representation for the input data, and the decoder can be used to recover the inputs from the codes.
  • An autoencoder is trained to minimize the reconstruction error, i.e., the difference between the input and the reconstruction.
  • For example, if we use the squared error loss, the optimization problem is \[ \min_{f,g} \sum_{i=1}^n \|\boldsymbol{x}_i - g(f(\boldsymbol{x}_i))\|^2. \]

Linear Autoencoders

  • If we use linear functions for the encoder and decoder, i.e., \(f(\boldsymbol{x}) = \boldsymbol{W}_1 \boldsymbol{x}\) and \(g(\boldsymbol{z}) = \boldsymbol{W}_2 \boldsymbol{z}\), the optimization problem becomes \[ \min_{\boldsymbol{W}_1, \boldsymbol{W}_2} \sum_{i=1}^n \|\boldsymbol{x}_i - \boldsymbol{W}_2 \boldsymbol{W}_1 \boldsymbol{x}_i\|^2 \] where \(\boldsymbol{W}_1 \in \mathbb{R}^{q \times p}\) and \(\boldsymbol{W}_2 \in \mathbb{R}^{p \times q}\).
  • If we additionally assume that \(W_2 = W_1^T\) and \(W_1\) is orthogonal, the optimization problem is equivalent to PCA.

Problem with autoencoders

  • If the encoder and decoder are allowed too much capacity, the autoencoder can learn to perform the copying task without extracting useful information about the distribution of the data.
  • For example, we have a very powerful encoder that maps \(\boldsymbol{x}_i\) to its index \(i\), and a very powerful decoder that maps \(i\) back to \(\boldsymbol{x}_i\).
  • In this case, although the reconstruction error is zero, the encoder is not extracting any useful information about the data.
  • Two approaches to address this issue:
    • Add a regularization term to the optimization problem to encourage the model to have other properties besides the copying ability.
    • Consider a probabilistic model for the autoencoder.

Probabilistic model for autoencoder

  • To cast the autoencoder in a probabilistic framework, we consider the following model: \[\begin{align*} \text{Latent variable} & : \boldsymbol{z} \sim p_{\boldsymbol{\theta}}(\boldsymbol{z}),\\ \text{Decoder} & : \boldsymbol{x} \mid \boldsymbol{z} \sim p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z}). \end{align*}\]
  • By Bayes’ theorem, the distribution of \(\boldsymbol{z} \mid \boldsymbol{x}\) is given by \[ p_{\boldsymbol{\theta}}(\boldsymbol{z} \mid \boldsymbol{x}) = \frac{p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z})p_{\boldsymbol{\theta}}(\boldsymbol{z})}{p_{\boldsymbol{\theta}}(\boldsymbol{x})}. \]
  • The marginal distribution of \(\boldsymbol{x}\) is given by \[ p_{\boldsymbol{\theta}}(\boldsymbol{x}) = \int p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z})p_{\boldsymbol{\theta}}(\boldsymbol{z})d\boldsymbol{z}. \]

Probabilistic model for autoencoder

  • The distribution of \(\boldsymbol{x}\) is usually complicated and the latent space model \(p_{\boldsymbol{\theta}}(\boldsymbol{z})\) is assumed to be simple.
  • The distribution \(p_{\boldsymbol{\theta}}(\boldsymbol{z} \mid \boldsymbol{x})\) is intractable, but we can approximate it by a simple distribution \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\).
  • The approximating distribution \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\) is called the variational distribution and the parameter \(\boldsymbol{\phi}\) is called the variational parameter.
  • In the context of autoencoder, it is called the variational encoder.
  • To summarize, the goal here is to find
    • the decoder \(p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z})\),
    • the variational encoder \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\),

Assumptions

  • Since we are using \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\) to approximate \(p_{\boldsymbol{\theta}}(\boldsymbol{z} \mid \boldsymbol{x})\), it is important to see what assumptions we pose on \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\).
  • Strong assumptions:
    • simpler \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\)
    • easier to optimize
    • larger approximation error
  • Weak assumptions:
    • more complex \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\)
    • harder to optimize
    • smaller approximation error
  • A common choice is the simple factorized Gaussian encoder \(z_i \mid \boldsymbol{x} \stackrel{\text{ind}}{\sim} N(\mu_i(\boldsymbol{x}, \boldsymbol{\phi}), \sigma_i^2(\boldsymbol{x}, \boldsymbol{\phi}))\) where \(\mu_i(\boldsymbol{x}, \boldsymbol{\phi})\) and \(\sigma_i^2(\boldsymbol{x}, \boldsymbol{\phi})\) are functions of \(\boldsymbol{x}\) and \(\boldsymbol{\phi}\).

Variational Autoencoder (VAE)

  • Using neural networks to model \(\mu_i(\boldsymbol{x}, \boldsymbol{\phi})\) and \(\sigma_i^2(\boldsymbol{x}, \boldsymbol{\phi})\), the model diagram is as follows:

  • The next question is how to train the VAE, more specifically, what loss function to use.

Wrap-up the notations

  • A VAE contains the following components:
    • The variational encoder \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\) with parameters \(\boldsymbol{\phi}\).
    • The decoder \(p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z})\) with parameters \(\boldsymbol{\theta}\).
  • If the encoder and decoder are parametrized by neural networks, the parameters \(\boldsymbol{\theta}\) and \(\boldsymbol{\phi}\) are the weights and biases of the encoder and decoder networks.
  • Training a VAE is to find the optimal parameters \(\boldsymbol{\theta}\) and \(\boldsymbol{\phi}\) that minimize the loss function (to be specified later).
  • We will introduce two criteria for training the VAE:
    • the variational inference approach (maximizing the evidence lower bound),
    • the maximum likelihood approach (maximizing the log-likelihood).

Variational Inference

  • Variational inference (VI) is an algorithmic framework for approximating distributions.
  • We want to approximate the distribution \(p_{\boldsymbol{\theta}}(\boldsymbol{z} \mid \boldsymbol{x})\) with a simpler distribution \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\) by minimizing the Kullback-Leibler (KL) divergence \[ D_{\text{KL}}(q_{\boldsymbol{\phi}}\|p_{\boldsymbol{\theta}}) = \int \log\frac{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}{p_{\boldsymbol{\theta}}(\boldsymbol{z} \mid \boldsymbol{x})} q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})d\boldsymbol{z} = \mathbb{E}_{q_{\boldsymbol{\phi}}}\left[\log\frac{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}{p_{\boldsymbol{\theta}}(\boldsymbol{z} \mid \boldsymbol{x})}\right]. \]
  • The main difficulty, as in other variational methods, is that \(p_{\boldsymbol{\theta}}(\boldsymbol{z} \mid \boldsymbol{x})\) is intractable.
  • Hence we are not able to compute the KL divergence directly.

Why the KL divergence?

  • The KL divergence is a measure of the difference between two distributions.
  • Suppose you want to measure the difference between two normal distributions \(N(\mu_1, 1)\) and \(N(\mu_2, 1)\).
  • The most straightforward way is to compute the Euclidean distance between the means, i.e., \(|\mu_1 - \mu_2|\).
  • However this is impossible when the two distributions are not of the same family, for example, a normal distribution and a Laplace distribution.
  • The KL divergence can be computed for any two distributions (with some minimal assumptions).
  • For two normal distributions \(N(\mu_1, \sigma_1^2)\) and \(N(\mu_2, \sigma_2^2)\), the KL divergence is given by \[ D_{\text{KL}}(N(\mu_1, \sigma_1^2)\|N(\mu_2, \sigma_2^2)) = \frac{1}{2}\left(\frac{\sigma_1^2}{\sigma_2^2} + \frac{(\mu_2 - \mu_1)^2}{\sigma_2^2} - 1 + \log\frac{\sigma_2^2}{\sigma_1^2}\right). \]

Evidence Lower Bound (ELBO)

The actual optimization objective of the variational autoencoder, like in other variational methods, is the evidence lower bound (ELBO), which is derived as follows:

\[ \begin{aligned} \log p_{\boldsymbol{\theta}}(\boldsymbol{x}) & =\mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log p_{\boldsymbol{\theta}}(\boldsymbol{x})\right] \\ & =\mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \left[\frac{p_{\boldsymbol{\theta}}(\boldsymbol{x}, \boldsymbol{z})}{p_{\boldsymbol{\theta}}(\boldsymbol{z} \mid \boldsymbol{x})}\right]\right] \\ & =\mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \left[\frac{p_{\boldsymbol{\theta}}(\boldsymbol{x}, \boldsymbol{z})}{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})} \frac{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}{p_{\boldsymbol{\theta}}(\boldsymbol{z} \mid \boldsymbol{x})}\right]\right] \\ & =\underbrace{\mathbb{E}_{q_\phi(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \left[\frac{p_{\boldsymbol{\theta}}(\boldsymbol{x}, \boldsymbol{z})}{q_\phi(\boldsymbol{z} \mid \boldsymbol{x})}\right]\right]}_{\substack{=\mathcal{L}_{\boldsymbol{\theta}, \phi}(\boldsymbol{x})\; (\text{ELBO})}} +\underbrace{\mathbb{E}_{q_\phi(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log \left[\frac{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}{p_\theta(\boldsymbol{z} \mid \boldsymbol{x})}\right]\right]}_{=D_{\text{KL}}\left(q_{\boldsymbol{\phi}} \| p_{\boldsymbol{\theta}}\right)} \end{aligned} \]

  • Since \(\log p_{\boldsymbol{\theta}}(\boldsymbol{x})\) is fixed, minimizing the KL divergence is equivalent to maximizing the ELBO.

Maximize the ELBO

  • Due to the non-negativity of the KL divergence, the ELBO is a lower bound on the log-likelihood of the data \[ \begin{aligned} \mathcal{L}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}) & =\log p_{\boldsymbol{\theta}}(\boldsymbol{x})-D_{\text{KL}}\left(q_{\boldsymbol{\phi}} \| p_{\boldsymbol{\theta}}\right) \\ & \leq \log p_{\boldsymbol{\theta}}(\boldsymbol{x}) \end{aligned} \]

  • Hence maximization of the ELBO \(\mathcal{L}_{\boldsymbol{\theta}, \phi}(\boldsymbol{x})\) w.r.t. the parameters \(\boldsymbol{\theta}\) and \(\boldsymbol{\phi}\), will concurrently optimize the two things:

    1. It will approximately maximize the marginal likelihood \(p_{\boldsymbol{\theta}}(\boldsymbol{x})\). This means that our generative model will become better.
    2. It will minimize the KL divergence of the approximation \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\) from the true distribution \(p_{\boldsymbol{\theta}}(\boldsymbol{z} \mid \boldsymbol{x})\), so \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\) becomes better.

ELBO based on empirical data

  • Given an i.i.d. dataset \(\mathcal{D} = \{\boldsymbol{x}_1, \ldots, \boldsymbol{x}_n\}\), the empirical distribution is \[ q_{\mathcal{D}}(\boldsymbol{x}) = \frac{1}{n} \sum_{i=1}^n \delta_{\boldsymbol{x}_i} \] where \(\delta_{\boldsymbol{x}_i}\) is the Dirac delta function at \(\boldsymbol{x}_i\).
  • With the empirical distribution \(q_{\mathcal{D}}(\boldsymbol{x})\) and the inference model \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\), we get a joint distribution \[ q_{\mathcal{D}, \phi}(\boldsymbol{x}, \boldsymbol{z})=q_{\mathcal{D}}(\boldsymbol{x}) q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x}). \]
  • The KL divergence between \(q_{\mathcal{D}, \phi}(\boldsymbol{x}, \boldsymbol{z})\) and \(p_{\boldsymbol{\theta}}(\boldsymbol{x}, \boldsymbol{z})\) is \[ \begin{aligned} D_{\text{KL}}\left(q_{\mathcal{D}, \phi}(\boldsymbol{x}, \boldsymbol{z}) \| p_{\boldsymbol{\theta}}(\boldsymbol{x}, \boldsymbol{z})\right) & = \mathbb{E}_{q_{\mathcal{D}, \phi}(\boldsymbol{x}, \boldsymbol{z})}\left[\log \frac{q_{\mathcal{D}, \phi}(\boldsymbol{x}, \boldsymbol{z})}{p_{\boldsymbol{\theta}}(\boldsymbol{x}, \boldsymbol{z})}\right]\\ & = \mathbb{E}_{q_{\mathcal{D}, \phi}(\boldsymbol{x}, \boldsymbol{z})}\left[\log \frac{q_{\mathcal{D}}(\boldsymbol{x}) q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}{p_{\boldsymbol{\theta}}(\boldsymbol{x}, \boldsymbol{z})}\right] \\ & = \mathbb{E}_{q_{\mathcal{D}, \phi}(\boldsymbol{x}, \boldsymbol{z})}\left[\log q_{\mathcal{D}}(\boldsymbol{x}) + \log q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x}) - \log p_{\boldsymbol{\theta}}(\boldsymbol{x}, \boldsymbol{z})\right] \\ & = \mathbb{E}_{q_{\mathcal{D}}(\boldsymbol{x})}\Big\{\mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\Big[\log q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})- \log p_{\boldsymbol{\theta}}(\boldsymbol{x}, \boldsymbol{z})\Big] + \log q_{\mathcal{D}}(\boldsymbol{x})\Big\} \\ & = \mathbb{E}_{q_{\mathcal{D}}(\boldsymbol{x})}\Big[-\mathcal{L}_{\boldsymbol{\theta}, \phi}(\boldsymbol{x})\Big] + \mathrm{constant}\\ & = -\frac{1}{n} \sum_{i=1}^n \mathcal{L}_{\boldsymbol{\theta}, \phi}(\boldsymbol{x}_i) + \mathrm{constant}. \end{aligned} \]
  • To carry out the actual computation, we make the following assumptions:
    • the prior distribution \(p_{\boldsymbol{\theta}}(\boldsymbol{z})\) is a standard Gaussian \(N(\boldsymbol{0}, \boldsymbol{I})\),
    • the simple factorized Gaussian encoder \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x}) = \prod_{i=1}^q N(z_i \mid \mu_i(\boldsymbol{x}, \boldsymbol{\phi}), \sigma_i^2(\boldsymbol{x}, \boldsymbol{\phi}))\).
  • During the forward propagation, we sample \(\boldsymbol{z} \sim q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\).
  • A useful reparameterization trick is to sample \(\boldsymbol{\epsilon}^{(1)}, \ldots, \boldsymbol{\epsilon}^{(L)} \stackrel{\rm iid}{\sim}N(\boldsymbol{0}, \boldsymbol{I})\) and set \(\boldsymbol{z}^{(l)} = \boldsymbol{\mu}(\boldsymbol{x}, \boldsymbol{\phi}) + \boldsymbol{\sigma}(\boldsymbol{x}, \boldsymbol{\phi}) \odot \boldsymbol{\epsilon}^{(l)}\).
  • The ELBO is \[ \begin{aligned} \mathcal{L}_{\boldsymbol{\theta}, \boldsymbol{\phi}}(\boldsymbol{x}) & = \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\log p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z})\right] - D_{\text{KL}}\left(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x}) \| p_{\boldsymbol{\theta}}(\boldsymbol{z})\right) \\ & \approx \frac{1}{L} \sum_{l=1}^L \log p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z}^{(l)}) - D_{\text{KL}}\left(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x}) \| p_{\boldsymbol{\theta}}(\boldsymbol{z})\right)\\ & = \frac{1}{L} \sum_{l=1}^L \log p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z}^{(l)}) + \frac{1}{2}\sum_{i=1}^q \left(1 + \log \sigma_i^2(\boldsymbol{x}, \boldsymbol{\phi}) - \mu_i^2(\boldsymbol{x}, \boldsymbol{\phi}) - \sigma_i^2(\boldsymbol{x}, \boldsymbol{\phi})\right). \end{aligned} \]
  • The second line is an MC approximation and the third line is from the KL divergence between two Gaussians.

Marginal Likelihood

  • We can also train the VAE by maximizing the marginal log-likelihood \(\log p_{\boldsymbol{\theta}}(\boldsymbol{x})\).
  • Given an i.i.d. dataset \(\mathcal{D} = \{\boldsymbol{x}_1, \ldots, \boldsymbol{x}_n\}\), the empirical distribution is \[ q_{\mathcal{D}}(\boldsymbol{x}) = \frac{1}{n} \sum_{i=1}^n \delta_{\boldsymbol{x}_i} \] where \(\delta_{\boldsymbol{x}_i}\) is the Dirac delta function at \(\boldsymbol{x}_i\).
  • The KL divergence between the empirical distribution \(q_{\mathcal{D}}(\boldsymbol{x})\) and the marginal likelihood \(p_{\boldsymbol{\theta}}(\boldsymbol{x})\) \[ \begin{aligned} D_{\text{KL}}\left(q_{\mathcal{D}}(\boldsymbol{x}) \| p_{\boldsymbol{\theta}}(\boldsymbol{x})\right) & =-\mathbb{E}_{q_{\mathcal{D}}(\boldsymbol{x})}\left[\log p_{\boldsymbol{\theta}}(\boldsymbol{x})\right]+\mathbb{E}_{q_{\mathcal{D}}(\boldsymbol{x})}\left[\log q_{\mathcal{D}}(\boldsymbol{x})\right] \\ & =-\frac{1}{n} \sum_{i=1}^n \log p_{\boldsymbol{\theta}}(\boldsymbol{x}_i)+\mathrm{constant}. \end{aligned} \]
  • That is, maximizing the marginal log-likelihood is equivalent to minimizing the KL divergence between the empirical distribution and the marginal likelihood.

Estimating the marginal likelihood

  • The marginal log-likelihood \(\log p_{\boldsymbol{\theta}}(\boldsymbol{x})\) can be estimated using an importance sampling technique: \[ \begin{aligned} \log p_{\boldsymbol{\theta}}(\boldsymbol{x})=\log \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\left[\frac{p_{\boldsymbol{\theta}}(\boldsymbol{x}, \boldsymbol{z})}{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\right] = \log \mathbb{E}_{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\left[p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z})\frac{p_{\boldsymbol{\theta}}(\boldsymbol{z})}{q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})}\right] \end{aligned} \]
  • With random samples \(\boldsymbol{z}^{(1)}, \ldots, \boldsymbol{z}^{(L)}\) from \(q_{\boldsymbol{\phi}}(\boldsymbol{z} \mid \boldsymbol{x})\), a Monte Carlo estimator of this is: \[ \begin{aligned} \log p_{\boldsymbol{\theta}}(\boldsymbol{x}) \approx \log \frac{1}{L} \sum_{l=1}^L p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z}^{(l)})\frac{p_{\boldsymbol{\theta}}(\boldsymbol{z}^{(l)})}{q_{\boldsymbol{\phi}}(\boldsymbol{z}^{(l)} \mid \boldsymbol{x})}. \end{aligned} \]
  • Hence the loss for sample \(\boldsymbol{x}\) is \[ \begin{aligned} L_{\text{ML}}({\boldsymbol{\theta}, \boldsymbol{\phi}}; \boldsymbol{x}) & = - \log \frac{1}{L} \sum_{l=1}^L p_{\boldsymbol{\theta}}(\boldsymbol{x} \mid \boldsymbol{z}^{(l)})\frac{p_{\boldsymbol{\theta}}(\boldsymbol{z}^{(l)})}{q_{\boldsymbol{\phi}}(\boldsymbol{z}^{(l)} \mid \boldsymbol{x})}. \end{aligned} \]
  • This is the objective used in importance weighted autoencoders (IWAE)1.
  • Note that for \(L = 1\), the ML loss and the ELBO loss are equivalent.

Image generation with VAE

Generated samples from a VAE trained on MNIST dataset:

Different types of generative models

References