AdaBoost算法和java实现

AdaBoost算法和java实现


算法描述

输入:训练数据集这里写图片描述,其中 xi χ Rn , yi {-1,+1};弱学习算法;
输出:最终分类器G(x)。

  1. 初始化训练集数据的权值分布
    D1 =( w11 ,…, wiN ), w1i =1/N, i=1,2…,N

  2. 对m=1,2,…,M

    • (a)使用具有权值分布 Dm 的训练数据集学习,得到基本分类器
      Gm(x):χ> {-1,+1}

    • (b) 计算 Gm(x) 在训练数据集上的分类误差率
      em= P(Gm(xi)yi)=Ni=1wmiI(Gm(xi)yi)


    • (c) 计算 Gx 的系数
      αm=12log1emem 这里的对数是自然对数。

  • (d)更新训练数据集的权值分布
    Dm+1=(wm+1,1,...wm+1,N)

    wm+1,i=wmiZmexp(αmyiGm(xi)),i=1,2,...,N
    , Zm 是规范化因子

    Zm=Ni=1wmiexp(αmyiGm(xi))
    它是 Dm+1 成为一个概率分布。


3. 构建基本分类器的线性组合

f(x)= Mm+1αmGm(x)

得到最终分类器

G(x)=sign(f(x))=sign(Mm=1αmGm(x))


举例说明

数据如下
这里写图片描述
当m=1时,
根据以上的公式有 D1 =( w1i,w2i,...,w2i ), w1i=0.1,i=1,2,...,10 然后在权值分布为 D1 的训练数据集上,阈值v取2.5时分类的误差率最低,故分类器为
注意
在训练集上的误差率 e1 =3*0.1(3表示有三个分类错误的数据,0.1对应权值数组 D1 上的值)

按照(c)中的公式据算 α1=12log1e1e1 =0.4236

更新数据的权值分布:
D2 =(0.07143,0.07143,0.07143,0.07143,0.07143,0.07143,0.16667,0.16667,0.16667,0.07143)()大家可以发现被错误分类的点的权值被加大了
f1(x)=α1G1(x) =0.4236 G1(x)
分类器sign[ f1(x) ]在训练数据集上有三个错误分类点。


当m=2时,
-在权值分布为 D2 的训练数据集上 ,阈值v是8.5时分类误差率最低,基本分类器为

这里写图片描述
- G2(x) 在训练数据集上的误差率 e2 =0.2143
- 计算 α2 =0.6496
-更新训练数据集权值分布:
D3 =(0.455,0.455,0.455,0.1667,0.1667,0.1667,0.1060,0.1060,0.1060,0.0455)
f2(x) =0.4236 G1(x)+0.6496G2(x)
分类器sign[ f2(x) ]在训练数据集上有三个错误分类点。


当m=3时,
-在权值分布为 D2 的训练数据集上 ,阈值v是8.5时分类误差率最低,基本分类器为

这里写图片描述
- G2(x) 在训练数据集上的误差率 e3 =0.1820
- 计算 α3 =0.7514
-更新训练数据集权值分布:
D3 =(0.125,0.125,0.125,0.102,0.102,0.102,0.065,0.065,0.065,0.125)
f2(x) =0.4236 G1(x)+0.6496G2(x)+0.7514G3(x)
分类器sign[ f2(x) ]在训练数据集上有0个错误分类点。
故: G(x)=sign[f3(x)]=sign[0.4236G1(x)+0.6496G2(x)+0.7514G3(x)]


import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;

public class Test08 {
    public  ArrayList list=new ArrayList();
    public static final double k = 0.5;
    public static void main(String[] args) {
        // TODO Auto-generated method stub
        Test08 test=new Test08();
        Map map = new HashMap();
        map.put(0, 1);
        map.put(1, 1);
        map.put(2, 1);
        map.put(7, 1);
        map.put(5, -1);
        map.put(6, 1);
        map.put(8, 1);
        map.put(9, -1);
        map.put(3, -1);
        map.put(4, -1);
        System.out.println(test.adaBoost(test.sortMapByKey(map)));
    }

