What can you do when your model is overfitting your data?
This problem often occurs when we are dealing with an imbalanced dataset. If your dataset represents several classes, one of which is much less represented than the others, then it is difficult to learn the true underlying distribution representing such minor class.
As explained in this must-read paper, the method of addressing class imbalance that emerged as dominant in almost all analyzed scenarios is oversampling. Oversampling should be applied to the level that completely eliminates the imbalance, whereas the optimal undersampling ratio depends on the extent of imbalance. As opposed to some classical machine learning models, oversampling does not cause overfitting of CNNs.
In practice, when training a machine learning model, one would follow some key steps:
- Split the data into a training/testing set (80%, 20%).
- Train a machine learning model by fitting it on training data.
- Evaluate the performance on testing set.
When using deep learning architectures, it is common to split the training data into batches that we feed to our neural network during training time. To build such batches, we usually randomly sample from the training set following a uniform distribution on the set of observations.
Some simple statistics are now needed. Suppose we have a dataset with 2 classes class_1 and class_2. What is the probability of randomly sampling a point from say class_1?
Following a uniform distribution over the set of points, such probability is easily expressed:
In practice, a class imbalance in a binary problem appears when we have much more observations from one class than the other:
As a result, we have:
In other words, it is more likely to draw a point from class_1 than from class_2. Because the model sees much less class_2, it not surprising that it is not capable of learning useful features from such class…
Now, before diving into the coding, we need to understand a key idea when artificially augmenting the data. What we want is to make sure that by artificially augmenting the minor class, we have:
It is time! Let’s code to solve this problem with WeightedRandomSampler from Pytorch.
Dataset: We build a dataset with 900 observations from class_major labeled 0 and100 observations from class_minor labeled 1. (90%, 10%)
Assuming we build 10 batches of 100 sentences each, we would end up having in average 10 sentences of class 1 and 90 sentences of class 0.
How can I rebalance the above easily? Let’s write a few lines of code using Pytorch library.
From the above, we can see that WeightedRandomSampler uses the array example_weights which corresponds to weights given to each class. The goal is to assign a higher weight to the minor class. This will affect the likelihood of drawing a point from each class by moving from a uniform distribution to a multinomial distribution with controlled parameters.
Now we can look in details at the batches contained in arr_batch, each of them should have 100 sentences in practice. For visualisation purpose, we focus on the labels only here.
As we can see from the figure above, we now have balanced batches of data. As a result, during training time, our model will not see significantly more one class over another and risks of overfitting are hence reduced.
In Summary, we saw that:
- Oversampling is a key strategy to address class imbalance and hence reduce risks of overfitting.
- Randomly sampling from your dataset is a bad idea when it has class imbalance.
- Weighted random sampling with WeightedRandomSampler is rebalancing our training data classes by oversampling the minor class.
In the next article, we will dive into the implementation of WeightedRandomSampler to understand better the weighting scheme. We will also apply oversampling in a simple machine learning scenario and analyse its consequence on overall performance.
Thanks for reading, please leave a comment below if you have any feedback! 🤗