Strategies for Sanity-Checking Machine Learning Models

Training high quality machine learning (ML) models and successfully productionizing them are surprisingly difficult tasks when you’re used to creating more traditional computer software. Data collection, model training, model evaluation, and real-world performance are all tightly coupled, often in subtle ways that make it easy to accidentally fool yourself about your model’s effectiveness.

Modern ML libraries make it easy for novices to create models, but even the experts sometimes kid themselves about how good their models are.

There are many potential issues when using machine learning, but I want to at least highlight a few of the most basic questions I ask myself when sanity-checking machine learning models created by myself or others, with an eye toward classification, particularly file classification.

How Do You Obtain, Generate, and Validate Labels for Your Data?

Labeled data is essential for training machine learning classifiers. The difference between a successful and failed ML model is often the existence of a large and accurate set of labels. Unfortunately, generating labels is typically one of the biggest bottlenecks when building a classifier: labeling data is difficult, expensive, and time-consuming.

However, it turns out that for file classification in particular, there are several open source collections of labels specifying which files are benign and which are malicious (or even more fine-grained labels, like the particular family of malware). So, if you want to train a model, it's as simple as aggregating these file labels and using them to train an ML classifier, right?

Well, unfortunately, many of these open source label collections tend to suffer from mediocre accuracy, especially for non-Portable Executable (PE) file types. And a single file may have multiple conflicting labels (e.g., two labelers disagree on whether a particular file is benign or malicious, or how to define what constitutes a potentially unwanted program, or PUP). How one chooses to resolve these disagreements can have a significant impact on the resulting model.

To make matters more difficult, we've found that aggregating existing labels from the Internet usually isn't sufficient to create really stellar models. To build state-of-the-art classifiers, we typically need to generate labels ourselves. To generate labels, we have a team of human experts who spend their time fixing incorrect labels and coming up with labels for new unknown files. This is a difficult and expensive process, but machine learning can help here, too, with techniques like active learning.

What Model Performance Metrics Do You Use?

You'll often hear people tout the accuracy of their machine learning models, but accuracy by itself is a misleading metric, especially when there is a large class imbalance: i.e., when one label occurs much more frequently than another.

For example, let's say you want to create a classifier that can identify malicious emails with 99% accuracy. You train your model, and lo-and-behold, you get a classifier that's 99% accurate. But, like all good ML practitioners, you’re skeptical of good results. So, you dig deeper and realize that, in general, only 1% of emails are malicious. That means that all your classifier needed to learn to reach 99% accuracy was to always classify the email as benign.

Here's an implementation of a 99% accurate malicious email classifier:

Obviously, such a classifier is useless, even though it's 99% accurate. What we really want are performance metrics that take into account class imbalances. There are many ways to do this, but perhaps the simplest way to think about imbalance-aware metrics is to visualize what's called a confusion matrix.

A confusion matrix tells you, for each class, how many times the classifier was correct, and how many times it was wrong. For example, the confusion matrix for our malicious email classifier might look like this:

 

Classifier calls it benign

Classifier calls it malicious

Truly benign

300

0    

Truly malicious

3

0


From this view, it’s much more obvious that the malicious email classifier isn’t very good—the classifier simply calls everything benign in order to achieve 99% accuracy.

Each cell in the confusion matrix has a special name, which you may already be familiar with:

 

Classifier calls it benign

Classifier calls it malicious

Truly benign

True negatives

False positives

Truly malicious

False negatives

True positives

 

Here we consider “negative” to mean benign (there was not a detection of a malicious email), and we consider “positive” to mean malicious (there was a detection of a malicious email).

These values can also be expressed as percentages. For example, the false negative rate is the percentage of truly malicious emails that were mistakenly classified as benign, which for the above example gives us a false negative rate of 100%—the classifier was wrong about all of the malicious emails.

There are many other metrics that account for class imbalances in various ways, some of which are more intuitive than others, but many of them are related to the confusion matrix directly or indirectly, such as precision, recall, F1 score, and area under the ROC curve.

