遗传算法解决PlayTennis问题

机器学习课程结束,作为遗传算法小白,简单实现了遗传算法解决PlayTennis问题,参考汤姆.米切尔的《机器学习》。本人认为这个问题最大的难点在于如何设计适应度函数。

首先把天气状况表示为二进制位串

Outlook

Sunny

100

Overcast

010

Rain

001

Humidity

High

10

Normal

01

Wind

Strong

10

Weak

01

PlayTennis

Yes

1

No

0

遗传算法原型:

GA(Fitness,Fitness_threshold,p,r,m)

             Fitness:适应度评分函数,为给定假设赋予一个评估分数

             Fitness_Threshold:指定终止判据的阈值

             p:群体中包含的假设数量

             r:每一步中通过交叉取代群体成员的比例

             m:变异率

  •  初始化群体:p(随机产生的p个假设)
  • 评估:对于p中的每个h,计算Fitness(h)
  • 当[maxFitness(h)]< Fitness_Threshold,做:

产生新一代Ps:

(1)  选择:用概率方法选择p的(1-r)p个成员加入Ps。

(2)  交叉:按概率选择r*p/2对假设,对每对假设,应用交叉算子产生两个后代。把所有的后代加入Ps。

(3)  变异:使用均匀的概率从Ps中选择m%的成员。对于选出的每个成员,在它的表示中随机选择一个位取反。

(4)  更新p<-Ps

  • 从p中返回适应度最高的假设

