A coupling-based approach to \(f\)-divergence diagnostics for Markov chain Monte Carlo

Adrien Corenflos

Department of Statistics, University of Warwick

Joint work with Hai-Dang Dau

1

Department of Statistics and Data Science, National University of Singapore

Motivation: MCMC

MCMC is used to compute expectations \(\pi(\varphi) = \mathbb{E}_{\pi}[\varphi(X)]\) when sampling from a target \(\pi\) is impossible.

Method:

  • Design a kernel \(K\) with invariant distribution \(\pi\): \(K\pi = \pi\).
  • Simulate \(X_{t+1} \sim K(X_t, \cdot)\), with \(\mathcal{L}(X_t) \to \pi\) as \(t \to \infty\).
  • Use time averages to approximate expectations: \[\begin{equation} \pi(\varphi) \approx \frac{1}{T-B+1} \sum_{t=B}^T \varphi(X_t) \end{equation}\]

Nagging Question: “are we there yet?”

Standard diagnostics:

  • Gelman-Rubin (\(\hat{R}\)): Compares variance within and between \(N\) parallel chains. A variance ratio \(\hat{R} \approx 1\) suggests convergence.
  • Effective Sample Size (ESS): Estimates the number of independent samples equivalent to the MCMC output for a given function \(\varphi\). \[\begin{equation} \text{ESS} = \frac{T}{1 + 2 \sum_{k=1}^T \text{Corr}(\varphi(X_0), \varphi(X_k))} \end{equation}\]

Limitations of Standard Diagnostics

The Gelman-Rubin diagnostic (Gelman and Rubin 1992) and ESS have a key limitation:

They are function-specific.

  • A chain might have a good ESS for \(\mathbb{E}[X]\) but a terrible one for \(\mathbb{E}[X^2]\).
  • We get a partial picture of convergence, tied to specific test functions \(\varphi\).

Wouldn’t it be better to diagnose the convergence of the distribution of the chain itself?

Distribution-based Diagnostics

Theoretical convergence is measured by statistical distances between the chain’s distribution \(\mu_t\) and the target \(\pi\).

  • Total Variation Distance: \(\|\mu_t - \pi\|_{\text{TV}}\)
  • Wasserstein Distance: \(W_p(\mu_t, \pi)\)

Biswas, Jacob, and Vanetti (2019) provide computable upper bounds on these distances using coupled chains.

A coupling runs two chains \((X_t, Y_t)\) jointly to make them meet:

If \(X_{t+L} \sim \pi\) and \(Y_{t} \sim \mu_t\), then the meeting time \(\frac{\tau - t -L}{L}\) is an estimator of \(\|\mu_t - \pi\|_{\text{TV}}\).

Distribution-based Diagnostics

The independent Metropolis-Hastings case

