Graphical model / machine learning decoder ring
January 18, 2024 — June 3, 2024
I’m thinking something through for myself. Details are absent right now. Twin to Causality+ML, perhaps.
Various ways of partitioning a vector of observed and unobserved variates, and their relationships to graphical models and the ML distinction of supervised/unsupervised learning, and the generalisations we might want to make these go.
1 What is supervised learning?
We look to Bayesian inference to solve problems with the following structure: I would like to infer the probability of some noisy label \(Y_*\) given some predictor \(X_*\). Since this is data-driven learning, I will suppose that I want to work out how to do this from a dataset of labelled examples, \(\mathcal{D} = \{(x_i, y_i)\}_{i=1}^N\). So in fact, I would like to know \(p(y_*|x_*,\mathcal{D})\).
We are using lazy Bayesian density notation where \(p(z)\) is the density of some RV \(Z\). I feel ambivalent about this notation, but it does get the job done. We can throw out the densities later and go to nice clean unambiguous measures, but this helps exposition.
How do I compute this equation-in-densities? Let us suppose that by some miracle it happens to be true that we know the generating process \(Y= G(X, W, \xi)\), where \(\xi\) is some unobserved noise, and there is a finite vector of parameters \(W\).
I hope to find good values for \(W\) such that \(p(y_*|x_*,W,\mathcal{D})\) is a good model for the data, in that if we take our model out into the world and show it lots of different values for \(x_*=X_*\) we will observe that the implied \(p(y_*|X_*=x_*,W,\mathcal{D})\), pumped through \(G\), describes the distribution of \(Y_*|X_*=x_*\).
How do we find \(W\)? There is the graphical model formalism which (details to come, trust me) tells us how to infer stuff about \(W\) from \(\mathcal{D}\): Suppose \(N=3\). Then the following graphical model describes our task:
The circled \(W\) node is what we are trying to get “correct”. By looking at those \(X_i,Y_i\) pairs we refine our estimate of \(W\). I am leaving out the details of what “refining” the estimate means; we will come back to that.
OK, we talked about refining \(W\). What does that look like in practice? In the Bayes setting, we condition a prior \(p(w)\) on the observed data \(\mathcal{D}\), to obtain the posterior distribution \(p(w|\mathcal{D})=p(w|X_1, X_2, X_3, Y_1, Y_2, Y_3)\). I am pretty sure we can alternatively do this as a frequentist thing, but TBH I am just not smart enough; the phrasing gets very weird, and we need to use all kinds of circumlocutions so it is not very clear, at least for my modest-sized brain what the hell is going on.
According to Bayes’ rule, we can write
\[ p(w \mid x_1, x_2, x_3, y_1, y_2, y_3) = \frac{p(y_1, y_2, y_3 \mid x_1, x_2, x_3, w) \, p(w)}{p(y_1, y_2, y_3 \mid x_1, x_2, x_3)} \]
The numerator can be further factorized due to the conditional independencies implied by the graphical model:
\[ p(y_1, y_2, y_3 \mid x_1, x_2, x_3, w) \, p(w) = p(y_1 \mid x_1, w) \, p(y_2 \mid x_2, w) \, p(y_3 \mid x_3, w) \, p(w) \]
The denominator, which is the marginal likelihood, can be obtained by integrating out \(W\):
\[ p(y_1, y_2, y_3 \mid x_1, x_2, x_3) = \int p(y_1 \mid x_1, w) \, p(y_2 \mid x_2, w) \, p(y_3 \mid x_3, w) \, p(w) \, dw \]
Putting it all together, the posterior distribution of \(W\) is:
\[ \begin{aligned} p(w \mid x_1, x_2, x_3, y_1, y_2, y_3) &= \frac{p(y_1 \mid x_1, w) \, p(y_2 \mid x_2, w) \, p(y_3 \mid x_3, w) \, p(w)}{\int p(y_1 \mid x_1, w) \, p(y_2 \mid x_2, w) \, p(y_3 \mid x_3, w) \, p(w) \, dw}\\ &= \frac{p(y_1 \mid x_1, w) \, p(y_2 \mid x_2, w) \, p(y_3 \mid x_3, w) \, p(w)}{\int p(y_1 \mid x_1, w) \, p(y_2 \mid x_2, w) \, p(y_3 \mid x_3, w) \, p(w) \, dw} \end{aligned} \]
This equation represents the posterior distribution of \(W\) in terms of the densities of the observed and latent variables, conditioned on the observed data.
This is misleading, since I said before we wanted a \(W\) that gave us good predictions for \(Y_*|X_*\). Maybe \(W\) is just a nuisance parameter and we don’t really care about it as long as we get good predictions for \(Y_*\). Maybe we actually want to solve a problem like this, and that \(W\) is a nuisance variable:
In ML, that is usually what we are doing; \(W\) is just some parameters rather than being something with intrinsic, semi-universal meaning, like the speed of light or the boiling point of lead in a vacuum. Our true target is to get this \(p(y_*|x_*,\mathcal{D})\), which we give the special name, posterior predictive distribution.
Let us replay all that conditioning mathematics, but make the target \(p(y_*|x_*,\mathcal{D})\), the posterior \(Y_*\) given \(X_*\), and the observed data \(X_1, Y_1, \ldots,\), i.e.
\[ p(y_* \mid x_*, w, x_1, y_1, \ldots, ) \]
Since information about \(Y_i\) and \(X_i\) about \(Y_*\) is mediated wholly through \(W\), we might like to think about that calculation in terms of a posterior \(W\),
\[ p(y_* \mid x_*, (W|\mathcal{D}))=:p(y_* \mid x_*, x_1, y_1, \ldots) \]
Hold on to that thought, because the idea of dealing with “conditioned” variables turns out to take us somewhere interesting.
If we want to integrate out \(W\) to get only the marginal posterior predictive distribution, we do this:
\[ p(y_* \mid x_*, x_1, y_1, \ldots) = \int p(y_* \mid x_*, w) \, p(w \mid x_1, y_1, \ldots) \, dw \]
Where \(p(w \mid x_1, y_1, \ldots, )\) is the posterior distribution of \(W\) given the observed data, which can be computed using Bayes’ rule as we did above.
I’m getting tired of writing out the indexed \(X_i,Y_i\) pairs, so let’s just write \(\mathcal{D}\) for the observed data, and \(X_*,Y_*\) for the unobserved data, and use plate notation to write a graph that automatically indexes over these variables:
OK, now in a Bayes setting we can talk about the marginal of interest. The red line denotes the marginal of interest.
2 What is unsupervised learning?
As seen in unconditional generative modeling. Now we don’t assume that we will observe a special \(X_*\) and find \(Y_*\), but rather we want to know something about the distribution of both jointly
In fact, we are treating \(X\) and \(Y\) symmetrically now, so we might as well concatenate them both into \(X\):
3 What are inverse problems?
I observe an output. What is my “posterior retrodictive”?
4 Other cool problems?
Hierarchical models.
5 Causally valid inference
TBD
6 Learning on non-i.i.d. data
7 Factor graphs
A natural setting for approximate inference is to devise variational algorithms over factor graphs which conveniently encode some useful approximate marginalisation algorithms.
The way that these work is we rewrite every conditional probability in the directed graph as a factor:
\[ p(x_1|x_2) \to f_{12}(x_1, x_2) \]
Now, we introduce a new set of nodes, the factors, which are the \(f_{ij}\), and we connect them to the implicated variable nodes.
Why? Everyone asks me why. My answer for now is that it comes out easier that way so suck it up.
A nicer answer might be: the various belief propagation rules come out nice this way, if we consider the variables and their relations independently.
Here is the \(W\) inference problem as a factor graph:
Here is that inverse problem from earlier.
We can spend a lot of time circumlocuting about when we can actually update over such a graph. The answer is complicated, but we are neural network people, so we’ll just YOLO it see what happens.
7.1 BP as generalized conditioning
👷
7.2 BP as variational approximation
Our goal in variational inference is to find an equation in beliefs, which are approximate marginals, and hope that we can calculate them in such a way that they end up good approximations for true marginals.
We want to represent every quantity of interest in terms of marginals at each node, and then solve for those marginals. If we can find some fixed point iteration that is local in some sense, and promises convergence to something useful, then we can declare victory.
8 How about neural networks?
9 To touch upon
- Dense non-causal connections
- re-factorisation
- inference estimate
- (approximate) independence and identifiability are simpler as variational approximation problems