Computing importance weights for Markov chain Monte Carlo via couplings: an application to f-divergence diagnostics

Adrien Corenflos

Department of Statistics, University of Warwick

Joint work with Hai-Dang Dau

1

College of Computing and Data Science, Nanyang Technological University

Motivation: MCMC

MCMC is used to compute expectations \(\pi(\varphi) = \mathbb{E}_{\pi}[\varphi(X)]\) when sampling from a target \(\pi\) is impossible.

Method:

  • Design a kernel \(K\) with invariant distribution \(\pi\): i.e., \(\pi\, K = \pi\).
  • Simulate \(X_{t+1} \sim K(X_t, \cdot)\), with \(\mathcal{L}(X_t) \to \pi\) as \(t \to \infty\).
  • Use time averages to approximate expectations: \[\begin{equation} \pi(\varphi) \approx \frac{1}{T-B+1} \sum_{t=B}^T \varphi(X_t) \end{equation}\]

The Target: Importance Weights

MCMC simulations generate samples \(X_t \sim \mu_t\), but our ultimate goal is to compute expectations with respect to \(\pi\).

The discrepancy between \(\mu_t\) and \(\pi\) is captured by their density ratio: \[\begin{equation} \psi_t(x) \propto \frac{\pi(x)}{\mu_t(x)} \end{equation}\]

This quantity \(\psi_t(x)\) is the importance weight of \(x\) at time \(t\). If we had access to \(\psi_t\), we could reweight our MCMC samples

\[\begin{equation} \pi(\varphi) = \int \varphi(x) \pi(x) dx = \frac{\int \varphi(x) \psi_t(x) \mu_t(x) dx}{\int \psi_t(x) \mu_t(x) dx} \end{equation}\]

Two Birds with One Stone

If we can accurately compute or approximate the importance weight \(\psi_t(x)\), we immediately unlock two capabilities:

  1. Bias-corrected Estimation: Compute expectations via importance sampling early in the chain. \[\begin{equation} \pi(\varphi) \approx \sum_{n=1}^{2N} W_t^n \varphi(X_t^n), \quad \text{where } W_t^n \propto \psi_t(X_t^n) \end{equation}\]
  2. Diagnostics: Track convergence flawlessly using the true distance between \(\mu_t\) and \(\pi\). This will be our main application in this talk, but the potential for bias-corrected estimation is equally exciting!

The Core Objective: Find a method to dynamically approximate \(\psi_t(x)\).

Importance Weights from MCMC Paths

If we look at the path \(X_{0:t}\) at stationarity vs. initialization, we have:

\[\begin{align*} \mathbb{P}_\pi(X_{0:t}) &= \pi(X_0) \prod_{s=1}^t K(X_{s-1}, X_s)\\ \mathbb{P}_{\mu_0}(X_{0:t}) &= \mu_0(X_0) \prod_{s=1}^t K(X_{s-1}, X_s) \end{align*}\]

The likelihood ratio over the entire path simplifies:

\[ \frac{\mathbb{P}_\pi(X_{0:t})}{\mathbb{P}_{\mu_0}(X_{0:t})} = \frac{\pi(X_0) \prod_{s=1}^t K(X_{s-1}, X_s)}{\mu_0(X_0) \prod_{s=1}^t K(X_{s-1}, X_s)} = \frac{\pi(X_0)}{\mu_0(X_0)} = w_0 \]

By choosing a test function that only keeps the \(t\)-th component (i.e., \(h(X_{0:t}) = h(X_t)\)), \(w_0\) remains a valid weight for time \(t\).

This talk

  1. A novel “weight harmonization” scheme for parallel MCMC chains that produces a consistent, computable approximation of the importance weights \(\psi_t(X_t^n)\).

  2. A general method to use these weights to build online upper bounds on any \(f\)-divergence (e.g. Total Variation, bounds).

  3. The resulting bounds provide a diagnostic that is guaranteed to improve over time without any lag or warm-up.

The Key Intuition: Reallocating Weights

Suppose we are tracking multiple weighted particles to approximate our target distribution.

Say two particles happen to reach the exact same state, \(X = Y = x\), but they carry different weights \(W_X\) and \(W_Y\).

Their joint contribution to the empirical measure is simply the sum of their masses at \(x\): \[\begin{equation} W_X \delta_{x} + W_Y \delta_{x} = (W_X + W_Y) \delta_{x} \end{equation}\]

Crucial Insight:

As long as we preserve this sum \(W_X + W_Y\), we are perfectly free to reallocate the individual weights between the two particles however we choose, without altering the empirical measure at all!

Weight Harmonization via Couplings

