Hamiltonian Variational Auto-Encoder

Authors:
Anthony L Caterini University of Oxford
Arnaud Doucet Oxford
Dino Sejdinovic University of Oxford

Introduction:

Variational Auto-Encoders (VAE) have become very popular techniques to performinference and learning in latent variable models as they allow us to leverage the richrepresentational power of neural networks to obtain flexible approximations of theposterior of latent variables as well as tight evidence lower bounds (ELBO).

Abstract:

Variational Auto-Encoders (VAE) have become very popular techniques to performinference and learning in latent variable models as they allow us to leverage the richrepresentational power of neural networks to obtain flexible approximations of theposterior of latent variables as well as tight evidence lower bounds (ELBO). Com-bined with stochastic variational inference, this provides a methodology scaling tolarge datasets. However, for this methodology to be practically efficient, it is neces-sary to obtain low-variance unbiased estimators of the ELBO and its gradients withrespect to the parameters of interest. While the use of Markov chain Monte Carlo(MCMC) techniques such as Hamiltonian Monte Carlo (HMC) has been previouslysuggested to achieve this [23, 26], the proposed methods require specifying reversekernels which have a large impact on performance. Additionally, the resultingunbiased estimator of the ELBO for most MCMC kernels is typically not amenableto the reparameterization trick. We show here how to optimally select reversekernels in this setting and, by building upon Hamiltonian Importance Sampling(HIS) [17], we obtain a scheme that provides low-variance unbiased estimators ofthe ELBO and its gradients using the reparameterization trick. This allows us todevelop a Hamiltonian Variational Auto-Encoder (HVAE). This method can bere-interpreted as a target-informed normalizing flow [20] which, within our context,only requires a few evaluations of the gradient of the sampled likelihood and trivialJacobian calculations at each iteration.

You may want to know: