Pytorch库中有许多与深度学习有关的代码块,在进行学习时可以直接调用,十分有利于新手学习和使用。本次深度学习我就是采用pytorch库进行变成实现对CIFAR10数据集的分类处理
直接上python代码(编译器为jupyter)
import torch
import torchvision
import torchvision.transforms as transforms
transform1 = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root=r'C:\Users\dell\Desktop\Python', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
shuffle=True, num_workers=1)
testset = torchvision.datasets.CIFAR10(root='test_batch', train=False,
download=True, transform=transform1)
testloader = torch.utils.data.DataLoader(testset, batch_size=50,
shuffle=False, num_workers=1)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
transforms.ToTensor();
ToTensor()将shape为(H, W, C)的nump.ndarray或img转为shape为(C, H, W)的tensor,其将每一个数值归一化到[0,1],其归一化方法比较简单,直接除以255即可。
transforms.Normalize
class torchvision.transforms.Normalize(mean, std)
给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化。即:Normalized_image=(image-mean)/std。
归一化是为了加快训练网络的收敛性,可以不进行归一化处理;
归一化的具体作用是归纳统一样本的统计分布性。归一化在0-1之间是统计的概率分布,归一化在-1–+1之间是统计的坐标分布。归一化有同一、统一和合一的意思。无论是为了建模还是为了计算,首先基本度量单位要同一,神经网络是以样本在事件中的统计分别几率来进行训练(概率计算)和预测的,归一化是同一在0-1之间的统计概率分布;当所有样本的输入信号都为正值时,与第一隐含层神经元相连的权值只能同时增加或减小,从而导致学习速度很慢。为了避免出现这种情况,加快网络学习速度,可以对输入信号进行归一化,使得所有样本的输入信号其均值接近于0或与其均方差相比很小。
torchvision.datasets.CIFAR10(root=r'C:\Users\dell\Desktop\Python', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
shuffle=True, num_workers=1)
pytorch中自带加载CIFAR10数据集的模块使用torchvision.datasets.CIFAR10加载CIFAR10数据集root是保存数据集的路径,train=true为训练集download为true要进行联网下载数据集如果路径文件夹中含有已经下载好的数据集则不用下载,transform为归一化处理在前面已经操作过了。
trainloadar作为一个容器在程序运行时装载数据集中的数据,trainset作为训练集,batch_size = 10为minibatch的数据量为100,shuffle = True 表明提取数据时,随机打乱顺序,因为我们都是基于随机梯度下降的方式进行训练优化,但测试的时候因为不需要更新参数,所以就无须打乱顺序了。
num_workers = 2 指定了工作线程的数量。
接下来的testset和testlodar为测试集。
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
classes规定数据集中的种类名称