pytorch图像处理:读取数据集Dataset和ImageFolder

 1、重写Dataset类:

#源码
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
 
#这个函数就是根据索引,迭代的读取路径和标签。因此我们需要有一个路径和标签的 ‘容器’供我们读
def __getitem__(self, index):
	raise NotImplementedError
 
#返回数据的长度
def __len__(self):
	raise NotImplementedError
def __add__(self, other):
	return ConcatDataset([self, other])

想制作自己的图像数据集供DataLoader拿取,首先要重写Datasets类,主要用来完成从哪里读取数据和标签的功能。主要是__getitem()__(返回数据集和标签)和__len__(返回数据的长度)这两个方法。

完成Datasets类的这两个主要功能后,训练的时候可以把数据集传送给DataLoader就可以获取自己想要的batch数据。

例1:通过包含 数据路径 和 标签 的TXT文件读取

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
 
#集成Dataset类
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
    """
    tex_path : txt文本路径,该文本包含了图像的路径信息,以及标签信息
    transform:数据处理,对图像进行随机剪裁,以及转换成tensor
    """
	fh = open(txt_path, 'r')  #读取文件
	imgs = []  #用来存储路径与标签
    #一行一行的读取
	for line in fh:
		line = line.rstrip()  #这一行就是图像的路径,以及标签  
        
		words = line.split()
		imgs.append((words[0], int(words[1])))  #路径和标签添加到列表中
		self.imgs = imgs                        
		self.transform = transform
		self.target_transform = target_transform
 
def __getitem__(self, index):
	fn, label = self.imgs[index]   #通过index索引返回一个图像路径fn 与 标签label
	img = Image.open(fn).convert('RGB')  #把图像转成RGB
	if self.transform is not None:
		img = self.transform(img) 
	return img, label              #这就返回一个样本
 
def __len__(self):
	return len(self.imgs)          #返回长度,index就会自动的指导读取多少
 

例2:通过标签文件读取

#首先集成Dataset这个类
class DealDataset(Dataset):
    """
        下载数据、初始化数据,都可以在这里完成
    """
    def __init__(self):
 
        #这里xy 就是一个容器,通过读取一个包含有数据和标签信息的文件
        xy = np.loadtxt('../dataSet/diabetes.csv.gz', delimiter=',', dtype=np.float32)
 
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
        
        #长度,可以给__len__返回用。
        self.len = xy.shape[0]
    
    def __getitem__(self, index):
        
        #通过索引index,索引到指定的数据以及对应的标签
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len
 

例3:没有标签文件,代码根据文件夹分类自己构造

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform
 
    def __getitem__(self, index):
     
        path_img, label = self.data_info[index]       #索引读取图像路径和标签
        img = Image.open(path_img).convert('RGB')     # 读取图像,返回Image 类型 0~255
 
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,把图像转为tensor等等
 
        return img, label
 
    def __len__(self):
        return len(self.data_info)
 
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
 
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))
 
        return data_info  ##返回的也就是图像路径 和 标签

2、文件夹读取ImageFolder

# 预处理 转为tensor 以及 标准化
transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 
#使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹
traindata = torchvision.datasets.ImageFolder('data/rmb_split/train/', transform=transform)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=4, shuffle=True, num_workers=1)
 
 
testset = torchvision.datasets.ImageFolder('data/rmb_split/test/', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=1)
 

 

你可能感兴趣的:(计算机视觉,Pytorch)