使用 PyTorch 搭建网络 - dataset_py篇

dataset.py

文件目录如下:

  • 导包
  • dataset
  • __init__
  • __len__
  • __getitem__
  • Legacy

参考内容:

  • wz专研6

导包

我们需要导入PIL库用以打开图像,numpy(可选)用以对图像处理,Dataset为需要继承的基类。

import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset

dataset

在该模块中,需要告诉程序如何读入你的数据,并且做一些预处理。你的DIYDataset类继承自Dataset类,并重写__init____len____getitem__三个魔法方法。

  1. __init__方法向外索取3个输入数据:
  • 读取数据路径
  • 是训练集还是验证集(因为你训练集和验证集往往是在两个不同的文件夹)
  • 你使用的预处理方法(以transform为主,transform也可以根据验证集还是测试集调用不同的trans)
  1. __getitem__需要返回image和对应label的Tensor
  2. __len__用来返回集合中图片个数

案例如下:

class MyDataset(Dataset):
	"""什么什么数据集"""
	
	def __init__(self, root: str, train: bool, transforms: object=None):
		"""初始化函数"""
		pass
	
	def __getitem__(self, index):
		pass
	
	def __len__(self):
		pass

__init__

在该方法中需要根据传入的path参数和train参数找到你的测试集或者训练集的物理地址,并将集合中的images和labels的物理地址存储在list中,以供后面方法使用。案例如下:

def __init__(self, root: str, train: bool, transforms: object=None):
    super(DriveDataset, self).__init__()
    self.flag = "training" if train else "test" # 由 train: bool 的布尔值来判断是取train还是test
    data_root = os.path.join(root, "DRIVE", self.flag)
    assert os.path.exists(data_root), f"path '{data_root}' does not exists."
    self.transforms = transforms
    img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
    self.img_list = [os.path.join(data_root, "images", i) for i in img_names]               # 返回所有images地址的list
    self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")  # 返回所有manuals地址的list
                   for i in img_names]
    # check manual files
    for i in self.manual:
        if os.path.exists(i) is False:
            raise FileNotFoundError(f"file {i} does not exists.")
    self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")    # 返回所有mask地址的list
                     for i in img_names]
    # check mask files
    for i in self.roi_mask:
        if os.path.exists(i) is False:
            raise FileNotFoundError(f"file {i} does not exists.")

__len__

返回测试集或者验证集中Image数量,因为Image和Label往往是一一对应的,所以返回哪个其实都一样。三维数据可能存在多对一情况。

案例如下:

def __len__(self):
	return len(self.img_list)

__getitem__

该方法根据image和label的物理地址,用PIL打开图片,再用transforms处理Image返回Tensor,最后返回处理过的Tensor类型元组(image, label)。

在该方法中,你可以使用PIL处理图像(mode),也可以将PIL转为numpy使用numpy处理图片(元素类型dtype),也可以使用Transforms处理图片(Normalization)等。

__getitem__案例见下:

def __getitem__(self, idx):
    """将Image转为RGB, 将label转为L"""
    img = Image.open(self.img_list[idx]).convert('RGB')
    manual = Image.open(self.manual[idx])
    manual = manual.convert('L')
    manual = np.array(manual) / 255
    roi_mask = Image.open(self.roi_mask[idx]).convert('L')
    roi_mask = 255 - np.array(roi_mask)
    # 将manual图片和Imae进行处理
    mask = np.clip(manual + roi_mask, a_min=0, a_max=255)
    # 这里转回PIL的原因是,transforms中是对PIL数据进行处理
    mask = Image.fromarray(mask)
    if self.transforms is not None:
        img, mask = self.transforms(img, mask)
    return img, mask

使用PIL处理

参看链接https://blog.csdn.net/qq_43369406/article/details/127781871

使用Numpy处理

参看链接https://blog.csdn.net/qq_43369406/article/details/127781871

使用transforms处理

参看链接[coming soon]

这段代码我们常写在train.py中。在进行transforms累加时候,我们常将所需要的transforms全部添加至一个list中,再将这个list给transforms.Compose掉,注意一定要添加transforms.ToTensor方法。transform的更多内容可以参考笔者的transforms博客。

在调用transforms时完整逻辑如下:

# 获取dataset
train_dataset = DriveDataset(args.data_path,
                                 train=True,
                                 transforms=get_transform(train=True, mean=mean, std=std))
# 获取tranforms
def get_transform(train, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
	"""对要获取的transforms进行判断,看是测试集的dataset还是验证集"""
    base_size = 565
    crop_size = 480

    if train:
        return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
    else:
        return SegmentationPresetEval(mean=mean, std=std)
# 测试集transforms
class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
                 mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        min_size = int(0.5 * base_size)
        max_size = int(1.2 * base_size)

        trans = [T.RandomResize(min_size, max_size)]
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))
        if vflip_prob > 0:
            trans.append(T.RandomVerticalFlip(vflip_prob))
        trans.extend([
            T.RandomCrop(crop_size),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])
        self.transforms = T.Compose(trans)

    def __call__(self, img, target):
        return self.transforms(img, target)

增加了魔法方法__call__只是为了对image/input,label/target一块进行transforms。

测试代码

在测试代码中,我们通过查看返回的Tensor的size来查看是否成功。

案例如下:

if __name__=="__main__":
    train_dataset = DriveDataset("/home/yingmuzhi/_data",
                                 train=True,
                                 transforms=None)
    a = train_dataset[0]

Legacy

迭代版本残留文件,不必看。

一个函数如果不好理解,记住输入和输出就行了,不需要记忆黑匣子里面的逻辑

这个模块的主要目的就是加载图像文件,如输入tiff/jpg/png等格式的图像,输出torch.Tensor类型的tensor对象。

该模块往往需要导入torch.utils.data.dataset.Dataset类,PIL.Image模块(为什么模块名大写,我也不知道),torchvision.transforms目录。

  1. 需要注意的是import直接导入最小只能小到直接导入.py模块,即只能直接导入目录或者.py模块,如果要导入模块中的类,必须使用from ... import。如要导入Dataset类要使用from torch.utils.data.dataset import Dataset
  2. 如果直接import torchimport torch.utils.data.dataset是一致的,在后面使用模块中的Dataset类都需要这么写torch.utils.data.dataset.Dataset

该模块主要要编写LeNetDataSet类,该类要继承自torch.utils.data.dataset.Dataset类,需要重写Dataset类的__init__方法和__getitem__()方法。

你可能感兴趣的:(人工智能,pytorch,深度学习,python)