算法课实验:分支限界求解01背包&Sherwood算法

分支限界01背包

/*
 * Copyright (c) 2019 Ng Kimbing, HNU, All rights reserved. May not be used, modified, or copied without permission.
 * @Author: Ng Kimbing, HNU.
 * @LastModified:2019-05-27 T 15:55:52.166 +08:00
 */
package ACMProblems.BFS;

import MyUtil.TestData;

import java.io.IOException;
import java.util.*;

public class BFS_Knapsack {
    private static class Thing {
        int id;
        int weight;
        int value;
        double performance;

        public Thing(int id, int weight, int value) {
            this.id = id;
            this.weight = weight;
            this.value = value;
            this.performance = (double) value / weight;
        }
    }

    private static int maxWeight;
    private static Thing[] things;
    private static int num;

    private static class Node {
        private int level; //consider things[level] when come to this condition.
        private int currValue;
        private int weightLeft;
        private double upProfit;
        private String path;

        Node(int level, int currValue, int weightLeft, double upProfit) {
            this.level = level;
            this.currValue = currValue;
            this.weightLeft = weightLeft;
            this.upProfit = upProfit;
        }

        Node(int level, int currValue, int weightLeft, double upProfit, String path) {
            this.level = level;
            this.currValue = currValue;
            this.weightLeft = weightLeft;
            this.upProfit = upProfit;
            this.path = path;
        }

        @Override
        public String toString() {
            return "value :" + currValue + "\tthingsList:   " + path;
        }
    }

    /**
     * get the profit upper bound.
     *
     * @param i the layer
     * @return the profit upper bound
     */
    private static double getBound(int i, int weightLeft, int currValue) {
        double bound = currValue;
        /*load each item one by one*/
        while (i < num && weightLeft > 0) {
            if (things[i].weight >= weightLeft) {
                bound += things[i].value;
                weightLeft -= things[i].weight;
            } else {
                /*The backpack does not have enough space to hold an entire item.
                so we only a part of it*/
                bound += weightLeft / things[i].weight * things[i].value;
                break;
            }
            ++i;
        }
        return bound;
    }

    private static int BFS() {
        int bestValue = 0;
        double up = getBound(0, maxWeight, 0);
        PriorityQueue<Node> heap = new PriorityQueue<>((o1, o2) -> Double.compare(o2.upProfit, o1.upProfit));
        heap.add(new Node(0, 0, maxWeight, up));
        while (!heap.isEmpty()) {
            Node currNode = heap.poll();
            assert currNode != null;
            int level = currNode.level;
            if (level == num)
                break;
            int currValue = currNode.currValue;
            int weightLeft = currNode.weightLeft;
            double upProfit = currNode.upProfit;
            Thing currItem = things[level];
            //left child
            if (weightLeft >= currItem.weight) {
                int newValue = currValue + currItem.value;
                bestValue = Math.max(bestValue, newValue);
                heap.add(new Node(level + 1, newValue, weightLeft - currItem.weight, upProfit));
            }
            //right child
            double rightUp = getBound(level + 1, weightLeft, currValue);
            if (rightUp >= bestValue) {
                heap.add(new Node(level + 1, currValue, weightLeft, rightUp));
            }
        }
        return bestValue;
    }

    private static Node BFSWithSolution() {
        int bestValue = 0;
        double up = getBound(0, maxWeight, 0);
        PriorityQueue<Node> heap = new PriorityQueue<>((o1, o2) -> Double.compare(o2.upProfit, o1.upProfit));
        heap.add(new Node(0, 0, maxWeight, up, ""));
        while (!heap.isEmpty()) {
            Node currNode = heap.poll();
            System.out.println(currNode + "\tup profit: " + currNode.upProfit);
            assert currNode != null;
            int level = currNode.level;
            if (level == num)
                return currNode;
            int currValue = currNode.currValue;
            int weightLeft = currNode.weightLeft;
            double upProfit = currNode.upProfit;
            Thing currItem = things[level];
            //left child
            if (weightLeft >= currItem.weight) {
                int newValue = currValue + currItem.value;
                bestValue = Math.max(bestValue, newValue);
                heap.add(new Node(level + 1, newValue, weightLeft - currItem.weight, upProfit, currNode.path + "1"));
            }
            double rightUp = getBound(level + 1, weightLeft, currValue);
            if (rightUp >= bestValue) {
                heap.add(new Node(level + 1, currValue, weightLeft, rightUp, currNode.path + "0"));
            }
        }
        return null;
    }

    private static int solve() {
        Arrays.sort(things, (o1, o2) -> Double.compare(o2.performance, o1.performance));
        return BFS();
    }

    static class Result {
        int maxValue;
        String thingsList;

        public Result(int maxValue, String thingsList) {
            this.maxValue = maxValue;
            this.thingsList = thingsList;
        }

        @Override
        public String toString() {
            return "Max value: " + maxValue + "\twe take: " + thingsList;
        }
    }

    private static Result solveWithSolution() {
        Arrays.sort(things, (o1, o2) -> Double.compare(o2.performance, o1.performance));
        Node ans = BFSWithSolution();
        if (ans == null)
            return null;
        char[] foo = ans.path.toCharArray();
        Set<Integer> set = new HashSet<>();
        for (int i = 0; i < foo.length; ++i) {
            if (foo[i] == '1')
                set.add(things[i].id);
        }
        return new Result(ans.currValue, set.toString());
    }

