Computing importance weights for Markov chain Monte Carlo via couplings: an application to f-divergence diagnostics

Adrien Corenflos

Department of Statistics, University of Warwick

Joint work with Hai-Dang Dau

1

Department of Statistics and Data Science, National University of Singapore

Motivation: MCMC

MCMC is used to compute expectations \(\pi(\varphi) = \mathbb{E}_{\pi}[\varphi(X)]\) when sampling from a target \(\pi\) is impossible.

Method:

  • Design a kernel \(K\) with invariant distribution \(\pi\): i.e., \(\pi\, K = \pi\).
  • Simulate \(X_{t+1} \sim K(X_t, \cdot)\), with \(\mathcal{L}(X_t) \to \pi\) as \(t \to \infty\).
  • Use time averages to approximate expectations: \[\begin{equation} \pi(\varphi) \approx \frac{1}{T-B+1} \sum_{t=B}^T \varphi(X_t) \end{equation}\]

Nagging Question: “are we there yet?”

Standard diagnostics:

  • Gelman-Rubin (\(\hat{R}\)): Compares variance within and between \(2N\) parallel chains. A variance ratio \(\hat{R} \approx 1\) suggests convergence.
  • Effective Sample Size (ESS): Estimates the number of independent samples equivalent to the MCMC output for a given function \(\varphi\). \[\begin{equation} \text{ESS} = \frac{T}{1 + 2 \sum_{k=1}^T \text{Corr}(\varphi(X_0), \varphi(X_k))} \end{equation}\]

Limitations of Standard Diagnostics

The Gelman-Rubin diagnostic (Gelman and Rubin 1992) and ESS have a key limitation:

They are function-specific.

  • A chain might have a good ESS for \(\mathbb{E}[X]\) but a terrible one for \(\mathbb{E}[X^2]\).
  • We get a partial picture of convergence, tied to specific test functions \(\varphi\).

Wouldn’t it be better to diagnose the convergence of the distribution of the chain itself?

Distribution-based Diagnostics

Theoretical convergence is measured by statistical distances between the chain’s distribution \(\mu_t\) and the target \(\pi\).

  • Total Variation Distance: \(\|\mu_t - \pi\|_{\text{TV}}\)
  • Wasserstein Distance: \(W_p(\mu_t, \pi)\)

Biswas, Jacob, and Vanetti (2019) provide computable upper bounds on these distances using coupled chains.

A coupling runs two chains \((X_t, Y_t)\) jointly to make them meet:

If \(X_{t+L} \sim \pi\) and \(Y_{t} \sim \mu_t\), then the meeting time \(\frac{\tau - t -L}{L}\) upper bounds \(\|\mu_t - \pi\|_{\text{TV}}\) on average.

Distribution-based Diagnostics

The independent Metropolis-Hastings case