A coupling \(\bar{K}(X_t^n, X_t^m, \cdot, \cdot)\) runs two weighted chains \((X_{t+1}^n, X_{t+1}^m)\) jointly from \((X_t^n, X_t^m)\), encouraging them to meet.

  • If they don’t meet (\(X_{t+1}^n \neq X_{t+1}^m\)):
    • Weights are unchanged: \(W_{t+1}^n = W_t^n\) and \(W_{t+1}^m = W_t^m\).
  • If they meet (\(X_{t+1}^n = X_{t+1}^m = x\)):
    • Following our intuition, their joint contribution is \((W_t^n + W_t^m) \delta_{x}\).
    • We harmonize their weights by equally reallocating the sum!

\[\begin{align} W_{t+1}^n \gets W_{t+1}^m \gets (W_t^n + W_t^m) / 2 \end{align}\]

The Weight Harmonization Algorithm

We run \(2N\) chains, viewed as \(N\) pairs.

Initialization (\(t=0\)): Sample \(X_0^n \sim \mu_0\) and compute normalized weights \(W_0^n \propto \pi(X_0^n) / \mu_0(X_0^n)\).

At each step \(t \to t+1\):

  1. Couple: For each of the \(N\) current pairs \((X_t^n, X_t^{m})\), sample \((X_{t+1}^n, X_{t+1}^{m}) \sim \bar{K}(X_t^n, X_t^{m}, \cdot, \cdot)\).

  2. Harmonize: If \(X_{t+1}^n = X_{t+1}^{m}\), set \(W_{t+1}^n = W_{t+1}^{m} = (W_t^n + W_t^{m})/2\). Otherwise, weights are unchanged.

  3. Reshuffle: Randomly permute the pairings among the chains that just met. This allows information (weights) to propagate throughout the entire system of \(2N\) chains over time.

The Weight Harmonization Algorithm

Theoretical Result 1: Consistency

Proposition (Invariance of Expectations) The un-normalized estimator \(\hat{I}_{t,2N}(\varphi) = \sum_{n=1}^{2N} w_t^n \varphi(X_t^n)\) has a constant expectation over time: \[\begin{equation} \mathbb{E}[\hat{I}_{t,2N}(\varphi)] = \mathbb{E}[\hat{I}_{0,2N}(\varphi)] = 2N \int \varphi(x) \gamma(x)\mathrm{d} x \end{equation}\]

Theorem (Consistency) For any time \(t\), as the number of particles \(2N \to \infty\), our weighted sample provides a consistent estimate of the target expectation: \[\begin{equation} \sum_{n=1}^{2N} W_t^n \varphi(X_t^n) \xrightarrow{a.s.} \int \varphi(x) \pi(x) \mathrm{d} x \end{equation}\]

Application: MCMC Diagnostics

What can we actually do with these approximated weights \(W_t^{1:2N}\)?

  1. We could use them to compute debiased expectations before the chain has converged.
  2. We can use them to diagnose the convergence of the algorithm in an online manner because measuring how close \(W_t^{1:2N}\) is to the uniform vector \((1 / 2N, \ldots, 1 / 2N)\) is a proxy for how close \(\mu_t\) is to \(\pi\).

Let’s focus on diagnostics.

Background: \(f\)-divergences

Let \(f:[0, \infty) \to \mathbb{R}\) be a convex function with \(f(1)=0\). \[\begin{equation} D_f(\pi \| \mu) = \int \mu(\mathrm{d} x) f\left( \frac{\mathrm{d}\pi}{\mathrm{d}\mu}(x) \right) \end{equation}\]

This family includes many famous divergences:

  • Kullback-Leibler: \(f(t) = t \log t\)
  • Total Variation: \(f(t) = |t-1|/2\)
  • \(\chi^2\)-divergence: \(f(t) = (t-1)^2\)
  • Hellinger distance: \(f(t) = (\sqrt{t}-1)^2\)

\(f\)-divergences for Weighted Samples

Consider two discrete measures:

  • Target: \(\pi_{2N} = \sum_{n=1}^{2N} W^n \delta_{X^n}\)
  • Reference: \(\mu_{2N} = \frac{1}{2N} \sum_{n=1}^{2N} \delta_{X^n}\)

The \(f\)-divergence between them is simple to compute: \[\begin{equation} D_f(\pi_{2N} \| \mu_{2N}) = \frac{1}{2N} \sum_{n=1}^{2N} f(2N W^n) \end{equation}\]

From Weights to Diagnostics

Key Insight (Theorem 1):

If we have a “good” set of weights \(W_t^{1:M}\) for our MCMC samples \(X_t^{1:M} \sim \mu_t\) (i.e., \(\sum W_t^m \delta_{X_t^m}\) is a consistent estimator of \(\pi\)), then:

\[\begin{equation} \mathbb{P}\left( \frac{1}{M}\sum_{m=1}^{M} f(M W_t^m) \le D_f(\pi \| \mu_t) - \varepsilon \right) \to 0 \quad \text{as } M \to \infty \end{equation}\]

So, because our weight harmonization algorithm generates consistent weights \(W_t^n\) at each MCMC step \(t\), we can directly track the convergence of the chain’s distribution using empirical \(f\)-divergences!

