When train, test, and deploy (really) diverge

Published

February 3, 2023

A common problem in machine learning is when your training data and real-world data diverge. One of the ways they can diverge is in their class balances.

For example, let’s say you are building a model to help a semiconductor factory to detect defects in their microchips before they are shipped off to their customers. Let’s say they are only expecting 1 such defect for every 10,000 chips they produce. Now, let’s also say you have procured a dataset of 300 microchips with defects. You would then need 300 * 10000 = 3,000,000 more images for your training dataset, each of them inspected to make sure they aren’t defective, to match the real class balance. Instead, you include just a few thousand such images, reasoning that this kind of model should easily converge with a more balanced dataset anyway.

But this poses a problem downstream. If the class balance of the training data and real world data diverge (in this case, 1:10 vs 1:10,000), the model’s behavior during deployment will be poorly calibrated. Here it expects the non-defect class to appear 1000 times more frequently than it actually does, which will cause an excessive false positive rate (low precision) at runtime. Simply put, you run the risk of saying there’s 1000 times more broken chips than there are.

Therefore, we must take additional steps to estimate the model’s theoretical runtime (deployment) accuracy.

One approach is to run your evaluation metrics using a sample of test data with the correct class balance (i.e. 1:10,000). You can then choose your decision threshold based on the precision-recall or ROC curve generated on this test data, and feel good that you know approximately the runtime precision/recall/tpr/fpr when deploying your model with this same threshold.

However, this might not be feasible in every scenario. Perhaps you don’t have access to all the millions of non-defect images taken or do not have resources to label that many images. Or even if you did, perhaps your class balance is so extreme that it’s simply not computationally tractable to reproduce. For example, let’s imagine you want to generate early earthquake warnings by building a model that uses 10 seconds of high-resolution seismic data. In California, there may only be 2-3 earthquakes of any significance per year, but there are 3,153,600 distinct 10s windows in a year, each of which could be considered a negative example. Multiply that by the dozens of different geolocations you would want different alert systems for and… you get the idea. This is such an extreme imbalance that you likely would never use all negative examples even for testing purposes.

As before, our training dataset has a different class balance relative to the deployment scenario, but now, so does our test data. We still need a way to estimate the model’s theoretical accuracy during runtime (deployment). How?

Well, there’s good news and bad news, and they’re both the same. It depends on which accuracy metric. Since recall is already conditioned on the positive class, it is insensitive to class ratio, therefore the true recall at any given threshold is already estimated in the test data regardless of imbalance. Precision, because it is conditioned on the predicted class, is what’s influenced. Recall specifically that:

\(Recall = P(\hat{y}=1|y=1)\)

\(Precision = P(y=1|\hat{y}=1)\)

If these are the two metrics we care about, then we only need to adjust our estimate of the model’s precision during deployment. Let’s begin with

\(Precision_{test} = TP_{test} / (TP_{test} + FP_{test})\)

Where \(TP_{test}\) and \(FP_{test}\) are the total number of true positives and false positives in the test dataset.

Let’s assume now that we know the class ratios in both the test dataset as well as in the deployed scenario. Since it is the negative class that is upsampled during deployment, we simulate this by keeping the value for TP fixed and multiplying the value for FP (since these are negative class examples) by the fold change in negative over positive class ratio.

\(FP_{deploy} = FP_{test} * \frac{neg_{deploy} / pos_{deploy}}{neg_{test} / pos_{test}}\)

Where \(neg\) and \(pos\) refer to the counts of negative and positive examples within either the test data or deployment scenario.

Then we get

\(Precision_{deploy} = TP_{test} / (TP_{test} + FP_{deploy})\)

As you generate the precision-recall curve using your test data, you can use this as your corrected precision value. Again, recall is unchanged. The curve corresponds to the theoretical behavior of the model during deployment.

Let’s take our microchip example again to see just how big of an impact the difference in class balance can have.

Let’s say we want a precision of 0.33 when deployed, and the difference in class balance between train/test and deployment is still 1000-fold.

Solving for \(0.33 = TP / (TP + FP*1000)\) gives us \(FP = 1.33*TP / 330\), which gives us an initial (uncorrected) precision of \(1 / (1 + 1.33 / 330) = 0.996\).

That means that at a (nearly perfect) precision of 0.996 in the test data, the same threshold during deployment has a (pretty bad) precision of 0.33.

Finally, keep in mind that this method makes some pretty strong assumptions:

  1. Your training and test datasets are random samples of your deployment scenario
  2. You know the “true” class balance in your deployment scenario
  3. The model’s probabilities are well calibrated (they reflect their true likelihoods) (see here). That’s especially important for higher probabilities, where often models are not so well calibrated! In my experience, this is often where things get quite tricky. It is worth doing some research to see whether your model can be better calibrated with the highest confidence predictions.

Some or all these assumptions are frequently wrong. Ultimately, there’s no equal substitute for having a correctly balanced representation of your real world data. However, that’s often not feasible, and in those cases this will help get a better approximation of your model’s real world behavior.