水塘抽样(应用场景+算法步骤+算法证明+Python实现)

应用场景

在以流形式获取 n n n 个元素中,随机抽取 k k k 个样本,满足 k ≤ n k \le n kn(通常有 k ≪ n k \ll n kn)要求:

  • 每个元素有相同的概率被选中;
  • 时间复杂度不超过 O ( n ) O(n) O(n),即遍历一次;空间复杂度不超过 O ( k ) O(k) O(k),即要抽取的样本数。

算法步骤

  • 初始化长度为 k k k 的缓存列表 L L L
  • 遍历第 i i i 个元素:
    • 如果 i ≤ k i \le k ik,则将该元素存入缓存列表;
    • 如果 i > k i>k i>k,则在 [ 0 , 1 ] [0,1] [0,1] 中取随机数 t t t,若 t < k i t < \frac{k}{i} t<ik,则保留第 i i i 个元素(保留概率为 k i \frac{k}{i} ik),并在缓存列表 L L L 中随机丢掉一个元素(每个元素被丢掉的概率为 1 i \frac{1}{i} i1);
  • 遍历完所有元素后,缓存列表 L L L 即为抽取的样本。

算法证明

1. 每个元素有相同的概率被选中

证明 使用数学归纳法。当 n = k n = k n=k 时,每个元素被选中的概率均为 1 1 1,结论显然成立。假设当 n = m n=m n=m m > = k m >= k m>=k)时结论成立,只需证明当 n = m + 1 n = m + 1 n=m+1 时结论仍成立。

因为当 n = m n = m n=m 时结论成立,所以在遍历第 m + 1 m+1 m+1 个元素前,前 m m m 个元素被选中的概率相同。因为前 m m m 个元素被选中的概率相同,所以每个元素被选中的概率为
P m ( i ) = k ÷ m = k m ( i = 1 , 2 , ⋯   , m ) (1) P_m(i) = k \div m = \frac{k}{m} \hspace{1em} (i = 1,2,\cdots,m) \tag{1} Pm(i)=k÷m=mk(i=1,2,,m)(1)
在遍历第 m + 1 m+1 m+1 个元素时,有 k m + 1 \frac{k}{m+1} m+1k 的概率会保留第 m + 1 m+1 m+1 个元素,所以第 m + 1 m+1 m+1 个元素被选中的概率为
P m + 1 ( m + 1 ) = k m + 1 (2) P_{m+1}(m+1) = \frac{k}{m+1} \tag{2} Pm+1(m+1)=m+1k(2)
在遍历第 m + 1 m+1 m+1 个元素时,有 k m + 1 \frac{k}{m+1} m+1k 的概率会在前 m m m 个元素选出的 k k k 个样本中随机丢掉一个,其中每个样本被丢掉的概率为
k m + 1 × 1 k = 1 m + 1 \frac{k}{m+1} \times \frac{1}{k} = \frac{1}{m+1} m+1k×k1=m+11
所以,将式 ( 1 ) (1) (1) 代入可得,在遍历第 m + 1 m+1 m+1 个元素后,前 m m m 个元素被选中的概率为
P m + 1 ( i ) = P m ( i ) × ( 1 − 1 m + 1 ) = k m × m m + 1 = k m + 1 ( i = 1 , 2 , ⋯   , m + 1 ) (3) P_{m+1}(i) = P_m(i) \times (1 - \frac{1}{m+1}) = \frac{k}{m} \times \frac{m}{m+1} = \frac{k}{m+1} \hspace{1em} (i = 1,2,\cdots,m+1) \tag{3} Pm+1(i)=Pm(i)×(1m+11)=mk×m+1m=m+1k(i=1,2,,m+1)(3)
综上所述,根据式 ( 2 ) (2) (2) 和式 ( 3 ) (3) (3) 可知,若当 n = m n=m n=m m > = k m >= k m>=k)时结论成立,则当 n = m + 1 n = m+1 n=m+1 时,前 m m m 个元素被选中的概率与第 m + 1 m+1 m+1 个元素被选中的概率相同,均为 k m + 1 \frac{k}{m+1} m+1k。得证。

2. 时间复杂度不超过 O ( n ) O(n) O(n),空间复杂度不超过 O ( k ) O(k) O(k)

根据算法步骤显然可知:

  • 时间复杂度为 O ( n ) O(n) O(n);遍历了所有元素;
  • 空间复杂度为 O ( k ) O(k) O(k),为缓存列表 L L L 的长度为 k k k

满足要求。

代码实现

代码满足 PEP-0008、PEP-0484、pylint 规范;使用 numpy 注释文档规范。

import random
from typing import Any, Iterator, List


def reservoir_sampling(population: Iterator[Any], n_sample: int):
    """水塘抽样

    Parameters
    ----------
    population : Iterator[Any]
        所有元素的可迭代对象(流状态)
    n_sample : int
        需要选取的样本数

    Returns
    -------
    sample : List[Any]
        选中的样本列表
    """
    cache: List[Any] = []  # 缓存列表
    i = 0
    for elem in population:
        if i < n_sample:
            cache.append(elem)
        elif random.random() < n_sample / i:
            cache[random.randint(0, n_sample - 1)] = elem
        i += 1
    return cache

测试用例:

reservoir_sampling(range(5), 5)  # [0, 1, 2, 3, 4]
reservoir_sampling(range(1000), 5)  # [385, 921, 590, 492, 243]

你可能感兴趣的:(Python,python,抽样,水塘抽样)