One of the biggest problems we have when using machine learning in practice is distribution shift.
A distribution shift occurs when the distribution of the data the model sees in production starts to look different than the data used to train it.
A simple example that broke a lot of models was COVID. The quarantine simply changed how people behaved and the historical data became less representative.
Another good example is credit card fraud. The behavior of criminals evolves over time, and so the data that we use to train our models needs to be constantly updated or else we will start to see a lot of false positives.
This is often hard to detect and can cause models to perform much worse in production than they did during offline testing.
Looking at Criteo’s blog I found they were sponsoring a competition that dealt with a similar problem: out-of-domain generalization.
The idea is close to the shift situation. We are given data from a different distribution (e.g. different country) at test time and need to predict it even though we didn’t have access to the distribution at training time.
How Does Distribution Shift Happen In Practice?
In practice, shift usually doesn’t happen suddenly.
COVID was an example of abrupt changes, but even in very adversarial environments like crypto trading, the patterns change every few days to months.
I have seen models deployed for years that kept working. Their performance clearly suffered, but they didn’t become useless.
So what we see in practice is not a completely different distribution, but a mix of the old and the new.
How Is Distribution Shift Solved In Practice?
The most used remedy for ever-changing distributions is monitoring and retraining. You need to keep track of the performance of your models and retrain them as soon as you see a significant drop in performance.
This is easier said than done. It’s hard enough to get people to monitor their models in the first place, let alone retrain them on a regular basis.
There are also technical challenges. For example, you need to have a way to quickly get the new data, process it, and train the model.
In cases where the data needs to be labeled by humans, this can be very slow and expensive. All of these challenges make distribution shifts a hard problem to solve in practice.
In this competition, Criteo wanted to find ways to create a model robust enough to make good predictions in domains it had never seen without needing to retrain it.
The Competition Setup
Although I am retired from competitions, this one sounded interesting because:
- It tried to solve a very relevant real-world problem.
- Had only categorical variables. 38 categorical variables with varying degrees of cardinality.
According to Criteo:
“to create different environments, samples from different advertisers are grouped together using the business vertical (e.g. Retail, Travel, Insurance, Classifieds etc). Such environments were found to exhibit all the different distribution shifts possible whilst still depending on the same underlying causal mechanism.”
The word environment equals domain for our purposes.
Criteo is an ad tech company, so the information they made available indicates the rows are either individual users or sessions from users.
The competition had two leaderboards: one showing the score in a domain that was not in training data, the other showing the score in new examples of the same domains as training.
The First Step: Research
By the way, I kept a journal during the competition, which helped me think about solutions, and now write this post. Try it ;)
While doing research I found that most of the papers talk about domain generalization in the context of unstructured data (mostly images).
It got me thinking if the same findings apply to tabular data like this one.
How To Create a Trustworthy Validation Split?
I read a paper that was recommended by the organizers.
The authors conclude that combining validation sets without excluding any of the domains gives a better estimate of the out-of-domain error than doing a leave-one-domain-out (LODO).
Intuitively I would think LODO would work better, as we get the closest to approximating the out-of-sample data distribution on the validation set.
Using data from the same domain to train and validate screams information leakage and unwarranted optimism in the metric. So I spent time thinking about why would this not be a problem…
Here is my hypothesis:
Mixing up the domain data, even though it doesn’t mimic completely the out-of-domain environment, helps mimic better how the final model will be trained (retrained with everything).
So leaving a domain out of training makes sense from a hard validation perspective, but not when you have to consider that this domain will be used to train the final model to predict other unseen domains.
We risk inflating our validation score by sharing domain info, but it’s better than completely ignoring that we will use it to retrain.
In practice, the tasks usually have domains that change through time or have some similarity between them at least temporarily, which makes a completely left out domain unrealistic.
The takeaway? Always question your assumptions and be open to learning.
I would have jumped straight to the LODO validation and probably have had a terrible time trying to find the best models for this task. Being open to learning from this paper saved me from it.
Domain Generalization Techniques Are Not Silver Bullets
Another interesting finding is that Empirical Risk Minimization, also known as, scikit-learn’s usual “fit” and “predict” on the unchanged training data, gives a very similar performance, on average, to techniques that are created specifically to make the models generalize to new domains.
More than showing if method X works better than Y, the paper shows how the model selection methodology (creating a trustworthy validation environment) is important.
Trying different algorithms with a suboptimal validation scheme is strictly worse than trying basic stuff with an optimal validation scheme.
What I mean is that ERM on the optimal validation scheme is better than whatever else (even ERM) on the suboptimal validation.
I like this Peter Norvig quote: “More data beats clever algorithms, but better data beats more data.”
In this case, it’s the same data but split in a clever way.
There is one caveat: the final evaluation in the paper was done with LODO. It seems inevitable due to the nature of the problem, but I have to think more about its implications.
Ideas That Worked and That Didn’t
The organizers published a benchmark logistic regression with the hashing trick.
It worked well. Hashing can be thought of as a regularization method, and here we needed strong regularization on the high cardinality variables.
Different Categorical Encoders and Gradient Boosting
My first attempt to beat the logistic regression was using a target encoder. In my experience, when this encoding works, it works really well, but when it doesn’t, it’s terrible.
Unfortunately, it didn’t work on its own here, validating my hypothesis that we need strong regularization instead of trying to capture every small detail of high cardinality features.
I tried to train a Random Forest, which probably beats the logistic regression in accuracy terms, but they are not calibrated.
This means that the probability estimates are way off, and for log loss you need to have the probability estimates as close to the true probabilities as possible.
Calibrating the Random Forest would take too long compared to simply training a gradient boosting model, so I moved forward.
I usually start with a Random Forest because it’s an easy model to get working fast. Put a lot of trees and let it run, while gradient boosting requires more tuning.
My first attempt was using LightGBM over the target encoded categoricals, which didn’t beat the benchmark logistic regression.
Then I created datasets with simple ordinal encoding and count encoding. Count encoding was the best but only slightly above the ordinal encoded solution.
GBMs can handle ordinal encoding pretty well even though it’s “wrong” for non-ordinal variables.
I didn’t want to give up target encoding yet, so I tried Catboost.
It performed very well on the validation score and on the part of the leaderboard that overlapped with the training data but not on the completely new domain.
So I decided to do the obvious, train LightGBM, Catboost, and XGBoost, and average their predictions. It didn’t give a significant boost, but I like the stability of ensembles.
Stacking didn’t work either.
After looking at the results of the paper above, I was thinking: is domain generalization only about removing spurious correlations? Or reducing enough “legitimate” correlations from the original datasets so the model can generalize to new domains?
There is a method well-known to the Numerai community where you basically train a linear model over a feature set and subtract its predictions from your main model. I decided to try it here and it worked wonders on the new domain.
I took the average prediction of my gradient boosting models and subtracted 0.1 times the prediction of the logistic regression.
It took the score on the unseen domain from 0.128 to 0.133. Although the score was log loss, it was normalized, so a higher number is better than a lower number.
The only issue is that it decreased the score for the seen domains, so I was not sure if this was good or bad in reality.
As I could select up to 3 submissions at the end of the competition, I decided to pick one that was good at each leaderboard score and a third that was more balanced.
Going back to the original question, I think this works because it reduces the effect of correlations that are stronger in the original domains while keeping an underlying signal that generalizes.
In the paper, the researchers try generalization methods on Colored MNIST which is purposefully built with correlations that don’t generalize to the new domains, which tend to break traditional ML methods.
The takeaway is that if you are working with a very noisy dataset, or want your model to last a little longer in production at the cost of a bit of accuracy, feature neutralization might be a good choice.
Domain Related Sample Weights
Inspired by a few methods described in the paper I tried to weigh the samples according to the number of rows of each domain.
Specifically taking the number of examples of a given domain, dividing by the number of examples of the largest domain, and using this number as the sample weight.
It decreased the score for the seen domains and didn’t improve the score for the unseen, so I threw it away.
I really love Michael Jahrer’s solution to the Porto Seguro competition on Kaggle. I imagine he tried that solution in other competitions until he found one where it worked beautifully.
The idea is to train a denoising autoencoder over your input features and then use the learned representation to train a downstream model.
The winners of the Jane Street competition did something similar and it worked well there too, so I decided to try it here.
I tried multiple versions: pure unsupervised, supervised, predicting both the environment and the final label, but none of it beat XGBoost with ordinal and count encoding.
I still like this idea and will keep it in my toolbox so I can understand the situations when it’s most likely to work.
Field-aware Factorization Machines
Field-aware Factorization Machines (LibFFM) is a classic tool used to win click-through prediction competitions with lots of categorical variables. I had to try it.
It didn’t work here.
My guess is the same as above: we benefit more from reducing the learning from our training dataset than trying to make the model understand the details of it better.
Practical Learnings For Projects Going Forward
Spending a week working on a problem, even part-time, can teach you a lot!
These are the learnings I am going to take to my projects outside of competitions.
Monitoring And Retraining Are King
The most important thing I learned in this competition is that monitoring and retraining are still the best ways to account for distribution shifts.
Even if the new domain dataset is small. Any data from the right distribution is better than no data.
The best way to spend your time when trying to generalize to new distributions is by collecting data from the target distribution and automating the retraining pipeline of the model.
This was already my belief but I always like to challenge my assumptions to find a better way to do stuff.
I recently automated a private project model to retrain automatically every day. It improved the performance and it was a good opportunity to learn more about keeping a model fresh with the currently available tools.
A great place to learn about monitoring and retraining machine learning systems is Chip Huyen’s blog.
Production Is The True Test Set
It’s something I learned to accept a few years ago.
Going fast to production and testing your model in a real environment is the only way to really know its performance.
There is no excuse to not do it as there are a ton of techniques, like shadow deployment or A/B testing with a small sample of incoming data, that will allow you to test the models without breaking anything.
Making a trustworthy validation split is very important, but limited. No amount of offline evaluation can replace true production testing.
After this competition, my rule of thumb for out-of-domain generalization is to evaluate offline with random split and retrain often.
Caveat: we didn’t have a time dimension explicit in this data, but in industry, we need to take it into account. So even with the same domain between training and validation, splitting by time still seems necessary.
What About Leave-One-Domain-Out (LODO) Validation?
I created a small experiment locally to compare the random split with the LODO split as the researchers did on the paper and got the same results.
A validation split that has both in and out-of-domain samples gives a much closer estimate of out-of-domain error than using LODO.
There must exist a case where this is false, but I am keeping it as a rule of thumb.
A final tip:
It’s very important to consider the retraining process during validation. It’s rare (as in pandemic rare) to have an abrupt change of domain.
So always simulate your retraining scheme with some rolling window validation as I did on the daily marketing mix model.