说明:算法来自于《集体智慧编程》-第五章
原书代码用 Python 实现,这两天看这章书,改用 Java 实现。
问题描述:Glass 一家六人在全国各地c,要到 LGA 碰头聚会。求花费最少的解法。
和原书代码意思不同的:计算增加了旅途中时间,0.5/h
/** * * FILENAME: Optimization.java * AUTHOR: vivizhyy[at]gmail.com * STUID: whu * DATE: 2010-4-12 * USAGE : */ package ch5.optimization; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Random; import org.apache.log4j.Logger; import org.apache.log4j.PropertyConfigurator; import org.joda.time.DateTime; import org.joda.time.LocalTime; public class Optimization { private HashMap<String, String> people = new HashMap<String, String>(); private static String[] family = {"Seymour", "Pranny", "Zooey", "Wait", "Buddy", "Les"}; private String destination = "LGA"; private Flights flights = new Flights(); private static final int MEMBER_NUM = 6; private Logger log = Logger.getLogger(Optimization.class.getName()); private void initPeople() { this.people.put("Seymour", "BOS"); this.people.put("Pranny", "DAL"); this.people.put("Zooey", "CAK"); this.people.put("Wait", "MIA"); this.people.put("Buddy", "ORD"); this.people.put("Les", "OMA"); } /** * * @param times */ public void printSchedule(int[] times) { StringBuilder scheduleResult = new StringBuilder(); int index = 0; for (String mem : family) { scheduleResult.append(mem + "\t" + people.get(mem) + "\t"); Flights f = flights.getFlightByOriginAndDest(people.get(mem), destination)[times[index * 2]]; scheduleResult.append(f.getDepart() + "-" + f.getArraive() + "\t$" + f.getPrice() + "\t"); f = flights.getFlightByOriginAndDest(destination, people.get(mem))[times[index * 2 + 1]]; scheduleResult.append(f.getDepart() + "-" + f.getArraive() + "\t$" + f.getPrice() + "\n"); index++; } System.out.println(scheduleResult); } /** * * @param t * @return */ public int getMinutes(LocalTime t) { return (t.getMinuteOfHour() + t.getHourOfDay() * 60); } /** * * @param sol * @return */ public double scheduleCost(int[] sol) { PropertyConfigurator.configure("D:/Documents/NetBeansProjects/CollectiveProgramming/src/log4j.properties"); double totalPrice = 0.0; int lastArrival = 0; int earliestDep = 24 * 60; int totalTravel = 0; Flights[] outBound = new Flights[MEMBER_NUM * 2]; Flights[] returnFlight = new Flights[MEMBER_NUM * 2]; for (int i = 0; i < MEMBER_NUM; i++) { //得到往返航班 outBound[i] = flights.getFlightByOriginAndDest(people.get(family[i]), destination)[sol[i * 2]]; returnFlight[i] = flights.getFlightByOriginAndDest(destination, people.get(family[i]))[sol[i * 2 + 1]]; //log.info("price:" + outBound[i].getPrice() + "\t" + returnFlight[i].getPrice()); //加航班价格 totalPrice += outBound[i].getPrice(); totalPrice += returnFlight[i].getPrice(); //加旅行时间 totalTravel += getMinutes(outBound[i].getArraive()) - getMinutes(outBound[i].getDepart()); totalTravel += getMinutes(returnFlight[i].getArraive()) - getMinutes(returnFlight[i].getDepart()); //记录最晚到达时间和最早离开时间 if (lastArrival < getMinutes(outBound[i].getArraive())) { lastArrival = getMinutes(outBound[i].getArraive()); } if (earliestDep > getMinutes(returnFlight[i].getDepart())) { earliestDep = getMinutes(returnFlight[i].getDepart()); } } int totalWait = 0; for (int i = 0; i < MEMBER_NUM; i++) { totalWait += lastArrival - getMinutes(outBound[i].getArraive()); totalWait += getMinutes(returnFlight[i].getDepart()) - earliestDep; } //要多付一天的汽车租用金吗? if (lastArrival > earliestDep) { totalPrice += 50; } totalPrice = totalPrice + totalWait + totalTravel * 0.5; return totalPrice; } /** * 随机在 <code>loopTimes</code> 次中找出最值 * @param loopTimes * @return */ public int[] randomOptimize(int loopTimes) { int[] bestr = new int[MEMBER_NUM * 2]; double best = 999999999; for (int i = 0; i < loopTimes; i++) { int[] r = randomResult(); double price = scheduleCost(r); if (best > price) { best = price; bestr = r; } } System.out.println("total cost: " + best); return bestr; } /** * 爬山法 * * @return */ public int[] hillclimb() { int[] sol = randomResult(); double best = 999999999; int count = 0; while (true) { count++; int[][] neighbors = new int[MEMBER_NUM * 2][MEMBER_NUM * 2]; for (int j = 0; j < MEMBER_NUM * 2; j = j + 2) { if (sol[j] > 0) { for (int m = 0; m < MEMBER_NUM * 2; m++) { if (m == j && sol[j] <= 9) { neighbors[j][m] = sol[m] + 1; } else { neighbors[j][m] = sol[m]; } } } if (sol[j] <= 9) { for (int m = 0; m < MEMBER_NUM * 2; m++) { if (m == j && sol[j] != 0) { neighbors[j + 1][j] = sol[m] - 1; } else { neighbors[j + 1][m] = sol[m]; } } } } double currentCost = scheduleCost(sol); for (int m = 0; m < MEMBER_NUM; m++) { double cost = scheduleCost(neighbors[m]); //System.out.println("cost: " + cost); if (cost < best) { best = cost; sol = neighbors[m]; } } if (best == currentCost) { System.out.println("best: " + best); System.out.println("loop: " + count); return sol; } } } /** * 退火算法 * * @param T * @param cool * @param step * @return */ public int[] annealingoptimize(double T, double cool, int step) { int[] vec = randomResult(); long seed = System.nanoTime(); Random random = new Random(seed); while (T > 0.1) { int index = random.nextInt(MEMBER_NUM * 2 - 1); int dir = (int) (random.nextInt(step) * Math.pow(-1, random.nextInt())); // System.out.println("dir: " + dir); int[] vecb = new int[MEMBER_NUM * 2]; for (int i = 0; i < MEMBER_NUM * 2; i++) { if (i == index) { if (vec[i] + dir < 0) { vecb[i] = 0; continue; } if (vecb[i] > 9) { vecb[i] = 9; continue; } vecb[i] = vec[i] + dir; } else { vecb[i] = vec[i]; } } double ea = scheduleCost(vec); double eb = scheduleCost(vecb); if (eb < ea || random.nextDouble() < Math.pow(Math.E, -(eb - ea) / T)) { vec = vecb; } T *= cool; } System.out.println(scheduleCost(vec)); return vec; } public int[] geneticoptimize(int popSize, int step, double mutprob, double elite, int maxiter) { long seed = System.nanoTime(); Random random = new Random(seed); //构造初始种群 ArrayList<int[]> pop = new ArrayList<int[]>(popSize); for (int i = 0; i < popSize; i++) { pop.add(randomResult()); } //每一代胜出者数目 int topelite = (int) (elite * popSize); ArrayList<Score> scores = new ArrayList<Score>(); for (int j = 0; j < maxiter; j++) { for (int x = 0; x < popSize; x++) { Score s = new Score(); s.list = pop.get(x); s.price = scheduleCost(pop.get(x)); scores.add(s); } CompareScore cs = new CompareScore(); Collections.sort(scores, cs); ArrayList<int[]> ranked = new ArrayList<int[]>(); for (int m = 0; m < popSize; m++) { pop.remove(0); ranked.add(scores.get(m).list); } for (int n = 0; n < topelite; n++) { pop.add(ranked.get(n)); } while (pop.size() < popSize) { if (random.nextDouble() < mutprob) { //变异 int c = random.nextInt(topelite); pop.add(mutate(ranked.get(c), step)); } else { int c1 = random.nextInt(topelite); int c2 = random.nextInt(topelite); pop.add(crossOver(ranked.get(c1), ranked.get(c2))); } } System.out.println(scores.get(0).price); } return scores.get(0).list; } /** * 变异 * * @param vec * @param step * @return */ public int[] mutate(int[] vec, int step) { int[] mutateR = new int[MEMBER_NUM * 2]; long seed = System.nanoTime(); Random random = new Random(seed); int index = random.nextInt(MEMBER_NUM * 2 - 1); if (random.nextDouble() < 0.5 && vec[index] > 0) { for (int i = 0; i < MEMBER_NUM * 2; i++) { if (i == index && (vec[i] - step) > 0) { mutateR[i] = vec[i] - step; } else { mutateR[i] = vec[i]; } } } else { for (int i = 0; i < MEMBER_NUM * 2; i++) { if (i == index && (vec[i] + step) < 9) { mutateR[i] = vec[i] + step; } else { mutateR[i] = vec[i]; } } } return mutateR; } /** * 交叉 * * @param r1 * @param r2 * @return */ public int[] crossOver(int[] r1, int[] r2) { int[] crossOverR = new int[MEMBER_NUM * 2]; long seed = System.nanoTime(); Random random = new Random(seed); int index = random.nextInt(MEMBER_NUM * 2 - 2); for (int i = 0; i < MEMBER_NUM * 2; i++) { if (i < index) { crossOverR[i] = r1[i]; } else { crossOverR[i] = r2[i]; } } return crossOverR; } /** * 产生一个长度为 MEMBER_NUM*2 的随机数组,数组中每个数字的值范围:(0, 9) * * @return r[<code>MEMBER_NUM</code> * 2] */ public int[] randomResult() { long seed = System.nanoTime(); Random random = new Random(seed); int[] r = new int[MEMBER_NUM * 2]; for (int j = 0; j < MEMBER_NUM * 2; j++) { r[j] = random.nextInt(9); } return r; } public static void main(String[] args) { //int[] s = {1, 4, 3, 2, 7, 3, 6, 3, 2, 4, 5, 3}; Optimization o = new Optimization(); //System.out.println(o.getMinutes(o.flights.getTimeByValue("1:21"))); o.initPeople(); //int[] best = o.randomOptimize(100); //int[] best = o.hillclimb(); // int[] best = o.annealingoptimize(10000.0, 0.95, 2); int[] best = o.geneticoptimize(50, 2, 0.2, 0.2, 100); o.printSchedule(best); } }
/** * * FILENAME: Flights.java * AUTHOR: vivizhyy[at]gmail.com * STUID: whu * DATE: 2010-4-12 * USAGE : */ package ch5.optimization; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import org.apache.log4j.Logger; import org.apache.log4j.PropertyConfigurator; import org.joda.time.LocalTime; public class Flights { private String origin; private String dest; private LocalTime depart; private LocalTime arraive; private float price; public static String SCHEDULE = "schedule.txt"; public Logger log = Logger.getLogger(Flights.class.getName()); public Flights() { } public LocalTime getArraive() { return arraive; } public LocalTime getDepart() { return depart; } public String getDest() { return dest; } public Logger getLog() { return log; } public String getOrigin() { return origin; } public float getPrice() { return price; } public void setArraive(LocalTime arraive) { this.arraive = arraive; } public void setDepart(LocalTime depart) { this.depart = depart; } public void setDest(String dest) { this.dest = dest; } public void setLog(Logger log) { this.log = log; } public void setOrigin(String origin) { this.origin = origin; } public void setPrice(float price) { this.price = price; } /** * * @return */ public static Flights[] getFlights() { Flights flights[] = new Flights[120]; for(int i = 0; i < 120; i++) flights[i] = new Flights(); PropertyConfigurator.configure("D:/Documents/NetBeansProjects/CollectiveProgramming/src/log4j.properties"); BufferedReader reader = null; String line; try { reader = new BufferedReader(new FileReader(SCHEDULE)); int count = 0; while ((line = reader.readLine()) != null && count < 120) { //init flights String result[] = line.split(","); if (result.length == 5) { flights[count].origin = result[0]; flights[count].dest = result[1]; flights[count].depart = getTimeByValue(result[2]); flights[count].arraive = getTimeByValue(result[3]); flights[count].price = new Integer(result[4]); count++; } else { System.err.println("schedule format wrong."); } } } catch (IOException ex) { ex.printStackTrace(); } finally { if (reader != null) { try { reader.close(); } catch (IOException ex) { ex.printStackTrace(); } } } return flights; } /** * * @param time * @return */ public static LocalTime getTimeByValue(String time) { int mark = time.indexOf(":"); int hour = new Integer(time.substring(0, mark)); int minute = new Integer(time.substring(mark + 1, time.length())); LocalTime timeResult = new LocalTime(hour, minute); return timeResult; } /** * * @param strOrigin * @param strDest * @return */ public Flights[] getFlightByOriginAndDest(String strOrigin, String strDest) { Flights[] resultSet = new Flights[15]; Flights flights[] = getFlights(); int count = 0; for (Flights f : flights) { if(count > 14) { log.error("flight time lager than define."); break; } if (f.origin.equals(strOrigin) && f.dest.equals(strDest)) { resultSet[count++] = f; } } return resultSet; } public String toString() { return this.origin + "\t" + this.dest + "\t" + this.depart + "-" + this.arraive + "\t" + "$" + this.price; } }
/** * * FILENAME: Score.java * AUTHOR: [email protected] * STUID: whu200732580127 * DATE: 2010-4-13 * USAGE : */ package ch5.optimization; import java.util.Comparator; public class Score { public int[] list = new int[12]; public double price; } class CompareScore implements Comparator{ public int compare(Object o1, Object o2) { Score s1 = (Score)o1; Score s2 = (Score)o2; int flag = 0; if(s1.price > s2.price) flag = 1; return flag; } }