For a target \(\pi\) and proposal \(q\), the independent Metropolis-Hastings (MH) algorithm iterates:

  1. Propose \(X' \sim q(\cdot)\)
  2. Accept \(X'\) with probability \(\alpha = 1 \wedge \tfrac{\pi(X') / q(X')}{\pi(X) / q(X)}\)

To couple two independent MH chains \((X_t, Y_t)\), we can use the same proposal \(X' \sim q(\cdot)\) for both chains at each step:

  1. Propose \(X' \sim q(\cdot)\)
  2. Compute acceptance probabilities \(\alpha_X\) and \(\alpha_Y\) for \(X_t\) and \(Y_t\) respectively.
  3. Draw a unique uniform random variable \(U \sim \text{Uniform}(0,1)\)
  4. If \(U < \alpha_X\), accept for \(X_t\); if \(U < \alpha_Y\), accept for \(Y_t\).
  5. If both chains accept the same proposal, they meet: \(X_{t+1} = Y_{t+1} = X'\).

Limitations of TV/Wasserstein Bounds

Practical Challenges:

  • Portability: I can’t use \(\|\mu_t - \pi\|_{\text{TV}} < 0.05\) to improve my estimates of \(\mathbb{E}_\pi[\varphi(X)]\) directly.
  • Lag requirement: The method requires a second chain started from a “converged” state \(Y_{t-L}\), for adaptive \(L\).

We seek a diagnostic that is:

  • Distributional
  • Usable
  • Lag-free

Background: \(f\)-divergences

Let \(f:[0, \infty) \to \mathbb{R}\) be a convex function with \(f(1)=0\). \[\begin{equation} D_f(\pi \| \mu) = \int \mu(\mathrm{d} x) f\left( \frac{\mathrm{d}\pi}{\mathrm{d}\mu}(x) \right) \end{equation}\]

This family includes many famous divergences:

  • Kullback-Leibler: \(f(t) = t \log t\)
  • Total Variation: \(f(t) = |t-1|/2\)
  • \(\chi^2\)-divergence: \(f(t) = (t-1)^2\)
  • Hellinger distance: \(f(t) = (\sqrt{t}-1)^2\)

\(f\)-divergences for Weighted Samples

Consider two discrete measures:

  • Target: \(\pi_{2N} = \sum_{n=1}^{2N} W^n \delta_{X^n}\)
  • Reference: \(\mu_{2N} = \frac{1}{2N} \sum_{n=1}^{2N} \delta_{X^n}\)

The \(f\)-divergence between them is simple to compute: \[\begin{equation} D_f(\pi_{2N} \| \mu_{2N}) = \frac{1}{2N} \sum_{n=1}^{2N} f(2N W^n) \end{equation}\]

The Ideal Scenario

Imagine we knew the density ratio \(\psi_t(x) = \frac{\mathrm{d}\pi}{\mathrm{d}\mu_t}(x)\). Given samples \(X_t^n \sim \mu_t\), we could:

  1. Get ideal estimates of \(\pi\) via importance sampling: \[\begin{equation} \pi(\varphi) \approx \sum_{n=1}^{2N} W_t^n \varphi(X_t^n), \quad \text{where } W_t^n \propto \psi_t(X_t^n) \end{equation}\]
  2. Compute any \(f\)-divergence \(D_f(\pi \| \mu_t)\), which only depends on this ratio \[\begin{equation} D_f(\pi \| \mu_t) = \mathbb{E}_{\mu_t}\left[f\left(\psi_t(X)\right)\right] \end{equation}\]

Of course, \(\psi_t(x)\) is intractable. But what if we could approximate it?

A Naive (and flawed) Approach

The likelihood ratio over the entire path simplifies to the initial weight only:

\[ \frac{\mathbb{P}_\pi(X_{0:t})}{\mathbb{P}_{\mu_0}(X_{0:t})} = \frac{\pi(X_0) \prod_{s=1}^t K(X_{s-1}, X_s)}{\mu_0(X_0) \prod_{s=1}^t K(X_{s-1}, X_s)} = \frac{\pi(X_0)}{\mu_0(X_0)} =: w_0 \]

This motivates a simple diagnostic method:

  1. Start \(2N\) chains from \(X_0^n \sim \mu_0\).
  2. Compute initial weights \(w_0^n = \frac{\pi(X_0^n)}{\mu_0(X_0^n)}\) and normalize to \(W_0^n\).
  3. Run the chains: \(X_t^n \sim K^t(X_0^n, \cdot)\).
  4. Use the initial weights \(W_0^n\) for the samples \(X_t^n\) and compute the diagnostic \(\frac{1}{2N}\sum f(2N W_0^n)\).

This talk

  1. A novel “weight harmonization” scheme for parallel MCMC chains that produces a consistent, computable approximation of the importance weights \(\psi_t(X_t^n)\).

  2. A general method to use these weights to get online upper bounds on any \(f\)-divergence, including KL, \(\chi^2\), Hellinger, and TV distance.

  3. The resulting diagnostic is guaranteed to improve over time and requires no lag or warm-up.

  4. The \(\chi^2\)-divergence bound provides an estimate of the “effective number of active chains”, a highly practical and interpretable metric.

From Weights to Diagnostics

Key Insight (Theorem 1):

If we have a “good” set of weights \(W_t^{1:M}\) for our MCMC samples \(X_t^{1:M} \sim \mu_t\) (i.e., \(\sum W_t^m \delta_{X_t^m}\) is a consistent estimator of \(\pi\)), then:

\[\begin{equation} \mathbb{P}\left( \frac{1}{M}\sum_{m=1}^{M} f(M W_t^m) \le D_f(\pi \| \mu_t) - \varepsilon \right) \to 0 \quad \text{as } M \to \infty \end{equation}\]

So, if we can generate “good” weights \(W_t^m\) at each MCMC step \(t\), we can track the convergence of the chain’s distribution.

Weight Harmonization via Couplings

A coupling \(\bar{K}(X_t^n, X_t^m, \cdot, \cdot)\) runs two weighted chains \((X_{t+1}^n, X_{t+1}^m)\) jointly from \((X_t^n, X_t^m)\), encouraging them to meet.

  • If they don’t meet (\(X_{t+1}^n \neq X_{t+1}^m\)):
    • Weights are unchanged: \(W_{t+1}^n = W_t^n\) and \(W_{t+1}^m = W_t^m\).
  • If they meet (\(X_{t+1}^n = X_{t+1}^m = x\)):
    • Their joint contribution is \(W_t^n \delta_{x} + W_t^m \delta_{x} = (W_t^n + W_t^m) \delta_{x}\).
    • The empirical measure depends only on the total weight, not on how we allocate it between the chains!

We harmonize their weights by taking a convex combination:

\[\begin{align} W_{t+1}^n \gets W_{t+1}^m \gets (W_t^n + W_t^m) / 2 \end{align}\]

The Weight Harmonization Algorithm

We need to allow all chains to interact. We run \(2N\) chains, viewed as \(N\) pairs.

At each step \(t \to t+1\):

  1. Couple: For each of the \(N\) current pairs \((X_t^n, X_t^{m})\), sample \((X_{t+1}^n, X_{t+1}^{m}) \sim \bar{K}(X_t^n, X_t^{m}, \cdot, \cdot)\).

  2. Harmonize: If \(X_{t+1}^n = X_{t+1}^{m}\), set \(W_{t+1}^n = W_{t+1}^{m} = (W_t^n + W_t^{m})/2\). Otherwise, weights are unchanged.

  3. Reshuffle: Randomly permute the pairings among the chains that just met. This allows information (weights) to propagate throughout the entire system of \(2N\) chains over time.

The Weight Harmonization Algorithm

Theoretical Result 1: Consistency

Proposition (Invariance of Expectations) The un-normalized estimator \(\hat{I}_{t,2N}(\varphi) = \sum_{n=1}^{2N} w_t^n \varphi(X_t^n)\) has a constant expectation over time: \[\begin{equation} \mathbb{E}[\hat{I}_{t,2N}(\varphi)] = \mathbb{E}[\hat{I}_{0,2N}(\varphi)] = 2N \int \varphi(x) \gamma(x)\mathrm{d} x \end{equation}\]

Theorem (Consistency) For any time \(t\), as the number of particles \(2N \to \infty\): \[\begin{equation} \sum_{n=1}^{2N} W_t^n \varphi(X_t^n) \xrightarrow{a.s.} \int \varphi(x) \pi(x) \mathrm{d} x \end{equation}\]

Corollary (Non-Asymptotic Bounds) These can be refined to finite-\(N\) guarantees: with probability \(\ge 1-\alpha\), \[\begin{equation} \frac{1}{2N}\sum_{n=1}^{2N} f(2N W_t^n) + \text{Error}_{2N,t}(\alpha) \ge D_f(\pi \| \mu_t) \end{equation}\]

Theoretical Results 3 & 4: Convergence to Zero

Theorem (Geometric Convergence) If \(\mathbb{P}(X' = Y' \mid x,y) \ge p_c > 0\) for all \(x,y\), the weights converge exponentially fast: \[\begin{equation} \mathbb{E}[\|W_t - \bar{W}\|_2^2] = O(\rho^t) \quad \text{for some } 0 < \rho < 1 \end{equation}\]

Theorem (Almost Sure Convergence) Under general ergodicity conditions (no rate required): \[\begin{equation} W_t \xrightarrow{a.s.} \bar{W} = (1/2N, \dots, 1/2N) \quad \text{as } t \to \infty \end{equation}\]

In both cases, \(D_f(\pi \| \mu_t) \to 0\) almost surely: the diagnostic vanishes at stationarity.

Experiment 2: MALA on Stochastic Volatility

A high-dimensional (\(d=2500\)) model.

  • Still conservative compared to SOTA, but less so than in the previous example: we can benefit from a warm-start.
  • Provides bounds for any \(f\)-divergence simultaneously.

Summary

We introduced a new method for MCMC convergence diagnostics based on weight harmonization.

  • It uses couplings to progressively average importance weights between parallel chains.
  • This provides consistent, online, non-increasing upper bounds for any \(f\)-divergence.
  • The diagnostic is guaranteed to converge to zero for ergodic chains.
  • It is easy to implement if a coupled kernel is available.
  • The \(\chi^2\)-based Effective Sample Size (ESS) provides a very intuitive measure of how many chains are contributing effectively to the estimate.

Limitations and Future Work

The main limitation is the conservativeness of the bound, especially for large \(2N\). This is due to the pairwise interaction scheme.

Potential Improvements:

  1. Rao-Blackwellization: Instead of picking one random pairing, can we average over all possible pairings? This would dramatically increase interaction but is computationally challenging.

  2. Offline Correction: The current method is a “filtering” approach. Can we use ideas from particle smoothing (a “backward pass”) to refine the weights after the run is complete?

  3. Variance Reduction: Incorporate control variates to accelerate the convergence of the diagnostic itself.

Thank you!

References

Biswas, Niloy, Pierre E. Jacob, and Paul Vanetti. 2019. “Estimating Convergence of Markov chains with L-lag couplings.” Advances in Neural Information Processing Systems 32.
Corenflos, Adrien, and Hai-Dang Dau. 2025. A coupling-based approach to f-divergences diagnostics for Markov chain Monte Carlo.” https://arxiv.org/abs/2510.07559.
Gelman, Andrew, and Donald B. Rubin. 1992. Inference from Iterative Simulation Using Multiple Sequences.” Statistical Science 7 (4): 457–72. https://doi.org/10.1214/ss/1177011136.
Jacob, Pierre E., John O’Leary, and Yves F. Atchadé. 2020. “Unbiased Markov Chain Monte Carlo Methods with Couplings.” Journal of the Royal Statistical Society: Series B (Statistical Methodology) 82 (3): 543–600.
Lindvall, Torgny. 2002. Lectures on the Coupling Method. Dover.
Thorisson, Hermann. 2000. Coupling, Stationarity, and Regeneration. Springer-Verlag.
Vats, Dootika, and Christina Knudson. 2021. Revisiting the Gelman–Rubin Diagnostic.” Statistical Science 36 (4): 518–29. https://doi.org/10.1214/20-STS812.