10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (2024)

Introduction

While working as a data scientist, some of the most frequently occurring problem statements are related to binary classification. A common problem when solving these problem statements is that of class imbalance. When observation in one class is higher than in other classes, a class imbalance exists. Example: To detect fraudulent credit card transactions. As shown in the graph below, the fraudulent transaction is around 400 compared to the non-fraudulent transaction of around 90000.

10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (1)

Class Imbalance in machine learning oversampling in machine learning is a common problem in machine learning, especially in classification problems. Imbalance data can hamper our model accuracy big time. It appears in many domains, including fraud detection, spam filtering, disease screening, SaaS subscription churn, advertising click-throughs, etc. Let’s understand how to deal with imbalanced data in machine learning.

Learning Objectives

  • Get familiar with class imbalance in ML through coding tutorials in this article.
  • Understand various techniques for handling imbalanced data, such as Random under-sampling, Random over-sampling, and NearMiss.

Table of contents

  • The Problem With Class Imbalance in Machine Learning
  • Credit Card Fraud Detection Example
  • The Metric Trap
  • Resampling Techniques to Solve Class Imbalance
  • How to Balance Data With the Imbalanced-Learn Python Module?
  • Advantages and Disadvantages of Under-Sampling
  • Advantages and Disadvantages of Over-Sampling
  • Frequently Asked Questions

The Problem With Class Imbalance in Machine Learning

Most machine learning algorithms work best when the number of samples in each class is about equal. This is because most algorithms are designed to maximize accuracy and reduce errors.

However, if the dataframes has imbalanced classes, then In such cases, you get a pretty high accuracy just by predicting themajority class, but you fail to capture theminority class, which is most often the point of creating the model in the first place. For example, if the class distribution shows that 99% of the data has the majority class, then any basic classification model like the logistic regression or decision tree will not be able to identify the minor class data points.

Credit Card Fraud Detection Example

Let’s say we have adatasetof credit card companies where we have to find out whether the credit card transaction was fraudulent or not.

But here’s the catch… fraud transaction is relatively rare. Only 6% of the transactions are fraudulent.

Now, before you even start, do you see how the problem might break? Imagine if you didn’t bother training a model at all. Instead, what if you just wrote a single line of code that always predicts ‘no fraudulent transaction’

def transaction(transaction_data): return 'No fradulent transaction'

Well, guess what? Your “solution” would have 94% accuracy!

Unfortunately, that accuracy is misleading.

  • For all thosenon-fraudulenttransactions, you’d have 100% accuracy.
  • For those transactions which arefraudulent, you’d have 0% accuracy.
  • Your overallaccuracy would be highsimply because most of the transactions are not fraudulent (not because your model is any good).

This is clearly a problem because many machine learning algorithms are designed to maximize overall accuracy. In this article, we will see different techniques to handle imbalanced data.

Sample Dataset

We will use a credit card fraud detection dataset for this article. You can find the datasethere.

After loading the data display the first five-row of the data set.

Python Code:

You can clearly see that there is a huge difference between the data set. 9000 non-fraudulent transactions and 492 fraudulent.

The Metric Trap

One of the major issues that new developer users fall into when dealing with unbalanced datasets relates to the evaluation metrics used to evaluate their machine learning model. Using simpler metrics likeaccuracy scorecan be misleading. In a dataset with highly unbalanced classes, the classifier will always “predicts” the most common class without performing any analysis of the features, and it will have a high accuracy rate, obviously not the correct one.

Let’s do this experiment using the simple XGBClassifier and no feature engineering:

# import linraryfrom xgboost import XGBClassifierxgb_model = XGBClassifier().fit(x_train, y_train)# predictxgb_y_predict = xgb_model.predict(x_test)# accuracy scorexgb_score = accuracy_score(xgb_y_predict, y_test)print('Accuracy score is:', xbg_score)OUTPUTAccuracy score is: 0.992

We can see 99% accuracy, we are getting very high accuracy because it is predicting mostly themajorityclass that is 0 (Non-fraudulent).

