Department of Statistics, University of Warwick

Department of Statistics and Data Science, National University of Singapore
MCMC is used to compute expectations \(\pi(\varphi) = \mathbb{E}_{\pi}[\varphi(X)]\) when sampling from a target \(\pi\) is impossible.
Method:
Standard diagnostics:
The Gelman-Rubin diagnostic (Gelman and Rubin 1992) and ESS have a key limitation:
They are function-specific.
Wouldn’t it be better to diagnose the convergence of the distribution of the chain itself?
Theoretical convergence is measured by statistical distances between the chain’s distribution \(\mu_t\) and the target \(\pi\).
Biswas, Jacob, and Vanetti (2019) provide computable upper bounds on these distances using coupled chains.
A coupling runs two chains \((X_t, Y_t)\) jointly to make them meet:
If \(X_{t+L} \sim \pi\) and \(Y_{t} \sim \mu_t\), then the meeting time \(\frac{\tau - t -L}{L}\) upper bounds \(\|\mu_t - \pi\|_{\text{TV}}\) on average.
For a target \(\pi\) and proposal \(q\), the independent Metropolis-Hastings (MH) algorithm iterates:
To couple two independent MH chains \((X_t, Y_t)\), we can use the same proposal \(X' \sim q(\cdot)\) for both chains at each step:
Practical Challenges:
We seek a diagnostic that is:
Let \(f:[0, \infty) \to \mathbb{R}\) be a convex function with \(f(1)=0\). \[\begin{equation} D_f(\pi \| \mu) = \int \mu(\mathrm{d} x) f\left( \frac{\mathrm{d}\pi}{\mathrm{d}\mu}(x) \right) \end{equation}\]
This family includes many famous divergences:
Consider two discrete measures:
The \(f\)-divergence between them is simple to compute: \[\begin{equation} D_f(\pi_{2N} \| \mu_{2N}) = \frac{1}{2N} \sum_{n=1}^{2N} f(2N W^n) \end{equation}\]
Imagine we knew the density ratio \(\psi_t(x) = \frac{\mathrm{d}\pi}{\mathrm{d}\mu_t}(x)\). Given samples \(X_t^n \sim \mu_t\), we could:
Of course, \(\psi_t(x)\) is intractable. But what if we could approximate it?
The likelihood ratio over the entire path simplifies to the initial weight only:
\[ \frac{\mathbb{P}_\pi(X_{0:t})}{\mathbb{P}_{\mu_0}(X_{0:t})} = \frac{\pi(X_0) \prod_{s=1}^t K(X_{s-1}, X_s)}{\mu_0(X_0) \prod_{s=1}^t K(X_{s-1}, X_s)} = \frac{\pi(X_0)}{\mu_0(X_0)} =: w_0 \]
This motivates a simple diagnostic method:
A novel “weight harmonization” scheme for parallel MCMC chains that produces a consistent, computable approximation of the importance weights \(\psi_t(X_t^n)\).
A general method to use these weights to get online upper bounds on any \(f\)-divergence, including KL, \(\chi^2\), Hellinger, and TV distance.
The resulting diagnostic is guaranteed to improve over time and requires no lag or warm-up.
The \(\chi^2\)-divergence bound provides an estimate of the “effective number of active chains”, a highly practical and interpretable metric.
Key Insight (Theorem 1):
If we have a “good” set of weights \(W_t^{1:M}\) for our MCMC samples \(X_t^{1:M} \sim \mu_t\) (i.e., \(\sum W_t^m \delta_{X_t^m}\) is a consistent estimator of \(\pi\)), then:
\[\begin{equation} \mathbb{P}\left( \frac{1}{M}\sum_{m=1}^{M} f(M W_t^m) \le D_f(\pi \| \mu_t) - \varepsilon \right) \to 0 \quad \text{as } M \to \infty \end{equation}\]
So, if we can generate “good” weights \(W_t^m\) at each MCMC step \(t\), we can track the convergence of the chain’s distribution.
A coupling \(\bar{K}(X_t^n, X_t^m, \cdot, \cdot)\) runs two weighted chains \((X_{t+1}^n, X_{t+1}^m)\) jointly from \((X_t^n, X_t^m)\), encouraging them to meet.
We harmonize their weights by taking a convex combination:
\[\begin{align} W_{t+1}^n \gets W_{t+1}^m \gets (W_t^n + W_t^m) / 2 \end{align}\]
We need to allow all chains to interact. We run \(2N\) chains, viewed as \(N\) pairs.
At each step \(t \to t+1\):
Couple: For each of the \(N\) current pairs \((X_t^n, X_t^{m})\), sample \((X_{t+1}^n, X_{t+1}^{m}) \sim \bar{K}(X_t^n, X_t^{m}, \cdot, \cdot)\).
Harmonize: If \(X_{t+1}^n = X_{t+1}^{m}\), set \(W_{t+1}^n = W_{t+1}^{m} = (W_t^n + W_t^{m})/2\). Otherwise, weights are unchanged.
Reshuffle: Randomly permute the pairings among the chains that just met. This allows information (weights) to propagate throughout the entire system of \(2N\) chains over time.
Proposition (Invariance of Expectations) The un-normalized estimator \(\hat{I}_{t,2N}(\varphi) = \sum_{n=1}^{2N} w_t^n \varphi(X_t^n)\) has a constant expectation over time: \[\begin{equation} \mathbb{E}[\hat{I}_{t,2N}(\varphi)] = \mathbb{E}[\hat{I}_{0,2N}(\varphi)] = 2N \int \varphi(x) \gamma(x)\mathrm{d} x \end{equation}\]
Theorem (Consistency) For any time \(t\), as the number of particles \(2N \to \infty\): \[\begin{equation} \sum_{n=1}^{2N} W_t^n \varphi(X_t^n) \xrightarrow{a.s.} \int \varphi(x) \pi(x) \mathrm{d} x \end{equation}\]
Corollary (Non-Asymptotic Bounds) These can be refined to finite-\(N\) guarantees: with probability \(\ge 1-\alpha\), \[\begin{equation} \frac{1}{2N}\sum_{n=1}^{2N} f(2N W_t^n) + \text{Error}_{2N,t}(\alpha) \ge D_f(\pi \| \mu_t) \end{equation}\]
Theorem (Geometric Convergence) If \(\mathbb{P}(X' = Y' \mid x,y) \ge p_c > 0\) for all \(x,y\), the weights converge exponentially fast: \[\begin{equation} \mathbb{E}[\|W_t - \bar{W}\|_2^2] = O(\rho^t) \quad \text{for some } 0 < \rho < 1 \end{equation}\]
Theorem (Almost Sure Convergence) Under general ergodicity conditions (no rate required): \[\begin{equation} W_t \xrightarrow{a.s.} \bar{W} = (1/2N, \dots, 1/2N) \quad \text{as } t \to \infty \end{equation}\]
In both cases, \(D_f(\pi \| \mu_t) \to 0\) almost surely: the diagnostic vanishes at stationarity.
A high-dimensional (\(d=2500\)) model.
We introduced a new method for MCMC convergence diagnostics based on weight harmonization.
The main limitation is the conservativeness of the bound, especially for large \(2N\). This is due to the pairwise interaction scheme.
Potential Improvements:
Rao-Blackwellization: Instead of picking one random pairing, can we average over all possible pairings? This would dramatically increase interaction but is computationally challenging.
Offline Correction: The current method is a “filtering” approach. Can we use ideas from particle smoothing (a “backward pass”) to refine the weights after the run is complete?
Variance Reduction: Incorporate control variates to accelerate the convergence of the diagnostic itself.
Thank you!