Pytorch系列入门6----数据处理

系列文章目录


文章目录

  • 系列文章目录
  • 前言
  • 一、神经网络的流程
  • 二、为什么要数据处理
  • 三、拿来主义的公开数据集
  • 三、数据处理的流程
    • 1.制作数据集 Dataset
    • 2.数据增强torchvision.transforms
    • 3.数据批处理DataLoader
  • 总结


前言

相信大家通过前几篇博文的阅读已经能看懂、了解Pytorch中主要API的作用和使用了,今天开始我们另起一锅,来看看Pytorch是如何勾画一副神经网络图。


一、神经网络的流程

需要声明的是,我们所介绍的是神经网络在计算机视觉方面,具体是物体检测方面的应用,故整个流程基本遵从一下几点:

数据处理(主要是图片视频等数据)
模型搭建
损失计算
模型训练
模型部署

二、为什么要数据处理

对深度学习来说,就是从海量的数据中去预测未知数据,因此数据是很重要的,数据处理的好坏也决定了深度学习的上限,因此数据处理是必要的!!!

三、拿来主义的公开数据集

一般数据处理都做成数据集,里面包含数据的image(原图)、label(标签)、annotation(标注)等
公认的三大数据集:ImageNet数据集、PASCAL VOC数据集、COCO数据集
随着自动驾驶领域的快速发展,也出现了众多自动驾驶领域的数据集,如KITTI、Cityscape和Udacity等

三、数据处理的流程

from torchvision import datasets, transforms, models

1.制作数据集 Dataset

from torch.utils.data import Dataset
代码如下(示例):

	# 建立data类,即可以方便地进行数据集的迭代
    class my_data(Dataset):
          def __init__(self, image_path, annotation_path, transform):
              # 初始化,读取数据集
          def __len__(self):
              # 获取数据集的总大小
          def __getitem__(self, id):
              # 对于指定的id,读取该数据并返回
    # 实例化
    dataset = my_data("your image path", "your annotation path",data_transforms) # 实例化该类
    dataset.classes #数据集包含种类名
    #迭代获取每一组数据
    for data in dataset:
        print(data)

2.数据增强torchvision.transforms

注意:我们个人去拿一个具体场景的图片资料训练神经网络,首先一个突出问题就是,数据量不够。较少的数据一、不满足神经网络训练的海量要求。二、不具备代表性。三、数据集中的图片有可能存在大小不一的情况,并且原始图片像素RGB值较大(0~255),这些都不利于神经网络的训练收敛,因此还需要进行一些图像变换工作。为此数据增强是非常必要的

数据增强的方法主要有以下几种:
图片缩放、旋转、遮挡、裁剪、翻转等
transforms.Compose()用来把一系列的增强操作集合起来按顺序执行

data_transforms = transforms.Compose([
        transforms.Resize([96, 96]), #缩放
        transforms.RandomRotation(45),#随机旋转,-4545度之间随机选
        transforms.CenterCrop(64),#从中心开始裁剪
        Cutout(0.4),#随机遮挡的概率
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),#矩阵转换为张量必须的数据转换
        transforms.Normalize([0.535, 0.473, 0.572], [0.189, 0.276, 0.209])#均值,标准差进行归一化
    ])

3.数据批处理DataLoader

DataLoader模块直接读取batch数据
from torch.utils.data import Dataloader

 # 使用Dataloader进一步封装Dataset
 dataloader = Dataloader(dataset, batch_size=10, shuffle=True, num_workers=8)
 #参数含义 (Dataset的实例,批量batch的大小,是否打乱数据参数,使用几个线程来加载数据)

总结

以上就是今天介绍的有关数据处理方面的内容,从当前较为主流的公开数据集,然后介绍数据处理流程,从制作数据集、数据增强、数据批处理3个方面介绍PyTorch中相关的使用方法。数据已准备妥当,搭建合适的模型才能让安静的数据说话,接下来我们一起期待Pytorch系列入门7----网络结构。

你可能感兴趣的:(pytorch,深度学习,计算机视觉)