For a target \(\pi\) and proposal \(q\), the independent Metropolis-Hastings (MH) algorithm iterates:

  1. Propose \(X' \sim q(\cdot)\)
  2. Accept \(X'\) with probability \(\alpha = 1 \wedge \tfrac{\pi(X') / q(X')}{\pi(X) / q(X)}\)

To couple two independent MH chains \((X_t, Y_t)\), we can use the same proposal \(X' \sim q(\cdot)\) for both chains at each step:

  1. Propose \(X' \sim q(\cdot)\)
  2. Compute acceptance probabilities \(\alpha_X\) and \(\alpha_Y\) for \(X_t\) and \(Y_t\) respectively.
  3. Draw a unique uniform random variable \(U \sim \text{Uniform}(0,1)\)
  4. If \(U < \alpha_X\), accept for \(X_t\); if \(U < \alpha_Y\), accept for \(Y_t\).
  5. If both chains accept the same proposal, they meet: \(X_{t+1} = Y_{t+1} = X'\).

Limitations of TV/Wasserstein Bounds

Practical Challenges:

  • Interpretation: What does \(\|\mu_t - \pi\|_{\text{TV}} < 0.05\) imply for my estimates?
  • Lag requirement: The method requires a second chain started from a “converged” state \(Y_{t-L}\), which is circular.

We seek a diagnostic that is:

  • Distributional
  • Interpretable
  • Lag-free

The Ideal Scenario

Imagine we knew the density ratio \(\psi_t(x) = \frac{\mathrm{d}\pi}{\mathrm{d}\mu_t}(x)\). Given samples \(X_t^n \sim \mu_t\), we could:

  1. Get ideal estimates of \(\pi\) via importance sampling: \[\begin{equation} \pi(\varphi) \approx \sum_{n=1}^N W_t^n \varphi(X_t^n), \quad \text{where } W_t^n \propto \psi_t(X_t^n) \end{equation}\]
  2. Compute any \(f\)-divergence \(D_f(\pi \| \mu_t)\), which only depends on this ratio \[\begin{equation} D_f(\pi \| \mu_t) = \mathbb{E}_{\mu_t}\left[f\left(\psi_t(X)\right)\right] \end{equation}\]

Of course, \(\psi_t(x)\) is intractable. But what if we could approximate it?

This talk

  1. A novel “weight harmonization” scheme for parallel MCMC chains that produces a consistent, computable approximation of the importance weights \(\psi_t(X_t^n)\).

  2. A general method to use these weights to get online upper bounds on any \(f\)-divergence, including KL, \(\chi^2\), Hellinger, and TV distance.

  3. The resulting diagnostic is guaranteed to improve over time and requires no lag or warm-up.

  4. The \(\chi^2\)-divergence bound provides an estimate of the “effective number of active chains”, a highly practical and interpretable metric.

Background: \(f\)-divergences

Let \(f:[0, \infty) \to \mathbb{R}\) be a convex function with \(f(1)=0\). \[\begin{equation} D_f(\pi \| \mu) = \int \mu(\mathrm{d} x) f\left( \frac{\mathrm{d}\pi}{\mathrm{d}\mu}(x) \right) \end{equation}\]

This family includes many famous divergences:

  • Kullback-Leibler: \(f(t) = t \log t\)
  • Total Variation: \(f(t) = |t-1|/2\)
  • \(\chi^2\)-divergence: \(f(t) = (t-1)^2\)
  • Hellinger distance: \(f(t) = (\sqrt{t}-1)^2\)

\(f\)-divergences for Weighted Samples

Consider two discrete measures:

  • Target: \(\pi_N = \sum_{n=1}^N W^n \delta_{X^n}\)
  • Reference: \(\mu_N = \frac{1}{N} \sum_{n=1}^N \delta_{X^n}\)

The \(f\)-divergence between them is simple to compute: \[\begin{equation} D_f(\pi_N \| \mu_N) = \frac{1}{N} \sum_{n=1}^N f(N W^n) \end{equation}\]

From Weights to Diagnostics

Key Insight (Theorem 1):

If we have a “good” set of weights \(W_t^{1:N}\) for our MCMC samples \(X_t^{1:N} \sim \mu_t\) (i.e., \(\sum W_t^n \delta_{X_t^n}\) is a consistent estimator of \(\pi\)), then:

\[\begin{equation} \mathbb{P}\left( \frac{1}{N}\sum_{n=1}^N f(N W_t^n) \le D_f(\pi \| \mu_t) - \varepsilon \right) \to 0 \quad \text{as } N \to \infty \end{equation}\]

So, if we can generate good weights \(W_t^n\) at each MCMC step \(t\), we can track the convergence of the chain’s distribution.

A Naive (and flawed) Approach

How to get weights?

  1. Start \(N\) chains from \(X_0^n \sim \mu_0\).
  2. Compute initial weights \(w_0^n = \frac{\pi(X_0^n)}{\mu_0(X_0^n)}\) and normalize to \(W_0^n\).
  3. Run the chains: \(X_t^n \sim K^t(X_0^n, \cdot)\).
  4. Use the initial weights \(W_0^n\) for the samples \(X_t^n\).

Problem: This gives a consistent estimator, but the diagnostic \(\frac{1}{N}\sum f(N W_0^n)\) is static. It ignores all mixing done by the kernel.

We need weights that evolve.

Evolving Weights with Couplings

A coupling \(\bar{K}(x, y, \cdot, \cdot)\) runs two chains \((X_{t+1}, Y_{t+1})\) jointly from \((X_t, Y_t)\), encouraging them to meet.

Key Idea: If chains meet (\(X_{t+1} = Y_{t+1}\)), they become information-theoretically identical.

Therefore, they should share their historical information, i.e., their weights.

\[\begin{align} W_{t+1}^n &= \alpha W_t^n + (1-\alpha) W_t^m \\ W_{t+1}^m &= (1-\alpha) W_t^n + \alpha W_t^m \end{align}\] We use \(\alpha=1/2\), as this choice optimally reduces the variance of the weights.

The Harmonization Step

At step \(t\), we have two chains, \(X_t^n\) and \(X_t^m\), with weights \(W_t^n\) and \(W_t^m\).

We run one step of the coupled kernel: \((X_{t+1}^n, X_{t+1}^m) \sim \bar{K}(X_t^n, X_t^m, \cdot, \cdot)\).

  • If they don’t meet (\(X_{t+1}^n \neq X_{t+1}^m\)):
    • Weights are unchanged: \(W_{t+1}^n = W_t^n\), \(W_{t+1}^m = W_t^m\).
  • If they meet (\(X_{t+1}^n = X_{t+1}^m\)):
    • They now occupy the same state. It makes sense to pool their weights.
    • We harmonize their weights by averaging them: \[\begin{equation} W_{t+1}^n = W_{t+1}^m = \frac{W_t^n + W_t^m}{2} \end{equation}\]

The Weight Harmonization Algorithm

We need to allow all chains to interact. We run \(2N\) chains, viewed as \(N\) pairs.

At each step \(t \to t+1\):

  1. Couple: For each of the \(N\) current pairs \((X_t^n, X_t^{m})\), sample \((X_{t+1}^n, X_{t+1}^{m}) \sim \bar{K}(X_t^n, X_t^{m}, \cdot, \cdot)\).

  2. Harmonize: If \(X_{t+1}^n = X_{t+1}^{m}\), set \(W_{t+1}^n = W_{t+1}^{m} = (W_t^n + W_t^{m})/2\). Otherwise, weights are unchanged.

  3. Reshuffle: Randomly permute the pairings among the chains that just met. This allows information (weights) to propagate throughout the entire system of \(2N\) chains over time.

The Weight Harmonization Algorithm

The Weight Harmonization Algorithm

Theoretical Result 1: Consistency

Proposition (Invariance of Expectations) The un-normalized estimator \(\hat{I}_{t,N}(\varphi) = \sum_{n=1}^{2N} w_t^n \varphi(X_t^n)\) has a constant expectation over time: \[\begin{equation} \mathbb{E}[\hat{I}_{t,N}(\varphi)] = \mathbb{E}[\hat{I}_{0,N}(\varphi)] = 2N \int \varphi(x) \gamma(x)\mathrm{d} x \end{equation}\]

Theorem (Consistency) For any time \(t\), as the number of particles \(N \to \infty\), our weighted sample provides a consistent estimate of the target expectation: \[\begin{equation} \sum_{n=1}^{2N} W_t^n \varphi(X_t^n) \xrightarrow{\mathbb{P}} \int \varphi(x) \pi(x) \mathrm{d} x \end{equation}\]

Theoretical Result 2: Data-processing Inequality

Proposition (Non-increasing bounds) The \(f\)-divergence upper bound is guaranteed to be non-increasing with time, almost surely. \[\begin{equation} \sum_{n=1}^{2N} f(N W_{t+1}^n) \le \sum_{n=1}^{2N} f(N W_t^n) \end{equation}\]

Thus the ESS is non-decreasing. \[\begin{equation} \text{ESS}_{t+1} \ge \text{ESS}_t \end{equation}\]

Theoretical Result 3: Convergence to Zero

Assumption (Minimum probability of coupling) \[\begin{equation} \mathbb{P}(X' = Y' \mid x,y) \ge p_c > 0 \quad \text{for all } x,y \end{equation}\]

Theorem (Geometric Weight Convergence) Under the uniform coupling assumption, the weights converge exponentially fast to the uniform distribution \(\bar{W} = (1/2N, \dots, 1/2N)\): \[\begin{equation} \mathbb{E}[\|W_t - \bar{W}\|_2^2] = O(\rho^t) \quad \text{for some } 0 <\rho < 1 \end{equation}\]

This implies that \(D_f(\pi_N \| \mu_N) \to 0\) almost surely, confirming our method provides a valid convergence diagnostic that vanishes at stationarity.

Experiment 1: Ornstein-Uhlenbeck Process

  • Diagnostic correctly tracks the true convergence profile.
  • The bound is conservative, as predicted by theory.
  • Conservativeness worsens slightly with more particles (\(N\)).

Experiment 2: Pòlya-Gamma Gibbs Sampler

A Bayesian logistic regression (\(d=49\)) compared with the diagnostic of Biswas, Jacob, and Vanetti (2019).

  • Competitive on TV distance.
  • Key advantage: No lag/warm-up needed, works from \(t=0\).
  • ESS diagnostic suggests convergence after ~20 steps.

Experiment 3: MALA on Stochastic Volatility

A high-dimensional (\(d=2500\)) model.

  • TV bound is more conservative here than the lag-based method.
  • lag-free diagnostic that indicates convergence around \(t=300\).
  • Provides bounds for any \(f\)-divergence simultaneously.

Summary

We introduced a new method for MCMC convergence diagnostics based on weight harmonization.

  • It uses couplings to progressively average importance weights between parallel chains.
  • This provides consistent, online, non-increasing upper bounds for any \(f\)-divergence.
  • The diagnostic is guaranteed to converge to zero for ergodic chains.
  • It is easy to implement if a coupled kernel is available.
  • The \(\chi^2\)-based Effective Sample Size (ESS) provides a very intuitive measure of how many chains are contributing effectively to the estimate.

Limitations and Future Work

The main limitation is the conservativeness of the bound, especially for large \(N\). This is due to the pairwise interaction scheme.

Potential Improvements:

  1. Rao-Blackwellization: Instead of picking one random pairing, can we average over all possible pairings? This would dramatically increase interaction but is computationally challenging.

  2. Offline Correction: The current method is a “filtering” approach. Can we use ideas from particle smoothing (a “backward pass”) to refine the weights after the run is complete?

  3. Variance Reduction: Incorporate control variates to accelerate the convergence of the diagnostic itself.

Thank you!

References

Biswas, Niloy, Pierre E. Jacob, and Paul Vanetti. 2019. “Estimating Convergence of Markov chains with L-lag couplings.” Advances in Neural Information Processing Systems 32.
Corenflos, Adrien, and Hai-Dang Dau. 2025. A coupling-based approach to f-divergences diagnostics for Markov chain Monte Carlo.” https://arxiv.org/abs/2510.07559.
Gelman, Andrew, and Donald B. Rubin. 1992. Inference from Iterative Simulation Using Multiple Sequences.” Statistical Science 7 (4): 457–72. https://doi.org/10.1214/ss/1177011136.
Jacob, Pierre E., John O’Leary, and Yves F. Atchadé. 2020. “Unbiased Markov Chain Monte Carlo Methods with Couplings.” Journal of the Royal Statistical Society: Series B (Statistical Methodology) 82 (3): 543–600.
Lindvall, Torgny. 2002. Lectures on the Coupling Method. Dover.
Thorisson, Hermann. 2000. Coupling, Stationarity, and Regeneration. Springer-Verlag.
Vats, Dootika, and Christina Knudson. 2021. Revisiting the Gelman–Rubin Diagnostic.” Statistical Science 36 (4): 518–29. https://doi.org/10.1214/20-STS812.