Pytorch中数据采样方法Sampler(torch.utils.data)(二) —— WeightedRandomSampler & SubsetRandomSampler

WeightedRandomSampler加权随机采样

平衡不平衡数据的抽取

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)

其中__iter__为:

iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

其中

  • weights为index权重,权重越大的取到的概率越高
  • num_samples: 生成的采样长度
  • replacement:是否为有放回取样
  • multinomial: 伯努利随机数生成函数,也就是根据概率设定生成{0,1,…,n}

如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍

import torchvision
from torchvision import transforms
from torch.utils.data import sampler
from torch.utils.data import DataLoader
from torch.utils.data.sampler import *

transform = transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
 
trainset = torchvision.datasets.MNIST(
    root='dataset/',
    train=True,  #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
    download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
    transform=transform
)

## 如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍
weights = [2 if label == 1 else 1 for data, label in trainset]
sampler = WeightedRandomSampler(weights,num_samples=10, replacement=True)
dataloader = DataLoader(trainset, batch_size=16, sampler=sampler)

SubsetRandomSampler索引随机采样

根据index从数据集中抽取这些index对应的图片,然后随机排序

torch.utils.data.SubsetRandomSampler(indices)

其中__iter__为:

(self.indices[i] for i in torch.randperm(len(self.indices)))

其中

  • torch.randperm对数组随机排序
  • indices为给定的下标数组

所以SubsetRandomSampler的功能是在给定一个数据集下标后,对该下标数组随机排序,然后不放回取样
 

如果我要划分train_set和test_set, 那么读进整个数据集来再split比较慢

不如我直接生成train_set的index和test_set的index这样就可以很快了,所以就出现了SubsetRandomSampler

import torchvision
from torchvision import transforms
from torch.utils.data import sampler
from torch.utils.data import DataLoader

transform = transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
 
trainset = torchvision.datasets.MNIST(
    root='dataset/',
    train=True,  #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
    download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
    transform=transform
)
 
testset = torchvision.datasets.MNIST(
    root='dataset/',
    train=False,
    download=True,
    transform=transform
)

split_num = int(len(trainset) * 0.8)
index_list = list(range(len(trainset)))
train_idx, test_idx = index_list[:split_num], index_list[split_num:]

train_sampler = sampler.SubsetRandomSampler(train_idx)
test_sampler = sampler.SubsetRandomSampler(test_idx)

loader_train = DataLoader(trainset, batch_size=100,
                          sampler=train_sampler)

loader_val = DataLoader(testset, batch_size=100,
                        sampler=test_sampler)

你可能感兴趣的:(pytorch,深度学习,python)