    private static Thing[] getThings(int[] w, int[] v) {
        assert w.length == v.length;
        Thing[] t = new Thing[w.length];
        for (int i = 0; i < t.length; ++i)
            t[i] = new Thing(i, w[i], v[i]);
        return t;
    }

    public static void main(String[] args) throws Exception {
        randomTest();
    }

    private static void randomTest() throws IOException {
        int size = 10000;
        num = size;
        maxWeight = size;
        int[] data = new int[2 * size];
        int[] w = new int[size];
        int[] v = new int[size];
        String path = "BFS_Knapsack" + size + ".txt";
        TestData.makeData(2 * size, path, false, 1, 300 * size);
        TestData.loadData(data, path);
        System.arraycopy(data, 0, w, 0, size);
        System.arraycopy(data, size, v, 0, size);
        System.out.println("data done");
        things = getThings(w, v);
        long start = System.currentTimeMillis();
        int ans = solve();
        long end = System.currentTimeMillis();
        System.out.println("size: " + size + "\tMax weight: " + maxWeight);
        System.out.println("-------------------------------------\nthe result is " + ans);
//        System.out.println("here are things we take into the knapsack");
//        System.out.println(ansSet);
        long duration = end - start;
        System.out.println("It takes " + duration + " milliseconds");
    }
}

Sherwood


/*
 * Copyright (c) 2019 Ng Kimbing, HNU, All rights reserved. May not be used, modified, or copied without permission.
 * @Author: Ng Kimbing, HNU.
 * @LastModified:2019-05-28 T 12:01:53.909 +08:00
 */

package MyUtil;

import StupidCode.Test.Test;

import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;

public class SherwoodSelect {
    private static int select(int[] a, int l, int r, int k) {
        assert r >= l;
        Random random = new Random();
        while (true) {
            if (l >= r)
                return a[l];
            //randomly choose a pivot (a[pivotIndex])
            int pivotIndex = l + random.nextInt(r - l + 1); // pivotIndex = rd(l, r) l, r inclusive
            int pivot = a[pivotIndex];
            //after the partition, a[pivotIndex] = pivot
            pivotIndex = partition(a, l, r, pivotIndex, pivot);
            //If the pivot happens to be the kth number in the interval
            if (pivotIndex + 1 == l + k)
                return pivot;
            if (pivotIndex + 1 < l + k) {
                //there are pivotIndex - l + 1 numbers in the left part
                k -= pivotIndex - l + 1;
                l = pivotIndex + 1;
            } else r = pivotIndex - 1;
        }
    }

    private static int partition(int[] a, int l, int r, int pivotIndex, int pivot) {
        swap(a, l, pivotIndex);
        int i = l;
        int j = r + 1;
        while (true) {
            while (i < r && a[++i] < pivot) ;
            while (a[--j] > pivot) ;
            if (i >= j)
                break;
            swap(a, i, j);
        }
        //swap a[l] and a[j] (put pivot to a[j])
        //swap(a, l, j)
        a[l] = a[j];
        a[j] = pivot;
        return j;
    }

    private static void swap(Object[] a, int i, int j) {
        Object temp = a[i];
        a[i] = a[j];
        a[j] = temp;
    }

    private static void swap(int[] a, int i, int j) {
        int temp = a[i];
        a[i] = a[j];
        a[j] = temp;
    }

    public static int select(int[] a, int k) {
        assert k <= a.length;
        return select(a, 0, a.length - 1, k);
    }

    public static int select(int[] a, int num, int k) {
        assert k <= num;
        return select(a, 0, num - 1, k);
    }

    public static <T> T select(T[] a, int k, Comparator<T> comparator) {
        assert k <= a.length;
        return select(a, 0, a.length - 1, k, comparator);
    }

    public static <T> T select(T[] a, int num, int k, Comparator<T> comparator) {
        assert k <= num;
        return select(a, 0, num - 1, k, comparator);
    }

    public static <T> T select(T[] a, int l, int r, int k, Comparator<T> comparator) {
        assert r >= l && k + l - 1 <= r;
        Random random = new Random();
        //a[l+k-1] is the k-th element.
        while (true) {
            if (l >= r)
                return a[l];
            int i = l, j = l + random.nextInt(r - l + 1);
            swap(a, i, j);
            j = r + 1;
            T pivot = a[l];
            while (true) {
                while (i < r && comparator.compare(a[++i], pivot) < 0) ;
                while (comparator.compare(a[--j], pivot) > 0) ;
                if (i >= j)
                    break;
                swap(a, i, j);
            }
            if (j + 1 == l + k)
                return pivot;
            a[l] = a[j];
            a[j] = pivot;
            if (j + 1 < l + k) {
                k -= j - l + 1;
                l = j + 1;
            } else r = j - 1;
        }
    }

    public static void main(String[] args) throws Exception {
        int size = 10000000;
        String path = "Sherwood" + size + ".txt";
        System.out.println("making data");
        TestData.makeData(size, path, false);
        System.out.println("done");
        int[] a = new int[size];
        TestData.loadData(a, path);
        int k = TestData.random(1, size);
        long start = System.currentTimeMillis();
        int result = select(a, k);
        long end = System.currentTimeMillis();
//        System.out.println("From the array below:");
//        System.out.println(Arrays.toString(a));
        System.out.println("from " + size + " numbers, select the " + k + "-th one");
        System.out.println("The result is :" + result);
        System.out.println("It takes " + (end - start) + " milliseconds");
    }
}

你可能感兴趣的:(算法作业)