主要代码public class Main {

      static List pool = new ArrayList<>();//假设池

      static List newPool = new ArrayList<>();//子代

      static int GEN = 5;//一共演化5代

      static float p_c = (float) 0.5;//交叉率

      static float p_y = (float) 0.5;//变异率

         

      public static void main(String[] args) {

             Random rand = new Random();

             //初始化

             for(int i=0;i<300;i++){

                    String code = "";

                    for(int j=0;j<8;j++)

                           code += rand.nextInt(2);

                    pool.add(code);

             }

             System.out.println("--------初始化--------");

             print(pool);

             //演化

             for(int i=0;i

                    System.out.println("--------演化"+i+"--------");

                    evolution(pool);

             }

             //得到最佳假设

             String bestChoise = pool.get(0);

             float bestFitness = calculateFitness(pool.get(0));

             for(int i=0;i

                    if(calculateFitness(pool.get(i))

                           bestChoise = pool.get(i+1);

                           bestFitness =calculateFitness(pool.get(i+1));

                    }

             }

             System.out.println("--------最佳假设--------");

             System.out.println(bestChoise);

             System.out.println("--------适应度--------");

             System.out.println(bestFitness);

      }

     

      /**

       * 适应度计算

       * @param code

       * @return

       */

      private static float calculateFitness(String code) {

             int accCount = 0;

             int wrongCount = 0;

             for (String data : trainDatas) {

                    int equalCount = 0;

                    for(int i=0;i

                           if(data.charAt(i)=='1' &&code.charAt(i)=='1')

                                  equalCount++;

                    }

                    if(equalCount==3)

                           if(data.charAt(data.length()-1)==code.charAt(code.length()-1))

                                  accCount++;

                           else

                                  wrongCount++;

             }

             if(accCount == 0)

                    return 0;

             float acc = (float)accCount/(accCount+wrongCount);

             if(code.charAt(code.length()-1)=='0'){

                    float recall = (float)accCount/4;

                    return 2*acc*recall/(acc+recall);

             }else{

                    float recall = (float)accCount/8;

                    return 2*acc*recall/(acc+recall);

             }

      }

      /**

       * 演化

       * @param pool

       */

      private static void evolution(List pool){

             //计算适应度

             Random rand = new Random();

             List fitness = new ArrayList<>();

             fitness.add(0f);

             for(int i=0;i

                    fitness.add(

                           fitness.get(i)+calculateFitness(pool.get(i))

                    );

             }

             fitness.remove(0);

             //轮盘

             for(int i=0;i<20;i++){//选20个

                    float pick = rand.nextFloat()*fitness.get(fitness.size()-1);

                    boolean selectDone = false;

                    float selectFit = 0;

                    for(int j=0; j

                           if(selectDone)

                                  fitness.set(j,fitness.get(j)-selectFit);

                           if(!selectDone &&pick<=fitness.get(j)){

                                  newPool.add(pool.get(j));

                                  if(j==0){

                                         selectFit = fitness.get(j);

                                  }else{

                                         selectFit = fitness.get(j) -fitness.get(j-1);

                                  }

                                  fitness.remove(j);

                                  pool.remove(j);

                                  j--;

                                  selectDone = true;

                           }

                          

                    }

             }

             pool.clear();

             pool.addAll(newPool);

             newPool.clear();

             System.out.println("--------轮盘赌--------");

             print(pool);

             //交叉

             for(int i=0;i

                    for(int j=i+1;j

                           if(pool.get(i).charAt(7)==pool.get(j).charAt(7)){

                                  cross(pool.get(i),pool.get(j));

                           }else{

                                  newPool.add(pool.get(i));

                                  newPool.add(pool.get(j));

                           }

                    }

             }

             pool.clear();

             pool.addAll(newPool);

             newPool.clear();

             System.out.println("--------交叉--------");

             print(pool);

             //变异

             for(int i=0;i

                    float v = rand.nextFloat();

                    if(v

                           variation(pool.get(i));

                    }else{

                           newPool.add(pool.get(i));

                    }

             }

             pool.clear();

             pool.addAll(newPool);

             newPool.clear();

             System.out.println("--------变异--------");

             print(pool);

      }

      /**

       * 交叉

       * @param parent1

       * @param parent2

       */

      private static void cross(String parent1,String parent2){

             Random rand = new Random();

             StringBuilder child1 = new StringBuilder();

             StringBuilder child2 = new StringBuilder();

             float cro = rand.nextFloat();

             if(cro>p_c){

                    child1.append(parent2.substring(0,3));

                    child2.append(parent1.substring(0,3));

             }else{

                    child1.append(parent1.substring(0,3));

                    child2.append(parent2.substring(0,3));

             }

             cro = rand.nextFloat();

             if(cro>p_c){

                    child1.append(parent2.substring(3,5));

                    child2.append(parent1.substring(3,5));

             }else{

                    child1.append(parent1.substring(3,5));

                    child2.append(parent2.substring(3,5));

             }

             cro = rand.nextFloat();

             if(cro>p_c){

                    child1.append(parent2.substring(5,7));

                    child2.append(parent1.substring(5,7));

             }else{

                    child1.append(parent1.substring(5,7));

                    child2.append(parent2.substring(5,7));

             }

             cro = rand.nextFloat();

             if(cro>p_c){

                    child1.append(parent2.substring(7,8));

                    child2.append(parent1.substring(7,8));

             }else{

                    child1.append(parent1.substring(7,8));

                    child2.append(parent2.substring(7,8));

             }

             newPool.add(child1.toString());

             newPool.add(child2.toString());

      }

      /**

       * 变异,采用点变异

       * @param children

       */

      private static void variation(String parent){

             Random rand = new Random();

             int location = rand.nextInt(parent.length()-2);//变异位置,对最后一位不进行变异

             char v = parent.charAt(location);

             String children;

             if(v == 0){

                    children = parent.substring(0,location)+"1"+parent.substring(location+1, parent.length());

             }else{

                    children = parent.substring(0,location)+"0"+parent.substring(location+1, parent.length());

             }

             newPool.add(children);

      }

      /**

       * 输出

       * @param pool

       */

      private static void print(List pool){

             for(int i=0;i

                    System.out.println(pool.get(i));

             }

      }

      private static String[] trainDatas = {

                    "10010010",

                    "10010100",

                    "10001011",

                    "10001101",

                    "01001011",

                    "01010011",

                    "01001101",

                    "01010101",

                    "00101100",

                    "00110100",

                    "00101011",

                    "00110011",

      };

}

 


你可能感兴趣的:(统计学习方法)