In the last post we showed that the mean hypothesis of a neural network training is influenced by the choice of training data in a locally additive manner.
That is, given a large random base dataset $D$, and an extra dataset $S$, there exists a set of hypotheses $\bar f_1, \dots, \bar f_n \in \mathcal H$ such that \begin{equation} \bar f_{D \sqcup S’} = \bar f_D + \sum_{i=1}^n \bar f_i \cdot 1{(x_i, y_i) \in S’} \end{equation}
We showed that this is true with a high degree of accuracy when $D$ is the first 40K random examples of the CIFAR-10 training set, and $S$ is the last 10K.
In this section our goal will be to directly estimate the values of $\bar f_1(x), \dots, \bar f_n(x)$ for each test input $x$ in the CIFAR-10 test-set. To do so we will use 10 million runs of training.
We will first analyze the situation. Let $w_i := \bar f_i(x)$, and assume that Definition 1 holds. We want to know how many runs of training will be required in order to estimate $w_i$ with a given level of accuracy.
Rather than trying to ascertain whether [Equation 1](#eq1) indeed holds for all subset choices $S_1 \subset S$ (it basically does), in the following section we will simply proceed on the assumption that [Theorem 1](#thm1) is true and see what happens.Estimation of Local Datamodel Weights
To estimate the value of the $i$th local datamodel weight $w_i$ of an input $x \in \mathcal X$, we use the following procedure.
- First, we sample $M$ random subsets $S_1’, \dots, S_M’ \subseteq S$.
- Second, we train $K$ models on $D \sqcup S_m’$, for $m = 1, \dots, M$.
- Third, we calculate the following estimator $\hat w_i$.
\begin{equation} \hat w_i := \sum_{m = 1}^M \left(\frac{1\{(x_i, y_i) \in S_m’\}}{\textstyle\sum_{m’=1}^M 1\{(x_i, y_i) \in S_{m’}’\}} - \frac{1\{(x_i, y_i) \notin S_m’\}}{\textstyle\sum_{m’=1}^M 1\{(x_i, y_i) \notin S_{m’}’\}}\right)\frac1K \sum_{k=1}^K f^{(m, k)}(x) \end{equation}
Where $f^{(m, k)}$ is the $k$th model trained on $D \sqcup S_m’$, and the subsets are sampled via IID Bernoulli variables, so that $P((x_n, y_n) \in S_m’) = 0.5$ for each training example $(x_n, y_n) \in S$.
Theorem 2. If $\hat w_i$ is defined as above, then $\mathbb E[\hat w_i] = w_i$ and $\mathrm{Var}(\hat w_i) = \frac1M \left(\sum_{n \neq i} w_n^2 + \frac{4\sigma^2}{K}\right)$, where $\sigma^2$ is the inter-run variance in outputs.
Estimating that quantity
Assuming there are no extreme outlier weights $w_i$ such that $w_i^2 \gg \sum_{n \neq i} w_n^2$, this means that all weights will have roughly $\mathrm{Var}(\hat w_i) \approx \frac1M(\sum_{n \neq i} w_n^2 + 4\sigma^2/K)$.
We can estimate $\sigma^2$ using repeated trainings with a single subset. We estimate that it is $\sigma^2 \approx 0.238$. We confirm that this quantity is highly consistent across varying subsets as assumed in the proof.
To obtain an estimate for $\sum_{n=1}^N w_n^2$ we use the following method. Let $S’ \subseteq S$ be a random subset. Then \begin{align} \mathbb E_{S’}[(\bar f_{D \sqcup S’}(x) - \bar f_{D \sqcup (S - S’)}(x))^2] &= \mathbb E_{S’}[((\bar f_D(x) + w^T 1_{S’}) - (\bar f_D(x) + w^T 1_{S - S’}))^2] \newline &= \mathbb E_{S’}[(w^T (2 \cdot 1_{S’} - 1))^2] \newline &= \mathbb E_{S’}\left[\sum_{n=1}^N (w_n (2 \cdot (1_{S’})_n - 1))^2\right] \newline &= \sum_{n=1}^N w_n^2. \end{align}
This estimator has relatively high variance for a single test input, so we instead compute its average across all the test inputs that we wish to estimate datamodel weights for. \begin{equation} \Omega_{S’} := \frac{1}{|D_{test}|}\sum_{x \in D_{test}} (\bar f_{D \sqcup S’}(x) - \bar f_{D \sqcup (S - S’)}(x))^2 \end{equation} We find that this average has empirically very low variance (across the choice of subset $S’$). It is consistently equal to $\Omega_{S’} \approx 0.0184$.
We conclude that our estimator has variance roughly given by $\mathrm{Var}(\hat w_i) \approx (0.0184 \cdot K + 0.952)/MK$. In order to trade off the memory cost implied by a small $K$ with the inefficiency implied by a large $K$, we set $K = 50$, which results in a roughly 50% loss of efficiency relative to $K = 1$, while saving 98% of the storage requirements for the outputs of the trained models (because we can store the average outputs over $K$ runs per subset, instead of each output individually).
This causes the formula for variance to simplify to $\mathrm{Var}(\hat w_i) \approx 0.0374/M$.
Using this to predict the behavior
In the last section we showed that each weight estimator has variance $\mathrm{Var}(\hat w_i) \approx 0.0374/M$. We now want to determine the variance of the counterfactual logit predictions for a new subset $S’ \subseteq S$.
Given local datamodel weight estimates $\hat w_1, \dots, \hat w_N$ for an input $x$, and a subset $S’ \subseteq S$, the counterfactual prediction is given by \begin{equation} \hat y := \bar f_D(x) + \sum_{n=1}^N 1\{(x_n, y_n) \in S’\} \hat w_n. \end{equation}
If we make a further simplifying assumption that the estimated weights $\hat w_1, \dots, \hat w_N$ are independent, then $\mathrm{Var}(\hat y) = \sum_{n=1}^N 1\{(x_n, y_n) \in S’\} \cdot \mathrm{Var}(\hat w_n) = |S’| \cdot 0.0374/M$.
In the case we are studying, we have $N = 10000$. Therefore, if $S’ = S$ is the full set, then this becomes $\mathrm{Var}(\hat w_n) = 374/M$. The other case of interest for us is that case that $S’$ is a random subset, in which case $|S’| \approx N/2$ so that $\mathrm{Var}(\hat w_n) = 187/M$.
We need one further measurement to complete our theory, which is the mean-square of the variable we are trying to model. On $S = S’$, this is $|\bar f_{D \sqcup S} - \bar f_D|_2^2 = 0.0272$.
When $S’ \subseteq S$ is a random 50%-subset, it is $|\bar f_{D \sqcup S’} - \bar f_D|_2^2 = UNKNOWN$.
If $Y$ is a random variable with variance $\mathrm{Var}(Y) = 0.0272$, and $X$ is a model for it such that $\mathrm{Var}(X - Y) = 374/M$, then the Pearson correlation between $X$ and $Y$ will be. \begin{equation} \frac{1}{\sqrt{1 + 374/(M \cdot 0.0272)}} = \frac{1}{\sqrt{1 + 13750/M}} \end{equation}
Thus we have arrived at the following conclusions: If we estimate local datamodel weights $\hat w_1, \dots, \hat w_n$ for a dataset $S$ using Equation bla bla, then
- The weight estimates themselves will have $\mathrm{Var}(\hat w_n) \approx 0.0374/M$.
- Counterfactual predictions on a random subset $S’ \subseteq S$ will have variance $\approx 187/M$.
- The counterfactual predictions on the full set $S’ = S$ will have variance $\approx 374/M$.
WIP UNDER CONSTRUCTION WIP