Improving Worst-Group Accuracy with Large Pre-Training Datasets
You Only Need a Good Embeddings Extractor to Fix Spurious Correlations
Spurious correlations in training data often lead to robustness issues since models learn to use them as shortcuts.
For example, when predicting whether an object is a cow, a model might learn to rely on its green background, so it would do poorly on a cow on a sandy background.
A standard dataset for measuring state-of-the-art on methods mitigating this problem is waterbirds.
The best method (group distributionally robust optimization-groupdro) currently achieves 89\% worst group accuracy and standard training from scratch on raw images only gets 72\%.
Authors
Raghav Mehta, Vítor Albiero, Li Chen, Ivan Evtimov, Tamar Glaser, Zhiheng Li, Tal Hassner
Many machine learning models may rely on spurious correlation prevalent in the training dataset for learning classification boundaries.
When spurious correlations are not present in real world (validation or testing) datasets, model performance is severely degraded.
Most methods considered in the literature so far mitigate this robustness gap by using explicit or inferred labels of spuriously correlated subsets in the training data to achieve high worst-group accuracy (wga).
Here, we observe that training simple linear classifiers on embeddings extracted from frozen, large pre-trained networks might be enough to mitigate spurious correlations.
Result
In this paper, we show that by only training a linear classifier on embeddings extracted from a vit-h-14 network pre-trained on swag and followed by e2e finetuning on imagenet without performing any e2e training on the waterbirds dataset or needing group labels during the training, we can achieve 90.13% worst group accuracy (wga) and 89.2% overall accuracy (oa) than the maximum reported in the literature.
Our experiments show that for each group of networks trained on, generally higher capacity networks perform marginally better compared to lower capacity networks.