算法逻辑在这里:
http://www.cnblogs.com/Azhu/p/4131733.html
贴之前先说下,本来呢是打算自己写一个的,在matlab 上,不过,实在是写不出来那么高效和健壮的,网上有很多实现的代码,例如上面参考里面的,那个代码明显有问题阿,然后因为那里面的代码与逻辑分析是一致的,那在其基础上修改看看,结果发现代码健壮性实在太差了,我的数据集是 70-by-2000 的矩阵,70个样本2000维,结果协方差的逆根本算不出来,全部是inf,那去前50维,还是算不出来,这个虽然逻辑是对的,但是这软件的局限阿。
那只能用其他方法了,有一个写的很好的,下面会贴出来,不过都是矩阵运算,看是能看懂的,不过数学计算实在写不出来,按这么来的也只是跟着其敲了一遍,敲之前还看了半天才懂其的数学计算,matlab 的内置函数也不算熟,这里就顺便写下来好了。
主函数:
1 function [label, model, llh] = emgm(X, init) 2 % Perform EM algorithm for fitting the Gaussian mixture model. 3 % X: d x n data matrix 4 % init: k (1 x 1) or label (1 x n, 1<=label(i)<=k) or center (d x k) 5 % Written by Michael Chen (sth4nth@gmail.com). 6 %% initialization 7 % fprintf('EM for Gaussian mixture: running ... \n'); 8 % load('final_initlize'); 9 % X = dataset(1).x'; 10 % init = dataset(1).y'; 11 % R n-by-k 矩阵,表示i-th 样本属于j-th 个类的概率,初始化时候为1、0,迭代后变是权重化了。 12 R = initialization(X,init); 13 % label 表示n 个样本的类标号。 14 [~,label(1,:)] = max(R,[],2); 15 % 这句是为了处理类标号不连续的情况 16 R = R(:,unique(label)); 17 18 %pect = zeros(size(label)); 19 % tol 是阀值控制 20 tol = 1e-10; 21 maxiter = 500; 22 % loglikehood 23 llh = -inf(1,maxiter); 24 converged = false; 25 % 当前迭代的标号 26 t = 1; 27 while ~converged && t < maxiter 28 t = t+1; 29 model = maximization(X,R); 30 [R, llh(t)] = expectation(X,model); 31 32 [~,label(:)] = max(R,[],2); 33 u = unique(label); % non-empty components 34 if size(R,2) ~= size(u,2) 35 R = R(:,u); % remove empty components 36 else 37 converged = llh(t)-llh(t-1) < tol*abs(llh(t)); 38 end 39 40 end 41 llh = llh(2:t); 42 % if converged 43 % fprintf('Converged in %d steps.\n',t-1); 44 % llh = t-1; 45 % else 46 % fprintf('Not converged in %d steps.\n',maxiter); 47 % llh = maxiter; 48 % end
初始化函数:
这个函数很简单,没什么好解释的。
1 %% init 2 function R = initialization(X, init) 3 % 初始化一共用4中方式,一种是给定GMM 模型的参数初始值,一种是给定k 的个数,一种是给各sample 的标号,一种是给出类的中心点 4 [d,n] = size(X); 5 if isstruct(init) % initialize with a model 6 R = expectation(X,init); 7 elseif length(init) == 1 % random initialization 8 k = init; 9 idx = randsample(n,k); 10 m = X(:,idx); 11 [~,label] = max(bsxfun(@minus,m'*X,dot(m,m,1)'/2),[],1); 12 [u,~,label] = unique(label); 13 while k ~= length(u) 14 idx = randsample(n,k); 15 m = X(:,idx); 16 [~,label] = max(bsxfun(@minus,m'*X,dot(m,m,1)'/2),[],1); 17 [u,~,label] = unique(label); 18 end 19 R = full(sparse(1:n,label,1,n,k,n)); 20 elseif size(init,1) == 1 && size(init,2) == n % initialize with labels 21 label = init; 22 k = max(label); 23 R = full(sparse(1:n,label,1,n,k,n)); 24 elseif size(init,1) == d %initialize with only centers 25 k = size(init,2); 26 m = init; 27 [~,label] = max(bsxfun(@minus,m'*X,dot(m,m,1)'/2),[],1); 28 R = full(sparse(1:n,label,1,n,k,n)); 29 else 30 error('ERROR: init is not valid.'); 31 end
m-step函数:
1 %% m-step 2 function model = maximization(X, R) 3 [d,n] = size(X); 4 % k 为类个数 5 k = size(R,2); 6 % 各类的sample个数 7 nk = sum(R,1); 8 w = nk/n; 9 mu = bsxfun(@times, X*R, 1./nk); 10 11 Sigma = zeros(d,d,k); 12 % 这个值是为了下面计算时候得到R, 13 sqrtR = sqrt(R); 14 for i = 1:k 15 Xo = bsxfun(@minus,X,mu(:,i)); 16 Xo = bsxfun(@times,Xo,sqrtR(:,i)'); 17 Sigma(:,:,i) = Xo*Xo'/nk(i); 18 Sigma(:,:,i) = Sigma(:,:,i)+eye(d)*(1e-6); % add a prior for numerical stability 19 end 20 21 model.mu = mu; 22 model.Sigma = Sigma; 23 model.weight = w;
e-step:
%Gaussian posterior probability
%N(x|pMiu,pSigma) = 1/((2pi)^(D/2))*(1/(abs(sigma))^0.5)*exp(-1/2*(x-pMiu)'pSigma^(-1)*(x-pMiu))
1 %% e-step 2 function [R, llh] = expectation(X, model) 3 mu = model.mu; 4 Sigma = model.Sigma; 5 w = model.weight; 6 7 n = size(X,2); 8 k = size(mu,2); 9 logRho = zeros(n,k); 10 11 for i = 1:k 12 logRho(:,i) = loggausspdf(X,mu(:,i),Sigma(:,:,i)); 13 end 14 logRho = bsxfun(@plus,logRho,log(w)); 15 T = logsumexp(logRho,2); 16 llh = sum(T)/n; % loglikelihood 17 logR = bsxfun(@minus,logRho,T); 18 R = exp(logR); 19 %% log pdf 20 function y = loggausspdf(X, mu, Sigma) 21 22 d = size(X,1); 23 X = bsxfun(@minus,X,mu); 24 [U,p]= chol(Sigma); 25 if p ~= 0 26 error('ERROR: Sigma is not PD.'); 27 end 28 Q = U'\X; 29 q = dot(Q,Q,1); % quadratic term (M distance) 30 c = d*log(2*pi)+2*sum(log(diag(U))); % normalization constant 31 y = -(c+q)/2;