Python:相对简洁的基于高斯混合模型的聚类算法(GMM)

"""
GMM clustering algorithm
By Daniel He
At CQUPT
"""
import numpy as np
from scipy.stats import multivariate_normal
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from sklearn import datasets
import matplotlib.pyplot as plt



class GMM():
    def __init__(self, X, Y, K, max_iters=100):
        self.X = X
        self.Y = Y
        self.K = K  #高斯模型的个数
        self.max_iters = max_iters
        self.nSample, self.nDim = self.X.shape
        self.mu, self.cov, self.weight = self.Initial()
        self.gamma = np.zeros((self.nSample, self.K))
        self.assignments = None
        self.run()

    def Initial(self):
        kmeans = KMeans(n_clusters=self.K)
        assignments = kmeans.fit_predict(self.X)
        labs, count = np.unique(assignments, return_counts=True)
        mu = kmeans.cluster_centers_
        cov = np.array([np.eye(self.nDim)] * self.K)
        weight = np.ones(self.K)
        for i, lab in enumerate(labs):
            weight[i] = count[i] / np.sum(count)
            cov[i] = np.cov(self.X[assignments==lab].T)
        return mu, cov, weight

    def Expectation(self):
        for k in range(self.K):
            self.gamma[:,k] = multivariate_normal.pdf(self.X, self.mu[k],self.cov[k])
        self.gamma = self.weight * self.gamma
        self.gamma /= self.gamma.sum(axis=1)[:,np.newaxis]

    def Maximization(self):
        self.weight = self.gamma.sum(axis=0)
        for k in range(self.K):
            resp = self.gamma[:,k][:,np.newaxis]
            self.mu[k] = (resp * self.X).sum(axis=0) / self.weight[k]
            self.cov[k] = ((self.X - self.mu[k]).T).dot((self.X - self.mu[k])*resp) / self.weight[k]
        self.weight /= self.weight.sum()

    def run(self):
        for i in range(self.max_iters):
            self.Expectation()
            self.assignments = self.gamma.argmax(axis=1)
            self.Maximization()


if __name__ == '__main__':
    # X, y = load_iris(return_X_y=True)
    # X = X[:,0:2]
    X, y = datasets.make_blobs(n_features=2, n_samples=800,centers=3, random_state=4, cluster_std=[1, 1, 1])
    plt.scatter(X[:,0],X[:,1],c=y)
    plt.show()

    gmm = GMM(X=X, Y=y, K=3, max_iters=100)
    y_hat = gmm.assignments
    plt.scatter(X[:,0],X[:,1],c=y_hat)
    plt.show()

Python:相对简洁的基于高斯混合模型的聚类算法(GMM)_第1张图片

 

Python:相对简洁的基于高斯混合模型的聚类算法(GMM)_第2张图片

 

你可能感兴趣的:(算法,python,GMM,聚类)