pytorch WeightRandomSampler要提供两个参数

朋友torch,单独提供了一个sampler模块,用来对数据进行采样,常用的有随机采样器randomsampler,当shuffle的参数为true,系统自动调用这个采样器,实现打乱数据。默认的是sequential sampler,他会按照顺序一个一个进行采样,还有一个WeightRandomSampler,他会根据每个样本的权重选取数据,在样本比例不均衡问题中,可以用它进行重采样。

WeightRandomSampler要提供两个参数,每个样本的权重weights,共选取的样本总数num_samples,以及一个可选参数replacement,。权重越大的样本被选中的概率越大,待选取的 样本数一般小于全部的样本总数。replacement用于指定是否可以重复选取某一个样本,默认为true,即允许一个epoch中重复采样某一个数据。

replacement为true,会覆盖dataset的实际大小,即一个epoch返回的图片总数取决于sampler.num_samples

 

#在数据处理中,
from  torch.utils.data.sampler import   WeightedRandomSampler
#狗的图片被取出的概率是猫的两倍
#两类图片被取出的概率与weights的绝对大小无关,至于比值有关
wights=[2 if label==1 else 1 for data, label in datasets]
wights=[2,2,1,1,1,1,2,2]
sampler=WeightedRandomSampler(wights,num_samples=9,replacement=True)

 

你可能感兴趣的:(深度学习,pytorch)