Pytorch:使用官网提供数据集的相关参数设置,以CIFAR10为例进行说明

文章目录

  • 前言
  • 一、Dataset
    • 定义-组成
    • 分类
  • 二、获取数据集
    • 1.参数说明
    • 2.相关Demo


前言

本文记录笔者关于Dataset的相关学习记录,以Pytorch官网文档为主进行学习

一、Dataset

定义-组成

所谓Dataset,指的是我们在学习神经网络中要接触的数据集,一般由原始数据,标注Label及相关索引构成
这里笔者给出基于自己的理解所进行的论述,比方说,我们要训练一个识别猫和狗的神经网络,我们获取到的原始图像数据集组成大致如下

  • 一张狗的照片->二进制图像,可以理解为二维数组
  • 照片的Label->区分这张图片到底是猫还是狗,一般为int值,对应0或者1
  • 相关索引->一维数组或者二维数组的下标,根据下标Index获取到数据集的数据

分类

数据集按照其功能用途分,可分为训练集(train set)、测试集(validation set)和验证集(test set)

  • 训练集:训练集用来训练模型,即确定模型的权重和偏置这些参数,通常我们称这些参数为学习参数。
  • 测试集:而验证集用于模型的选择,更具体地来说,验证集并不参与学习参数的确定,也就是验证集并没有参与梯度下降的过程。验证集只是为了选择超参数,比如层数、网络节点数、迭代次数、学习率这些都叫超参数。比如在k-NN算法中,k值就是一个超参数。所以可以使用验证集来求出误差率最小的k。换句话说,笔者理解验证集并不参与反向传播的过程,也就是不会对网络的参数产生直接的负反馈指导,只对网络的结构产生参考影响。
  • 验证集:测试集只使用一次,即在训练完成后评价最终的模型时使用。它既不参与学习参数过程,也不参数超参数选择过程,而仅仅使用于模型的评价。值得注意的是,千万不能在训练过程中使用测试集,而后再用相同的测试集去测试模型。这样做其实是一个cheat,使得模型测试时准确率很高。

二、获取数据集

在使用Pytorch框架进行神经网络开发过程中,在引入torch包后,我们可以使用官网为我们提供的数据集相关命令进行数据集的获取和下载

1.参数说明

这里以CIFAR10为例,进行说明
Pytorch:使用官网提供数据集的相关参数设置,以CIFAR10为例进行说明_第1张图片

root:表示dataset所存放的路径
train:表示训练集的种类划分,如果设置为true,则是训练集,如果为false,则为验证集
transform:表示对数据集中的数据进行处理获取到的变化后的数据,一般为ToTensor()
target_transform:一个函数,输入为target,输出对其的转换。例子,输入的是图片标注的string,输出为word的索引。
download:表示是否要进行下载,设置为True的时候,如果数据集不在本地,则会从云端对数据集进行下载

2.相关Demo

获取CIFAT10相关数据集的代码如下

import torchvision.datasets

train_data=torchvision.datasets.CIFAR10("../data",train=True,transform=torchvision.transforms.ToTensor(),
                                        download=True)

test_data=torchvision.datasets.CIFAR10("../data",train=False,transform=torchvision.transforms.ToTensor(),
                                        download=True)

# 数据集长度
train_lens=len(train_data)
test_lens=len(test_data)

print(train_lens)
print(test_lens)

运行效果
Pytorch:使用官网提供数据集的相关参数设置,以CIFAR10为例进行说明_第2张图片

你可能感兴趣的:(#,Python,#,神经网络,pytorch,深度学习)