Skip to content
Search
Generic filters
Exact matches only

Address class imbalance easily with Pytorch

Data augmentation in computer vision. Credits for the picture to fastai.

What can you do when your model is overfitting your data?

Mastafa Foufa

This problem often occurs when we are dealing with an imbalanced dataset. If your dataset represents several classes, one of which is much less represented than the others, then it is difficult to learn the true underlying distribution representing such minor class.

As explained in this must-read paper, the method of addressing class imbalance that emerged as dominant in almost all analyzed scenarios is oversampling. Oversampling should be applied to the level that completely eliminates the imbalance, whereas the optimal undersampling ratio depends on the extent of imbalance. As opposed to some classical machine learning models, oversampling does not cause overfitting of CNNs.

In practice, when training a machine learning model, one would follow some key steps:

  1. Split the data into a training/testing set (80%, 20%).

When using deep learning architectures, it is common to split the training data into batches that we feed to our neural network during training time. To build such batches, we usually randomly sample from the training set following a uniform distribution on the set of observations.

Some simple statistics are now needed. Suppose we have a dataset with 2 classes class_1 and class_2. What is the probability of randomly sampling a point from say class_1?

Following a uniform distribution over the set of points, such probability is easily expressed:

In practice, a class imbalance in a binary problem appears when we have much more observations from one class than the other:

As a result, we have:

Class imbalance in a binary problem is described by an unbalanced likelihood to draw an observation from a given class.

In other words, it is more likely to draw a point from class_1 than from class_2. Because the model sees much less class_2, it not surprising that it is not capable of learning useful features from such class…

Now, before diving into the coding, we need to understand a key idea when artificially augmenting the data. What we want is to make sure that by artificially augmenting the minor class, we have:

After augmenting our data, our goal is to make the likelihood of drawing a sample from each class as close as possible.

It is time! Let’s code to solve this problem with WeightedRandomSampler from Pytorch.

Dataset: We build a dataset with 900 observations from class_major labeled 0 and100 observations from class_minor labeled 1. (90%, 10%)

Sample of our dataset. A label of 1 corresponds to a sentence in French and a label of 0 to sentence in English.
Class distribution for an unbalanced dataset with textual data and two classes of values 0 and 1. We have 900 sentences of class 0 and 100 sentences of class 1.

Assuming we build 10 batches of 100 sentences each, we would end up having in average 10 sentences of class 1 and 90 sentences of class 0.

Distribution of classes in each of the 10 batches of 100 sentences each. In red is represented the minor class and in blue the major class. We can clearly see the unbalance in each batch of training data. Estimated proportion of class 0 is now 90.5 and estimated proportion of class 1 is 9.5.

How can I rebalance the above easily? Let’s write a few lines of code using Pytorch library.

24 lines of python magic to build balanced batches.

From the above, we can see that WeightedRandomSampler uses the array example_weights which corresponds to weights given to each class. The goal is to assign a higher weight to the minor class. This will affect the likelihood of drawing a point from each class by moving from a uniform distribution to a multinomial distribution with controlled parameters.

Now we can look in details at the batches contained in arr_batch, each of them should have 100 sentences in practice. For visualisation purpose, we focus on the labels only here.

Estimated proportion of class 0 is now 51.4 and estimated proportion of class 1 is 48.6.

As we can see from the figure above, we now have balanced batches of data. As a result, during training time, our model will not see significantly more one class over another and risks of overfitting are hence reduced.

In Summary, we saw that:

  1. Oversampling is a key strategy to address class imbalance and hence reduce risks of overfitting.

In the next article, we will dive into the implementation of WeightedRandomSampler to understand better the weighting scheme. We will also apply oversampling in a simple machine learning scenario and analyse its consequence on overall performance.

Thanks for reading, please leave a comment below if you have any feedback! 🤗