1.简介
一个长度为M的对象数组,对象有权重属性W(权重总和不服从1分配),要求根据权重随机找出N个对象,概率服从权重分配(或者可按一定分布服从)
2.原始(第一)想法
2.1 权重映射
先遍历一遍数组,找到每个权重的上下限Wmin与Wmax 并计算出总和Wtotal,在0~Wtotal中取随机数,再根据二分查找(可以根据Wtotal和size算出Waverage 使得二分查找更精确)找到对应范围内的对象
如果数量级比较小可以直接申请一段空间,简化回查复杂度。
Java代码实现如下:
package com.kowalski;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* Created by kowalski.zhang on 2018/5/14
*/
public class Algorithm {
public static void main(String... args) {
}
public static Random random = ThreadLocalRandom.current();
public static List getRandomListByWeight(List sourceList, int takeNum){
if(sourceList == null || sourceList.isEmpty() || takeNum <= 0){
return null;
}
List resList = new ArrayList<>();
int total = getTotalWeightAndFillWeight(sourceList);
if(sourceList.size() == 1){
resList.add(sourceList.get(0));
return resList;
}
resList = IntStream.range(0, takeNum).map(i -> search(sourceList, random.nextInt(total)))
.filter(resIndex -> resIndex != -1).mapToObj(sourceList::get).collect(Collectors.toList());
return resList;
}
/***
* 二分查找(可优化)
* @param sourceList
* @param randomNum
* @param
* @return
*/
public static int search(List sourceList, int randomNum){
int low = 0;
int high = sourceList.size() - 1;
int middle;
while(low <= high){
middle = (low + high) / 2;
if(sourceList.get(middle).getLowWeight() > randomNum){
//比关键字大则关键字在左区域
high = middle - 1;
}else if(sourceList.get(middle).getHighWeight() < randomNum){
//比关键字小则关键字在右区域
low = middle + 1;
}else{
return middle;
}
}
return -1; //最后仍然没有找到,则返回-1
}
public static int getTotalWeightAndFillWeight(List sourceList){
if(sourceList == null || sourceList.isEmpty()){
return 0;
}
int total = 0;
for(T source:sourceList){
/**填充上下限 计算total*/
source.setLowWeight(total);
total += source.getWeight();
source.setHighWeight(total);
}
return total;
}
public static class Weight {
public Weight(int weight) {
this.weight = weight;
}
public int getWeight() {
return weight;
}
public void setWeight(int weight) {
this.weight = weight;
}
public int getLowWeight() {
return lowWeight;
}
public void setLowWeight(int lowWeight) {
this.lowWeight = lowWeight;
}
public int getHighWeight() {
return highWeight;
}
public void setHighWeight(int highWeight) {
this.highWeight = highWeight;
}
/**权重*/
private int weight;
/**下限*/
private int lowWeight;
/**上限*/
private int highWeight;
}
}
3.原始(第一)想法的破灭
3.1问题的出现:
3.1.1.取出N个不重复对象
如何取出N个不重复的对象?
*将已取出对象抛出重新分配上下限?这个效率肯定不允许
*将已取出对象与最后的对象进行位置交换(设立最后部分为禁区),下次随机数将在0~(Wtotal-Wtoke)间产生,如果当前对象的范围小于尾部对像,则直接将当前对象置换为尾部对象,但如果抛出对象范围较大,则问题就会变得很复杂…数组调整,等等,效率也不一定允许。
3.1.2.不只是想服从绝对的权重占总权重比
可能需要要求某些权重永远不能被取出(或者更大程度的缩小小权重出现概率),可以动态变化分布规则
4.新想法探究
如果想要已更高的效率取出N个不重复的服从权重分布的随机对像,基本上以上方法已经无法满足。
那是一个明媚的下午…刚睡醒还在懵懵状态下的我,在纸上画出了个这么个玩意:
图解:每个不同的权重对应不同的长度,最右边一条线在往左靠近的过程中,取出依次接触到的权重线,那怎么能让短的(像Obj4的W4)也有机会被取出呢,那么就再加一个随机数不就好啦~ 大的W有可能随机到小的数,小的有可能随机到大的数,然后取topk不就好了!!!!(脑瓜崩嗡嗡的…)
然后就引入了个离散权重的概念:
Java代码在这里:
package com.kowalski;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* Created by kowalski.zhang on 2018/5/21
*/
public class FinalFinalAlgorithm {
public static void main(String... args) {
List demos;
/**造数据*/
demos = IntStream.range(0, 100000).mapToObj(i -> new Demo(0, StrictMath.random() * 100))
.collect(Collectors.toList());
long time1 = System.currentTimeMillis();
/**take*/
List topK = Taker.take(demos, 10000);
System.out.println(System.currentTimeMillis() - time1);
int mm = 0;
double max = 0;
for(Demo demo:topK){
mm ++;
System.out.println("W:" + demo.getWeight() + " D:" + demo.getDispersedWeight());
if(demo.getWeight() > max){
max = demo.getWeight();
}
}
System.out.println("max " + max);
}
public static class Weight implements Comparable<Weight>, Serializable {
private static final long serialVersionUID = 2816396154208421520L;
public Weight(double weight) {
this.weight = weight;
}
public double getWeight() {
return weight;
}
public void setWeight(double weight) {
this.weight = weight;
}
public double getDispersedWeight() {
return dispersedWeight;
}
public void setDispersedWeight(Double dispersedWeight) {
this.dispersedWeight = dispersedWeight;
}
/**
* 权重
*/
private double weight;
/**
* 离散后权重
*/
private double dispersedWeight;
@Override public int compareTo(Weight o) {
if (this.getDispersedWeight() > o.getDispersedWeight()) {
return 1;
} else {
return -1;
}
}
}
/***
* 实体类
*/
public static class Demo extends Weight {
private static final long serialVersionUID = 3499760378453027078L;
private Integer num;
public Demo(Integer num, double weight) {
super(weight);
this.num = num;
}
public Integer getNum() {
return num;
}
public void setNum(Integer num) {
this.num = num;
}
}
/***
* taker
*/
public static class Taker {
/**
* take
* @param source 原数据
* @param takeNum 取出量
* @param isDestruction 是否破坏原数据
* @param dispersedType 离散方式 (默认权重与总权重比)
* @param
* @return
*/
public staticextends Weight> List take(List source, int takeNum, boolean isDestruction,
DispersedTypeEnum dispersedType){
fillDispersedWeight(source, dispersedType == null?
DispersedTypeEnum.ABSOLUTE_FOLLOW_WEIGHT:dispersedType);
List rest = Collections.emptyList();
try {
rest = Sort.topKWithSortByQuickSort(isDestruction?source:deepCopy(source), takeNum);
} catch (IOException | ClassNotFoundException e) {
e.printStackTrace();
}
return rest;
}
/**
* take 默认破坏原数据
* @param source 原数据
* @param takeNum 取出量
* @param dispersedType 离散方式 (默认简单加法离散_平均离散值)
* @param
* @return
*/
public staticextends Weight> List take(List source, int takeNum, DispersedTypeEnum dispersedType) {
return take(source, takeNum, true, dispersedType);
}
/**
* take 默认破坏原数据
* @param source 原数据
* @param takeNum 取出量
* @param
* @return
*/
public staticextends Weight> List take(List source, int takeNum) {
return take(source, takeNum, true, null);
}
/**
* take 默认破坏原数据 全量排序
* @param source 原数据
* @param
* @return
*/
public staticextends Weight> List take(List source) {
if(source == null || source.isEmpty()){
return Collections.emptyList();
}
return take(source, source.size(), true, null);
}
/**
* 不破坏原数组
* @param source
* @param takeNum
* @param
* @return
*/
public staticextends Weight> List takeWithOutDestruction(List source, int takeNum)
throws IOException, ClassNotFoundException {
return Sort.topKWithSortByQuickSort(deepCopy(source), takeNum);
}
/**
* 破坏原数组
* @param source
* @param takeNum
* @param
* @return
*/
public staticextends Weight> List takeWithDestruction(List source, int takeNum){
return Sort.topKWithSortByQuickSort(source, takeNum);
}
/**
* 获取平均权重
* @param source
* @param
* @return
*/
public staticextends Weight> double getAverageWeight(List source){
if(source == null || source.isEmpty()){
return 0;
}
return getTotalWeight(source) / source.size();
}
/**
* 获取总权重
* @param source
* @param
* @return
*/
public staticextends Weight> double getTotalWeight(List source){
if(source == null || source.isEmpty()){
return 0;
}
return source.stream().mapToDouble(Weight::getWeight).sum();
}
/**
* 填充离散权重
* @param source
* @param dispersedType 离散方式
* @param
*/
public staticextends Weight> void fillDispersedWeight(List source, DispersedTypeEnum dispersedType){
if(source == null || source.isEmpty()){
return ;
}
/**离散量*/
Double dispersedNum = dispersedType.getDispersedNum();
// Method method = dispersedType.getMethod();
/**填充离散*/
switch (dispersedType){
case ABSOLUTE_FOLLOW_WEIGHT:
double totalWeight = getTotalWeight(source);
for(T t:source){
t.setDispersedWeight((StrictMath.random() * (t.getWeight()/totalWeight)));
}
break;
case SIMPLE_ADD:
case SIMPLE_ADD_AVERAGE:
if(dispersedNum == null){
dispersedNum = getAverageWeight(source);
}
for(T t:source){
// try {
// t.setDispersedWeight((double) method.invoke(DispersedMethod.class, t.getWeight(), dispersedNum));
// } catch (IllegalAccessException | InvocationTargetException e) {
// e.printStackTrace();
// }
t.setDispersedWeight((StrictMath.random() * dispersedNum) + (t.getWeight()));
}
break;
}
}
/***
* 数组深拷贝--序列化方案
* @param src
* @param
* @return
* @throws IOException
* @throws ClassNotFoundException
*/
public static extends Weight> List deepCopy(List src) throws IOException, ClassNotFoundException {
ByteArrayOutputStream byteOut = new ByteArrayOutputStream();
ObjectOutputStream out = new ObjectOutputStream(byteOut);
out.writeObject(src);
ByteArrayInputStream byteIn = new ByteArrayInputStream(byteOut.toByteArray());
ObjectInputStream in = new ObjectInputStream(byteIn);
@SuppressWarnings("unchecked")
List dest = (List) in.readObject();
return dest;
}
}
// public static final class DispersedMethod{
// private DispersedMethod() {
// }
//
// public static double simpleAdd(double weight, double dispersedNum){
System.out.println("aaa");
// return (StrictMath.random() * dispersedNum) + weight;
// }
// }
public enum DispersedTypeEnum {
ABSOLUTE_FOLLOW_WEIGHT(0, "绝对服从权重", null),//概率服从当前权重与总权重之比
SIMPLE_ADD_AVERAGE(1, "简单加法离散_平均离散值", null/*, "simpleAdd"*/),
SIMPLE_ADD(2, "简单加法离散", 100.0d/*, "simpleAdd"*/);
private Integer type;
private String desc;
private Double dispersedNum;
// private Method method;
/**
* 离散量 离散量为空 采取平均权重离散
*/
DispersedTypeEnum(Integer type, String desc, Double dispersedNum/*, String methodName*/){
this.type = type;
this.desc = desc;
this.dispersedNum = dispersedNum;
// try {
// this.method = DispersedMethod.class.getMethod(methodName, double.class, double.class);
// } catch (NoSuchMethodException e) {
// e.printStackTrace();
// }
}
public Integer getType() {
return type;
}
public void setType(Integer type) {
this.type = type;
}
public String getDesc() {
return desc;
}
public Double getDispersedNum() {
return dispersedNum;
}
public void setDispersedNum(Double dispersedNum) {
this.dispersedNum = dispersedNum;
}
private static final Map map = new HashMap<>();
static {
for (DispersedTypeEnum enums : DispersedTypeEnum.values()) {
map.put(enums.getType(), enums);
}
}
public static DispersedTypeEnum getEnumValue(int code) {
return map.get(code);
}
public static String getDescByType(int code) {
return map.get(code).getDesc();
}
public static Double getDispersedNumByType(int code) {
return map.get(code).getDispersedNum();
}
// public Method getMethod() {
// return method;
// }
//
// public void setMethod(Method method) {
// this.method = method;
// }
public void setDesc(String desc) {
this.desc = desc;
}
}
/***
* 排序工具类
*/
public static class Sort{
/**
* topk + 排序(快排实现)
* @param source
* @param k
* @param
* @return
*/
public staticextends Weight> List topKWithSortByQuickSort(List source, int k) {
if(source == null || source.isEmpty()){
return Collections.emptyList();
}
int index;
int rank;
int start = 0;
int end = source.size() - 1;
while (end > start) {
index = partition(source, start, end);
rank = index + 1;
if (rank >= k) {
end = index - 1;
} else if ((index - start) > (end - index)) {
quickSort(source, index + 1, end);
end = index - 1;
} else {
quickSort(source, start, index - 1);
start = index + 1;
}
}
return source.subList(0, k);
}
public static extends Weight> int partition(List lst, int start, int end) {
T x;
x = lst.get(start);
int i = start;
for (int j = start + 1; j <= end; j++) {
if (lst.get(j).compareTo(x) > 0) {
i = i + 1;
swap(lst, i, j);
}
}
swap(lst, start, i);
return i;
}
public static extends Weight> void swap(List lst, int p, int q) {
T temp = lst.get(p);
lst.set(p, lst.get(q));
lst.set(q, temp);
}
public staticextends Weight> void quickSort(List lst, int start, int end) {
if (start < end) {
int index = partition(lst, start, end);
quickSort(lst, start, index - 1);
quickSort(lst, index + 1, end);
}
}
}
}
代码中TopK是基于快速排序改造的(不只是取了topk 还把这topk给排了序,网上很多只是找到了第N个位置的数,然后就直接把前N个返回了,前N个并未排序,排序很关键!!!)
效率很满意,嘿嘿嘿~
5.待研究
可以指定不同的离散方式干扰最终结果的分布~ 什么卡方分布,正态分布之类的(数学不行了…很多东西算不出来了…) 或许有更好的方案 有兴趣的同学一起研究~ (数学好的用到了别的离散方式的同学欢迎带带我)
email:[email protected]