Here’s the deal: I used PyMC, matplotlib, and Jake Vanderplas’ JSAnimation to create javascript animations of three MCMC sampling algorithms – Metropolis-Hastings, slice sampling and NUTS.
I like visualizations because they provide a good intuition for how the samplers work and what problems they can run into.
You can download the full notebook here or view it in your browser. Note that for this post I used video embedding due to the size of the animations if they are not compressed. The notebook contains code for both.
The model is a simple linear model as explained in my previous blog post on Bayesian GLMs. Essentially, I generated some data and estimate intercept
and slope
. In the lower left corner is the joint posterior while the plot above shows the trace of the marginal posterior of the intercept
while the right plot shows the trace of the marginal posterior of the slope
parameter. Each point represents a sample drawn from the posterior. At 3 quarters of the way I added a thousand samples to show that they all sample from the posterior eventually.
Metropolis-Hastings
First, lets see how our old-school Metropolis-Hastings (MH) performs. The code uses matplotlib’s handy FuncAnimation
(see here for a tutorial), my own animation code, and the recently merged iterative sampling function iter_sample()
.
As you can see, there is quite some correlation between intercept
and slope
– if we believe in a higher intercept we must also believe in a lower slope (which makes geometrical sense if you think how lines could fit through the point clouds). This often makes it difficult for the MCMC algorithm to converge (i.e. sample from the true posterior) as we wittness here.
The reason MH does not do anything at first is that MH proposes huge jumps that are not accepted because they are way outside the posterior. PyMC then tunes the proposal distribution so that smaller jumps are proposed. These smaller jumps however lead to the random-walk behavior you can see which makes sampling inefficient (for a good intuition about this “drunken walk”, see here).
Slice sampling
Lets see how Slice sampling fares.
As you can see, slice sampling does a much better job. For one thing, there are no rejections (which is a property of the algorithm). But there’s still room for improvement. At the core, slice sampling always updates one random variable at a time while keeping all others constant. This property leads to small steps being taken (imagine trying to move along a diagonal area on the chess board with a Rook) and makes sampling from correlated posteriors inefficient.
NUTS (Hamiltonian Monte Carlo)
NUTS on the other hand is a newer gradient-based sampler that operates on the joint posterior. Correlations are not a problem because this sampler can actually move diagonally as well (more like the Queen). As you can see, it does a much better job at exploring the posterior and takes much wider steps.
Mesmerizing, ain’t it?
What surprised me about the slice sampling is that if I looked at the individual traces (top and right plot) only, I’d say they hadn’t converged. But rather it seems that while the step-size is small, averaging samples over a longer run should still provide meaningful inference.
Where to go from here
I was initially setting out to get real-time plotting while sampling into PyMC. What I’ve shown here just creates an animation after sampling has finished. Unfortunately, I don’t think it’s currently possible to do so in the IPython Notebook as it requires embedding of HTML for which we need the finished product. If anyone has an idea here that might be a very interesting extension.
Further reading
- Jake’s tutorial on matplotlib animations
- Jake’s blog post on embedding JS animations in the notebook
- Abe Flaxman’s much prettier videos on MCMC (Would be nice to replace my crappy plotting code with his – PRs welcome.)