pytorch基础知识整理(二)数据加载

pytorch数据加载组件位于torch.utils.data中。

from torch.utils.data import DataLoader, Dataset, Sampler

1, torch.utils.data.DataLoader

pytorch提供的数据加载器,它返回一个可迭代对象。不使用这个DataLoader,直接手动把每batch数据导入显存当然也可以,但是DataLoader类可以使跑模型和加载数据并行进行,效率高且更加灵活,所以通常都应该用DataLoader来加载数据。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, 
			collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

基本上从字面就能看懂各参数的含义。其中num_workers是指开多进程加载数据,似乎在windows上不支持大于0的数字。pin_memory是指是否把数据固定到内存中。
然后再用分别分装数据集并用DataLoader调用:

train_dataset = DealDataset(mode='train')
test_dataset = DealDataset(mode='test')
kwargs = {"num_workers": 0, "pin_memory": True}
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size, **kwargs)
test_loader = DataLoader(dataset=test_dataset, shuffle=True, batch_size=batch_size, **kwargs)

训练或推理时从DataLoader中取数据的方法一般如下:

for i in range(epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        if use_gpu:
            data, target = data.cuda(), target.cuda()

建议在训练模型前,先分别运行一次仅加载数据不跑模型的过程和仅跑模型不加载数据的过程,分别记录两个过程的时间以评估数据加载过程的耗时在训练过程中的比例,并据此考虑是否采取更复杂的措施提高数据加载速度。

2 torch.utils.data.Dataset

必须要先把数据构造成dataset类型才能被DataLoader调用,支持两种类型,一种是匹配型Dataset类,也就是其中定义了__getitem__()和__len__()方法,这种比较常用;另一种是迭代型IterabelDataset类,也就是其中定义了__iter__()方法的。

2.1 torch.utils.data.TensorDataset

把tensor直接包装成dataset,通常数据不需要处理可直接用,且数据量不是太大的情况下使用。
注意:新版本pytorch中data_tensor, target_tensor两个参数名已取消,直接放数据就可以,再指名参数名会报错。

# myDataset = TensorDataset(data_tensor=x_tensor, target_tensor=y_tensor) 报错
myDataset = TensorDataset(x_tensor, y_tensor)

2.2 torch.utils.data.Dataset

封装dataset的基本类,可实现各种情况下非常灵活的数据集加载,使用时需要重写它的__getitem__和__len__方法。

class DealDataset(Dataset):
    def __init__(self,mode='train'):
        X, y, Xt, yt = get_data()
        if mode=='train':
            self.x_data = X
            self.y_data = y
        elif mode=='test':
            self.x_data = Xt
            self.y_data = yt
        self.len = self.x_data.shape[0]
    
    def __getitem__(self, index):
        data = self.x_data[index]
        target = self.y_data[index]
        return data, target

    def __len__(self):
        return self.len

在完成一项较大的建模工程时,通常需要试验各种各样的数据处理方案,因此数据加载方案要被大量修改,为了便于修改的灵活性、代码的整洁性和避免修改的版本混乱,可以先使用一个BaseDataset基类确定肯定不会变的文件路径等内容,再使用子类继承来获得各种版本的数据处理方案。

3 torch.utils.data.Sampler

通常sampler不是必须的,但使用sampler可以更灵活的定义采样次序,可以使用SequentialSampler顺序采样;RandomSampler随机采样(有放回或无放回);WeightedRandomSampler按权重随机采样;BatchSampler在一个batch中封装一个其他的采样器,返回一个batch大小的index索引。
也可以通过重写Sampler类或其他子类中的__iter__()方法实现更灵活的自定义采样器。

4, torchvision.transforms

对图像类数据进行处理时经常用到trochvision.transforms
.Compose(transforms)用来把多种变换组合起来

img_transforms = transforms.Compose([transforms.Resize((224,224)),
									transforms.ToTensor(),
                                    transforms.Normalize((0.485, 0.456, 0.406),
                                                          (0.229, 0.224, 0.225))])
img = img_transforms(img)
#注:该标准化系数为ImageNet用的系数

各种变换有:
.CenterCrop(size) 中心切割
.Resize((224,224)) 尺寸变换
.RandomCrop(size, padding=0) 随机中心点切割
.RandomHorizontalFlip() 随机水平翻转
.RandomSizedCrop(size, interpolation=2) 随机大小切割,然后再resize到size大小
.Pad(padding, fill=0) 四周pad
.Normalize(mean, std) 标准化
.ToTensor() 将PIL.Image或np.ndarray转换为tensor
.Lambda(lambd) 函数式自定义变换

你可能感兴趣的:(随笔·各种知识点整理,pytorch,深度学习)