带权随机算法-根据权重随机选出N个对象研究历程

1.简介
一个长度为M的对象数组,对象有权重属性W(权重总和不服从1分配),要求根据权重随机找出N个对象,概率服从权重分配(或者可按一定分布服从)
2.原始(第一)想法
2.1 权重映射
先遍历一遍数组,找到每个权重的上下限Wmin与Wmax 并计算出总和Wtotal,在0~Wtotal中取随机数,再根据二分查找(可以根据Wtotal和size算出Waverage 使得二分查找更精确)找到对应范围内的对象
如果数量级比较小可以直接申请一段空间,简化回查复杂度。
这里写图片描述

Java代码实现如下:

package com.kowalski;


import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * Created by kowalski.zhang on 2018/5/14
 */
public class Algorithm {

  public static void main(String... args) {

  }

  public static Random random = ThreadLocalRandom.current();

  public static  List getRandomListByWeight(List sourceList, int takeNum){

    if(sourceList == null || sourceList.isEmpty() || takeNum <= 0){
      return null;
    }

    List resList = new ArrayList<>();

    int total = getTotalWeightAndFillWeight(sourceList);

    if(sourceList.size() == 1){
      resList.add(sourceList.get(0));
      return resList;
    }
    resList = IntStream.range(0, takeNum).map(i -> search(sourceList, random.nextInt(total)))
        .filter(resIndex -> resIndex != -1).mapToObj(sourceList::get).collect(Collectors.toList());
    return resList;
  }

  /***
   * 二分查找(可优化)
   * @param sourceList
   * @param randomNum
   * @param 
   * @return
   */
  public static int search(List sourceList, int randomNum){
    int low = 0;
    int high = sourceList.size() - 1;
    int middle;

    while(low <= high){
      middle = (low + high) / 2;
      if(sourceList.get(middle).getLowWeight() > randomNum){
        //比关键字大则关键字在左区域
        high = middle - 1;
      }else if(sourceList.get(middle).getHighWeight() < randomNum){
        //比关键字小则关键字在右区域
        low = middle + 1;
      }else{
        return middle;
      }
    }

    return -1;      //最后仍然没有找到,则返回-1
  }

  public static int getTotalWeightAndFillWeight(List sourceList){
    if(sourceList == null || sourceList.isEmpty()){
      return 0;
    }
    int total = 0;
    for(T source:sourceList){
      /**填充上下限 计算total*/
      source.setLowWeight(total);
      total += source.getWeight();
      source.setHighWeight(total);
    }
    return total;
  }

  public static class Weight {

    public Weight(int weight) {
      this.weight = weight;
    }

    public int getWeight() {
      return weight;
    }

    public void setWeight(int weight) {
      this.weight = weight;
    }

    public int getLowWeight() {
      return lowWeight;
    }

    public void setLowWeight(int lowWeight) {
      this.lowWeight = lowWeight;
    }

    public int getHighWeight() {
      return highWeight;
    }

    public void setHighWeight(int highWeight) {
      this.highWeight = highWeight;
    }

    /**权重*/
    private int weight;
    /**下限*/
    private int lowWeight;
    /**上限*/
    private int highWeight;
  }
}

3.原始(第一)想法的破灭

3.1问题的出现:
3.1.1.取出N个不重复对象

如何取出N个不重复的对象?
*将已取出对象抛出重新分配上下限?这个效率肯定不允许
*将已取出对象与最后的对象进行位置交换(设立最后部分为禁区),下次随机数将在0~(Wtotal-Wtoke)间产生,如果当前对象的范围小于尾部对像,则直接将当前对象置换为尾部对象,但如果抛出对象范围较大,则问题就会变得很复杂…数组调整,等等,效率也不一定允许。

3.1.2.不只是想服从绝对的权重占总权重比

可能需要要求某些权重永远不能被取出(或者更大程度的缩小小权重出现概率),可以动态变化分布规则

4.新想法探究

