pytorch对数据集进行重新采样

背景
当不同类型数据的数量差别巨大的时候,比如猫有200张训练图片,而狗有2000张,很容易出现模型只能学到狗的特征,导致准确率无法提升的情况。

这时候,一种可行的方法就是对原始数据集进行采样,从而生成猫、狗图片数量接近的新数据集。这个新数据集中可能猫、狗图片都各有500张,其中猫的图片有一部分重复的,而狗的2000张图片中有一部分没有被采样到,但是这时候新数据集的数据分布是均衡的,就可以比较好的训练了。

操作方法
我们知道pytorch训练一般都是用的DataLoader加载数据的,我们可以通过给Dataloader传入一个sampler的采样器进行采样操作。

train_loader = DataLoader( train_dataset, batch_size=256, num_workers=2, sampler=sampler)

采样器sampler有多种,大家可以根据自己需要研究一下,这里我们使用一个按权重采样的WeightedRandomSampler。其作用是:我们可以人为的给每张图片定一个被抽取到的概率,一般每一类的所有图片的概率可以一样,然后就按每个图片的这个概率对整个数据集进行重新采样。

比如:猫只有200张图片,我们设置取到每张猫的图片的概率为1/200,而狗有2000张图片,我们设置取到每张狗的概率为1/2000。这样虽然狗的图片比较多,但我们取到猫和狗的概率是一样的,只是猫会有一些重复,而狗有一些不会取到,最终形成的新数据集就平衡了。

参考 https://blog.csdn.net/tyfwin/article/details/108435756

pytorch对数据集进行重新采样_第1张图片
注意
注意上图的replacement参数,为True表示有放回的采样,也就是我们上边说的那种采样,有部分数据重复,有部分数据没有出现;为False表示不放回的采样,即采样后的数据集跟原来一样,只是内部数据的顺序有些变化,概率大的可能会在前边,这主要作用于有序的数据。num_samples定义采样的次数,也即采样后的数据集数目,一般设为跟原来一样。

代码

参考 https://www.cnblogs.com/king-lps/p/11004653.html

# 定义每个类别采样的权重,这个只做参考,可以根据自己需要随便定义
target = train_dataset.targets
class_sample_count = np.array([len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
# 读取原始数据为datasets对象                                                                              
dataset_train = datasets.ImageFolder(traindir)                                                                                                         

# 在DataLoader的时候传入采样器即可                                                                                
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, sampler = sampler) 

你可能感兴趣的:(pytorch使用,pytorch)