在使用 Pytorch 做分类任务的时候,一般会用自带的torchvision.datasets.ImageFolder()
函数,但是这个对数据存储方式有要求,不一定适合自己,如果考虑加载自己的数据,就要考虑重写Dataset类
了。
ImageFolder
对数据存储方式要求: root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
... ...
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
一般我们不想来回移动数据,知道图片的路径即可,告诉模型在哪里自己去拿,是比较好的方式。所以我们只要继承Dataset
类,重新实现一下即可。
label
整理到文本中(什么文本都可以,方式也不限,但要方便自己解析)。__getitem__()
函数,读取每条数据和标签,并返回。 root/dog/xxx.png 0
root/dog/xxy.png 0
root/dog/xxz.png 0
root/cat/123.png 1
root/cat/nsdf3.png 1
root/cat/asd932_.png 1
#!/usr/bin/python
# -*- coding: UTF-8 -*-
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch
__all__ = ['MyDataset']
class MyDataset(Dataset):
def __init__(self, dataPath, transform=None, target_transform=None):
imgsPath = open(dataPath, 'r')
imgs = []
for line in imgsPath:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.transform(label)
return img, label
def __len__(self):
return len(self.imgs)
if __name__ == '__main__':
transform_train = transforms.Compose([transforms.Resize(256), # 重置图像分辨率
transforms.RandomResizedCrop(224), # 随机裁剪
transforms.RandomHorizontalFlip(), # 以概率p水平翻转
transforms.RandomVerticalFlip(), # 以概率p垂直翻转
transforms.ToTensor(),])
trainset = MyDataset(dataPath='train.txt', transform=transform_train) # 训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
for step, (tx, ty) in enumerate(trainloader, 0):
print('---test---', tx, ty)
声明: 总结学习,有问题或不当之处,可以批评指正哦,谢谢。
[1]:https://github.com/tensor-yu/PyTorch_Tutorial
[2]:https://blog.csdn.net/u011995719/article/details/85102770