数据的输入-pytorch

无论在图像识别,目标检测,语义分割哪一种网络,数据输入都是通过两个类对数据进行处理
1、打包数据,进行数据增强,输出标签以及图像
2、对打包的数据进行包括分批处理,是否打乱顺序,是否多线程加载数据等操作
需要注意的是,图像识别数据以及标签较为简单,一般不需要自己重写这两个类,但在目标检测和语义分割模型中需要自己重写

打包数据的类一定会包含三个方法: __ init__、__ len__、__ getitem__
第一个方法不做介绍,第二个方法可以使该类可以作为len()函数的参数查看数据长度,第三个方法可以使该类根据索引获取数据

这个类的结构为:

import torch.utils.data as data

class My_dataset(data.Dataset):

	def __init__(self, ...):
		super(My_dataset, self).__init__()
		pass
	def __len_():
		pass
	def __getitem__(self,index)...
		return image target

对打包的数据进行处理的这个类不做具体介绍,只说明用法

一、图像分类

from torchvision import datasets

打包数据
train_dataset = datasets.ImageFolder(data_path,transform)

data_path:   数据路径,该路径下存放多个文件夹,每个文件夹是一类数据

transform:  对数据做数据增强

打开这个内置类看一下,可以看到里面确实有上面说的方法
数据的输入-pytorch_第1张图片

后面系统会自动调用__getitem__这个方法,返回图像以及对应标签

标签就是按照文件夹顺序进行排序,可以看debug结果,下面第一张图是数据文件夹的情况,第二张图可以看到datasets.ImageFolder()将这五个文件夹排序并标了序号,这个序号就是图像标签,第三张图通过输入一个索引来调用__getitem__方法,返回数据集的第二张图片信息(图像以及标签)
数据的输入-pytorch_第2张图片
数据的输入-pytorch_第3张图片

数据的输入-pytorch_第4张图片
第二个类为:

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    train_dataset:刚才第一个类的名称
    batch_size:批次大小
    shuffle:是否打乱数据顺序
    num_workers:几个进程进行数据加载

二、目标检测、语义分割

这两个模型需要自己重写Datasets,因为这里的__getitem__与pytorch自定义的不一样,所以需要自己写
from torch.utils.data.dataset import Dataset

class My_Dataset(Dataset):
    def __init__(self, annotation_lines):
        super(My_Dataset, self).__init__()
        self.annotation_lines   = annotation_lines

    def __len__(self):
        return len(self.annotation_lines)

    def __getitem__(self, index):
       
		image, box  = self.annotation_lines[index]
		image = Image.open(image)
		# 这里可以加入对image和box的一些变换,也就是数据增强,但前提是将image转换为numpy格式
        return image, box
 
self.annotation_lines中含有图像路径以及图像anchor坐标信息,或者是语义分割标签

第二个类为:

from torch.utils.data import DataLoader

data = DataLoader(train_dataset, 
				shuffle     =  shuffle, 
				batch_size  =  batch_size, 
				num_workers =  num_workers, 
				pin_memory  =  True,
                drop_last   =  True, 
                collate_fn  =  dataset_collate, 
                sampler     =  train_sampler)

前四个参数是一样的,这里只介绍collate_fn函数,因为这个需要自己重写
def dataset_collate(batch):
    images = []
    bboxes = []
    for img, box in batch:
        images.append(img)
        bboxes.append(box)
    images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
    bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
    return images, bboxes

这个函数就是将第一个类的图像以及标签按照batch_size进行打包,方便后续调用,这个和pytorch默认的collate不一样,所以需要自己重写,语义分割的话和目标检测一样的,只不过把box换一下

你可能感兴趣的:(笔记,pytorch,深度学习,python)