机器学习课程结束,作为遗传算法小白,简单实现了遗传算法解决PlayTennis问题,参考汤姆.米切尔的《机器学习》。本人认为这个问题最大的难点在于如何设计适应度函数。
首先把天气状况表示为二进制位串
Outlook |
Sunny |
100 |
Overcast |
010 |
|
Rain |
001 |
|
Humidity |
High |
10 |
Normal |
01 |
|
Wind |
Strong |
10 |
Weak |
01 |
|
PlayTennis |
Yes |
1 |
No |
0 |
GA(Fitness,Fitness_threshold,p,r,m)
Fitness:适应度评分函数,为给定假设赋予一个评估分数
Fitness_Threshold:指定终止判据的阈值
p:群体中包含的假设数量
r:每一步中通过交叉取代群体成员的比例
m:变异率
产生新一代Ps:
(1) 选择:用概率方法选择p的(1-r)p个成员加入Ps。
(2) 交叉:按概率选择r*p/2对假设,对每对假设
(3) 变异:使用均匀的概率从Ps中选择m%的成员。对于选出的每个成员,在它的表示中随机选择一个位取反。
(4) 更新p<-Ps
主要代码public class Main {
static List
static List
static int GEN = 5;//一共演化5代
static float p_c = (float) 0.5;//交叉率
static float p_y = (float) 0.5;//变异率
public static void main(String[] args) {
Random rand = new Random();
//初始化
for(int i=0;i<300;i++){
String code = "";
for(int j=0;j<8;j++)
code += rand.nextInt(2);
pool.add(code);
}
System.out.println("--------初始化--------");
print(pool);
//演化
for(int i=0;i System.out.println("--------演化"+i+"--------"); evolution(pool); } //得到最佳假设 String bestChoise = pool.get(0); float bestFitness = calculateFitness(pool.get(0)); for(int i=0;i if(calculateFitness(pool.get(i)) bestChoise = pool.get(i+1); bestFitness =calculateFitness(pool.get(i+1)); } } System.out.println("--------最佳假设--------"); System.out.println(bestChoise); System.out.println("--------适应度--------"); System.out.println(bestFitness); } /** * 适应度计算 * @param code * @return */ private static float calculateFitness(String code) { int accCount = 0; int wrongCount = 0; for (String data : trainDatas) { int equalCount = 0; for(int i=0;i if(data.charAt(i)=='1' &&code.charAt(i)=='1') equalCount++; } if(equalCount==3) if(data.charAt(data.length()-1)==code.charAt(code.length()-1)) accCount++; else wrongCount++; } if(accCount == 0) return 0; float acc = (float)accCount/(accCount+wrongCount); if(code.charAt(code.length()-1)=='0'){ float recall = (float)accCount/4; return 2*acc*recall/(acc+recall); }else{ float recall = (float)accCount/8; return 2*acc*recall/(acc+recall); } } /** * 演化 * @param pool */ private static void evolution(List //计算适应度 Random rand = new Random(); List fitness.add(0f); for(int i=0;i fitness.add( fitness.get(i)+calculateFitness(pool.get(i)) ); } fitness.remove(0); //轮盘 for(int i=0;i<20;i++){//选20个 float pick = rand.nextFloat()*fitness.get(fitness.size()-1); boolean selectDone = false; float selectFit = 0; for(int j=0; j if(selectDone) fitness.set(j,fitness.get(j)-selectFit); if(!selectDone &&pick<=fitness.get(j)){ newPool.add(pool.get(j)); if(j==0){ selectFit = fitness.get(j); }else{ selectFit = fitness.get(j) -fitness.get(j-1); } fitness.remove(j); pool.remove(j); j--; selectDone = true; } } } pool.clear(); pool.addAll(newPool); newPool.clear(); System.out.println("--------轮盘赌--------"); print(pool); //交叉 for(int i=0;i for(int j=i+1;j if(pool.get(i).charAt(7)==pool.get(j).charAt(7)){ cross(pool.get(i),pool.get(j)); }else{ newPool.add(pool.get(i)); newPool.add(pool.get(j)); } } } pool.clear(); pool.addAll(newPool); newPool.clear(); System.out.println("--------交叉--------"); print(pool); //变异 for(int i=0;i float v = rand.nextFloat(); if(v variation(pool.get(i)); }else{ newPool.add(pool.get(i)); } } pool.clear(); pool.addAll(newPool); newPool.clear(); System.out.println("--------变异--------"); print(pool); } /** * 交叉 * @param parent1 * @param parent2 */ private static void cross(String parent1,String parent2){ Random rand = new Random(); StringBuilder child1 = new StringBuilder(); StringBuilder child2 = new StringBuilder(); float cro = rand.nextFloat(); if(cro>p_c){ child1.append(parent2.substring(0,3)); child2.append(parent1.substring(0,3)); }else{ child1.append(parent1.substring(0,3)); child2.append(parent2.substring(0,3)); } cro = rand.nextFloat(); if(cro>p_c){ child1.append(parent2.substring(3,5)); child2.append(parent1.substring(3,5)); }else{ child1.append(parent1.substring(3,5)); child2.append(parent2.substring(3,5)); } cro = rand.nextFloat(); if(cro>p_c){ child1.append(parent2.substring(5,7)); child2.append(parent1.substring(5,7)); }else{ child1.append(parent1.substring(5,7)); child2.append(parent2.substring(5,7)); } cro = rand.nextFloat(); if(cro>p_c){ child1.append(parent2.substring(7,8)); child2.append(parent1.substring(7,8)); }else{ child1.append(parent1.substring(7,8)); child2.append(parent2.substring(7,8)); } newPool.add(child1.toString()); newPool.add(child2.toString()); } /** * 变异,采用点变异 * @param children */ private static void variation(String parent){ Random rand = new Random(); int location = rand.nextInt(parent.length()-2);//变异位置,对最后一位不进行变异 char v = parent.charAt(location); String children; if(v == 0){ children = parent.substring(0,location)+"1"+parent.substring(location+1, parent.length()); }else{ children = parent.substring(0,location)+"0"+parent.substring(location+1, parent.length()); } newPool.add(children); } /** * 输出 * @param pool */ private static void print(List for(int i=0;i System.out.println(pool.get(i)); } } private static String[] trainDatas = { "10010010", "10010100", "10001011", "10001101", "01001011", "01010011", "01001101", "01010101", "00101100", "00110100", "00101011", "00110011", };