bfprt算法与改写快排的方法的唯一的不同就是,bfprt算法对于找的随机数讲究
1.3个数一组7个数一组也可以收敛到O(N),因为是5个人发明的,所以是5个一组
public static class MaxHeapComparator implements Comparator {
@Override
public int compare(Integer o1, Integer o2) {
return o2 - o1;
}
}
// 利用大根堆,时间复杂度O(N*logK)
public static int minKth1(int[] arr, int k) {
PriorityQueue maxHeap = new PriorityQueue<>(new MaxHeapComparator());
for (int i = 0; i < k; i++) {
maxHeap.add(arr[i]);
}
for (int i = k; i < arr.length; i++) {
if (arr[i] < maxHeap.peek()) {
maxHeap.poll();
maxHeap.add(arr[i]);
}
}
return maxHeap.peek();
}
// 改写快排,时间复杂度O(N)
// k >= 1
public static int minKth2(int[] array, int k) {
int[] arr = copyArray(array);
return process2(arr, 0, arr.length - 1, k - 1);
}
public static int[] copyArray(int[] arr) {
int[] ans = new int[arr.length];
for (int i = 0; i != ans.length; i++) {
ans[i] = arr[i];
}
return ans;
}
// arr 第k小的数
// process2(arr, 0, N-1, k-1)
// arr[L..R] 范围上,如果排序的话(不是真的去排序),找位于index的数
// index [L..R]
public static int process2(int[] arr, int L, int R, int index) {
if (L == R) { // L = =R ==INDEX
return arr[L];
}
// 不止一个数 L + [0, R -L]
int pivot = arr[L + (int) (Math.random() * (R - L + 1))];
int[] range = partition(arr, L, R, pivot);
if (index >= range[0] && index <= range[1]) {
return arr[index];
} else if (index < range[0]) {
return process2(arr, L, range[0] - 1, index);
} else {
return process2(arr, range[1] + 1, R, index);
}
}
public static int[] partition(int[] arr, int L, int R, int pivot) {
int less = L - 1;
int more = R + 1;
int cur = L;
while (cur < more) {
if (arr[cur] < pivot) {
swap(arr, ++less, cur++);
} else if (arr[cur] > pivot) {
swap(arr, cur, --more);
} else {
cur++;
}
}
return new int[] { less + 1, more - 1 };
}
public static void swap(int[] arr, int i1, int i2) {
int tmp = arr[i1];
arr[i1] = arr[i2];
arr[i2] = tmp;
}
// 利用bfprt算法,时间复杂度O(N)
public static int minKth3(int[] array, int k) {
int[] arr = copyArray(array);
return bfprt(arr, 0, arr.length - 1, k - 1);
}
// arr[L..R] 如果排序的话,位于index位置的数,是什么,返回
public static int bfprt(int[] arr, int L, int R, int index) {
if (L == R) {
return arr[L];
}
// L...R 每五个数一组
// 每一个小组内部排好序
// 小组的中位数组成新数组
// 这个新数组的中位数返回
int pivot = medianOfMedians(arr, L, R);
int[] range = partition(arr, L, R, pivot);
if (index >= range[0] && index <= range[1]) {
return arr[index];
} else if (index < range[0]) {
return bfprt(arr, L, range[0] - 1, index);
} else {
return bfprt(arr, range[1] + 1, R, index);
}
}
// arr[L...R] 五个数一组
// 每个小组内部排序
// 每个小组中位数领出来,组成marr
// marr中的中位数,返回
public static int medianOfMedians(int[] arr, int L, int R) {
int size = R - L + 1;
int offset = size % 5 == 0 ? 0 : 1;
int[] mArr = new int[size / 5 + offset];
for (int team = 0; team < mArr.length; team++) {
int teamFirst = L + team * 5;
// L ... L + 4
// L +5 ... L +9
// L +10....L+14
mArr[team] = getMedian(arr, teamFirst, Math.min(R, teamFirst + 4));
}
// marr中,找到中位数
// marr(0, marr.len - 1, mArr.length / 2 )
return bfprt(mArr, 0, mArr.length - 1, mArr.length / 2);
}
public static int getMedian(int[] arr, int L, int R) {
insertionSort(arr, L, R);
return arr[(L + R) / 2];
}
public static void insertionSort(int[] arr, int L, int R) {
for (int i = L + 1; i <= R; i++) {
for (int j = i - 1; j >= L && arr[j] > arr[j + 1]; j--) {
swap(arr, j, j + 1);
}
}
}
// for test
public static int[] generateRandomArray(int maxSize, int maxValue) {
int[] arr = new int[(int) (Math.random() * maxSize) + 1];
for (int i = 0; i < arr.length; i++) {
arr[i] = (int) (Math.random() * (maxValue + 1));
}
return arr;
}
public static void main(String[] args) {
int testTime = 1000000;
int maxSize = 100;
int maxValue = 100;
System.out.println("test begin");
for (int i = 0; i < testTime; i++) {
int[] arr = generateRandomArray(maxSize, maxValue);
int k = (int) (Math.random() * arr.length) + 1;
int ans1 = minKth1(arr, k);
int ans2 = minKth2(arr, k);
int ans3 = minKth3(arr, k);
if (ans1 != ans2 || ans2 != ans3) {
System.out.println("Oops!");
}
}
System.out.println("test finish");
}
// 时间复杂度O(N*logN)
// 排序+收集
public static int[] maxTopK1(int[] arr, int k) {
if (arr == null || arr.length == 0) {
return new int[0];
}
int N = arr.length;
k = Math.min(N, k);
Arrays.sort(arr);
int[] ans = new int[k];
for (int i = N - 1, j = 0; j < k; i--, j++) {
ans[j] = arr[i];
}
return ans;
}
// 方法二,时间复杂度O(N + K*logN)
// 解释:堆
public static int[] maxTopK2(int[] arr, int k) {
if (arr == null || arr.length == 0) {
return new int[0];
}
int N = arr.length;
k = Math.min(N, k);
// 从底向上建堆,时间复杂度O(N)
for (int i = N - 1; i >= 0; i--) {
heapify(arr, i, N);
}
// 只把前K个数放在arr末尾,然后收集,O(K*logN)
int heapSize = N;
swap(arr, 0, --heapSize);
int count = 1;
while (heapSize > 0 && count < k) {
heapify(arr, 0, heapSize);
swap(arr, 0, --heapSize);
count++;
}
int[] ans = new int[k];
for (int i = N - 1, j = 0; j < k; i--, j++) {
ans[j] = arr[i];
}
return ans;
}
public static void heapInsert(int[] arr, int index) {
while (arr[index] > arr[(index - 1) / 2]) {
swap(arr, index, (index - 1) / 2);
index = (index - 1) / 2;
}
}
public static void heapify(int[] arr, int index, int heapSize) {
int left = index * 2 + 1;
while (left < heapSize) {
int largest = left + 1 < heapSize && arr[left + 1] > arr[left] ? left + 1 : left;
largest = arr[largest] > arr[index] ? largest : index;
if (largest == index) {
break;
}
swap(arr, largest, index);
index = largest;
left = index * 2 + 1;
}
}
public static void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
// 方法三,时间复杂度O(n + k * logk)
public static int[] maxTopK3(int[] arr, int k) {
if (arr == null || arr.length == 0) {
return new int[0];
}
int N = arr.length;
k = Math.min(N, k);
// O(N)
int num = minKth(arr, N - k);
int[] ans = new int[k];
int index = 0;
for (int i = 0; i < N; i++) {
if (arr[i] > num) {
ans[index++] = arr[i];
}
}
for (; index < k; index++) {
ans[index] = num;
}
// O(k*logk)
Arrays.sort(ans);
for (int L = 0, R = k - 1; L < R; L++, R--) {
swap(ans, L, R);
}
return ans;
}
// 时间复杂度O(N)
public static int minKth(int[] arr, int index) {
int L = 0;
int R = arr.length - 1;
int pivot = 0;
int[] range = null;
while (L < R) {
pivot = arr[L + (int) (Math.random() * (R - L + 1))];
range = partition(arr, L, R, pivot);
if (index < range[0]) {
R = range[0] - 1;
} else if (index > range[1]) {
L = range[1] + 1;
} else {
return pivot;
}
}
return arr[L];
}
public static int[] partition(int[] arr, int L, int R, int pivot) {
int less = L - 1;
int more = R + 1;
int cur = L;
while (cur < more) {
if (arr[cur] < pivot) {
swap(arr, ++less, cur++);
} else if (arr[cur] > pivot) {
swap(arr, cur, --more);
} else {
cur++;
}
}
return new int[] { less + 1, more - 1 };
}
// for test
public static int[] generateRandomArray(int maxSize, int maxValue) {
int[] arr = new int[(int) ((maxSize + 1) * Math.random())];
for (int i = 0; i < arr.length; i++) {
// [-? , +?]
arr[i] = (int) ((maxValue + 1) * Math.random()) - (int) (maxValue * Math.random());
}
return arr;
}
// for test
public static int[] copyArray(int[] arr) {
if (arr == null) {
return null;
}
int[] res = new int[arr.length];
for (int i = 0; i < arr.length; i++) {
res[i] = arr[i];
}
return res;
}
// for test
public static boolean isEqual(int[] arr1, int[] arr2) {
if ((arr1 == null && arr2 != null) || (arr1 != null && arr2 == null)) {
return false;
}
if (arr1 == null && arr2 == null) {
return true;
}
if (arr1.length != arr2.length) {
return false;
}
for (int i = 0; i < arr1.length; i++) {
if (arr1[i] != arr2[i]) {
return false;
}
}
return true;
}
// for test
public static void printArray(int[] arr) {
if (arr == null) {
return;
}
for (int i = 0; i < arr.length; i++) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
// 生成随机数组测试
public static void main(String[] args) {
int testTime = 500000;
int maxSize = 100;
int maxValue = 100;
boolean pass = true;
System.out.println("测试开始,没有打印出错信息说明测试通过");
for (int i = 0; i < testTime; i++) {
int k = (int) (Math.random() * maxSize) + 1;
int[] arr = generateRandomArray(maxSize, maxValue);
int[] arr1 = copyArray(arr);
int[] arr2 = copyArray(arr);
int[] arr3 = copyArray(arr);
int[] ans1 = maxTopK1(arr1, k);
int[] ans2 = maxTopK2(arr2, k);
int[] ans3 = maxTopK3(arr3, k);
if (!isEqual(ans1, ans2) || !isEqual(ans1, ans3)) {
pass = false;
System.out.println("出错了!");
printArray(ans1);
printArray(ans2);
printArray(ans3);
break;
}
}
System.out.println("测试结束了,测试了" + testTime + "组,是否所有测试用例都通过?" + (pass ? "是" : "否"));
}
public static class RandomBox {
private int[] bag;
private int N;
private int count;
public RandomBox(int capacity) {
bag = new int[capacity];
N = capacity;
count = 0;
}
private int rand(int max) {
return (int) (Math.random() * max) + 1;
}
public void add(int num) {
count++;
if (count <= N) {
bag[count - 1] = num;
} else {
if (rand(count) <= N) {
bag[rand(N) - 1] = num;
}
}
}
public int[] choices() {
int[] ans = new int[N];
for (int i = 0; i < N; i++) {
ans[i] = bag[i];
}
return ans;
}
}
// 请等概率返回1~i中的一个数字
public static int random(int i) {
return (int) (Math.random() * i) + 1;
}
public static void main(String[] args) {
System.out.println("hello");
int test = 10000;
int ballNum = 17;
int[] count = new int[ballNum + 1];
for (int i = 0; i < test; i++) {
int[] bag = new int[10];
int bagi = 0;
for (int num = 1; num <= ballNum; num++) {
if (num <= 10) {
bag[bagi++] = num;
} else { // num > 10
if (random(num) <= 10) { // 一定要把num球入袋子
bagi = (int) (Math.random() * 10);
bag[bagi] = num;
}
}
}
for (int num : bag) {
count[num]++;
}
}
for (int i = 0; i <= ballNum; i++) {
System.out.println(count[i]);
}
System.out.println("hello");
int all = 100;
int choose = 10;
int testTimes = 50000;
int[] counts = new int[all + 1];
for (int i = 0; i < testTimes; i++) {
RandomBox box = new RandomBox(choose);
for (int num = 1; num <= all; num++) {
box.add(num);
}
int[] ans = box.choices();
for (int j = 0; j < ans.length; j++) {
counts[ans[j]]++;
}
}
for (int i = 0; i < counts.length; i++) {
System.out.println(i + " times : " + counts[i]);
}
}