EM算法求解混合伯努利模型

本文记录了EM算法求解混合伯努利分布的推导,并提供了matlab实验代码。

混合伯努利分布

单个 D 维伯努利分布的分布率为:

p(x|μ)=d=1Dμxdd(1μd)1xd,

这里 x=(x1,,xD)T D 维0-1向量, μ=(μ1,,μD)T 为对应维度上事件发生的概率。
混合伯努利分布是指由 K 个单个 D 维伯努利分布混合构成的分布,分布率如下:
p(x|μ,π)=k=1Kπkp(x|μk),

这里 μ={μ1,,μK},π={π1,,πK} ,
p(x|μk)=d=1Dμxdkd(1μkd)1xd.

EM算法混合伯努利分布

现在我们有一个来自于混合伯努利分布的数据集 X={x1,,xN} ,我们要最大化其似然函数:

maxμ,πlnp(X|μ,π)=n=1Nln{k=1Kπkp(xn|μk)}.

由于 ln 函数中的求和项,无法得到其闭式解,我们用EM算法求其一个局部数值解。关于EM算法的内容见上一篇博客。对于样本 xn ,我们构造一个隐变量 zn 用来表示 xn 来自于哪个模型。也即, zn=k 当且仅当 xn 来自于第 k 个伯努利模型。令 Z={z1,,zn} ,则 X,Z 联合分布为:
p(X,Z|μ,π)=n=1Nπznp(xn|μzn).

M步我们要求解
maxμ,πZp(Z|X,μ¯,π¯)lnp(X,Z|μ,π)=n=1Nk=1Kp(k|xn,μ¯,π¯)lnp(xn|μk,πk)(1)

观察到 p(k|xn,μ¯,π¯) 可看做关于 n,k 的常数,因此令
γnk=p(k|xn,μ¯,π¯)=πkp(xn|μk)i=1Kπip(xn|μi).

则式(1)中的优化问题可以进一步写作
maxμ,π=f(π,μ)=s.t.n=1Nk=1Kγnk[lnπk+xTnlnμk+(1xn)Tln(1μk)]k=1Kπk=1,0<μk<1(k=1,,K).

这里 lnx 表示对向量 x 逐元素操作。对 μk 求偏导并令其等于0,
μk=n=1Nγnk(xnμk1xn1μk)=n=1Nγnk(xnμk)μk(1μk)=0.

这里 xy xy 分别表示两个向量的表示逐元素除法和逐元素乘法。解得
μk=1Nkn=1Nγnkxn,

其中 Nk=n=1Nγnk 。 为了求解 π ,构造拉格朗日函数
L(π,λ)=f(π,μ)+λ(k=1Kπk1),

利用KKT条件得
πkL(π,λ)=1πkn=1Nγnk+λ=0k=1Kπk=1.k=1,,K

利用 n=1Nk=1Kγnk=N 解得
πk=NkN.

实验部分

这里我们利用MINIST手写数字数据集进行实验,这里有处理好的matlab数据(60000*785矩阵,每一行为一个样本,其中28*28=784为灰度图片拉成的行向量,最后一列为对应数字标签)。将灰度向量二值化就得到了0-1向量,我们把这个0-1向量当做由混合伯努利分布生成的一个样本,并且希望最后学习到的混合分布中的不同模型对应不同的数字向量概率分布。由于计算后验概率分布时涉及到概率的连乘操作,要注意精度问题。

EM算法求解混合伯努利模型_第1张图片

