torchvision.datasets源码地址:https://github.com/pytorch/vision/blob/master/torchvision/datasets
前两篇从搭建经典的ResNet,DenseNet入手简单的了解了下PyTorch搭建网络的方式,但训练一个模型光光搭建好一个网络是不够的,正所谓巧妇难为无米之炊,如何将数据处理成网络可以传递的Tensor也尤为重要,而数据准备过程最最最最最重要的就是Datasets和Dataloader两部分!
torchvision.datasets.ImageFolder就是官方给出的一个datasets的事例,具体使用直接贴上官方tutorial上的代码供参考:
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
但由于torchvision.datasets.ImageFolder函数的使用必须对数据的放置有要求,必须在data_dir目录下放置train和val两个文件夹,然后每个文件夹下,每一类图片单独放在一个文件夹里。官方的例子是ants和bees,所以在train和val文件夹下都有ants和bees这两个文件夹,分别放置相应的文件。
那么问题就来了,我们通常打完标签,是不会根据标签进行分类,而且在进行目标检测时一张图可能对应有多个标签,而是通过一个xml文件或者json文件用于记录label信息,所以是不满足ImageFolder的要求的。
所以根据实际数据情况,自定义Datasets就很关键,接下来我们就根据ImageFolder的函数形式,顺藤摸瓜从头来看如何自定义一个Datasets!
首先,可以看到ImageFolder类继承了DatasetFolder类,DatasetFolder类又继承了torch一个基础的抽象类torch.utils.data.Dataset类。
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
自定义Datasets的关键就是重载 "__len__"和"__getitem__"两个函数!而 "__add__"函数的作用是使得类定义对象拥有"object1 + object2"的功能,一般情况不需要重载该函数。
__len__函数:使得类对象拥有 "len(object)"功能,返回dataset的size。
__getitem__函数:使得类对象拥有"object[index]"功能,可以用索引i去获得第i+1个样本。
再来看看同样继承于torch.utils.data.Dataset的CocoDetection dataset是如何定义上述两个函数的!
def __init__(self, root, annFile, transform=None, target_transform=None):
# 从cocoapi导入pycocotools下的COCO类
from pycocotools.coco import COCO
self.root = root
# 初始化一个COCO对象
self.coco = COCO(annFile)
# 将每张图unique的id属性转化为list存储在self.ids中
self.ids = list(self.coco.imgs.keys())
self.transform = transform
self.target_transform = target_transform
(1)初始化函数可以接受四个参数:
(2)初始化COCO对象时,将.json文件解析为字典形式导入内存,并创建调用createIndex()创建索引。
(3)self.coco.imgs是以每张图unique的id作为key,json文件images下每一image信息作为value的一个字典。
def __len__(self):
# 因为图片的id是unique的,所以self.ids的长度就等于总图片数
return len(self.ids)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
"""
coco = self.coco
# 通过索引获得图片的id
img_id = self.ids[index]
# 再通过getAnnIds方法利用img_id找到对应的anno_id
ann_ids = coco.getAnnIds(imgIds=img_id)
# 根据anno_id和标签之间的映射关系,解析出标签target
target = coco.loadAnns(ann_ids)
path = coco.loadImgs(img_id)[0]['file_name']
# 根据每张图的file_name结合之前传入的图片放置的根目录读取图片信息
img = Image.open(os.path.join(self.root, path)).convert('RGB')
# 判断是否需要进行数据增强
if self.transform is not None:
img = self.transform(img)
# 判断标签是否需要进行变换
if self.target_transform is not None:
target = self.target_transform(target)
# 最终返回值形式可以根据自己需要进行设计。此处为一个tuple,包含一张图片以及对应的标签。
return img, target
以下这个例子就是自定义的FaceLandmarksDataset,效果是从.csv文件中读取每张图上的68个人脸面部关键点的坐标x,y,然后根据.csv文件中对应的图片名,读取相应的图片,然后返回值是一个sample字典,包含'image'和'landmarks'两个key。
class FaceLandmarksDataset(Dataset):
def __init__(self, root_dir, csv_file, transform=None):
self.root_dir = root_dir
self.landmarks_frame = pd.read_csv(csv_file)
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
数据准备阶段datasets部分就简单介绍完了,下篇继续介绍另一个关键部分dataloader!