Resampling Techniques to Solve Class Imbalance

One of the widely adopted class imbalance techniques for dealing with highly unbalanced datasets is called resampling. It consists of removing samples from the majority class (under-sampling) and/or adding more examples from the minority class (over-sampling).

10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (2)

Despite the advantage of balancing classes, these techniques also have their weaknesses (there is no free lunch).

The simplest implementation ofover-samplingis to duplicate random records from the minority class, which can cause overfishing.

Inunder-sampling, the simplest technique involves removing random records from the majority class, which can cause a loss of information.

Let’s implement this with the credit card fraud detection example.

We will start by separating the class that will be 0 and class 1.

# class countclass_count_0, class_count_1 = data['Class'].value_counts()# Separate classclass_0 = data[data['Class'] == 0]class_1 = data[data['Class'] == 1]# print the shape of the classprint('class 0:', class_0.shape)print('class 1:', class_1.shape
10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (3)

1. Random Under-Sampling

Undersampling can be defined asremoving some observations of the majority class. This is done until the majority and minority class is balanced out.

Undersampling can be a good choice when you have a ton of data -think millions of rows. But a drawback to undersampling is that we are removing information that may be valuable.

class_0_under = class_0.sample(class_count_1)test_under = pd.concat([class_0_under, class_1], axis=0)print("total class of 1 and0:",test_under['Class'].value_counts())# plot the count after under-sampelingtest_under['Class'].value_counts().plot(kind='bar', title='count (target)')
10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (4)

2. Random Over-Sampling

Oversampling can be defined as adding more copies to the minority class. Oversampling in machine learning can be a good choice when you don’t have a ton of data to work with.

A con to consider when undersampling is that it can cause overfitting and poor generalization to your test set.

class_1_over = class_1.sample(class_count_0, replace=True)test_over = pd.concat([class_1_over, class_0], axis=0)print("total class of 1 and 0:",test_under['Class'].value_counts())# plot the count after under-sampelingtest_over['Class'].value_counts().plot(kind='bar', title='count (target)')
10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (5)

How to Balance Data With the Imbalanced-Learn Python Module?

A number of more sophisticated resampling techniques have been proposed in the scientific literature.

For example, we can cluster the records of the majority class and do the under-sampling by removing records from each cluster, thus seeking to preserve information. In over-sampling, instead of creating exact copies of the minority class records, we can introduce small variations into those copies, creating more diverse synthetic samples.

Let’s apply some of these resampling techniquesusing the Python libraryimbalanced-learn. It is compatible with scikit-learn and is part of scikit-learn-contrib projects.

import imblearn

3. Random Under-Sampling With Imblearn

You may have heard about pandas, numpy, matplotlib, etc. while learning data science. But there is another library: imblearn, which is used to sample imbalanced datasets and improve your model performance.

RandomUnderSampleris a fast and easy way to balance the data by randomly selecting a subset of data for the targeted classes. Under-sample the majority class(es) by randomly picking samples with or without replacement.

# import libraryfrom imblearn.under_sampling import RandomUnderSamplerrus = RandomUnderSampler(random_state=42, replacement=True)# fit predictor and target variablex_rus, y_rus = rus.fit_resample(x, y)print('original dataset shape:', Counter(y))print('Resample dataset shape', Counter(y_rus))
10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (6)

4. Random Over-Sampling With imblearn

One way to fight imbalanced data is togenerate new samplesin the minority classes. The most naive strategy is to generate new samples by random sampling with the replacement of the currently available samples. TheRandomOverSampleroffers such a scheme.

# import libraryfrom imblearn.over_sampling import RandomOverSamplerros = RandomOverSampler(random_state=42)# fit predictor and target variablex_ros, y_ros = ros.fit_resample(x, y)print('Original dataset shape', Counter(y))print('Resample dataset shape', Counter(y_ros))
10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (7)

5. Under-Sampling: Tomek Links

Tomek links are pairs of very close instances but of opposite classes. Removing the instances of the majority class of each pair increases the space between the two classes, facilitating the classification process.

