Pyro

Approximate maximum in the density of probabilistic programming effort

October 2, 2019 — May 2, 2024

Bayes
generative
how do science
Monte Carlo
sciml
statistics
Figure 1: Typical posterior density landscape

A probabilistic programming language using modern NN frameworks, (pytorch) and jax, and implementing many fashionable algorithms from the probabilistic programming literature.

1 Classic Pyro

pytorch + Bayesian inference = pyro (Pradhan et al. 2018).

For rationale, see the pyro launch announcement:

We believe the critical ideas to solve AI will come from a joint effort among a worldwide community of people pursuing diverse approaches. By open sourcing Pyro, we hope to encourage the scientific world to collaborate on making AI tools more flexible, open, and easy-to-use. We expect the current (alpha!) version of Pyro will be of most interest to probabilistic modellers who want to leverage large data sets and deep networks, PyTorch users who want easy-to-use Bayesian computation, and data scientists ready to explore the ragged edge of new technology.

As a friendly, well-documented, consistent framework with less of the designed-during-interdepartmental-turf-war feel of the tensorflow frameworks, this is where much of the effort going into probabilistic programming seems to be going.

Framework documentation asserts that if you can understand one file, pyro/minipyro.py, you can understand the whole system.

2 Numpyro

Numpyro is an alternative version of pyro which uses jax as a backend instead of pytorch. In line with the general jax aesthetic, it is elegant, fast, badly increasingly well documented, and missing some conveniences. The API is not identical with pyro, but they rhyme.

UPDATE: Numpyro is really coming along. It has all kinds of features now, e.g. Automatic PGM diagrams. Also, it turns out that the jax backend is frequently less confusing (IMO) than pytorch.

Fun tip: the render_model method will automatically draw graphical model diagrams.

3 Tutorials and textbooks

4 Tips, gotchas

4.1 Distributed

MultiGPU Pyro is not necessarily obvious, since many of the implied inference methods are not just plain SGD, so they do not parallelize in the same way as a simple neural network might. The docs give an example of distributed training via Horovod.

4.2 Regression

Regression was not (for me) obvious, and the various ways you can set it up are illustrative of how to set up stuff in pyro generally.

We define the model as follows, a (Linear) regression model capturing predictor variables (Africa or not, Terrain roughness) and a response variable (GDP), and an interaction term.

Suppose we want to solve a posterior inference problem of the following form:

\[\begin{aligned} \text{GDP}_i &\sim \mathcal{N}(\mu, \sigma)\\ \mu &= a + b_a \cdot \operatorname{InAfrica}_i + b_r \cdot \operatorname{Ruggedness}_i + b_{ar} \cdot \operatorname{InAfrica}_i \cdot \operatorname{Ruggedness}_i \\ a &\sim \mathcal{N}(0, 10)\\ b_a &\sim \mathcal{N}(0, 1)\\ b_r &\sim \mathcal{N}(0, 1)\\ b_{ar} &\sim \mathcal{N}(0, 1)\\ \sigma &\sim \operatorname{Gamma}(1, \frac12) \end{aligned}\]

pyro.clear_param_store()
def model():
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Gamma(1.0, 0.5))
    is_cont_africa = pyro.sample(
        "is_cont_africa", dist.Bernoulli(0.5))  # <- overridden
    ruggedness = pyro.sample(
        "ruggedness", dist.Normal(1.0, 0.5))    # <- overridden
    mean = a + (b_a * is_cont_africa) \
        + (b_r * ruggedness) \
        + (b_ar * is_cont_africa * ruggedness)
    s = pyro.sample(
        "log_gdp", dist.Normal(mean, sigma))    # <- overridden
    return s

Note the trick here, that we gave distributions even to regression inputs. This is how we need to do it, even if that distribution will never be used. And indeed, during inference, we always override the values at those sites with data.

Inference proceeds by conditioning the model on the observed data, giving us updated estimates for the unknowns. In the MCMC setting, we approximate those posterior distributions with samples:

