[Python] Finding the number of clusters (components) in data using the Gaussian mixture model

• Post category:Algorithms

A Gaussian mixture model (GMM) is a probabilistic mixture model (combination of multiple probability distribution functions). It assumes that the data points are generated from a mixture of “K” Gaussian (normal) distributions with unknown parameters. The most common technique to estimate the unknown parameters for a GMM is the maximum likelihood estimation (MLE).

The GaussianMixute object of the scikit-learn library uses the expectation-maximization algorithm (EM algorithm) for mixture-of-Gaussian models. The expectation-maximization algorithm (EM algorithm) is a powerful method for ﬁnding maximum likelihood solutions for models with latent variables. The EM algorithm comprises two steps:

• E-Step (Estimation step): Estimate the value for each latent variable.
• M-Step (Maximization step): Optimize the parameters (e.g., means, covariances, mixing coefﬁcients, etc.) of the probability distribution using maximum likelihood estimation.

The estimation and maximization steps iterate until the estimate for latent variables stops changing.

This post explains how to find the number of clusters in the given data using the GaussianMixute object. I am using the make_classification() function of the scikit-learn library to generate synthetic data with five classes. The GaussianMixute needs “n_components” (number of clusters) as a parameter. The default value of this parameter is 1. In this code, I am calling the GaussianMixture iteratively with different values of “n_components.” In each iteration, after fitting the model, I compute the Bayesian Information Criterion (BIC). The value of “n_components” that returns the minimum BIC should be the number of clusters in the data.

In very high dimensional data, the GMM algorithm may not return the correct number of clusters in the data. Therefore, if you have very high dimensional data, you need to perform feature minimization before running GMM.

import numpy as np
from sklearn.datasets import make_classification
from copy import deepcopy
from sklearn.mixture import GaussianMixture

def generate_sample_data(sc, fc, nf, tc):
"""
Generate sample data using sklearn
"""
X, y = make_classification(n_samples=sc, n_features=fc, flip_y=nf,
n_redundant=0, n_classes=tc, random_state=4, n_clusters_per_class=1,
n_informative=tc)
return X, y

def determine_clusters(X, n_clusters):
"""
Determine clusters
"""
gm = GaussianMixture(n_components=n_clusters,
random_state=0,
max_iter=250,
covariance_type='full').fit(X)
return gm.bic(X)

if __name__ == '__main__':
"""
Use GMM algorithm to find the number clusters in the data
"""
noise_fraction_in_data = 0
feature_count = 25
cls_count = 5
sample_count = 5000
bic_vals = []
max_clusters = 10

print("\nGenerate simulated data with {0} classes".format(cls_count))
X_all, y_all = generate_sample_data(sample_count, feature_count, noise_fraction_in_data, cls_count)
print("Total records: {0}".format(len(y_all)))
for c in range(cls_count):
print("class {0} records: {1}".format(c, np.sum(y_all == c)))

print("Find clusters")
for v in range(1, max_clusters, 1):
print("Compute BIC using {0} clusters".format(v))
X_tr = deepcopy(X_all)
b = determine_clusters(X_tr, v)
bic_vals.append(b)
print("\nTrue cluster count: ", cls_count)
# find the index of minimum BIC
print("Estimated cluster count: ", np.argmin(bic_vals) + 1)

The above code returns the following output:

Generate simulated data with 5 classes
Total records: 5000
class 0 records: 1000
class 1 records: 1000
class 2 records: 1000
class 3 records: 1000
class 4 records: 1000
Find clusters
Compute BIC using 1 clusters
Compute BIC using 2 clusters
Compute BIC using 3 clusters
Compute BIC using 4 clusters
Compute BIC using 5 clusters
Compute BIC using 6 clusters
Compute BIC using 7 clusters
Compute BIC using 8 clusters
Compute BIC using 9 clusters

True cluster count: 5
Estimated cluster count: 5

If you find any issues with the code, please let me know in the comment section.

This site uses Akismet to reduce spam. Learn how your comment data is processed.