说明:搭建网络需要数据的读入和预处理,model的构建,损失函数和优化策略的选择。该博客以PIL读取图像,构建最基本的Dataset和DataLoader。
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as tr
import os
from PIL import Image
用PIL读取图像为例,还可以用matplotlib和cv2读取图像
class MyDataset(Dataset):
def __init__(self,root_path):
filenames = sorted(os.listdir(root_path))#filenames是列表 相当于存储的路径下的所有子文件名(查阅os.listdir)
self.transforms = tr.Compose([
tr.CenterCrop(256),#剪裁成等大的图像块 因为之后batch操作时必须保证每个batch中图像大小一致(尝试注释掉这行,看报错信息)
tr.ToTensor(),#将数据类型转换为tensor 之后才能加载到gpu上训练
tr.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))#这个是标准化处理,看具体数据情况
])
self.dataset=[]
for name in filenames:
file=os.path.join(root_path,name)#就相当于根路径和文件名连在一起
image=Image.open(file).convert('RGB')#用PIL读取并转换为RGB格式
self.dataset.append(self.transforms(image))#调用之前的变换(私有属性)不再赘述
def __len__(self): #这两个函数必须在Dataset里实现,一个返回数据集长度,一个通过索引来访问数据集中每个数据
return len(self.dataset)
def __getitem__(self,idx):
return self.dataset[idx]
root_path=r'/path/to/data'#换成你自己数据集存放位置(本博客用的数据集是Set14,由14张图片组成)
train_dataset=MyDataset(root_path)#实例化上述类
train_loader = DataLoader(dataset=train_dataset,batch_size=4,shuffle=True)#加载数据集,完成dataset里无法对数据进行的分批等操作
print(len(train_dataset))
# print(train_dataset.__len__())
# print(train_dataset.__getitem__(0))
结果 :14
如果我们注释掉第二部分中的__len__()这个方法后会报如下的错:
TypeError: object of type 'MyDataset' has no len()
print(len(train_loader))
结果: 4
那么为什么train_loader的长度是4? 我们将其中的元素大小打印出来:
for a in train_loader:
print(a.size())
结果如下:
torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256])
torch.Size([2, 3, 256, 256])
我们看看train_loader里到底存了什么,是4 * 3 * 256 * 256大小的张量(高维矩阵)。首先我们知道数据集里有14张图片,我们在DataLoader里设置的batch_size大小为4,没错第一维就是这个批的大小,因为有14张图片,4个为一批,所以最后会有两个剩余,默认人情况下会把不足的看作一批,当然也可以用DataLoader的参数使得不足的忽略掉,自己去查;第二维存的是通道数,我们读入的图片是RGB格式所以是三通道的,灰度图是单通道的,最后两个是图像的H和W(即高和宽)——原图像H和W不一定一致哦,这是我们crop(剪裁)的结果,为了进行批处理。
如果我们将剪裁的那行代码注释掉,最后在访问train_loader中的元素时会报如下的错:
RuntimeError: stack expects each tensor to be equal size, but got [3, 512, 512] at entry 0 and [3, 361, 250] at entry 1
也就是说,在分批处理的时候,一定要保证每批中的数据大小一致。
详细内容可参考pytorch官方文档:
torch.utils.data