Sampling from High-Dimensional Gaussian Distributions without the Full Covariance Matrix
October 30, 2024 — October 31, 2024
Assumed audience:
ML people
When dealing with high-dimensional Gaussian distributions, sampling can become computationally expensive, especially when the covariance matrix is large and dense. Traditional methods like the Cholesky decomposition become impractical. However, if we can efficiently compute the product of the covariance matrix with arbitrary vectors, we can leverage Langevin dynamics to sample from the distribution without forming the full covariance matrix.
I have been doing this recently in the setting where \(\Sigma\) is outrageously large, but I can nonetheless calculate it for arbitrary vectors \(\Sigma \mathbf{v}\); This arises, for example, when I have a kernel which I can evaluate and I need to use it to generate some samples from my random field, especially where the kernel arises as linear product under some feature map.
TODO: evaluate actual computational complexity of this method.
Note this is really just some notes I have made to myself. I need to sanity check the procedure on a real problem.
1 Problem Setting
We aim to sample from a multivariate Gaussian distribution:
\[ \mathbf{x} \sim \mathcal{N}(\boldsymbol{\mu}, \Sigma) \]
where:
- \(\boldsymbol{\mu} \in \mathbb{R}^D\) is the known mean vector.
- \(\Sigma \in \mathbb{R}^{D \times D}\) is the notional known covariance matrix, which might be too large to actually compute, let alone factorise for sampling in the usual way.
2 Langevin Dynamics for Sampling
Langevin dynamics provide a way to sample from a target distribution by simulating a stochastic differential equation (SDE) whose stationary distribution is the desired distribution. For a Gaussian distribution, the SDE simplifies due to the properties of the normal distribution (i.e. Gaussians all the way down).
2.1 The Langevin Equation
The continuous-time Langevin equation is
\[ d\mathbf{x}_t = -\nabla U(\mathbf{x}_t) \, dt + \sqrt{2} \, d\mathbf{W}_t \]
where:
- \(U(\mathbf{x})\) is the potential function related to the target distribution \(p(\mathbf{x})\) via \(p(\mathbf{x}) \propto e^{-U(\mathbf{x})}\).
- \(d\mathbf{W}_t\) represents the increment of a Wiener process (standard Brownian motion).
For our Gaussian distribution, the potential function is:
\[ U(\mathbf{x}) = \frac{1}{2} (\mathbf{x} - \boldsymbol{\mu})^\top \Sigma^{-1} (\mathbf{x} - \boldsymbol{\mu}) \]
We discretize the Langevin equation using the Euler-Maruyama method with time step \(\epsilon\):
\[ \mathbf{x}_{k+1} = \mathbf{x}_k - \epsilon \nabla U(\mathbf{x}_k) + \sqrt{2\epsilon} \, \boldsymbol{\eta}_k \]
where \(\boldsymbol{\eta}_k \sim \mathcal{N}(\mathbf{0}, \mathbf{I}_D)\).
Next, the gradient of the potential function is:
\[ \nabla U(\mathbf{x}) = \Sigma^{-1} (\mathbf{x} - \boldsymbol{\mu}) \]
Instead of computing \(\Sigma^{-1}\) directly, we can solve the linear system:
\[ \Sigma \mathbf{v} = \mathbf{x} - \boldsymbol{\mu} \]
for \(\mathbf{v}\), which gives \(\mathbf{v} = \Sigma^{-1} (\mathbf{x} - \boldsymbol{\mu})\).
3 Now, to solve that linear equation
To solve \(\Sigma \mathbf{v} = \mathbf{r}\) efficiently without forming \(\Sigma\), we use the Conjugate Gradient (CG) method. The CG method is suitable for large, sparse, and positive-definite matrices and relies only on matrix-vector products \(\Sigma \mathbf{v}\).
Given \(\Sigma \mathbf{v} = \mathbf{r}\):
- Initialize \(\mathbf{v}_0 = \mathbf{0}\), \(\mathbf{r}_0 = \mathbf{r} - \Sigma \mathbf{v}_0\), \(\mathbf{p}_0 = \mathbf{r}_0\).
- For \(k = 0, 1, \ldots\):
- \(\alpha_k = \frac{\mathbf{r}_k^\top \mathbf{r}_k}{\mathbf{p}_k^\top \Sigma \mathbf{p}_k}\)
- \(\mathbf{v}_{k+1} = \mathbf{v}_k + \alpha_k \mathbf{p}_k\)
- \(\mathbf{r}_{k+1} = \mathbf{r}_k - \alpha_k \Sigma \mathbf{p}_k\)
- If \(\|\mathbf{r}_{k+1}\| < \text{tolerance}\), stop.
- \(\beta_k = \frac{\mathbf{r}_{k+1}^\top \mathbf{r}_{k+1}}{\mathbf{r}_k^\top \mathbf{r}_k}\)
- \(\mathbf{p}_{k+1} = \mathbf{r}_{k+1} + \beta_k \mathbf{p}_k\)
4 Plug the bits together
We have the following algorithm:
- Initialization:
- Start with \(\mathbf{x}_0 = \boldsymbol{\mu}\) or any arbitrary vector.
- For \(k = 0, 1, \ldots, N\):
- Compute \(\mathbf{r}_k = \mathbf{x}_k - \boldsymbol{\mu}\).
- Solve \(\Sigma \mathbf{v}_k = \mathbf{r}_k\) using CG to get \(\mathbf{v}_k = \Sigma^{-1} (\mathbf{x}_k - \boldsymbol{\mu})\).
- Update$ _{k+1} = _k - _k + , _k$ where \(\boldsymbol{\eta}_k \sim \mathcal{N}(\mathbf{0}, \mathbf{I}_D)\).
5 PyTorch Implementation
For my sins, I am cursed to never escape PyTorch. Here is an implementation in that language that I got an LLM to construct for me from the above algorithm.
5.1 Define the Matrix-Vector Product
First, we need a function to compute \(\Sigma \mathbf{v}\) efficiently.
5.2 Conjugate Gradient Solver
Oh dang, the LLM did a really good job on this.
def cg_solver(b, tol=1e-5, max_iter=100):
x = torch.zeros_like(b)
r = b.clone()
p = r.clone()
rs_old = torch.dot(r, r)
for _ in range(max_iter):
Ap = sigma_mv_prod(p)
alpha = rs_old / torch.dot(p, Ap)
x += alpha * p
r -= alpha * Ap
rs_new = torch.dot(r, r)
if torch.sqrt(rs_new) < tol:
break
p = r + (rs_new / rs_old) * p
rs_old = rs_new
return x
5.3 Langevin Dynamics Sampler
def sample_mvn_langevin(mu, num_samples=1000, epsilon=1e-3, burn_in=100):
"""
Samples from N(mu, Σ) using Langevin dynamics.
Parameters:
- mu: Mean vector (torch.Tensor of shape [D])
- num_samples: Number of samples to collect after burn-in
- epsilon: Time step size
- burn_in: Number of initial iterations to discard
"""
D = mu.shape[0]
x = mu.clone().detach()
samples = []
total_steps = num_samples + burn_in
for n in range(total_steps):
# Compute gradient: v = Σ^{-1} (x - μ)
r = x - mu
v = cg_solver(r, tol=1e-5, max_iter=100)
# Langevin update
noise = torch.randn(D)
x = x - epsilon * v + torch.sqrt(torch.tensor(2 * epsilon)) * noise
if n >= burn_in:
samples.append(x.detach().clone())
return torch.stack(samples)
5.4 Usage Example
6 Validation
After sampling, it’s wise to verify that the samples approximate the target distribution.
import matplotlib.pyplot as plt
empirical_mean = samples.mean(dim=0)
empirical_cov = torch.from_numpy(np.cov(samples.numpy(), rowvar=False))
print("Empirical Mean:\n", empirical_mean)
print("Empirical Covariance Matrix:\n", empirical_cov)
# Plot histogram for the first dimension
plt.hist(samples[:, 0].numpy(), bins=30, density=True)
plt.title("Histogram of First Dimension")
plt.xlabel("Value")
plt.ylabel("Density")
plt.show()