PyTorch(一)——数据处理

目录连接
(1) 数据处理
(2) 搭建和自定义网络
(3) 使用训练好的模型测试自己图片
(4) 视频数据的处理
(5) PyTorch源码修改之增加ConvLSTM层
(6) 梯度反向传递(BackPropogate)的理解
(7) 模型的训练和测试、保存和加载
(8) pyTorch-To-Caffe
(总) PyTorch遇到令人迷人的BUG

PyTorch学习和使用(一)

PyTorch的安装比caffe容易太多了,一次就成功了,具体安装多的就不说了,PyTorch官方讲的很详细,还有PyTorch官方(中文)中文版本。
PyTorch的使用也比较简单,具体教程可以看Deep Learning with PyTorch: A 60 Minute Blitz, 讲的通俗易懂。

要使学会用一个框架,只会运行其测试实验是不行的,所以现在打算把caffe中的Siamese模型使用PyTorch实现,来巩固自己对PyTroch的熟练使用。

数据预处理

首先是数据处理这一块,PyTorch使用了torchvision来完成数据的处理,其只实现了一些数据集的处理,如果处理自己的工程则需要修改增加内容。

把原始数据处理为模型使用的数据需要3步:transforms.Compose() torchvision.datasets torch.utils.data.DataLoader()分别可以理解为数据处理格式的定义、数据处理和数据加载。

Compose() 代码中给出的解释是Composes several transforms together. 就是通过Compose把一些对图像处理的方法集中起来。比如先中心化,然后转换为张量(PyTorch的数据结构),其代码为:transforms.Compose([transform.CenterCrop(10), transofrms.ToTensor()])又比如先转换为张量,然后正则化,代码为:`transforms.Compose([transofrms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), 其具体的参数调用在源码中可以看到,在此不多说了。还要注意的是Compose的代码是:

def __call_(self, img):
    for t in self.transforms:
        img = t(img)
    return img

这就是把输入到Compose的操作按顺序进行执行。先执行第一个,然后第二个……。如果需要处理自己的数据,可以把具体的操纵放在这个类中实现。

torchvision.datasets ;里实现了不同针对数据集的处理方法,主要用来加载数据和处理数据。比如在mnist.py 和cifar.py 中用来处理mnist和cifar数据集。类的实现需要继承父类data.Dataset,其主要方法有2个:

  • __init__(self, root, train=Ture, transform=None, traget_transform=None, download=False):该方法用来初始化类和对数据进行加载(有时需要定义一些开关来防止重复处理)。数据的加载就是针对不同的数据,把其data和label(分为训练数据和测试数据)读入到内存中。

  • --getitem__(self, index):该方法是把读入的输出传给PyTorch(迭代器的方式)。**注意:**上面定义的transform.Compose在次数进行调用,通过index确定需要访问的数据,然后对其格式进行转换,最后返回处理后的数据。也就是说数据在定义时只是定义了一个类,其具体的数据传出在需要使用时使用该方法完成。

至此,对数据进行加载,然后处理传给PyTorch已经完成,如果需要对自己的数据进行处理,也是通过修改和增加此部分完成。接下来需要对训练的数据进行处理,比如分批次的大小,十分随机处理等等。

torch.utils.data.DataLoader() Data loder, Combines a dataset and and a sampler, and provides single, or multi-process iterators over the dataset. 就是把合成数据并且提供迭代访问。输入参数有:

  • dataset(Dataset)。输入加载的数据,就是上面的torchvision.datasets.myData()的实现,所以需要继承data.Dataset,满足此接口。

  • **batch-size, shuffle, sampler, num_workers, collate_fn, pin_memory, drop_last.**这些参数比较好理解,看名字就知道其作用了。分别为:

  1. batch-size。样本每个batch的大小,默认为1。
  2. shuffle。是否打乱数据,默认为False。
  3. sampler。定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。
  4. num_workers。数据分为几批处理(对于大数据)。
  5. collate_fn。整理数据,把每个batch数据整理为tensor。(一般使用默认调用default_collate(batch))。
  6. pin_memory。针对不同类型的batch进行处理。比如为Map或者Squence等类型,需要处理为tensor类型。
  7. drop_last。用于处理最后一个batch的数据。因为最后一个可能不能够被整除,如果设置为True,则舍弃最后一个,为False则保留最后一个,但是最后一个可能很小。

迭代器(DataLoaderIter)的具体处理就是根据这些参数的设置,分别进行不同的处理。

补充2017/8/10:

torch.utils.data.DataLoader类主要使用torch.utils.data.sampler实现,sampler是所有采样器的基础类,提供了迭代器的迭代(__iter__)和长度(__len__)接口实现,同时sampler也是通过索引对数据进行洗牌(shuffle)等操作。因此,如果DataLoader不适用于你的数据,需要重新设计数据的分批次,可以充分使用所提供的smapler

你可能感兴趣的:(PyTorch)