    public TreeMap  sortMapByKey(Map oriMap) {
        if (oriMap == null || oriMap.isEmpty()) {
            return null;}
        TreeMap sortedMap = new TreeMap(new Comparator() {
            public int compare(Integer o1, Integer o2) {
                // 如果有空值,直接返回0
                if (o1 == null || o2 == null)
                    return 0;
                return String.valueOf(o1).compareTo(String.valueOf(o2));
            }
        });
        sortedMap.putAll(oriMap);
        return sortedMap;
    }

    public Map adaBoost(TreeMap data) {
        Map result=new HashMap();
        int dataLenght=data.size();
        double[] weight = new double[dataLenght];
        //初始化权值数组
        for (int i = 0; i < dataLenght; i++) {
            weight[i]=1.0/dataLenght;
        }

        double grade1 = 0;
        double grade2 = 0;
        //double flag = 0;
        String f=null;
        double current=0;
        double ah=0;
        double low=data.firstKey();//选取最小的特征值
        double high=data.lastKey();//选取最大的特征值
        //迭代50次
        for(int it=0;it<50;it++){
            double min=1000;
            double flag=low;//用来标记比较优的特征的值
            while(flag<=high){
                int index = 0;// 用来索引权值数组
                grade1=0;
                grade2=0;
                for (Integer en : data.keySet()) {
                    //大于某一个特征值则为一时
                    if(GreatToOne(en, flag)!=data.get(en)){
                        grade1+=weight[index];
                    }
                    //小于某一个特征值则为一时
                    if(LessToOne(en, flag)!=data.get(en)){
                        grade2+=weight[index];
                    }   
                    index++;
                }
                //选取最优的特征值
                if (grade1 < min) {
                    min = grade1;
                    current = flag;
                    f="great";//用来标记采用的哪一个函数(GreatToOne or LessToOne)
                }
                if(grade2"less";
                }
                flag+=k;//将用来分类的特征值增加k
            }
            ah=0.5*Math.log((1-min)/min);
            double totle=0;
            int j=0;
            //
            for(Integer en:data.keySet()){
                if(f.equals("great")){
                    totle+=weight[j++]*Math.exp(-ah*data.get(en)*GreatToOne(en,current));
                }
                else{
                    totle+=weight[j++]*Math.exp(-ah*data.get(en)*LessToOne(en,current));
                }

            }
            j=0;
            for(Integer en:data.keySet()){
                if(f.equals("great")){
                    weight[j]=weight[j]*Math.exp(-ah*data.get(en)*GreatToOne(en,current))/totle;
                }
                else{
                    weight[j]=weight[j]*Math.exp(-ah*data.get(en)*LessToOne(en,current))/totle;
                }
                j++;
            }
        result.put(ah, current);
        list.add(f);
        //错误率为零,则退出
        if(calc(result,data)==0) break;
        }
    return result;
    }
    private int calc(Map result, TreeMap data) {
        // TODO Auto-generated method stub
        int count=0;

        for(Integer en:data.keySet()){
            double sum=0;int index=0;
            for(Double d:result.keySet()){
                if(list.get(index).equals("great")){
                    sum+=d*GreatToOne(en,result.get(d));
                }
                else{
                    sum+=d*LessToOne(en,result.get(d));
                }   
                index++;            
            }
            if(sum>0&&data.get(en)==-1) {
                count++;
            }
            if(sum<0&&data.get(en)==1){
                count++;
            }

        }
        if(count==0){
            return 0;
            }
        else{
            return 1;
        }
    }

    public int GreatToOne(int x,double flag){
        if(x>flag) {
            return 1;
        }else{
            return -1;
        }
    }
    public int LessToOne(int x,double flag){
        if(xreturn 1;
        }else{
            return -1;
        }
    }
}

结果如下:
这里写图片描述

统计学习方法(李航)

你可能感兴趣的:(机器学习)