Tomek’s link exists if the two samples are the nearest neighbors of each other.

10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (8)

In the code below, we’ll useratio='majority'to resample the majority class.

# import libraryfrom imblearn.under_sampling import TomekLinkstl = RandomOverSampler(sampling_strategy='majority')# fit predictor and target variablex_tl, y_tl = ros.fit_resample(x, y)print('Original dataset shape', Counter(y))print('Resample dataset shape', Counter(y_ros))
10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (9)

6. Synthetic Minority Oversampling Technique (SMOTE)

This technique generates synthetic data for the minority class.

SMOTE (Synthetic Minority Oversampling Technique in machine learning) works by randomly picking a point from the minority class and computing the k-nearest neighbors for this point. Thesynthetic points are addedbetween the chosen point and its neighbors.

10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (10)

SMOTE algorithmworks in 4 simple steps:

  1. Choose a minority class as the input vector.
  2. Find its k nearest neighbors (k_neighborsis specified as an argument in theSMOTE()function).
  3. Choose one of these neighbors and place a synthetic point anywhere on the line joining the point under consideration and its chosen neighbor.
  4. Repeat the steps until the data is balanced.
# import libraryfrom imblearn.over_sampling import SMOTEsmote = SMOTE()# fit predictor and target variablex_smote, y_smote = smote.fit_resample(x, y)print('Original dataset shape', Counter(y))print('Resample dataset shape', Counter(y_ros))
10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (11)

7. NearMiss

NearMiss is an under-sampling technique. Instead of resampling the Minority class, using a distance will make the majority class equal to the minority class.

from imblearn.under_sampling import NearMissnm = NearMiss()x_nm, y_nm = nm.fit_resample(x, y)print('Original dataset shape:', Counter(y))print('Resample dataset shape:', Counter(y_nm))
10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (12)

8. Change the Performance Metric

Accuracy is not the best metric to use when evaluating imbalanced datasets, as it can be misleading.

Metrics that can provide better insight are:

  • Confusion Matrix:a table showing correct predictions and types of incorrect predictions.
  • Precision:the number of true positives divided by all positive predictions. Precision is also called Positive Predictive Value. It is a measure of a classifier’s exactness. Low precision indicates a high number of false positives.
  • Recall:the number of true positives divided by the number of positive values in the test data. The recall is also called Sensitivity or the True Positive Rate. It is a measure of a classifier’s completeness. Low recall indicates a high number of false negatives.
  • F1: Score:the weighted average of precision and recall.
  • Area Under ROC Curve(AUROC): AUROC represents the likelihood of your model distinguishing observations from two classes.
    In other words, if you randomly select one observation from each class, what’s the probability that your model will be able to “rank” them correctly?

9. Penalize Algorithms (Cost-Sensitive Training)

The next tactic is to use penalized learning algorithms that increase the cost of classification mistakes in the minority class.

A popular algorithm for this technique is Penalized-SVM.

During training, we can use the argumentclass_weight=’balanced’to penalize mistakes on the minority class by an amount proportional to how under-represented it is.

We also want to include the argumentprobability=Trueif we want to enable probability estimates for SVM algorithms.

Let’s train a model using Penalized-SVM on the original imbalanced dataset:

# load libraryfrom sklearn.svm import SVC# we can add class_weight='balanced' to add panalize mistakesvc_model = SVC(class_weight='balanced', probability=True)svc_model.fit(x_train, y_train)svc_predict = svc_model.predict(x_test)# check performanceprint('ROCAUC score:',roc_auc_score(y_test, svc_predict))print('Accuracy score:',accuracy_score(y_test, svc_predict))print('F1 score:',f1_score(y_test, svc_predict))
10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (13)

10. Change the Algorithm

While in every machine learning problem, it’s a good rule of thumb to try a variety of algorithms, it can be especially beneficial with imbalanced datasets.

Decision trees frequently perform well on imbalanced data. In modern machine learning, tree ensembles (Random Forests, Gradient Boosted Trees, etc.) almost always outperform singular decision trees, so we’ll jump right into those:

