weibu的深度学习篇之pytorch(4)——utils.data和torchvision

数据处理工具箱的概述

Pytorch涉及数据处理(数据装载、数据预处理、数据增强等)主要工具包及相关关系如下:

pytorch数据处理工具包概述

torch.utils.data工具包

        1)Dataset:抽象类,其他数据集要继承这个类,包含两个方法__getitem__和__len__。

        2)DataLoader:定义一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)并提供加速功能。

        3)random_split:把数据集随机拆分为给定长度的非重叠的新数据集。

        4)*sampler:多种采样函数

torchvision工具包

安装:pip install torchvision #或conda install torchvision

        1)datasets:提供常用的数据集加载,设计上都是继承torch.utils.data.Dataset,主要包括MNIST、CIFAR10/100、ImageNet和COCO数据集。

       2)models:提供深度学习中经典网络结构,以及训练好的模型,例如AlexNet、VGG、ResNet等。

       3)transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作

       4)utils:两个函数,一个是make_grid,将多个图片拼割在一个网格中,一个是save_img,他将Tensor保存成图片。

utils.data

        utils.data包括Dataset和DataLoader。torch.utils.data.Dataset为抽象类。自定义数据集合要继承这个类,并实现两个函数,一个是__len__,另一个是__getitem__,前者提供了数据大小,后者通过给定的索引获取数据和标签。

       __getitem__一次只能获取一个数据,所以需要通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取。

1.使用Dataset

import torch 
from torch.utils import data
import numpy as np
class TestDataset(data.Dataset):
def __init__(self):
#一些由2维向量表示的数据集
        self.Data = np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]]) 
#这是数据集合对应的标签
        self.Label = np.asarray([0,1,0,1,2])  
    def __getitem__(self,index):
        #numpy 转换成 Tensor
        txt = torch.from_numpy(self.Data[index])
        label = torch.tensor(self.Label[index])
        return txt,label
    def __len__(self):
        return len(self.Data)
#获取数据集中的数据
Test = TestDataset()
#相当于调用__getitem__(2),输出[2,1]
print(Test[2])
print(Test.__len__())

2. 使用DataLoader

其中Dataset只负责数据的抽取,调用一次__getitiem__只返回一次样本。如果希望批量处理(batch),还需要同时进行shuffle和并行加速处理,可以选择Dataloader。DataLoader的格式为:

data.DataLoader(
		dataset,
		bactch_size = 1,
		shuffle = False,
		sample = None,
		bactch_sampler = None,
		num_workers = 0,
		collate_fn = ,
		pin_memory = False,
		drop_last = False,
		timeout = 0,
		worker_init_fn = None,
)
# 主要参数说明:
'''

       dataset:加载数据集
       batch_size:批大小
       shuffle:是否将数据打乱
       sampler:样本抽样
       num_workers:使用多进程加载的进程数,0表示不适应多进程
       collate_in:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
       pin_memory:是否将数据保存在pin_memory区,pin_memory中的数据转换GPU会快一些。
       drop_last:dataset中的数据可能不是bacth_size的整数倍,drop_last为True会将多出来的不足一个bacth的数据丢弃
'''

结合上述程序,给出示例:

test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=0)
for i,traindata in enumerate(test_loader):
    print('i:',i)
    Data,Label=traindata
    print('data:',Data)
    print('Label:',Label)

torchvision

        torchvision有4个功能模块:model、datasets、transforms和utils

transforms

       transforms提供了对PIL Image对象和Tensor对象的常用操作

1. 对PIL Image的常见操作

       Scale/Resize:调整尺寸,长宽比例保持不变

       CenterCrop、RandomCrop、RandomSizeCrop:裁剪图片,CenterCrop和RandomCrop在crop时是固定size,RandomResizeCrop则是random size的crop。

       Pad:填充

       ToTensor:把一个取值范围为[0,255]的PIL Image转换成Tensor。形状为(H,W,C)的Numpy.ndarry转换成形状[C,H,W],取值范围是[0,1.0]的torch.FloatTensor.

       RandomHorizontalFlip:图像的随机水平翻转,翻转概率为0.5

       RandomVerticalFlip:图像随机垂直翻转

       ColorJitter:修改宽高、对比度和饱和度

2.对Tensor的常见操作

       Normalize:标准化,即减均值,除以标准差

       ToPILImage:将Tensor转换为PIL Image。

       如果要对数据进行多个处理,可以使用Compose将这些操作像管道一样拼接起来,类似于nn.Sequential()

       transforms.Compose( [

              #将给定的PIL Image进行中心切割,得到给定的size

              #size可以是tuple, (target_height, targht_width)

              #size可以是一个Integer,在这种情况下,切出来的图片是正方形的。

              transforms.CenterCrop(10),

              #切割中心点的位置随机选取

              transforms.RandomCrop(20,padding = 0),

              #把一个取值范围为[0,255]的PIL Image或者shape为(H,W,C)的numpy.ndarray转换成形状为(C,H,W),取值范围为[0,1]的torch.FloatTensor

              transforms.ToTensor(),

              #规范化到[-1,1]

              transforms.Normalize(mean = (0.5,0.5,0.5),std = (0.5,0.5,0.5))

])

ImageFolder

当文件依据标签处于不同的文件下时,如

——data

       |——zhangliu

       |    |——001.jpg

       |    |——002.jpg

       |——wuhua 

       |      |——001.jpg

       |      |——002.jpg

       我们可以利用torchvision.datasets.ImageFolder来直接构造dataset,代码如下:

       loader = datasets.ImageFold(path)

       loader = data.DataLoader(datasets)

       ImageFolder会将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签自动就是整数序列了。

       e.g:利用ImageFolder读取不同目录下的图片数据,然后使用transforms进行预处理,预处理有多个,使用Compose将这些操作拼接一起,使用DataLoader加载。

        ###其中数据集合trochvision_data可以使用任意的图像建立,并放置在程序运行的文件夹下,命名为torchvision_data

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt 
from torch.utils import data
my_trans=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
train_data = datasets.ImageFolder('./torchvision_data', transform=my_trans)
train_loader = data.DataLoader(train_data,batch_size=8,shuffle=True,)
                                            
for i_batch, img in enumerate(train_loader):
    if i_batch == 0:
        print(img[1])
        fig = plt.figure()
        grid = utils.make_grid(img[0])
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.show()
        utils.save_image(grid,'test01.png')
    break

运行结果:

        weibu的深度学习篇之pytorch(4)——utils.data和torchvision_第1张图片

 

你可能感兴趣的:(Pytorch基础,pytorch,深度学习,目标检测,人工智能)