目录
1. 需要用到的库
2. 数据扩充定义
3. 自定义Dataset
4. 测试
开始一个新的系列,基于Kaggle比赛的猫狗大战数据集,基于PyTorch实现猫狗图像分类。
数据集地址在:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/overview。
下面是第一部分,主要介绍如何使用Pytorch自定义Dataloader。
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
image_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
数据扩充主要分为以下几步:
1)将图像的短边resize到256;
2)然后随即裁减224x224;
3)再进行随机水平翻转;
4)最后将图像转为Tensor并且标准化。
class DogVsCatDataset(Dataset):
"""Dog vs Cat dataset."""
def __init__(self, root_dir, train=True, transform=None):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.root_dir = root_dir
self.img_path = os.listdir(self.root_dir)
if train:
self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path))
else:
self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))
self.transform = transform
def __len__(self):
return len(self.img_path)
def __getitem__(self, idx):
image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))
label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1
if self.transform:
image = self.transform(image)
label = torch.from_numpy(np.array([label]))
return image, label
数据集初始化时要设置图片目录;是否是训练集或者是验证集,图片编号小于10000的为训练集,大于等于10000的为验证集;及数据扩充方式;猫的标签为0,狗的标签为1。
if __name__ == '__main__':
catanddog_dataset = DogVsCatDataset(root_dir='../dogs-vs-cats-redux-kernels-edition/train', train=False,
transform=image_transform)
train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)
image, label = iter(train_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))
测试的时候使用“if __name__ == '__main__':”可以在其他文件import时,不执行这些语句。执行代码后,显示的图片和打印的标签如下所示:
Label is: [0]
Label is: [1]