function[]=main()
    RR = 1;        
    CC = 3;         
    R = 10;
    C = 10;
    r = 28;
    c = 28;
    Eps = 1e-5;
    scale = 0.8;        %每个digit图片的初始大小为28*28,这里为了计算效率和精度长宽都缩放为0.8倍
    digit = [1 3 4];    %选用的digits
    K = RR*CC;          %混合的模型个数
    N = R*C*K;          %总共样本个数,即每个digit有R*C个样本
    D = r*c;
    load('Data.mat');   %导入MNIST数据集,Data为60000*785的矩阵,每一行对应一个样本,样本最后一个维度
    label = Data(:,D+1);

    index = zeros(N,1);
    for k = 1:K;
        t = find(label==digit(k));
        index((k-1)*R*C+1:k*R*C) = (t:t+R*C-1)';
    end
    X = Data(index, 1:D);
    tX = zeros(N, ceil(r*scale)*ceil(c*scale));
    for n = 1:N %缩放并二值化
        t = imresize(reshape(X(n,:), r, c), [ceil(r*scale) ceil(c*scale)]);
        tX(n, :) = reshape(t, 1, ceil(r*scale)*ceil(c*scale));
        tX(n,:) = im2bw(tX(n, :), graythresh(tX(n,:)));
    end
    X = tX;
    r = ceil(r*scale);
    c = ceil(c*scale);
    D = r*c;


    Pi = rand(K, 1);
    Pi = Pi/(ones(1, K)*Pi);
    Mu = rand(K, D);
    logLikeBar = calLogLikeBound(X, K, Pi, Mu);
    fprintf('Initial LogLikelihoodBound: %.6f\n', logLikeBar);

    cnt = 0;
    Gamma = zeros(N, K);
    while true
        %E step:
         for n=1:N
            t = 0;
            %tic;
            for k = 1:K
                digits(100);
                Gamma(n, k) = vpa(exp(vpa(X(n, :))*log(Mu(k, :)')+vpa(ones(1,D)-X(n,:))*log(ones(D,1)-Mu(k, :)'))); %这里用了ln变换和高精度运算
                t = t+ Gamma(n, k);
            end
            Gamma(n,:) = Gamma(n,:)/t;
            %toc;
         end
         Nk = (ones(1, N)*Gamma)';

         %M step:
         Mu = (X'*Gamma*(diag(Nk))^(-1))';
         Pi = Nk/N;
         Mu(Mu <= 0) = 1e-5;
         Mu(Mu >= 1) = 1-1e-5;

         cnt = cnt+1;
         logLike = calLogLikeBound(X, K, Pi, Mu);
         relGap = abs(logLike-logLikeBar)/abs(logLikeBar);
         disp(['Iteration: ', num2str(cnt), ' RelativeGap: ', sprintf('%.6f LogLikelihoodBound: %.6f', relGap, logLike)]);
         logLikeBar = logLike;
         label = zeros(N, 1);
         for k = 1:K
            label(Gamma(:,k)'==max(Gamma'))=k;
         end
         showRes(RR, CC, R, C, r, c, X, cnt, label, lambda);

        if relGap < Eps
            break;
        end
    end

end
function [logLike] = calLogLikeBound(X, K, Pi, Mu) %计算对数似然函数值
    logLike = 0;
    N = size(X,1);
    D = size(X,2);
    for n =1:N
        t = exp(X(n,:)*log(Mu')+(ones(1, D)-X(n, :))*log(ones(D, K)-Mu'))';
        logLike = logLike+log(Pi'*t);
    end
end

function [] = showRes(RR, CC, R, C, r, c, X, cnt, label, lambda) %绘制图片
    lambda = 0.35;  %控制图片颜色程度
    colorNum = 20;
    mycolor = colorcube(colorNum);
    for II = 1:RR
        for JJ = 1:CC
           for I = 1:R
               for J = 1:C
                   t1 = zeros(r, c, 3);
                   t2 = (II-1)*CC*(R*C)+(JJ-1)*(R*C)+(I-1)*C+J;
                   for i = 1:r
                        t1(i,:,1)  = X(t2, (i-1)*c+1:i*c);
                        t1(i,:,2) = t1(i,:,1);
                        t1(i,:,3) = t1(i,:,1);
                   end
                   t1(:,:,1) = t1(:,:,1)*lambda+mycolor(mod(label(t2)-1, colorNum)+1, 1)*(1-lambda);
                   t1(:,:,2) = t1(:,:,2)*lambda+mycolor(mod(label(t2)-1, colorNum)+1, 2)*(1-lambda);
                   t1(:,:,3) = t1(:,:,3)*lambda+mycolor(mod(label(t2)-1, colorNum)+1, 3)*(1-lambda);
                   tr = (II-1)*R*r+(I-1)*r;
                   tc = (JJ-1)*C*c+(J-1)*c;
                   RGBMat(tr+1:tr+r, tc+1:tc+c,:)=t1;
               end
           end
        end
    end
    figure(cnt);
    imshow(RGBMat);
    saveas(gcf,sprintf('%d.jpg', cnt));
end

参考文献

Christopher M.. Bishop. Pattern recognition and machine learning. pp. 423-450

你可能感兴趣的:(实验室学习,读论文)