The Score Function Estimator
a.k.a. REINFORCE; a gradient estimator for expectations
September 13, 2024 — September 13, 2024
The score function (gradient) estimator, a.k.a. log-derivative trick, a.k.a. REINFORCE (all-caps, for some reason?), is a generic method that works on various types of variables. Credited to (Williams 1992), it must be older than that. So named because the score function features prominently in it.
It has notoriously high variance if done naïvely.
This is the fundamental insight:
\[ \begin{aligned} g(\ell) &= \frac{\partial}{\partial \theta} \mathbb{E}_{\mathsf{x}\sim p(\mathsf{x};\theta)} \ell(\mathsf{x}) \\ &= \frac{\partial}{\partial \theta} \int \ell(x) p(x;\theta) \mathrm{d} x\\ &= \int \ell(x) \frac{\partial}{\partial \theta} p(x;\theta) \mathrm{d} x\\ &= \mathbb{E}_{\mathsf{x}\sim p(\mathsf{x};\theta)} \ell(\mathsf{x}) \frac{\partial}{\partial \theta}\log p(\mathsf{x};\theta) \end{aligned} \]
This suggests a simple and obvious Monte Carlo estimate of the gradient by choosing sample \(x_i\sim p(x;\theta)\),
\[ \begin{aligned} \hat{g}_{\text{REINFORCE}}(\ell) &= \sum_i \ell(x_i) \frac{\partial}{\partial \theta}\log p(x_i;\theta) \end{aligned} \]
For unifying overviews, see (Mohamed et al. 2020; Schulman et al. 2015; van Krieken, Tomczak, and Teije 2021) and the Storchastic docs.
- Shakir Mohamed, Log Derivative Trick
- Syed Ashar Javed, REINFORCE vs Reparameterization Trick
It is annoyingly hard to find a clear example of this method online, despite its simplicity; all the code examples I see wrap it up with reinforcement learning or some other unnecessarily specific complexity.
Laurence Davies and I put together this demo, in which we try to find the parameters that minimise the difference between the categorical distribution we sample from and some target distribution.
import torch
# True target distribution probabilities
true_probs = torch.tensor([0.1, 0.6, 0.3])
# Optimisation parameters
n_batch = 1000
n_iter = 3000
lr = 0.01
def loss(x):
"""
The target loss, a negative log-likelihood for a
categorical distribution with the given probabilities.
"""
return -torch.distributions.Multinomial(
total_count=1, probs=true_probs).log_prob(x)
# Set the seed for reproducibility
torch.manual_seed(42)
# Initialize the parameter estimates
theta_hat = torch.nn.Parameter(torch.tensor([0., 0., 0.]))
optimizer = torch.optim.Adam([theta_hat], lr=lr)
for epoch in range(n_iter):
optimizer.zero_grad()
# Sample from the estimated distribution
x_sample = torch.distributions.Multinomial(
1, logits=theta_hat).sample((n_batch,))
# evaluate log density at the sample points
log_p_theta_x = torch.distributions.Multinomial(
1, logits=theta_hat).log_prob(x_sample)
# Evaluate the target function at the sample points
f_hat = loss(x_sample)
# Compute the gradient of the log density wrt parameters.
# The `grad_outputs` multiply the `f_hat` by gradient directly.
grad_log_p_theta_x = torch.autograd.grad(
outputs=log_p_theta_x,
inputs=theta_hat,
grad_outputs=torch.ones_like(log_p_theta_x),
create_graph=True)[0]
# The final gradients are weighted over the sample points
final_gradients = (
f_hat.detach().unsqueeze(1)
* grad_log_p_theta_x
).mean(dim=0)
theta_hat.grad = final_gradients
optimizer.step()
if epoch % 100 == 0:
print(f"Epoch {epoch}, Estimated Probs:"
f"{torch.softmax(theta_hat, dim=0).detach().numpy()}")
# Display the final estimated probabilities
estimated_final_probs = torch.softmax(theta_hat, dim=0)
print("Final Estimated Probabilities: "
f" {estimated_final_probs.detach().numpy()}"
f" (True Probabilities: {true_probs.detach().numpy()}")
Note that the batch size there is very large. If we set it to be smaller, the variance of the estimator is too high to be useful.
Classically we might address such problems with a diminishing learning rate as per SGD, but I have lazily not done that here.
1 Rao-Blackwellization
Rao-Blackwellization (Casella and Robert 1996) seems like a natural extension to reduce the variance. How would it work? Liu et al. (2019) is a contemporary example; I have a vague feeling that I saw something similar in Rubinstein and Kroese (2016). 🚧TODO🚧 clarify.