目录
1.torchvision中加载数据集
2.重写Dataset类加载数据集
3.transforms
4.Dataloader对数据进一步处理
官方文档给出的数据
下面以CIFAR数据集为例子:
torchvision.datasets.CIFAR10(root: str, train: bool = True,
transform: Optional[Callable] = None, target_transform:
Optional[Callable] = None, download: bool = False)
import os
import torch
import numpy as np
from PIL import Image
from torchvision import datasets,transforms
#数据集的预处理
transform=transforms.Compose([
transforms.ToTensor()
])
#下载数据集
root='myDataset/CIFAR10'
#将数据集下载之后保存到root路径下(如果数据集已经存在,则不从网上下载,直接从给定的root路径下获取)
train_data=datasets.CIFAR10(root=root,train=True,transform=transform,download=True)
test_data=datasets.CIFAR10(root=root,train=False,transform=transform,download=True)
print('trainSize: {}'.format(len(train_data)))
print('testSize: {}'.format(len(test_data)))
#下载数据集
root='myDataset/CIFAR10'
#将数据集下载之后保存到root路径下(如果数据集已经存在,则不从网上下载,直接从给定的root路径下获取)
test_data=datasets.CIFAR10(root=root,train=False,download=True)
#显示图片的类别
print('图片包含类别: {}'.format(train_data.classes))
#显示图片
imgOne,target=test_data[0]
imgOne.show()
#查看图片所属类别
print('class: {}'.format(test_data.classes[target]))
官方文档torch.utls.data.Dataset
import os
import pathlib
from PIL import Image
from torch.utils.data import Dataset
class myDataset(Dataset):
def __init__(self,img_path):
self.data_dir=pathlib.Path(img_path)
self.dataset=list(self.data_dir.glob('*/*.jpg'))
#根据索引index获取数据,index是根据数据集的顺序来获取的
def __getitem__(self, index):
img=self.dataset[index]
imgTo=Image.open(img)
return imgTo
#获取数据集的大小
def __len__(self):
#统计flower_photos文件夹下面所有的图片数据集数量
self.len=len(list(self.data_dir.glob('*/*jpg')))
return self.len
if __name__ == '__main__':
mydataset=myDataset(img_path=r'E:\myDataset\flower_photos')
print('dataSize: {}'.format(len(mydataset)))
#获取第1张图片数据
img=mydataset[0]
print('imgsize: {}'.format(img.size))
#显示图片
img.show('img')
打开transforms.py文件可以看到其中包含的对数据的处理方法:(关于这些方法要使用的时候都可以直接查询)
torchvision官网查看功能:
https://pytorch.org/vision/stable/transforms.html
transforms.compose使用
#compose中包含一个数组,数组中包含的是对图片数据集进行处理的过程
#比如下面,首先对一张图片进行中心的裁剪,其次将PIL数据类型转换为Tensor数据类型,
#最后是将数据转换为浮点类型
transf=transforms.Compose([
transforms.CenterCrop(10),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float)
])
print('imgSize: {}'.format(img.size))
print(type(transf(img)))
如果读者不使用.compose的话,也可以使用下面的方法一步一步的进行数据转化:
import os
import torch
from PIL import Image
from torchvision import transforms,datasets
img_path="myDataset/flower_photos/daisy/5547758_eea9edfd54_n.jpg"
#读取图片数据
img=Image.open(img_path)
#显示图片大小
print('imgSize: {}'.format(img.size))
#第一步:对图片数据进行中心裁剪
centerCut=transforms.CenterCrop(100)
img_cut=centerCut(img)
img_cut.show('img_cut')
print('imgCutSize: {}'.format(img_cut.size))
#第二步:将图片数据集转换为Tensor
ToTensor=transforms.ToTensor()
img_ToTensor=ToTensor(img_cut)
print(type(img_ToTensor))
#第三步:将Tensor数据转换为浮点类型
FloatData=transforms.ConvertImageDtype(dtype=torch.float)
img_Float=FloatData(img_ToTensor)
print(type(img_Float))
关于上面一些自己比较常用的一些方法 :但是读者应该注意的是,在Compose中使用这些方法时,对数据的处理先后顺序注意,因为有些方法要传入的是Tensor数据类型,所以将数据转换为Tensor类型方法可能得放在其他方法的前面,注意报错的问题所在。
transform=transforms.Compose([
transforms.Resize(size=[224,224]),
transforms.CenterCrop(100),
transforms.ToTensor(),
#output[channel] = (input[channel] - mean[channel]) / std[channel],由于图片是三通道的,所以平均值和方差都是分别给出三个值
transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]),
#p表示水平翻转的概率
transforms.RandomHorizontalFlip(p=0.5),
#垂直翻转
transforms.RandomVerticalFlip(p=0.5),
#随机旋转,degrees表示旋转度数,center表示旋转中心坐标,还有其他的参数可以自行选择
transforms.RandomRotation(degrees=45,center=[50,50],)
])
官网解释:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
以下给出的是一些常见设置参数:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, num_workers=0)
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
#数据集的预处理
transform=transforms.Compose([
transforms.ToTensor()
])
#下载数据集
root='myDataset/CIFAR10'
#将数据集下载之后保存到root路径下(如果数据集已经存在,则不从网上下载,直接从给定的root路径下获取)
train_data=datasets.CIFAR10(root=root,train=True,transform=transform,download=True)
test_data=datasets.CIFAR10(root=root,train=False,transform=transform,download=True)
#获取图片类别
classes=train_data.classes
#加载数据集
train_loader=DataLoader(dataset=train_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
test_loader=DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
for data in train_loader:
imgs,targets=data
print('imgs: {}'.format(imgs.shape))
print('target: {}'.format(targets))
#打印前四张打包的图片类别
for stop,i in enumerate(targets):
print('target[{}]---->{}'.format(i,classes[i]))