Machine Learning Overtraining

Overtraining is a concern that we must be aware of when training a machine learning model.  Vaimal allows several methods to reduce the potential for overtraining.  Before discussing how to prevent overtraining, lets see what overtraining looks like.

Overtraining

Training error is a function of the difference between predicted outcome and actual outcome for each training data point.  Validation error is a function of the difference between predicted outcome and actual outcome for a separate validation data set using the network being trained.  The chart below shows training and validation error of an MLP neural network during the course of training.

Error vs. Epoch

The training and validation error generally trend downward until about epoch 6500.  Then the two errors start diverging as shown in the close-up view below.

Error vs. Epoch (close-up)

When this divergence becomes clear, then we have overtrained the network.  As training error continues to decrease, validation error increases.  This means the network is learning the training data well, but it doesn't generalize to new data.

One question that arises from the charts above it how can we be sure that we're overtraining?  If we are watching errors at epoch 3000 we may be tempted to stop training because validation error has spiked upward.  However, we see that validation error continued to decrease later on.  Fortunately Vaimal takes care of this by saving the network parameters that resulted in the lowest validation error.  In this case, the MLP weights were saved at epoch 6471, which was the lowest validation error.

Ways to Avoid Overtraining

Use a Train/Validation/Test Partition

If there is an ample amount of data available, the data can be partitioned into three sets.  The training set is used to train the model.  The model is tested on the validation set during training to determine error on unseen data.  That is, data that was not used to train the model.  Using this method, Vaimal will save the model at the point where validation error is lowest.  Finally, the test set is used to verify generalization performance of the model.

Regularization

Vaimal has two types of regularization available for MLPs: L1 and L2.  In both types a penalty is added to training error depending on the magnitude of weights.  Larger magnitude weights may be a sign of overtraining where the network is trying to learn noisy features of the training data set.  Regularization steers the model toward smaller weights while minimizing error.  The lambda constant is a multiplier on the penalty.

L1: total error = training error + λΣ|weight|

L2: total error = training error + λΣ weight²

Bagging

Bagging (bootsrap aggregating) is an ensemble method that combines multiple models into a meta-model.  For each model, the training data is randomly sampled with replacement from all data flagged as training. This means that some data points can be represented in the “bag” more than once for a model. Other training data may not be represented at all in a model. This sampling method creates different training sets for each model which increases model diversity and helps avoid overtraining.