深度学习中主要分为两大任务,分类和回归。
1、 分类即classification,就是将具有相同属性的样本划分为同一类,具有不同属性的样本划分为不同类。
以往我们需要通过对样本打标签来划分类别,用0,1,2,3,…表示类别。而在Pytorch中只需要将同一类别的样本图片放在同一文件夹下,会自动将文件夹作为类别的区分。详细的操作与代码在之前的博客(Pytorch学习笔记(I)——预训练模型(一):加载与使用)中有介绍。即通过torchvision.datasets
中已经封装好的ImageFolder
载入分类任务的数据集。样例如下:
train_data=torchvision.datasets.ImageFolder('/disk2/lockonlxf/pin/trainData',transform=transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
]))
train_loader = DataLoader(train_data,batch_size=20,shuffle=True)
2、 回归即regression,就是通过学习,将输入样本转变为另一种形式。那么标签就不再是0,1,2,3,…这样的分类了,而是每一个样本对应的GT(groundtruth)。以下简单举了几个例子。
任务类型 | 输入(样本) | 输出(GT) |
---|---|---|
关键点检测 | 图片 | 所有关键点的坐标(x,y) |
目标检测 | 图片 | 边框坐标或边框尺寸 |
显著性检测 | 图片 | 目标掩膜图片 |
人类重建 | 图片 | 图片 |
因此,回归任务就不能
用ImageFolder
载入数据集,大多情况下需要自定义数据集载入方式来满足自己的任务要求。
半年前,我写了一篇博客Pytorch学习笔记(II)——自定义数据集载入方式(一),介绍了一种能够应用于大多数任务(多输入或多输出)的数据集载入方法。可以说,这一种方法是一种傻瓜式教学,因为要把所有文件的路径保存到txt文件中,还得读取txt,略有点麻烦。
本文将介绍一种简易的方法,但是不能保证适用于所有的任务。
torch.utils.data.Dataset
是一个表示数据集的抽象类,自定义数据集需要继承这个类,并且重写其以下内容:
__init__ :数据初始化
__len__ :返回数据库的大小
__getitem__ :支持使用下标的方式 如dataset[i] 来获取第i个样本
我这里还是以我自己的实验为例,一般我们做回归任务,都会有与输入样本配对的GT。然后,将对应的输入和GT放在一个文件夹里
我一共有122450个训练样本,于是我就有122450个文件夹。当然如果你看完博客后有更好的存放方式,欢迎交流。
接着,我们点开其中2个文件夹来看。可以看到,一共有4个文件,第一个是原图,第二个是经过裁剪的图片,第三个是图片的特征,第四个是由特征恢复的点云。
这种方法对文件命名很讲究,同样为jpg文件,第二张图会有crop
的前缀,同样为csv文件,后面两个用feature
和vertex
两个前缀来区分。总的来说,属于同一类别需要用同样的关键词明明。如果文件夹里只有一个jpg图片,那么就直接用文件类型检索就好,具体可以看后续代码。
接下来,将裁剪后的图片作为输入,特征文件作为GT。
初始化不需要写太多,除了要载入的母文件夹路径之外,transform一定要加!!!transform一定要加!!!transform一定要加!!!
def __init__(self,path,transform=None):
self.path = path
self.transform = transform
相比前一篇需要载入多个txt文件,这里只要载入我们的母文件夹路径,即只需一个输入。
文件载入后,可以在这一步对文件进行处理,比如提取信息或数据转化等等
def __getitem__(self, index):
#image_path = os.path.join(self.face, str(index + 1), '*.jpg')
image_path = os.path.join(self.face, str(index + 1), 'crop_*.jpg')
image_name = glob.glob(image_path)[0]
I_face = Image.open(image_name)
##上面是载入图片,下面是载入csv,使用时根据个人情况修改
feature_path = os.path.join(self.path, str(index + 1), 'feature_*.csv')
feature_name = glob.glob(feature_path)[0]
with open(feature_name) as feature_file:
feat_reader = csv.reader(feature_file) # Return an iterable reader object.
label = []
for element in feat_reader:
label.append(float(element[0])) # Each element is a list containing single string.
mm228 = torch.tensor(label).reshape(-1, 1)
if self.transform:
I_face = self.transform(I_face)
return I_face, mm228
这一步是计算样本的数量,其实只要计算母文件夹下有多少个文件夹就行了。
def __len__(self):
return len(os.listdir(self.face))
在对应位置,写上母文件夹的绝对路径即可
train_data = MyDataset('/home/xxxxx/300w',
transform=transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
]))
train_loader = DataLoader(train_data, batch_size=10, shuffle=False)
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import torch
import os
import glob
import csv
#######自定义dataset
class MyDataset(Dataset):
def __init__(self,path,transform=None):
self.path = path
self.transform = transform
def __getitem__(self, index):
#image_path = os.path.join(self.face, str(index + 1), '*.jpg')
image_path = os.path.join(self.face, str(index + 1), 'crop_*.jpg')
image_name = glob.glob(image_path)[0]
I_face = Image.open(image_name)
##上面是载入图片,下面是载入csv,使用时根据个人情况修改
feature_path = os.path.join(self.path, str(index + 1), 'feature_*.csv')
feature_name = glob.glob(feature_path)[0]
with open(feature_name) as feature_file:
feat_reader = csv.reader(feature_file) # Return an iterable reader object.
label = []
for element in feat_reader:
label.append(float(element[0])) # Each element is a list containing single string.
mm228 = torch.tensor(label).reshape(-1, 1)
if self.transform:
I_face = self.transform(I_face)
return I_face, mm228
def __len__(self):
return len(os.listdir(self.face))
train_data = MyDataset('/home/xxxxx/300w',
transform=transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
]))
train_loader = DataLoader(train_data, batch_size=10, shuffle=False)
#注意看这里!!!如果自定义没有问题,下面的循环是可以跑通的,如果有问题,第一行for就会报错
#如果出错,可以将上面的shuffle设置为False,就是不打乱,然后在debug的时候看看是哪一个数据出了问题
for step, data in enumerate(train_loader):
I_face, mm228 = data
#还可以看看,载入的尺寸是否正确,一般是会比原来多一维,代表的是batch_size
print(I_face.shape)
print(mm228.shape)