This lesson is still being designed and assembled (Pre-Alpha version)

Cross-Validation

Overview

Teaching: 10 min
Exercises: 10 min
Questions
  • How can we determine whether a model generalizes to new data?

  • How do we diagnose over-/underfitting in our model?

Objectives
  • Learners can apply cross-validation to their candy data set to test the validity of their algorithm.

  • Learners understand the difference between overfitting and underfitting and the problems they cause.

How well does a model generalize to new data?

This is a key question in building a machine learning model. Earlier in this lesson, we learned about evaluation metrics, and suggested to split your training data set–the data for which you know the outcomes–into two subsets, one to be used for training, and one for testing.

Challenge

Can you think of reasons why your algorithm wouldn’t generalize to new, unknown examples?

Solution

There are several reasons why an algorithm might not generalize to new examples, some of which we talked about before. For example, if your training data set is missing examples from a class, the model won’t know about that class, and won’t be able to recognize examples once it sees them later on. Another key reason might be that the model is either not flexible enough to represent the structure in the data or too flexible, that is, it is so flexible that it picks up on some minute differences caused by noise or measurement error. This is called under- and overfitting.

Underfitting and Overfitting

Earlier, we talked about different parameters algorithms might have. For example, logistic regression includes a parameter for each feature you use in the model. In general, more parameters make the model more flexible; decision boundaries have more freedom to wiggle to fit the data. On the surface, this sounds great! More parameters mean a better ability to represent structure in the data.

John Von Neumann famously said “With four parameters I can fit an elephant, with five I can make him wiggle his trunk.” With this he meant to say that more complex models aren’t necessarily better, because with a complex enough function you can fit any data set. How is this relevant to our machine learning problem?

Think back to when you recorded data.

You’re probably pretty good at taking measurements on average, but for any individual candy, there’s a good chance your measurement is off by just a bit. This is called measurement error, and it’s important in machine learning: you want your model to fit real differences between classes, but not the differences generated by the errors in your measurements. If your model isn’t flexible enough, it won’t be able to capture all the differences between the different classes you’re interested in. This is called underfitting. Conversely, if it is too flexible, it will fit the differences between classes, but also variations in your data caused by measurement error. That’s called overfitting.

Challenge

Look at the data you recorded (or downloaded). Can you think of potential sources of measurement error in these data points?

Solution

Training your algorithm on part of your data set, and testing its performance on both the data you trained on and a subset of your data you held back from training, is a great way to explore whether your model is under- or overfitting.

If you’ve underfit your data, then the performance on both the training and test data sets will be bad. That is because the model can’t represent all the complexities of your data properly, and so a significant fraction of samples in both your training and test sets will likely be misclassified. A model that overfits, on the other hand, will do really well on the training set: the key property of overfitting is that it tries to make sure it represents all of the training data correctly, so a very high fraction of your training set will be classified correctly. But because overfit models don’t generalize well, performance on your test data set will still be bad: the model will be so fine-tuned to the specific training data set it has seen, including the noise that it contains, that it won’t do a good job on the test data.

Crossvalidation

So far, we’ve talked about how you can evaluate whether your model performs well using the test data set. However, earlier we also cautioned that one should set aside the test data set until the very end of the process — after you’ve chosen an algorithm, the different options and parameters for its use, and are ready to train the model to use on new, unknown examples. This leaves us in a bind: how can we figure out which algorithm and which options to choose, when we don’t have the test data set to evaluate its performance?

The answer is called crossvalidation. This is a process by which you split up your training data set again, and use one part for training, the other for testing. This additional test set is often called a validation set, because you use it to validate your algorithm’s performance. However, by doing this once, you might get either a good or bad result just by chance: the specific subset you chose for training is missing examples from one class, or examples from a particular part of parameter space (for example, perhaps it has only very small peanut M&Ms purely by chance). Instead of choosing a single split between training data set and validation data set, in practice you can repeat that process multiple times, and then see how the performance varies with different splits.

One specific type of cross-validation is called leave-one-out crossvalidation (LOO CV). As the name suggests, you select all but one of your training examples to train your algorithm. Then you use the final example (the one you left out) to test your algorithm. This process is repeated many times, leaving out a different example to test on each time. By the end, each training example has been left out as a test example once. You can average the results of each test to get an estimate of the overall error.

Leave-one-out cross-validation has a number of issues. For example, many evaluation metrics won’t work on a single sample, and often, you are left with a simple error metric that measures whether a particular sample was classified correctly or not. More importantly, LOO crossvalidation is computationally expensive: you need to retrain your algorithm times, where is the number of examples in your training set. Most algorithms also take longer to train the more training data you have. As a result, LOO CV might be unfeasible if your data set is very large.

A strategy that’s used much more often is called -fold crossvalidation. Instead of leaving out a single example, you pick a subset of examples to retain as a validation set, and train on the rest, then test on the held-out data. In the next step, you pick a different subset to hold out as a new validation set, and repeat the process. As an example, imagine you have 100 candies. For 10-fold crossvalidation, you would split your candies into 10 subsets of 10 candies each. In the first step, you train your algorithm on the first nine subsets, and test the performance on the tenth. In the next step, you train on subsets 1-8 and 10, then test on the 9th subset. You continue this process until all of your subsets–also called folds–have been used for testing the performance once. This process allows for a wider range of performance metrics and is far more computationally feasible. After you have run through all combinations, you can look at the performance metrics from across the different subsets, which help you gauge the algorithm’s performance and stability.

Challenge

In the example above, I used 10 folds to demontrate how to do k-fold crossvalidation. Given that we have four types of candy, with approximately equal proportions, do you think 10 folds is a good split? Would you use more or less folds? How would your answer change if one type of candy was over- or underrepresented?

Key Points

  • Training is not enough: we need to make sure the model generalizes to new data points it hasn’t seen before.

  • Underfitting and overfitting are common problems when applying machine learning models that can be diagnosed with cross-validation.

  • Crossvalidation is the process of randomly dividing the data into subsets, and using different combinations of subsets as training and validation sets.