如果想要已更高的效率取出N个不重复的服从权重分布的随机对像,基本上以上方法已经无法满足。
那是一个明媚的下午…刚睡醒还在懵懵状态下的我,在纸上画出了个这么个玩意:
带权随机算法-根据权重随机选出N个对象研究历程_第1张图片
图解:每个不同的权重对应不同的长度,最右边一条线在往左靠近的过程中,取出依次接触到的权重线,那怎么能让短的(像Obj4的W4)也有机会被取出呢,那么就再加一个随机数不就好啦~ 大的W有可能随机到小的数,小的有可能随机到大的数,然后取topk不就好了!!!!(脑瓜崩嗡嗡的…)
然后就引入了个离散权重的概念:

Java代码在这里:

package com.kowalski;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * Created by kowalski.zhang on 2018/5/21
 */
public class FinalFinalAlgorithm {

  public static void main(String... args) {

    List demos;
    /**造数据*/
    demos = IntStream.range(0, 100000).mapToObj(i -> new Demo(0, StrictMath.random() * 100))
        .collect(Collectors.toList());

    long time1 = System.currentTimeMillis();
    /**take*/
    List topK = Taker.take(demos, 10000);
    System.out.println(System.currentTimeMillis() - time1);

    int mm = 0;
    double max = 0;
    for(Demo demo:topK){
      mm ++;
      System.out.println("W:" + demo.getWeight() + " D:" + demo.getDispersedWeight());
      if(demo.getWeight() > max){
        max = demo.getWeight();
      }
    }

    System.out.println("max " + max);

  }

  public static class Weight implements Comparable<Weight>, Serializable {

    private static final long serialVersionUID = 2816396154208421520L;

    public Weight(double weight) {
      this.weight = weight;
    }

    public double getWeight() {

      return weight;
    }

    public void setWeight(double weight) {
      this.weight = weight;
    }

    public double getDispersedWeight() {
      return dispersedWeight;
    }

    public void setDispersedWeight(Double dispersedWeight) {
      this.dispersedWeight = dispersedWeight;
    }

    /**
     * 权重
     */
    private double weight;
    /**
     * 离散后权重
     */
    private double dispersedWeight;

    @Override public int compareTo(Weight o) {

      if (this.getDispersedWeight() > o.getDispersedWeight()) {
        return 1;
      } else {
        return -1;
      }
    }
  }

  /***
   * 实体类
   */
  public static class Demo extends Weight {
    private static final long serialVersionUID = 3499760378453027078L;
    private Integer num;

    public Demo(Integer num, double weight) {
      super(weight);
      this.num = num;
    }

    public Integer getNum() {
      return num;
    }

    public void setNum(Integer num) {
      this.num = num;
    }
  }

  /***
   * taker
   */
  public static class Taker {

    /**
     * take
     * @param source 原数据
     * @param takeNum 取出量
     * @param isDestruction 是否破坏原数据
     * @param dispersedType 离散方式 (默认权重与总权重比)
     * @param 
     * @return
     */
    public staticextends Weight> List take(List source, int takeNum, boolean isDestruction,
        DispersedTypeEnum dispersedType){

      fillDispersedWeight(source, dispersedType == null?
          DispersedTypeEnum.ABSOLUTE_FOLLOW_WEIGHT:dispersedType);
      List rest = Collections.emptyList();
      try {
        rest = Sort.topKWithSortByQuickSort(isDestruction?source:deepCopy(source), takeNum);
      } catch (IOException | ClassNotFoundException e) {
        e.printStackTrace();
      }
      return rest;
    }

    /**
     * take 默认破坏原数据
     * @param source 原数据
     * @param takeNum 取出量
     * @param dispersedType 离散方式 (默认简单加法离散_平均离散值)
     * @param 
     * @return
     */
    public staticextends Weight> List take(List source, int takeNum, DispersedTypeEnum dispersedType) {
      return take(source, takeNum, true, dispersedType);
    }
    /**
     * take 默认破坏原数据
     * @param source 原数据
     * @param takeNum 取出量
     * @param 
     * @return
     */
    public staticextends Weight> List take(List source, int takeNum) {
      return take(source, takeNum, true, null);
    }

