Sling Academy
Home/Scikit-Learn/Nearest Centroid Classification in Scikit-Learn

Nearest Centroid Classification in Scikit-Learn

Last updated: December 17, 2024

Nearest centroid classification is a simple yet effective method utilized in supervised learning for pattern recognition. This technique assigns labels to samples based on the proximity of the sample to the computed centroids of each class. In this article, we'll delve into utilizing the Nearest Centroid classifier with Python's Scikit-Learn, explaining each step and providing comprehensive code examples.

Scikit-Learn, a robust library for machine learning in Python, provides an easy-to-use implementation of the nearest centroid classifier. This simplicity makes it suitable especially for multi-class classification problems.

Introduction to Nearest Centroid Classifier

The Nearest Centroid Classifier works by calculating the mean point (centroid) of each class and then allocating a new data point to the class with the nearest centroid. While this method is straightforward and less computationally intensive than other classification techniques, it can be quite effective, especially in cases where the datasets are well-separated by class.

Steps for Implementing Nearest Centroid Classification

Before diving into the implementation, ensure that you have the necessary libraries installed. If not, you can install Scikit-Learn using pip:

pip install scikit-learn

We'll proceed with the following steps to use the Nearest Centroid Classifier:

  1. Import the required libraries.
  2. Prepare the dataset.
  3. Initialize the nearest centroid classifier.
  4. Fit the model on the training data.
  5. Predict the class labels for new data.
  6. Evaluate the model performance.

1. Importing Required Libraries

Begin with importing necessary libraries. Apart from Scikit-Learn, we'll use NumPy and Matplotlib for handling our data and visualization respectively.

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.neighbors import NearestCentroid
import matplotlib.pyplot as plt

2. Preparing the Dataset

We will work with the Iris dataset provided by Scikit-Learn, which is an ideal dataset for simple classification tasks.

# Load the iris dataset
iris = load_iris()
X = iris.data
y = iris.target

# Split the dataset into training and testing data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

3. Initializing the Nearest Centroid Classifier

The NearestCentroid classifier in Scikit-Learn can be instantiated like this:

# Initialize the Nearest Centroid classifier
clf = NearestCentroid()

4. Fitting the Model

Now let's fit the classifier to the training data:

# Fit the model on the training data
clf.fit(X_train, y_train)

5. Making Predictions

After the model is trained, we can predict the class labels for the test set:

# Predict the labels for the test set
predictions = clf.predict(X_test)

6. Evaluating the Model Performance

Finally, evaluate the performance by calculating accuracy:

# Calculate the accuracy
accuracy = accuracy_score(y_test, predictions)
print('Accuracy:', accuracy)

An accuracy score will offer an insight into how well the nearest centroid classifier is performing on our dataset. Remember, you can enhance model evaluation by creating a confusion matrix or performing cross-validation for more in-depth analysis.

Conclusion

In conclusion, the nearest centroid classifier is an excellent starting point for data classification challenges. Its speed and simplicity make it preferable in certain scenarios, especially with a clean and well-structured dataset. While more sophisticated algorithms may offer better performance, the nearest centroid can be your go-to for quick and effective classification tasks.

Feel free to explore further by trying new datasets and comparing nearest centroid classification with other classifiers like SVM or k-NN to broaden your machine learning expertise.

Next Article: Multi-Layer Perceptrons in Scikit-Learn

Previous Article: K-Nearest Neighbors Classification with Scikit-Learn

Series: Scikit-Learn Tutorials

Scikit-Learn

You May Also Like

  • Generating Gaussian Quantiles with Scikit-Learn
  • Spectral Biclustering with Scikit-Learn
  • Scikit-Learn Complete Cheat Sheet
  • ValueError: Estimator Does Not Support Sparse Input in Scikit-Learn
  • Scikit-Learn TypeError: Cannot Broadcast Due to Shape Mismatch
  • AttributeError: 'dict' Object Has No Attribute 'predict' in Scikit-Learn
  • KeyError: Missing 'param_grid' in Scikit-Learn GridSearchCV
  • Scikit-Learn ValueError: 'max_iter' Must Be Positive Integer
  • Fixing Log Function Error with Negative Values in Scikit-Learn
  • RuntimeError: Distributed Computing Backend Not Found in Scikit-Learn
  • Scikit-Learn TypeError: '<' Not Supported Between 'str' and 'int'
  • AttributeError: GridSearchCV Has No Attribute 'fit_transform' in Scikit-Learn
  • Fixing Scikit-Learn Split Error: Number of Splits > Number of Samples
  • Scikit-Learn TypeError: Cannot Concatenate 'str' and 'int'
  • ValueError: Cannot Use 'predict' Before Fitting Model in Scikit-Learn
  • Fixing AttributeError: NoneType Has No Attribute 'predict' in Scikit-Learn
  • Scikit-Learn ValueError: Cannot Reshape Array of Incorrect Size
  • LinAlgError: Matrix is Singular to Machine Precision in Scikit-Learn
  • Fixing TypeError: ndarray Object is Not Callable in Scikit-Learn