How Do You Validate Your Models Before Releasing Them to Customers?

I’ll talk a little bit about standard machine learning best practices, but I’ll also touch on some subtleties specific to file classification.

When you train a machine learning classifier, you want to end up with a model that generalizes well. In other words, you want a classifier that can accurately classify data that wasn’t seen during training. Or, to say it in machine learning lingo, you want a classifier that avoids overfitting.

Typically, in machine learning, generalization is measured by taking your dataset and splitting it into a training set and a validation set; note that “validation set,” “test set,” and “hold-out set” are sometimes used interchangeably, although these terms can have important differences depending on context. You then train the classifier on the training set. Crucially, you don’t show the validation set to the classifier during training.

Finally, once you have a trained model, you use it to classify everything in your validation set and see how well it did. If the classifier is good at classifying things in the training set, but not in the validation set, then the classifier doesn’t generalize well—the classifier simply memorized the training set without learning anything deeper about the data.

However, if the classifier does well on the validation set, even though it’s never seen that data before, then you have a classifier that generalizes pretty well.

There are further complications and subtleties here that are outside the scope of this blog post, but the key idea is that to evaluate a classifier’s ability to generalize, you need a set of validation data that’s totally separate from your training data.

For file classification in particular, ensuring this separation is trickier than it might appear at first glance. You need to be very careful to avoid subtle, unintentional duplication between your training and validation sets: most malware is a very slight variation of some other piece of malware; a single piece of malware may have thousands of variants, sometimes with only one-bit differences meant to fool blacklists of hashes; or, maybe many different unique files were generated with the same toolkit and are nearly functionally identical to one another.

If you’re not doing deduplication, you’ll end up with these sorts of near-duplicates sprinkled across both your training and validation sets, and your validation accuracy will look unrealistically high.

A common but imperfect method to make the validation set separate from the training set is to use temporal separation. In other words, train your file classifier on all files first seen before time t, and validate the resulting classifier on all files first seen after time t. The intuition is that the validation set couldn’t have “accidentally leaked” into your training set because information can’t leak backwards in time. Of course, this could arguably lead to misleadingly high accuracy numbers due to the near-duplicate problem mentioned above.

Another thing to look out for is sampling bias: your validation set needs to reflect the distribution of data your classifier will actually see in real-world scenarios, otherwise your accuracy measures are irrelevant at best, misleading at worse.

For example, maybe you have a file classifier that is really good at identifying bitcoin miners; if bitcoin miners appear a hundred times more often in your validation set than they do in the real world, your validation accuracy will look higher than when deployed in the real world.

In practice, a simple validation technique that often works well is to run your classifier in “silent mode” against real-time customer data, i.e., recording classifier decisions, but not actually taking any action based on those decisions. If this model performs more poorly with real-time customer data than you saw in your internal testing, then you may need to go back to the drawing board.

How Do You Address Model Degradation?

For file classification, models inevitably degrade over time as new files appear with fundamental differences from the files in the original training set; these fundamental differences can be extreme enough to go beyond the classifier’s ability to generalize. Additionally, these new, fundamentally different files may have properties that are not well represented by the classifier’s existing features.

Reasonable remediations for degraded models include:

  • Retraining the model from scratch, possibly with new features.
  • Instead of retraining from scratch, start training the model from where you left off, this time with newer, updated data—some types of classifiers are more amenable to this than others, and it typically isn’t an option if you’re adding new features.
  • Training a new classifier to work in concert with the old classifier via some kind of ensemble and/or boosting technique.

Of course, the key to dealing with model degradation is being able to detect it. Fundamentally, this comes down to gathering real-world statistics about how your classifier performs in your customers’ environments. From these statistics, you can build automated systems to detect degradation, as well as have human experts selectively confirm or refute your model’s decisions.

Conclusion

Successfully training and productionizing high-quality machine learning models can be a daunting task. There are many potential issues when developing machine learning models for use in the real world, but asking some of the basic questions highlighted above can provide sanity checks that help keep you on the right track.