Learning under distribution shift
Also transfer learning, covariate shift, transferable learning, domain adaptation, transportability etc
October 16, 2020 — October 24, 2023
Predictive models can be trained on independent or identically distributed data without much fuss. Sometimes our data is not identically distributed but is drawn from several different distributions. Say, I am training a model which predicts customer behaviour, and I have customers in Australia and customers in India. Can I nonetheless train a model which works well on all of the data?
If we are using a parametric hierarchical model, we can pool data in the normal way and learn interaction effects.
If we are doing Neural Network Stuff though, it is not really clear how to do that. We might be vexed, and then surprised, and then write an article about it. If we are a typical researcher, that article might be blind to prior art in statistics. e.g. Google AI Blog: How Underspecification Presents Challenges for Machine Learning, or, Sebastian Ruder’s NN-style introduction to “transfer learning”.
I hope I don’t sound (too) snarky; there can be virtue in reinventing things with fresh eyes. Transfer learning and domain adaptation and such, these are all concepts that arise in the NN framing, and sometimes the methods overlap with statistical classics and sometimes they extend the repertoire.
Here we will investigate all of them that I have time to.
1 What is transfer learning or domain adaptation actually?
Everyone I talk to seems to have a different notion, and also to think that their idea is canonical.
We need a taxonomy. How about this one? In thuml/A-Roadmap-for-Transfer-Learning Junguang Jiang, Yang Shu, Jianmin Wang and Mingsheng Long propose the following taxonomy of transfer methods(Jiang et al. 2022):
-
- Meta-Learning (see my meta-learning page)
- Causal Learning (see my causal learning page)
They handball to zhaoxin94/awesome-domain-adaptation for a finer domain adaptation taxonomy.
One survey paper not enough? Want a better taxonomy? Here are survey papers harvested from the above links:
(Csurka 2017; Gulrajani and Lopez-Paz 2020; Jiang et al. 2022; Kouw and Loog 2019; Ouali, Hudelot, and Tami 2020; Pan and Yang 2010; Patel et al. 2015; Sun, Shi, and Wu 2015; Tan et al. 2018; M. Wang and Deng 2018; Wilson and Cook 2020; Yuchen Zhang et al. 2019; J. Zhang et al. 2019; L. Zhang and Gao 2024; S. Zhao et al. 2020; Zhuang et al. 2020).
Transfer learning connects also to semi-supervised learning and fairness, argues (Schölkopf et al. 2012; Schölkopf 2022).
2 Generic theory
(Bareinboim and Pearl 2016, 2013, 2014; Ben-David et al. 2010; Kaddour et al. 2022; Kulinski and Inouye 2022; Mansour, Mohri, and Rostamizadeh 2009; Pearl and Bareinboim 2014; Schölkopf et al. 2012, 2021; Schölkopf 2022; Subbaswamy, Schulam, and Saria 2019; Zellinger, Moser, and Saminger-Platz 2021; Yuchen Zhang et al. 2019; J. Wang and Chen 2023; Q. Yang et al. 2020)
3 Graphical models
To my mind the most straightforward thing, Simply do causal inference in a hierarchical model which encodes all the causal constraints. All the tools of graphical modelling stuff are still well-posed. It is easy to explain in a Bayesian framework in particular. I think this is what is referred to in Elias Bareinbohm’s data fusion framing (Bareinboim and Pearl 2016, 2013, 2014; Pearl and Bareinboim 2014). In this case we can use standard statistical tooling, such as HMC to sample from some posterior under various interventions, e.g. a shift in some parameter of the population distribution.
The hairy part is that this breaks down in neural networks. There is a million-dimensional nuisance parameter that we need to integrate out, i.e. the neural weights. For reasons of size alone that is frequently impractical, with the computation cost blowing out.
Some other works that look related: (Gong et al. 2018; Moraffah et al. 2019; Yue et al. 2021; Xu, Wang, and Ni 2022; Rothenhäusler et al. 2020).
A graphical model approach has many things to recommend it if it works, though; We do not need to worry about missing values (they may also be inferred); we can estimate intervention distributions etc.
4 Pre-training
Train on some big set of generic tasks. Now specialize for this particular one.
5 Sample weighting
If the proportion of the populations of various kinds has changed we can do Stratified sampling to estimate the quantity of interest over the new population.
6 Model stacking
Numpyro worked example: Bayesian Hierarchical Stacking: Well Switching Case Study (Yao et al. 2022).
7 Bi-level / adversarial
OK, all that graphical model stuff failed to scale to my problem of interest; what next? As noted in Yuchen Zhang et al. (2019) many domain adaptation strategies can be framed as bi-level optimisation problems of minimax type. so that presumably corresponds to Domain Adversarial Learning. I think that Invariant risk minimisation and probably can be put in this minimax framework too, but also “learning invariants” is somehow conceptually separate.
Update: Yes, Ahuja et al. (2020) are helpful in giving us some semblance of taxonomy:
The standard risk minimization paradigm of machine learning is brittle when operating in environments whose test distributions are different from the training distribution due to spurious correlations. Training on data from many environments and finding invariant predictors reduces the effect of spurious features by concentrating models on features that have a causal relationship with the outcome. In this work, we pose such invariant risk minimization as finding the Nash equilibrium of an ensemble game among several environments. By doing so, we develop a simple training algorithm that uses best response dynamics and, in our experiments, yields similar or better empirical accuracy with much lower variance than the challenging bi-level optimization problem of Arjovsky et al. (2020). One key theoretical contribution is showing that the set of Nash equilibria for the proposed game are equivalent to the set of invariant predictors for any finite number of environments, even with nonlinear classifiers and transformations. As a result, our method also retains the generalization guarantees to a large set of environments shown in Arjovsky et al. (2020). The proposed algorithm adds to the collection of successful game-theoretic machine learning algorithms such as generative adversarial networks.
I’m a little confused that people seem to describe Arjovsky et al. (2020) method as bi-level optimisation; the paper discusses a bi-level optimization but they go on to implement an approximation which seems to be a basic single-level regularized optimization. I am missing something, either in the original paper or the detractors.
I will inspect IBM/OoD: Repository for theory and methods for Out-of-Distribution (OoD) generalization.
8 Semi-supervised learning
9 Source and target empirical risks
What does this heading even mean? I had some idea, but I have forgotten it, I confess (Ben-David et al. 2006; Ben-David et al. 2010; Blitzer et al. 2007; Mansour, Mohri, and Rostamizadeh 2009).
10 Learning invariants
I am not sure if the various sub-methods in this category are in fact distinct. H. Zhao et al. (2019) devises necessary conditions for invariant representation learning to work. Possibly this is a special case/particular framing of what I called “bi-level” optimisation, above.
10.1 Regularising features towards invariance
DAN (Long et al. 2015)
Recent studies reveal that a deep neural network can learn transferable features which generalize well to novel tasks for domain adaptation. However, as deep features eventually transition from general to specific along the network, the feature transferability drops significantly in higher layers with increasing domain discrepancy. Hence, it is important to formally reduce the dataset bias and enhance the transferability in task-specific layers. In this paper, we propose a new Deep Adaptation Network (DAN) architecture, which generalizes deep convolutional neural network to the domain adaptation scenario. In DAN, hidden representations of all task-specific layers are embedded in a reproducing kernel Hilbert space where the mean embeddings of different domain distributions can be explicitly matched. The domain discrepancy is further reduced using an optimal multi-kernel selection method for mean embedding matching. DAN can learn transferable features with statistical guarantees, and can scale linearly by unbiased estimate of kernel embedding. Extensive empirical evidence shows that the proposed architecture yields state-of-the-art image classification error rates on standard domain adaptation benchmarks.
10.2 Invariant risk minimisation
A trick from Arjovsky et al. (2020). Ermin Orhan summarises the method plus several negative results (Gulrajani and Lopez-Paz 2020; Rosenfeld, Ravikumar, and Risteski 2020) about IRM:
Take invariant risk minimization (IRM), one of the more popular domain generalization methods proposed recently. IRM considers a classification problem that takes place in multiple domains or environments, \(e_1, e_2, \ldots, e_E\) (in an image classification setting, these could be natural images, drawings, paintings, computer-rendered images etc.). We decompose the learning problem into learning a feature backbone \(\Phi\) (a featurizer), and a linear readout \(\beta\) on top of it. Intuitively, in our classifier, we only want to make use of features that are invariant across different environments (for instance, the shapes of objects in our image classification example), and not features that vary from environment to environment (for example, the local textures of objects). This is because the invariant features are more likely to generalize to a new environment. We could, of course, do the old, boring empirical risk minimization (ERM), your grandmother’s dumb method. This would simply lump the training data from all environments into one single giant training set and minimize the loss on that, with the hope that whatever features are more or less invariant across the environments will automatically emerge out of this optimization. Mathematically, ERM in this setting corresponds to solving the following well-known optimization problem (assuming the same amount of training data from each domain): \(\min _{\Phi, \beta} \frac{1}{E} \sum_c \mathfrak {R}^c(\Phi, \hat{\beta})\), where \(\mathfrak {R}^c\) is the empirical risk in environment \(e\). IRM proposes something much more complicated instead: why don’t we learn a featurizer with the same optimal linear readout on top of it in every environment? The hope is that in this way, the extractor will only learn the invariant features, because the non-invariant features will change from environment to environment and can’t be decoded optimally using the same fixed readout. The IRM objective thus involves a difficult bi-level optimization problem…
Does it though? The general IRM objective is difficult, but there is a simple approximation in the paper, IRMv1 which is claimed to be easier. Either way, though, the critiques of (Gulrajani and Lopez-Paz 2020; Rosenfeld, Ravikumar, and Risteski 2020) are useful.
- facebookresearch/InvariantRiskMinimization: PyTorch code to run synthetic experiments. (Arjovsky et al. 2020)
- reiinakano/invariant-risk-minimization: Implementation of Invariant Risk Minimization
Interesting variants:(Ahuja et al. 2022, 2020; Shah et al. 2021)
11 Conformal
Conformal learning + distributional shift.
12 FiLM
13 Justification for batch normalization
Apparently a thing? Should probably note some of the literature about that.
14 Tools
14.1 Transfer-Learning-Library
TLlib (Jiang et al. 2022) is an open-source and well-documented library for Transfer Learning. It is based on pure PyTorch with high performance and friendly API. Our code is pythonic, and the design is consistent with torchvision. You can easily develop new algorithms or readily apply existing ones.
Our API is divided by methods, which include:
- domain alignment methods (
tllib.alignment
)- domain translation methods (
tllib.translation
)- self-training methods (
tllib.self\_training
)- regularization methods (
tllib.regularization
)- data reweighting/resampling methods (
tllib.reweight
)- model ranking/selection methods (
tllib.ranking
)- normalization-based methods (
tllib.normalization
)
14.2 DomainBed
facebookresearch/DomainBed: DomainBed is a suite to test domain generalization algorithms
DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in Gulrajani and Lopez-Paz (2020)
14.3 Salad
salad is a library to easily set up experiments using the current state-of-the-art techniques in domain adaptation. It features several recent approaches, with the goal of running fair comparisons between algorithms and transferring them to real-world use cases.
14.4 transferlearning.xyz
Jindon Wang’s site is a good resource and also includes a library of popular techniques.
- Transfer Learning | Transfer learning / domain adaptation / domain generalization / multi-task learning etc. Papers, codes, datasets, applications, tutorials.-迁移学习
- jindongwang/transferlearning: Transfer learning / domain adaptation / domain generalization / multi-task learning etc. Papers, codes, datasets, applications, tutorials.-迁移学习
14.5 WILDS
WILDS: A Benchmark of in-the-Wild Distribution Shifts
To facilitate the development of ML models that are robust to real-world distribution shifts, our ICML 2021 paper presents WILDS, a curated benchmark of 10 datasets that reflect natural distribution shifts arising from different cameras, hospitals, molecular scaffolds, experiments, demographics, countries, time periods, users, and codebases.