写在前边:最近在博客上遇到了个不错的博主,通过读他的博客自己做一些笔记,与大家分享
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html这个是pytorch 的官方教程
数据加载和处理——>模型定义(包括损失函数的选择)——>训练——>测试
主要使用的是torchvision 中的一些模块
torchvision.datasets:加载一些深度学习比较常用的数据集,比如Mnist、Imagenet、Cifar10
torchvision.transforms:书要是对数据的处理
torchvision.models:一些模型
torchvision.utils:一些工具
import torch
import torchvision
import torchvision.transforms as transforms
#进行数据格式的转换
#因为通过 torchvision.datasets 读进来的数据是(0,1)之间的PILImage 的格式的数据
#所以要将数据进行归一化转成(-1,1)之间,因为transforms.Normalize处理的是tensor 数据,所以要先将数据转换为tensor
#Normalize 的第一个参数代表均值,第二个参数代表方差,每个参数 有三个代表 R/G/B 三个通道
transform=transforms.Compose(
[transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)
#进行数据的下载
trainset=torchvision.datasets.CIFAR10(root="./data",train=True,download=True,transform=transform)
#批训练数据
batch_size:每次feed到神经网络的数量
shuffle:是否乱序
num_works:表示的是进程的数目,进程的数目越多数据加载的越快,但是消耗CPU的资源越多
trainloader=torch.utils.data.Dataloader(trainset,batch_size=4,shuffle=True,num_workers=2)
testset=torchvision.datasets.CIFAR10(root="./data",train=False,download=True,transform=transform)
testloader=torch.utils.data.Dataloader(testset,batch_size=4,shuffle=True,num_workers=2)
#某些情况下 如果已经下载好数据集可以使用下边的方法直接打开
#dataset=torchvision.datasets.ImageFolder("datase_path",transform=transform)
#dataloader=torch.utils.data.DataLoader(dataset,batch_size=4,shuffle=True,num_workers=2)
#cifar10的10类
classes=("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")
加油,当你看到我的博客的时候,相信你一定在努力成为最好的自己的路上!
加油,别放弃,坚持,每天一点点,相信一年后的你一定会感谢现在的你!
如果你真的迷茫了,我愿当你的倾听者,但是你千万不能放弃,因为改变命运的机会真的不多呀!
QQ小号:1817780086