Abstract
For the problem of adaptation to test-time distribution-shift, we show that the highly-cited method of BatchNorm adaptation [1,2,3] can be extended to networks which do not use BatchNorm. The possibility of this extension implies that there is no special relationship between BatchNorm and distribution shift.
Note: I don’t plan to write this up as a paper, because I think it belongs to a prior, now obsolete era. So I will just communicate it via this text post.
Introduction
Test-time adaptation refers to the following problem:
Given a trained neural network classifier, suppose we are confronted by a test dataset of a different distribution than the training dataset. Assume that we don’t have access to its labels. Then how can we adapt the trained network in order to get better performance on this test dataset?
The most cited method, from the 2020 era, is that of BatchNorm adaptation [1,2,3]: Specifically, given a BatchNorm-based network, we re-estimate the statistics of the BN layers (namely, their running mean and variance) on the test dataset, before running inference. In Pytorch, this is as simple as doing a few forward passes in train mode on the test dataset, before switching back to eval mode for the full inference. Doing this typically leads to large gains in performance on standard benchmarks, e.g. +10% accuracy for adapting a pretrained ResNet-50 to corrupted ImageNet distributions [1].
The common intuition about this method seems to be that normally, BN layers have the special property that their statistics are adapted for whatever dataset they’re trained on. So, it just makes sense to adapt those statistics to the test-time distribution.
In this post I’ll show that this intuition is wrong, and instead, BN adaptation can be extended to a mathematically equivalent general method which works for all networks (incl. norm-free networks and LayerNorm-based networks).
Method
Given the trained network that we wish to adapt, let N refer to some specific neuron within the network. Let (mu, sigma) be the mean and standard deviation of that neuron’s pre-activation, across the training dataset.
Then our goal will be to shift and rescale the pre-activation of N such that its mean and standard deviation, over the new test distribution, are equal to (mu, sigma). Note that this is exactly the effect that BN adaptation has for BN-based networks.
To accomplish this goal for all neurons in the network at once, we propose the following adaptation method.
- Measure (mu, sigma) across the training set, by collecting pre-activations across several batches, and computing their means and standard deviations.
- Insert BatchNorms into the network after every layer. Set their shift coefficient to mu, and scale coefficient to sigma.
- Put the added BatchNorms in training mode, and reset their running statistics across the test-set.
- Put them in evaluation mode.
- “Fuse” the added BNs back into their preceding layers [6], to recover the original architecture.
Once these steps are completed, our goal will be satisfied, as the reader can verify. Note that this method is strongly related to prior work REPAIR [4].
Results
We consider two test-time adaptation tasks: corrupted CIFAR-10 and corrupted ImageNet (also known as CIFAR-10C and ImageNet-C) [5]. The goal is to take a network trained on the clean version of each dataset, and maximize its performance on the corrupted test sets.
On CIFAR-10 we experiment with BatchNorm based ResNet-18 as well as a norm-free VGG network. On ImageNet we experiment with a wide array of pretrained networks from HuggingFace. For normalization-free networks, we experiment with VGG-11,13,16, and 19, as well as Fixup-ResNet50 and NF-ResNet50. For BatchNorm based networks we experiment with ResNet-18, ResNet-50, and EfficientNet-b0 through b7. For LayerNorm-based networks, we experiment with ConvNext networks of a wide range of sizes.
Across every network that uses BN, we compute the boost in accuracy on the corrupted test set yielded by BN adaptation. And for every network that does not use BN, we compute the boost in corrupted accuracy yielded by our proposed method. We additionally compute the base corrupted accuracies for each network without any adaptation method.
The result is as follows. Across all networks that we experiment with, we find that, given networks with similar base corrupted accuracy (i.e., before adaptation), the boost given by BN reset for networks that use BN is similar to the boost given by our proposed method for networks that do not use BN.
That is, given a BN-based network f_1 and a norm-free network f_2 which attain the same accuracy on ImageNet, the boost given by BN adaptation to the corrupted accuracy of f_1 will be similar to the boost given by our method to the corrupted accuracy of f_2.
For example, given a ResNet-18 which is trained for a short duration on CIFAR-10 such that it has the same performance as a fully trained VGG-16 network, we find that they both yield (a) the same corrupted accuracy before adaptation, and (b) the same (improved) corrupted accuracy after the respective adaptations.
As a downside, we observe that both methods have decreasing benefits with scale. For the largest (BatchNorm-based) EfficientNets, BN reset provides no benefit in terms of corrupted accuracy. And for the largest (LayerNorm-based) ConvNeXt models, our proposed method similarly provides no benefit. Thus, both methods appear to be only useful at the smaller scale, and can be considered obsolete at the large scale.
Conclusion
Our proposed adaptation method successfully extends BatchNorm adaptation [1,2,3] to networks which do not use BatchNorm, in the sense of causing both the same statistical effects and the same benefits to test-time performance. The possibility of this extension implies that there is no special relationship between BatchNorm and distribution shift.
[1] https://arxiv.org/abs/2006.16971
[2] https://arxiv.org/abs/2006.10963
[3] https://arxiv.org/abs/2006.10726
[4] https://arxiv.org/abs/2211.08403