在以流形式获取 n n n 个元素中,随机抽取 k k k 个样本,满足 k ≤ n k \le n k≤n(通常有 k ≪ n k \ll n k≪n)要求:
证明 使用数学归纳法。当 n = k n = k n=k 时,每个元素被选中的概率均为 1 1 1,结论显然成立。假设当 n = m n=m n=m( m > = k m >= k m>=k)时结论成立,只需证明当 n = m + 1 n = m + 1 n=m+1 时结论仍成立。
因为当 n = m n = m n=m 时结论成立,所以在遍历第 m + 1 m+1 m+1 个元素前,前 m m m 个元素被选中的概率相同。因为前 m m m 个元素被选中的概率相同,所以每个元素被选中的概率为
P m ( i ) = k ÷ m = k m ( i = 1 , 2 , ⋯ , m ) (1) P_m(i) = k \div m = \frac{k}{m} \hspace{1em} (i = 1,2,\cdots,m) \tag{1} Pm(i)=k÷m=mk(i=1,2,⋯,m)(1)
在遍历第 m + 1 m+1 m+1 个元素时,有 k m + 1 \frac{k}{m+1} m+1k 的概率会保留第 m + 1 m+1 m+1 个元素,所以第 m + 1 m+1 m+1 个元素被选中的概率为
P m + 1 ( m + 1 ) = k m + 1 (2) P_{m+1}(m+1) = \frac{k}{m+1} \tag{2} Pm+1(m+1)=m+1k(2)
在遍历第 m + 1 m+1 m+1 个元素时,有 k m + 1 \frac{k}{m+1} m+1k 的概率会在前 m m m 个元素选出的 k k k 个样本中随机丢掉一个,其中每个样本被丢掉的概率为
k m + 1 × 1 k = 1 m + 1 \frac{k}{m+1} \times \frac{1}{k} = \frac{1}{m+1} m+1k×k1=m+11
所以,将式 ( 1 ) (1) (1) 代入可得,在遍历第 m + 1 m+1 m+1 个元素后,前 m m m 个元素被选中的概率为
P m + 1 ( i ) = P m ( i ) × ( 1 − 1 m + 1 ) = k m × m m + 1 = k m + 1 ( i = 1 , 2 , ⋯ , m + 1 ) (3) P_{m+1}(i) = P_m(i) \times (1 - \frac{1}{m+1}) = \frac{k}{m} \times \frac{m}{m+1} = \frac{k}{m+1} \hspace{1em} (i = 1,2,\cdots,m+1) \tag{3} Pm+1(i)=Pm(i)×(1−m+11)=mk×m+1m=m+1k(i=1,2,⋯,m+1)(3)
综上所述,根据式 ( 2 ) (2) (2) 和式 ( 3 ) (3) (3) 可知,若当 n = m n=m n=m( m > = k m >= k m>=k)时结论成立,则当 n = m + 1 n = m+1 n=m+1 时,前 m m m 个元素被选中的概率与第 m + 1 m+1 m+1 个元素被选中的概率相同,均为 k m + 1 \frac{k}{m+1} m+1k。得证。
根据算法步骤显然可知:
满足要求。
代码满足 PEP-0008、PEP-0484、pylint 规范;使用 numpy 注释文档规范。
import random
from typing import Any, Iterator, List
def reservoir_sampling(population: Iterator[Any], n_sample: int):
"""水塘抽样
Parameters
----------
population : Iterator[Any]
所有元素的可迭代对象(流状态)
n_sample : int
需要选取的样本数
Returns
-------
sample : List[Any]
选中的样本列表
"""
cache: List[Any] = [] # 缓存列表
i = 0
for elem in population:
if i < n_sample:
cache.append(elem)
elif random.random() < n_sample / i:
cache[random.randint(0, n_sample - 1)] = elem
i += 1
return cache
测试用例:
reservoir_sampling(range(5), 5) # [0, 1, 2, 3, 4]
reservoir_sampling(range(1000), 5) # [385, 921, 590, 492, 243]