\[\begin{aligned} &p (a, b_a, b_{ar}, b_r,\sigma \mid \operatorname{GDP}, \operatorname{Ruggedness},\operatorname{InAfrica} )\\ &\quad \propto \prod_i p (\operatorname{GDP}_i \mid \operatorname{Ruggedness}_i,\operatorname{InAfrica}_i ,a, b_a, b_{ar}, b_r,\sigma)\\ & \qquad \cdot p (a, b_a, b_{ar}, b_r,\sigma) \end{aligned}\]


observed_model = poutine.condition(model, data={
    "log_gdp": log_gdp, "ruggedness": ruggedness, "is_cont_africa": is_cont_africa})
nuts_kernel = NUTS(observed_model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run()

To actually make predictions, we need to use the Predictive class, which is IMO not well explained in the docs, but you can work it out from their example. An only-slightly-confusing explanation is here. tl;dr to predict the GDP of a country NOT in Africa with a Ruggedness of 3, we would do this:

Predictive(poutine.condition(model, data={
    "ruggedness": torch.tensor(3.0),
    "is_cont_africa": torch.tensor(0.)}),
    posterior_samples=mcmc.get_samples())()['log_gdp']

TODO: re-do this example with

  1. default arguments.
  2. factory functions.

4.3 Complex numbers

Currently do not work.

5 Algebraic effects

TBD

Sounds like this lands not too far from message passing ideas?

6 Funsors

I’ve seen funsors mentioned in this context. I gather they are some kind of graphical model-inference abstraction in the algebraic effect vein. What do they do exactly? Obermeyer et al. (2020) attempts to explain it although I do not feel like I got it:

It is a significant challenge to design probabilistic programming systems that can accommodate a wide variety of inference strategies within a unified framework. Noting that the versatility of modern automatic differentiation frameworks is based in large part on the unifying concept of tensors, we describe a software abstraction for integration —functional tensors— that captures many of the benefits of tensors, while also being able to describe continuous probability distributions. Moreover, functional tensors are a natural candidate for generalized variable elimination and parallel-scan filtering algorithms that enable parallel exact inference for a large family of tractable modeling motifs.

…This property is extensively exploited by the Pyro probabilistic programming language (Pradhan et al. 2018) and its implementation of tensor variable elimination for exact inference in discrete latent variable models, in which each random variable in a model is associated with a distinct tensor dimension and broadcasting is used to compile a probabilistic program into a discrete factor graph. Functional tensors (hereafter “funsors”) both formalize and extend this seemingly idiosyncratic but highly successful approach to probabilistic program compilation by generalising tensors and broadcasting to allow free variables of non-integer types that appear in probabilistic models, such as real number, real-valued vector, or real-valued matrix. Building on this, we describe a simple language of lazy funsor expressions that can serve as a unified intermediate representation for a wide variety of probabilistic programs and inference algorithms. While in general there is no finite representation of functions of real variables, we provide a funsor interface for restricted classes of functions, including lazy algebraic expressions, non-normalized Gaussian functions, and Dirac delta distributions.

Sounds like this lands not so far from message passing?

7 Incoming

8 References

Baudart, Burroni, Hirzel, et al. 2021. Compiling Stan to Generative Probabilistic Languages and Extension to Deep Probabilistic Programming.” arXiv:1810.00873 [Cs, Stat].
Moore, and Gorinova. 2018. Effect Handling for Composable Program Transformations in Edward2.” arXiv:1811.06150 [Cs, Stat].
Obermeyer, Bingham, Jankowiak, et al. 2020. Functional Tensors for Probabilistic Programming.” arXiv:1910.10775 [Cs, Stat].
Pradhan, Chen, Jankowiak, et al. 2018. Pyro: Deep Universal Probabilistic Programming.” arXiv:1810.09538 [Cs, Stat].
Ritter, and Karaletsos. 2022. TyXe: Pyro-Based Bayesian Neural Nets for Pytorch.” Proceedings of Machine Learning and Systems.