    /**
     * take 默认破坏原数据 全量排序
     * @param source 原数据
     * @param 
     * @return
     */
    public staticextends Weight> List take(List source) {
      if(source == null || source.isEmpty()){
        return Collections.emptyList();
      }
      return take(source, source.size(), true, null);
    }

    /**
     * 不破坏原数组
     * @param source
     * @param takeNum
     * @param 
     * @return
     */
    public staticextends Weight> List takeWithOutDestruction(List source, int takeNum)
        throws IOException, ClassNotFoundException {
      return Sort.topKWithSortByQuickSort(deepCopy(source), takeNum);
    }

    /**
     * 破坏原数组
     * @param source
     * @param takeNum
     * @param 
     * @return
     */
    public staticextends Weight> List takeWithDestruction(List source, int takeNum){
      return Sort.topKWithSortByQuickSort(source, takeNum);
    }

    /**
     * 获取平均权重
     * @param source
     * @param 
     * @return
     */
    public staticextends Weight> double getAverageWeight(List source){
      if(source == null || source.isEmpty()){
        return 0;
      }
      return getTotalWeight(source) / source.size();
    }

    /**
     * 获取总权重
     * @param source
     * @param 
     * @return
     */
    public staticextends Weight> double getTotalWeight(List source){
      if(source == null || source.isEmpty()){
        return 0;
      }
      return source.stream().mapToDouble(Weight::getWeight).sum();
    }

    /**
     * 填充离散权重
     * @param source
     * @param dispersedType 离散方式
     * @param 
     */
    public staticextends Weight> void fillDispersedWeight(List source, DispersedTypeEnum dispersedType){
      if(source == null || source.isEmpty()){
        return ;
      }
      /**离散量*/
      Double dispersedNum = dispersedType.getDispersedNum();
//      Method method = dispersedType.getMethod();
      /**填充离散*/
      switch (dispersedType){
        case ABSOLUTE_FOLLOW_WEIGHT:
          double totalWeight = getTotalWeight(source);
          for(T t:source){
            t.setDispersedWeight((StrictMath.random() * (t.getWeight()/totalWeight)));
          }
          break;
        case SIMPLE_ADD:
        case SIMPLE_ADD_AVERAGE:
          if(dispersedNum == null){
            dispersedNum = getAverageWeight(source);
          }
          for(T t:source){
            //        try {
            //          t.setDispersedWeight((double) method.invoke(DispersedMethod.class, t.getWeight(), dispersedNum));
            //        } catch (IllegalAccessException | InvocationTargetException e) {
            //          e.printStackTrace();
            //        }
            t.setDispersedWeight((StrictMath.random() * dispersedNum) + (t.getWeight()));
          }
          break;
      }
    }

    /***
     * 数组深拷贝--序列化方案
     * @param src
     * @param 
     * @return
     * @throws IOException
     * @throws ClassNotFoundException
     */
    public static extends Weight> List deepCopy(List src) throws IOException, ClassNotFoundException {
      ByteArrayOutputStream byteOut = new ByteArrayOutputStream();
      ObjectOutputStream out = new ObjectOutputStream(byteOut);
      out.writeObject(src);

      ByteArrayInputStream byteIn = new ByteArrayInputStream(byteOut.toByteArray());
      ObjectInputStream in = new ObjectInputStream(byteIn);
      @SuppressWarnings("unchecked")
      List dest = (List) in.readObject();
      return dest;
    }
  }

//  public static final class DispersedMethod{
//    private DispersedMethod() {
//    }
//
//    public static double simpleAdd(double weight, double dispersedNum){
      System.out.println("aaa");
//      return (StrictMath.random() * dispersedNum) + weight;
//    }
//  }

  public enum DispersedTypeEnum {

