从0实现高斯混合模型(EM-GMM)

Problem:

  1. Please build a Gaussian mixture model (GMM) to model the data in file TrainingData_GMM.csv. Note that the data is composed of 4 clusters, and the model should be trained by expectation maximization (EM) algorithm.

  2. Based on the GMM learned above, assign each training data point into one of 4 different clusters

Questions:

1) Show how the log-likelihood evolves as the training proceeds


image

x轴为迭代次数,y轴为log-likelihood值

2) The learned mathematical expression for the GMM model after training on the given dataset

\sigma = \begin{bmatrix} \begin{bmatrix}0.03446&-0.01299\\ -0.01299&0.03458\\\end{bmatrix} \begin{bmatrix}0.02259&-0.00761\\ -0.00761&0.02361\\\end{bmatrix} \begin{bmatrix}0.00886&0.00187\\ 0.00187&0.00881\\\end{bmatrix} \begin{bmatrix}0.07024&0.03731\\ 0.03731&0.06498\\\end{bmatrix} \end{bmatrix}
3) Randomly select 500 data points from the given dataset and plot them on a 2dimensional coordinate system. Mark the data points coming from the same cluster (using the results of Problem 2) with the same color.

image

4) Some analyses on the impacts of initialization on the converged values of EM algorithm
不同的初始参数对EM-GMM算法最后收敛的效果影响非常大,我的
是随机生成的,最佳收敛为-1946左右,但是也在-3000,-2000收敛过

5) Some analyses on the results you obtained
从结果上来看,算法已经可以很好地区分4个类别,说明EM-GMM算法在这种含有隐变量的问题,尤其是聚类问题中有着不错的效果。这次作业我也从推导入手,弄明白了EM与GMM,以及二维高斯分布等问题,收获颇丰!

import numpy as np
# Read data from csv
data = np.genfromtxt('TrainingData_GMM.csv', delimiter=',')
# n = num of sample, d = dimension of data, in this case, is 2-d.
n,d = np.shape(data)
# k = num of cluster, set to 4 by the problem describtion.
k = 4
# iter
iter = 500
# 2-d guassin distribution
def guassin_distribution(_x,_miu,_sigma):
    tmp1 = 1/((2*np.pi)*np.linalg.det(_sigma)**(1/2))
    tmp2 = np.exp(-0.5*(_x-_miu)@np.linalg.inv(_sigma)@(_x-_miu))
    return tmp1*tmp2
# calculate log-likelihood of given parameters
def log_likelihood():
    prob = np.zeros([n,k])
    for i in range(n):
        for j in range(k):
            prob[i, j] = guassin_distribution(data[i],miu[j],sigma[j]) 
    return np.sum(np.log(prob@alpha))
# E step, calculate gamma[y,j].
def E():
    gamma = np.zeros([n,k])
    for i in range(n):
        for j in range(k):
            gamma[i,j] = alpha[j] * guassin_distribution(data[i],miu[j],sigma[j]) 
#     print(np.sum(gamma,axis=1).reshape((n,1)))
    new_gamma = gamma / np.sum(gamma,axis=1).reshape((n,1))
    return new_gamma,gamma
# M step, calculate the latest parameters
def M(gamma):
    K_sum_gamma = np.sum(gamma,axis=0)
    miu_tmp = np.zeros((k, d))
    sigma_tmp = np.zeros((k, d, d))
    alpha_tmp = np.zeros(k)
    for j in range(k):
        tmp_miu = 0
        tmp_sigma = 0
        for i in range(n):
            tmp_miu += gamma[i,j]*data[i]
            tmp_sigma += gamma[i,j]*(data[i]-miu[j]).reshape(2,1)@(data[i]-miu[j]).reshape(1,2)
        miu_tmp[j] = tmp_miu/K_sum_gamma[j]
        sigma_tmp[j] = tmp_sigma/K_sum_gamma[j]
        alpha_tmp[j] = K_sum_gamma[j]/n
    return miu_tmp,sigma_tmp,alpha_tmp
# parmameter initialization
# miu,sigma is the parameter of gaussian distribution
# 因为这里是二维的高斯分布,所以sigma在这里是2维的协方差(Dim*Dim),当处理一维高斯分布的时候我们通常把sigma^2看作方差

miu = np.random.randint(0,2,(k,d))
sigma = np.random.randint(0,2,(k,d,d))
alpha = np.array([0.25,0.25,0.25,0.25])
print(miu)
for i in range(0, k):
    sigma[i] = np.diag([1,1])
    
index = 0
likehood_list = []
while 1 :
    index += 1
    gamma,_ = E()
    miu, sigma, alpha = M(gamma)
    loglike  = log_likelihood()
    likehood_list.append(loglike)
    if index > 5 and  abs(loglike - likehood_list[-2]) < 0.01 :
        print('Finish training')
        break
#     loglike = log_likelihood()
    print('index %d, log likehood %f'%(index,loglike))
[[1 1]
 [1 0]
 [0 0]
 [0 1]]