Tree base algorithm work by learning a hierarchy of if/else questions. This can force both classes to be addressed.

# load libraryfrom sklearn.ensemble import RandomForestClassifierrfc = RandomForestClassifier()# fit the predictor and targetrfc.fit(x_train, y_train)# predictrfc_predict = rfc.predict(x_test)# check performanceprint('ROCAUC score:',roc_auc_score(y_test, rfc_predict))print('Accuracy score:',accuracy_score(y_test, rfc_predict))print('F1 score:',f1_score(y_test, rfc_predict))
10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (14)

Advantages and Disadvantages of Under-Sampling

Advantage:

  • It can help improve run time and storage problems by reducing the number of training data samples when the training data set is huge.

Disadvantages:

  • It can discard potentially useful information which could be important for building rule classifiers.
  • The sample chosen by random under-sampling may be a biased sample. And it will not be an accurate representation of the population. Thereby resulting in inaccurate results with the actual test data set.

Advantages and Disadvantages of Over-Sampling

Advantages:

  • Unlike under-sampling, this method leads to no information loss.
  • Outperforms under sampling

Disadvantages:

  • It increases the likelihood of overfitting since it replicates the minority class events.

Conclusion

To summarize, in this article, we have seen various techniques to handle the class imbalance in a dataset. There are actually many methods to try when dealing with imbalanced data. You can check the implementation of these codes in my GitHub repositoryhere.

Key Takeaways

  • In this article, we learned about the different techniques that we can perform to handle class imbalance in machine learning.
  • Some of the most widely used techniques are SMOTE, imblearn oversampling, and under sampling.
  • There is no “best“ method for handling imbalance, it depends on your use case.

Frequently Asked Questions

Q1. What are class imbalances?

A. Class imbalances in MLhappen when the categories in your dataset are not evenly represented. For example, in a medical dataset, you might have many more healthy patients than sick ones. This can make it hard for a model to learn to recognize the less common category (the sick patients in this case).

Q2. What is a class balance?

A. Class balance refers to having an even number of samples for each category in your dataset. For example, if you have a dataset for spam detection, a class-balanced dataset would have roughly the same number of spam and non-spam emails.

Q3. How to solve class imbalance problem?

A. There are several ways to address class imbalance:
Resampling: You can oversample the minority class or undersample the majority class to balance the dataset.
Synthetic Data: Generate new samples for the minority class using techniques like SMOTE (Synthetic Minority Over-sampling Technique).
Class Weighting: Adjust the weights of the classes in your loss function to give more importance to the minority class.
Anomaly Detection Models: Sometimes, models designed to detect anomalies can work well for imbalanced datasets.

Q4. Which loss is best for class imbalance?

A. One commonly used loss function for handling class imbalance in ML is Focal Loss. It reduces the weight of well-classified examples and focuses more on hard-to-classify examples, which helps the model to learn better from the minority class.

class imbalanceimbalanced datasetimblearnNearMissrandom undersamplingSMOTE

g

guest_blog10 Jun 2024

ClassificationIntermediateMachine LearningPythonStructured Data

10 Techniques to Solve Imbalanced Classes in Machine Learning (Updated 2024) (2024)

References

Top Articles
Latest Posts
Article information

Author: Edmund Hettinger DC

Last Updated:

Views: 6117

Rating: 4.8 / 5 (58 voted)

Reviews: 89% of readers found this page helpful

Author information

Name: Edmund Hettinger DC

Birthday: 1994-08-17

Address: 2033 Gerhold Pine, Port Jocelyn, VA 12101-5654

Phone: +8524399971620

Job: Central Manufacturing Supervisor

Hobby: Jogging, Metalworking, Tai chi, Shopping, Puzzles, Rock climbing, Crocheting

Introduction: My name is Edmund Hettinger DC, I am a adventurous, colorful, gifted, determined, precious, open, colorful person who loves writing and wants to share my knowledge and understanding with you.