混合高斯模型&AIC-BIC挑选中心个数

实验室项目中要把数据按正态分布分成几类,但是有不知道有几类,估计不超过三类。然后就用了BIC准则选择类个数,效果出奇的好  哈哈哈哈

GMM初值对结果会有很大影响,按数据min,max均分正态的均值,然后用整体数据的方差作为初始方差  完美的解决了这个问题。可能是我们数据本身的原因。

研一学的总算用上一点  好开心

const static int MAX_ITERATOR = 1000;
const static double END_THR = 0.0001;
const static double SIM_THR = 0.2;
const static double PI = 3.14159265;
const static double EE  = 2.71828;

struct Gaussian{
    double mean, dalta;
    double weight;
    Gaussian(double m=0, double v=0, double w=1.0): mean(m), dalta(v), weight(w){
    }
    double getProbability(double x) const {
        return weight * std::pow(EE, -std::pow(x-mean, 2.0) / (2*dalta*dalta)) / ( std::pow(2*PI, 0.5) * dalta );
    }
    private:
    friend std::ostream& operator<<(std::ostream& os, const Gaussian & x);
};

std::ostream& operator<<(std::ostream& os, const Gaussian & x) {
    os << "mean: " << x.mean << " dalta: " << x.dalta << " weight: " << x.weight;
    return os;
}

class GMM {
public:
    void gmm(const std::vector & data, int mxCenter, std::vector< Gaussian > &re) {
        double BIC = DBL_MAX;
        std::vector< Gaussian > tmpResult;
        for(int i = 1; i <= mxCenter; ++i) {
            std::vector< Gaussian > tmp;
            double newBIC = fixCenterGmm(data, i, tmp);
            if( newBIC < BIC) {
                BIC = newBIC;
                tmpResult = tmp; 
            }
        }
        for(int i=0; i operator()(const std::pair &a, double x) {
            return std::make_pair(a.first + x, a.second + x*x);
        }
    };

    Gaussian getGaussian(const std::vector & data) {
        std::pair re = accumulate(data.begin(), data.end(), std::make_pair(0.0, 0.0), comp());
        return Gaussian(re.first / data.size(), std::pow( re.second / data.size() - std::pow(re.first / data.size(), 2.0), 0.5), 1.0);
    }
    double getDalta(const std::vector & data) {
        std::pair re = accumulate(data.begin(), data.end(), std::make_pair(0.0, 0.0), comp());
        return std::pow( re.second / data.size() - std::pow(re.first / data.size(), 2.0), 0.5);
    }

    double fixCenterGmm(const std::vector & data, int centers, std::vector< Gaussian > &re ) {
        if( centers <= 1 ) {
            re.push_back( getGaussian(data) );
            return caculateBIC(data, re);
        }
        double mx = *max_element(data.begin(), data.end());
        double mn = *min_element(data.begin(), data.end());
        double diff = mx - mn;
        double dalta = getDalta(data);
        for(int i = 0; i < centers; ++i) {
            re.push_back( Gaussian(mn + i*diff/(centers-1), dalta, 1.0 / centers) );
        }
        std::vector< std::vector > beta( data.size(), std::vector(centers, 0.0) );
        std::vector< Gaussian > tmp(centers, Gaussian() );
        int itera = 0;
        while( itera++ < MAX_ITERATOR && !ok(re, tmp) ) {
            tmp = re;
            for(int i=0; i& re, const std::vector< Gaussian >& tmp) {
        double diff = 0.0;
        double sum = 0.0;
        for(int i=0; i &data, const std::vector< Gaussian >& gau) {
        double BIC = (2 * gau.size() ) * log( data.size() );
        for(int i=0; i


你可能感兴趣的:(工程)