对于Pytorch自带的数据集,只需要调用torchvision.datasets.XXXX()即可,例如想要读取CIFAR10数据集:torchvision.datasets.CIFAR10()
'''导入读取图片数据所需要的工具包'''
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data
在这里先介绍一下torchvision.transforms
模块,这个模块对于图像数据来说是非常重要的,它提供了常用的一些数据增强的方法(裁剪、翻转、旋转、尺寸变换等)。
裁剪(Crop)
裁剪主要分为中心裁剪、随机裁剪等,我自己主要常用的有:中心裁剪transforms.CenterCrop(size)
,随机裁剪:transforms.RondomCrop(size)
,随机长宽比裁剪:transforms.RandomResizedCrop(size)
。
代码 | 作用 |
---|---|
transforms.CenterCrop(size) |
从图像中间裁剪出尺寸为size的图片 |
transforms.RondomCrop(size) |
从图像中随机裁剪出尺寸为size的图片 |
transforms.RandomResizedCrop(size) |
随机大小、长宽比裁剪出尺寸为size的图片 |
其中size就是输入图像通过transforms变换后的尺寸(通常也就是模型输入图片的尺寸),当然还有其他的参数,但是我用的比较少一般都是默认,一般只需要修改尺寸就行,size如果是一个int值(size=n),则最后裁剪得到的是一个正方形图像(尺寸为n×n);size如果是一个数对(h,w),则最后裁剪得到的图像为一个矩形(尺寸为h×w)
翻转、旋转(Flip and Rotation)
自己用的比较多的有:依概率p水平翻转:transforms.RandomHorizontalFlip(p)
依概率p垂直翻转:transforms.RandomVerticalFlip(p)
随机旋转:transforms.RandomRotation(degrees)
。
代码 | 作用 |
---|---|
transforms.RandomHorizontalFlip(p) |
按照概率p来对图像进行水平翻转 |
transforms.RandomVerticalFlip(p) |
按照概率p来对图像进行垂直翻转 |
transforms.RandomRotation(degrees) |
在给定角度范围内随机旋转图片 |
其中p为旋转的概率,这个根据不同需求选择不同的数值;而对于随机旋转有点不一样,degrees为旋转的角度范围,是一个数对(-degrees, +degrees)。
图像变换(resize)
自己在读取数据时常用的有:尺寸变换:transforms.Resize(size)
、标准化:transforms.Normalize(mean, std)
以及将图片数据转为tensor类型并归一化至[0, 1]:transforms.ToTensor()
。
代码 | 作用 |
---|---|
transforms.Resize(size) |
将图像尺寸变换成size |
transforms.Normalize(size) |
将输入图像按批次标准化 |
transforms.ToTensor() |
将输入图像归一化成[0, 1]之间的tensor |
这里详细讲解一下transforms.Normalize(mean, std)
的作用
transforms.ToTensor()
后,x=(x/255)∈[0, 1];transforms.Normalize(mean, std)
(假设mean=0.5,std=0.5),x=[(x-0.5)/0.5]∈[-1, 1] -> (x_mean=0,x_std=1)其中size为转换后预期的图像尺寸,mean代表图像的均值、std代表图像的方差。 这里要注意的是,transforms.ToTensor()
是直接将图片数据的每个像素值除以255。
上述针对图像的各类操作函数没有固定的搭配,都是根据具体需要具体去选择,最后,transforms模块的整体代码为:
data_transform = transforms.Compose([transforms.CenterCrop(32), # 对输入图像进行中心裁剪,裁剪后的图像尺寸为3*32*32(对于3通道RGB图像)
transforms.RandomHorizontalFlip(0.5), # 对输入图像按照0.5的概率进行水平翻转
transforms.ToTensor(), # 将PIL Image或者ndarray类型的数据转换为tensor类型(非常关键,一定要加)
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]) # 对图像数据进行标准化
通过torchvision.transforms
编辑好所需要的图像变换规则后,就可以调用torchvision.datasets
和torch.utils.data.DataLoader
来读取与加载所需要的图片数据:
batchsize = 128
'''获取训练集与测试集数据'''
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=data_transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=data_transform, download=True)
'''加载训练集与测试集数据'''
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8)
Files already downloaded and verified
Files already downloaded and verified
训练集和测试集数据准备完毕,下面可以开始进行模型的训练与测试了
root=
:指出数据存放的路径,可以是绝对路径也可以是相对路径。若路径已存在,则数据直接保存在相应路径下;若路径不存在,则先创建路径然后将数据保存至路径下;train=
:bool型变量,若为True则表示获取训练集数据;若为False则表示获取测试集数据;transform=
:调用数据增强的各种方法,这里调用的是上文中所编写的data_transform
,可视情况而定;download=
:bool型变量,这里一般都设置为True,在第一次运行时会自动将所需数据集下载到指定路径下,后续运行不会重复下载(路径中已存在相应数据集)。train_dataset
:指出要加载的数据集,在这里为train_dataset表示要加载训练数据集,为test_dataset则表示要加载测试数据集;batch_size=
:设定一批次的图片数量,这个超参没有固定值,根据数据集、模型大小以及硬件配置自行设置;shuffle=
:bool型变量,若为True,则将打乱加载数据的顺序,若为False则不会打乱顺序;num_workers=
:可以理解为加载数据的通道数量,一般来说num_workers越多加载数据速度越快。若要可视化所读取和加载的数据,可以调用以下方法(这里只做简单的介绍,通常可视化是用来简单检测读取加载数据是否正确,训练时一般会去掉):
'''可视化数据集所包含的类别'''
classes = train_dataset.classes
print(classes)
print('-----------')
'''可视化数据集的数量与各个图片的尺寸'''
print(train_dataset.data.shape) # (50000, 32, 32, 3),50000表示数据集图片的数量;按顺序:32表示h,32表示w,3表示c。
print('-----------')
'''可视化加载后一批次的图片数量以及各图片尺寸'''
for data, label in train_loader:
print(data.shape) # 一批次的图片数量为128(batchsize),图片的尺寸为(3, 32, 32),。
print(label.shape) # 每个图片对应的标签
break
'''可视化一些样本图片'''
import matplotlib.pyplot as plt # 一个画图的包
plt.imshow(train_dataset.data[1])
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
-----------
(50000, 32, 32, 3)
-----------
torch.Size([128, 3, 32, 32])
torch.Size([128])
针对这种方式,用的相对就多了,其对应的图片数据储存方式应该为:
根路径/数据集文件夹(weather_5)/train(test)/类别文件夹(label)/图片
保存成这样的格式之后,就可以直接利用pytorch定义好的派生类ImageFolder来读取了。ImageFolder其实就是Dataset的派生类,专门被定义来读取特定格式的图片的,它也是torchvision库帮中为了我们方便读取文件夹类型的图片数据而创建的。
'''导入读取图片数据所需要的工具包'''
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data
batchsize = 2
'''同样先定义transform'''
data_transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])
'''获取训练集与测试集数据'''
train_dataset = datasets.ImageFolder(root='./data/weather_5/train/', transform=data_transform)
test_dataset = datasets.ImageFolder(root='./data/weather_5/test/', transform=data_transform)
'''加载训练集与测试集数据'''
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8)
这里只对datasets.ImageFolder
做一个简单的介绍。
通过train_dataset = datasets.ImageFolder
后返回的train_dataset
包含以下三种属性:
train_dataset.class
:用一个list保存数据集中类别名称train_dataset.class_to_idx
:类别对应的数字索引train_dataset.imgs
:保存(img_path, class)tuple的 list重点关注下面两块小代码,可以更清晰的了解datasets
和DataLoader
的差别
'''可视化读取到的数据集第一张图片的尺寸与其标签'''
for data, label in train_dataset:
print(data.shape) # 第一张图片的为(3, 224, 224)
print(label) # 第一张图片的标签为0
break
print('-----------')
'''可视化加载后一批次的图片数量以及各图片尺寸'''
for data, label in train_loader:
print(data.shape) # 一批次的图片数量为2(batchsize),图片的尺寸为(3, 224, 224)
print(label.shape) # 每个图片对应的标签
print(label[0]) # 可以将torch.utils.data.DataLoader()中的shuffle变量修改为False,比较输出有什么不同。
break
print(train_dataset.class_to_idx)
torch.Size([3, 224, 224])
0
-----------
torch.Size([2, 3, 224, 224])
torch.Size([2])
tensor(4)
{'cloudy': 0, 'haze': 1, 'rainy': 2, 'snow': 3, 'sunny': 4}
这种加载图像数据的方式应该是现在用的最多的一种方式,这种方式Pytorch就没有现成的方法让我们直接去加载数据了,但是我们可以基于Pytorch定义我们自己的Dataset
类。
这里以另一个天气分类模型所用到的数据集为例:
'''导入读取图片数据所需要的工具包'''
import torch.utils.data
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import os
import pandas as pd
from PIL import Image
'''定义transform'''
data_transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
'''定义自己的Dataset类'''
class MyDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
"""
csv_file: 标签文件的路径.
root_dir: 所有图片的路径.
transform: 一系列transform操作
"""
self.data_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
# print(self.data_frame.info) # 可以查看通过pandas读取后的data_frame具体包含些什么
def __getitem__(self, idx):
'''获取图片'''
img_path = os.path.join(self.root_dir,
self.data_frame.iloc[idx, 0]) #获取图片所在路径
img = Image.open(img_path).convert('RGB') # 防止有些图片不是RGB格式
'''获取标签'''
label_number = self.data_frame.iloc[idx, 1] # 获取图片的类别标签
'''判断是否要进行图像变换'''
if self.transform:
img = self.transform(img)
return img, label_number # 返回图片和标签
def __len__(self):
return len(self.data_frame) # 返回数据集长度用来建立索引idx
'''调用自己的Dataset类来读取数据'''
train_dataset = MyDataset(csv_file='./data/weather/Train_label.csv',
root_dir='./data/weather/Train',
transform=data_transform)
'''将读取好的数据进行加载'''
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=2, num_workers=8) # 加载数据集
'''可视化加载后一批次的图片数量以及各图片尺寸'''
for data, label in train_iter: # 迭代batchsize中的数据
print(data.shape) # torch.Size([128, 3, 224, 224])
print(label.shape) # torch.Size([128])
break
torch.Size([2, 3, 224, 224])
torch.Size([2])
总结一下,在定义自己的Dataset类的时候注意三点就可以:
def __init__(self)
:声明并初始化需要用到的变量,比如.csv路径、图片数据存放的文件夹路径等def __getitem__(self, idx)
:根据路径逐一读取样本的图像和标签,并返回一个元组def __len__(self)
:获取数据集中样本的数量这里简单介绍一下上一块代码用到的几个工具包:import os
,import pandas as pd
,from PIL import Image
:
import os
:在python下写程序,需要对文件以及文件夹或者其他的进行一系列的操作import pandas as pd
:pandas这个库是用来读取数据的,比如.csv文件以及Excel文件,上一块代码用到的pd.read_csv(csv_file)
就是读取.csv文件的指令from PIL import Image
:Python图像库PIL(Python Image Library)是Python的第三方图像处理库,可以做很多和图像处理相关的事情:图像归档、可视化、处理等,如上一块代码中用到的img = Image.open(img_path).convert('RGB')
来读取图片数据,并将其转化成RGB模式在这里详细介绍一下有关os模块常用的指令,对于其它两个模块由于涉及的内容太多,可以去官网查看具体用法。
import os
:在python下写程序,需要对文件以及文件夹或者其他的进行一系列的操作,就需要引入os模块,常用的指令有:
指令 | 作用 |
---|---|
os.path.join(path_1, path_2) |
将路径path_1和path_2拼接起来形成新路径 |
os.path.split(path) |
将path分割成目录和文件名并以元组方式返回 |
os.splitext(path) |
分离扩展名然后按照元组返回 |
os.path.exists(path) |
如果path是一个存在的路径,返回True,否则返回 False |
总的来说,对于计算机视觉的数据读取和加载,就是利用好torchvision
中的transforms和datasets以及torch.utils.data
中的Dataloader类,当然还有很多种读取图像、视频数据的方法。
torchvision.transforms
去设计自己的transform模块;完整项目在我上传的资源里面,需要的可以自取