pytorch 在载入数据时用torchvision.datasets.ImageFolder
配合 torch.utils.data.DataLoader
很方便,但是只能遍历图片和图片的标签,无法灵活的获取图片的其他信息,比如图片的名字,本文介绍如何定义自己的 ImageFolder,在使用 Dataloader
时实现获取图片名字的功能!
以分类为例,用 pytorch 的 torchvision.datasets.ImageFolder
配合 torch.utils.data.DataLoader
即可对数据按类别进行读取、预处理、分成 batch
import torchvision
import torch
train_dataset = torchvision.datasets.ImageFolder(
train_data_pth,
transforms.Compose([
transforms.Resize(input_size,interpolation=2), # resize
transforms.ToTensor(), # ToTensor
normalize,])) # Normalization
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size, # set batchsize
shuffle=False,
num_workers=n_worker,
pin_memory=True)
参考
ImageFolder
中 train_data_pth
是存放数据集的文件夹,文件结构应该如下
train_data_pth
class1
xxx.jpg
...
class2
xxx.jpg
...
...
classn
xxx.jpg
...
Dataloader
的参数介绍如下
- dataset:加载的数据集(Dataset对象)
- batch_size:batch size
- shuffle:是否将数据打乱
- sampler: 样本抽样,后续会详细介绍
- num_workers:使用多进程加载的进程数,0 代表不使用多进程
- collate_fn: 如何将多个样本数据拼接成一个 batch,一般使用默认的拼接方式即可’
- pin_memory:是否将数据保存在pin memory 区,pin memory 中的数据转到 GPU 会快一些
- drop_last:dataset中的数据个数可能不是 batch_size 的整数倍,drop_last 为 True 会将多出来不足一个batch的数据丢弃
参考 pytorch之DataLoader()函数
官网中 Dataloader
的介绍如下(https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
在训练和测试时,可用如下循环来对数据进行操作
for batch_images, batch_labels in train_loader:
pass
把数据集的文件夹建立好,直接调用 ImageFolder
和 DataLoader
来进行数据的载入分批读取确实很方便,但是如果我们想知道哪些图片分类错误了, train_loader
loader 中仅有 image
(图片) 和 label
属性,没有 image name
(图片名称) 属性,有些力不从心!
因此,我们可以自己写 ImageFolder
来实现读取 image、label、image name 的功能,当然熟悉这个流程后,以后可以进行更个性化的操作!
自己写数据读取和预处理,来替代 torchvision.datasets.ImageFolder
的功能,具体实现如下 class Own_Dataset
所示
class Own_Dataset(Dataset):
def __init__(self, image_label_list, transform=None):
super().__init__()
self.samples_list = image_label_list # xxx.jpg class1
self.transform = transform # pre-processing of data
def __getitem__(self, index):
img_name = self.samples_list[index][0] # absolute path of image name
with open(img_name,"rb") as f:
img = Image.open(f).convert("RGB") # load image
label = self.samples_list[index][1] # image label
if img is None:
print(img_name)
if self.transform is not None:
img = self.transform(img)
return img, label, img_name
def __len__(self):
return len(self.samples_list)
其中 image_label_list
为列表,存放着图片的绝对路径以及标签信息,格式如下
[(/train_data_pth/calss1/1.jpg,class1),
(/train_data_pth/calss1/2.jpg,class1),
...,
(/train_data_pth/calssn/m.jpg,classn)]
想实现更多功能,在 def __getitem__(self, index):
中定义即可,
__getitem__:实例[idx] 时触发
参考:【python】类(11)
配合 DataLoader
使用
train_loader = torch.utils.data.DataLoader(
Own_Dataset(image_label_list=val_list,
transform=transforms.Compose([
transforms.Resize(input_size,interpolation=2), # resize
transforms.ToTensor(),
normalize,])),
batch_size=test_batch_size,
shuffle=False,
num_workers=n_worker,
pin_memory=True)
训练测试时,就可以访问图片,类别以及图片名信息了,如下所示
for batch_images, batch_labels,batch_names in train_loader:
pass
下面介绍部分 torchvision.transforms
方法
更多的 torchvision.transforms
方法可以参考官网介绍
https://pytorch.org/docs/stable/torchvision/transforms.html
train_dataset = datasets.ImageFolder(
train_data_pth,
transforms.Compose([
transforms.Resize(scale_size,interpolation=2),
transforms.RandomRotation(5),
transforms.ColorJitter(brightness=0.1,contrast=0.1,
saturation=0.1,hue=0.1),
transforms.FiveCrop(input_size),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(
mean = [0.5,0.5,0.5],
std = [0.5,0.5,0.5])(crop) for crop in crops]))
]))
input_size 和 scale_size 写成元组的形式,eg,(224,224) 和 (256,256)
Normalize 时注意 mean 和 std 一定要除以 255,值介于 0~1 之间
FiveCrop 或者 TenCrop 时,测试代码也需要进行相应的调整,如下
原来
out = net(batch_images)
现在
bs, ncrops, c, h, w, = batch_images.size()
result = net(batch_images.view(-1,c,h,w))
out = result.view(bs,ncrops,-1).mean(1)