Theoretical Result 2: Data-processing Inequality

Proposition (Non-increasing bounds) The \(f\)-divergence upper bound is guaranteed to be non-increasing with time, almost surely. \[\begin{equation} \sum_{n=1}^{2N} f(2N W_{t+1}^n) \le \sum_{n=1}^{2N} f(2N W_t^n) \end{equation}\]

Mirrors the true data-processing inequality for \(f\)-divergences, which states that applying a Markov kernel can only improve the convergence of the chain:

\[\begin{equation} D_f(\pi \| \mu_{t+1}) \le D_f(\pi \| \mu_t) \end{equation}\]

Theoretical Result 3: Convergence to Zero

Assumption (Minimum probability of coupling) \[\begin{equation} \mathbb{P}(X' = Y' \mid x,y) \ge p_c > 0 \quad \text{for all } x,y \end{equation}\]

Theorem (Geometric Weight Convergence) Under the uniform coupling assumption, the weights converge exponentially fast to the uniform distribution \(\bar{W} = (1/2N, \dots, 1/2N)\): \[\begin{equation} \mathbb{E}[\|W_t - \bar{W}\|_2^2] = O(\rho^t) \quad \text{for some } 0 <\rho < 1 \end{equation}\]

This implies that \(D_f(\pi_{2N} \| \mu_{2N}) \to 0\) almost surely, confirming our method provides a valid convergence diagnostic that vanishes at stationarity.

Theoretical Result 4: Convergence without Rates

Theorem (Almost Sure Convergence) Under general ergodicity conditions (without requiring the uniform coupling bound), the weights converge almost surely to the uniform distribution \(\bar{W} = (1/2N, \dots, 1/2N)\): \[\begin{equation} W_t \xrightarrow{a.s.} \bar{W} \quad \text{as } t \to \infty \end{equation}\]

Again, this implies that \(D_f(\pi_{2N} \| \mu_{2N}) \to 0\) almost surely, but without any rate guarantees.

Experiment 1: Pólya-Gamma Gibbs Sampler

A Bayesian logistic regression (\(d=49\)) compared with the diagnostic of Biswas, Jacob, and Vanetti (2019).

  • Conservative compared to SOTA
  • Key advantage: No lag/warm-up needed, works from \(t=0\).

Experiment 2: MALA on Stochastic Volatility

A high-dimensional (\(d=2500\)) model.

  • Still conservative compared to SOTA, but less so than in the previous example: we can benefit from a warm-start.
  • Provides bounds for any \(f\)-divergence simultaneously.

Summary

We introduced a new method for MCMC convergence diagnostics based on weight harmonization.

  • It uses couplings to progressively average importance weights between parallel chains.
  • This provides consistent, online, non-increasing upper bounds for any \(f\)-divergence.
  • The diagnostic is guaranteed to converge to zero for ergodic chains.
  • It is easy to implement if a coupled kernel is available.

Limitations and Future Work

The main limitation is the conservativeness of the bound, especially for large \(2N\). This is due to the pairwise interaction scheme.

Potential Improvements:

  1. Rao-Blackwellization: Instead of picking one random pairing, can we average over all possible pairings? This would dramatically increase interaction but is computationally challenging.

  2. Offline Correction: The current method is a “filtering” approach. Can we use ideas from particle smoothing (a “backward pass”) to refine the weights after the run is complete?

  3. Variance Reduction: Incorporate control variates to accelerate the convergence of the diagnostic itself.

Some recent news

A recent paper just appeared on arXiv: Multi-Marginal Couplings for Metropolis–Hastings by Buu Phan, Gergely Flamich, Ashish Khisti, Shahab Asoodeh, which proposes somthing that kind of can be seen as a Rao–Blackwellized version of our method.

To be continued…

Thank you!

References

Biswas, Niloy, Pierre E. Jacob, and Paul Vanetti. 2019. “Estimating Convergence of Markov chains with L-lag couplings.” Advances in Neural Information Processing Systems 32.
Corenflos, Adrien, and Hai-Dang Dau. 2025. A coupling-based approach to f-divergences diagnostics for Markov chain Monte Carlo.” https://arxiv.org/abs/2510.07559.
Jacob, Pierre E., John O’Leary, and Yves F. Atchadé. 2020. “Unbiased Markov Chain Monte Carlo Methods with Couplings.” Journal of the Royal Statistical Society: Series B (Statistical Methodology) 82 (3): 543–600.
Lindvall, Torgny. 2002. Lectures on the Coupling Method. Dover.
Thorisson, Hermann. 2000. Coupling, Stationarity, and Regeneration. Springer-Verlag.
Vats, Dootika, and Christina Knudson. 2021. Revisiting the Gelman–Rubin Diagnostic.” Statistical Science 36 (4): 518–29. https://doi.org/10.1214/20-STS812.