[ 611.45229966 1430.6772165  1926.06578173 1031.80470211]
index 1, log likehood -9672.165311
[ 335.75137577 1331.29076907 2399.18394139  933.77391378]
index 2, log likehood -7516.916844
[ 295.54773607 1315.88106212 2098.17872769 1290.39247412]
index 3, log likehood -6732.294851
[ 237.72130526 1330.64190269 1722.02088632 1709.61590574]
index 4, log likehood -5824.873937
[ 207.66570465 1342.17076547 1470.71920008 1979.44432979]
index 5, log likehood -5031.472039
[ 197.38485415 1269.03430146 1470.63230013 2062.94854426]
index 6, log likehood -4229.474246
[ 179.51063735 1174.58748844 1603.0608239  2042.8410503 ]
index 7, log likehood -3351.133138
[ 171.67195764 1150.05077496 1709.70143046 1968.57583694]
index 8, log likehood -2959.046982
[ 178.79795077 1149.99864314 1792.51929472 1878.68411137]
index 9, log likehood -2927.094710
[ 193.96805643 1149.99873754 1861.41354903 1794.619657  ]
index 10, log likehood -2905.027898
[ 213.54444589 1149.99946804 1917.28085395 1719.17523212]
index 11, log likehood -2886.123975
[ 236.35308699 1149.99983307 1960.1619301  1653.48514984]
index 12, log likehood -2868.333690
[ 262.01822615 1149.99992085 1989.96797803 1598.01387497]
index 13, log likehood -2850.162029
[ 290.40638865 1149.99994905 2006.96904628 1552.62461602]
index 14, log likehood -2831.098999
[ 320.9322868  1149.99996915 2012.72943982 1516.33830423]
index 15, log likehood -2812.377171
[ 352.0658956  1149.99998355 2010.51437783 1487.41974302]
index 16, log likehood -2795.924313
[ 381.96342042 1149.99999176 2003.92893738 1464.10765044]
index 17, log likehood -2782.622865
[ 409.29180942 1149.99999574 1995.47462626 1445.23356858]
index 18, log likehood -2772.198095
[ 433.42623915 1149.99999759 1986.46653381 1430.10722945]
index 19, log likehood -2763.908830
[ 454.32611636 1149.99999847 1977.46869471 1418.20519047]
index 20, log likehood -2756.982532
[ 472.3401729  1149.99999891 1968.60577922 1409.05404897]
index 21, log likehood -2750.755481
[ 488.03176146 1149.99999913 1959.72715051 1402.2410889 ]
index 22, log likehood -2744.664587
[ 502.07146729 1149.99999924 1950.48262339 1397.44591008]
index 23, log likehood -2738.171491
[ 515.21134548 1149.99999928 1940.33514025 1394.453515  ]
index 24, log likehood -2730.646916
[ 528.34343007 1149.99999927 1928.51915412 1393.13741655]
index 25, log likehood -2721.213850
[ 542.65002974 1149.99999921 1913.95371358 1393.39625747]
index 26, log likehood -2708.543565
[ 559.8341125  1149.9999991  1895.14228341 1395.02360499]
index 27, log likehood -2690.605833
[ 582.32328095 1149.99999894 1870.14857631 1397.52814381]
index 28, log likehood -2664.390693
[ 613.17479658 1149.99999871 1836.83467602 1399.99052869]
index 29, log likehood -2625.841179
[ 655.23092526 1149.99999839 1793.67419024 1401.09488611]
index 30, log likehood -2570.724441
[ 709.4085606  1149.99999793 1741.09910613 1399.49233533]
index 31, log likehood -2497.122214
[ 773.45862031 1149.99999729 1682.02719971 1394.51418269]
index 32, log likehood -2406.763618
[ 842.49057709 1149.99999645 1620.80983062 1386.69959584]
index 33, log likehood -2301.998774
[ 910.83901083 1149.9999955  1561.73542083 1377.42557284]
index 34, log likehood -2188.814428
[ 973.19529188 1149.99999457 1508.32843011 1368.47628345]
index 35, log likehood -2082.601753
[1024.65287295 1149.99999384 1463.98573368 1361.36139953]
index 36, log likehood -2006.034687
[1062.59718728 1149.99999346 1430.74073067 1356.66208859]
index 37, log likehood -1968.437544
[1088.75814826 1149.99999334 1407.47964219 1353.76221621]
index 38, log likehood -1954.693990
[1106.59286891 1149.99999327 1391.88124644 1351.52589138]
index 39, log likehood -1949.858506
[1118.97485693 1149.99999317 1381.70521529 1349.31993461]
index 40, log likehood -1947.959181
[1127.78707665 1149.99999306 1375.16996326 1347.04296702]
index 41, log likehood -1947.122789
[1134.20779982 1149.99999298 1371.00469136 1344.78751583]
index 42, log likehood -1946.719193
[1138.98103843 1149.99999292 1368.36206003 1342.65690863]
index 43, log likehood -1946.509007
[1142.589615   1149.99999287 1366.69310582 1340.71728632]
index 44, log likehood -1946.392227
[1145.35625181 1149.99999283 1365.64552897 1338.99822638]
index 45, log likehood -1946.323806
[1147.50243001 1149.99999281 1364.99377606 1337.50380112]
index 46, log likehood -1946.282033
[1149.18368366 1149.99999279 1364.59348535 1336.22283821]
index 47, log likehood -1946.255744
[1150.51148373 1149.99999277 1364.35227103 1335.13625247]
index 48, log likehood -1946.238844
[1151.56720095 1149.99999276 1364.21104918 1334.22175711]
index 49, log likehood -1946.227821
[1152.41122685 1149.99999275 1364.13209462 1333.45668578]
Finish training
import pylab
# %matplotlib inline 
%config InlineBackend.figure_format = 'png'
pylab.style.use('default')
pylab.plot(likehood_list)
[]
image
node_num = 500
_,gamma=E()
label = np.argmax(gamma,1)
selected_node_index = np.random.choice(range(n),size=node_num)
node_pos = data[selected_node_index]
label = label[selected_node_index]
pylab.scatter(node_pos[:,0],node_pos[:,1],marker='o',c=label,cmap=pylab.cm.Accent)

image

你可能感兴趣的:(从0实现高斯混合模型(EM-GMM))