    ABSOLUTE_FOLLOW_WEIGHT(0, "绝对服从权重", null),//概率服从当前权重与总权重之比
    SIMPLE_ADD_AVERAGE(1, "简单加法离散_平均离散值", null/*, "simpleAdd"*/),
    SIMPLE_ADD(2, "简单加法离散", 100.0d/*, "simpleAdd"*/);

    private Integer type;
    private String desc;
    private Double dispersedNum;
//    private Method method;
    /**
     * 离散量  离散量为空 采取平均权重离散
     */
    DispersedTypeEnum(Integer type, String desc, Double dispersedNum/*, String methodName*/){
      this.type = type;
      this.desc = desc;
      this.dispersedNum = dispersedNum;
//      try {
//        this.method = DispersedMethod.class.getMethod(methodName, double.class, double.class);
//      } catch (NoSuchMethodException e) {
//        e.printStackTrace();
//      }
    }

    public Integer getType() {
      return type;
    }

    public void setType(Integer type) {
      this.type = type;
    }

    public String getDesc() {
      return desc;
    }

    public Double getDispersedNum() {
      return dispersedNum;
    }

    public void setDispersedNum(Double dispersedNum) {
      this.dispersedNum = dispersedNum;
    }

    private static final Map map = new HashMap<>();

    static {
      for (DispersedTypeEnum enums : DispersedTypeEnum.values()) {
        map.put(enums.getType(), enums);
      }
    }

    public static DispersedTypeEnum getEnumValue(int code) {
      return map.get(code);
    }

    public static String getDescByType(int code) {
      return map.get(code).getDesc();
    }

    public static Double getDispersedNumByType(int code) {
      return map.get(code).getDispersedNum();
    }

//    public Method getMethod() {
//      return method;
//    }
//
//    public void setMethod(Method method) {
//      this.method = method;
//    }

    public void setDesc(String desc) {
      this.desc = desc;
    }
  }
  /***
   * 排序工具类
   */
  public static class Sort{
    /**
     * topk + 排序(快排实现)
     * @param source
     * @param k
     * @param 
     * @return
     */
    public staticextends Weight> List topKWithSortByQuickSort(List source, int k) {
      if(source == null || source.isEmpty()){
        return Collections.emptyList();
      }
      int index;
      int rank;
      int start = 0;
      int end = source.size() - 1;
      while (end > start) {
        index = partition(source, start, end);
        rank = index + 1;
        if (rank >= k) {
          end = index - 1;
        } else if ((index - start) > (end - index)) {
          quickSort(source, index + 1, end);
          end = index - 1;
        } else {
          quickSort(source, start, index - 1);
          start = index + 1;
        }
      }
      return source.subList(0, k);
    }

    public static extends Weight> int partition(List lst, int start, int end) {
      T x;
      x = lst.get(start);
      int i = start;
      for (int j = start + 1; j <= end; j++) {
        if (lst.get(j).compareTo(x) > 0) {
          i = i + 1;
          swap(lst, i, j);
        }
      }
      swap(lst, start, i);
      return i;
    }

    public static  extends Weight> void swap(List lst, int p, int q) {
      T temp = lst.get(p);
      lst.set(p, lst.get(q));
      lst.set(q, temp);
    }

    public staticextends Weight> void quickSort(List lst, int start, int end) {
      if (start < end) {
        int index = partition(lst, start, end);
        quickSort(lst, start, index - 1);
        quickSort(lst, index + 1, end);
      }
    }
  }
}

代码中TopK是基于快速排序改造的(不只是取了topk 还把这topk给排了序,网上很多只是找到了第N个位置的数,然后就直接把前N个返回了,前N个并未排序,排序很关键!!!)
效率很满意,嘿嘿嘿~

5.待研究
可以指定不同的离散方式干扰最终结果的分布~ 什么卡方分布,正态分布之类的(数学不行了…很多东西算不出来了…) 或许有更好的方案 有兴趣的同学一起研究~ (数学好的用到了别的离散方式的同学欢迎带带我)
email:[email protected]

你可能感兴趣的:(web开发,算法,算法,带权随机,抽奖)