pytorch学习之旅(一)——自定义数据读取

最近在研究显著性检测,学着使用pytorch框架,以下纯属个人见解,如有错误请指出

(一)自定义数据读取

首先官方案例:

PyTorch读取图片,主要是通过Dataset类,所以先简单了解一下Dataset类。Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它,类似于C++中的虚基类。

class Dataset(object):
     def __getitem__(self, index):
          raise NotImplementedError
    def __len__(self):
          raise NotImplementedError
    def __add__(self, other):
          return ConcatDataset([self, other])

这里重点看 getitem函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

然而,如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。 那么读取自己数据的基本流程就是: 1. 制作存储了图片的路径和标签信息的txt 2. 将这些信息转化为list,该list每一个元素对应一个样本 3. 通过getitem函数,读取数据和标签,并返回数据和标签

在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的iter(self),后面会详细讲解读取过程。在本小节,主要讲Dataset子类。 因此,要让PyTorch能读取自己的数据集,只需要两步: 1. 制作图片数据的索引 2. 构建Dataset子类

下面是我做显著性检测时自定义的(我纠结label的定义足足两天,总算明白了:label 在官网给出的是分类问题,因此标签是对应的类别要么是文字要么手写体表示的数字,而我需要的是图片,这里就发一下他们之间的对比,就很容易理解到pytorch这个自定义的类是有多么方便)
下面是分类问题常用模板(显著性检测用的比较少,所以我就没有运行过代码,仅作为对比帮助理解)

from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
    fh = open(txt_path, 'r')
    imgs = []
    for line in fh:
        line = line.rstrip()
        words = line.split()
        imgs.append((words[0], int(words[1])))
        self.imgs = imgs 
        self.transform = transform
        self.target_transform = target_transform
def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = Image.open(fn).convert('RGB') 
    if self.transform is not None:
        img = self.transform(img) 
    return img, label
def __len__(self):
    return len(self.imgs)

下面是我自己的数据读取,最后生成一个dataset的类

主要思路将地址对应的image,label,通过地址列表形式,一个一个的导入,不过也有一个弊端,这个只能一张图片的输入到网络中,正好我们的batch_size = 1,最后一个代码我将用官方给出的例子改写,这样方便后续设置出我们需要的batch_size, 这样还有一个坏处,我的内存会溢出,一次性把全部图片读取出来,内存不够用,后续可以考虑把图片一张一张的读取,然后再一张一张的送进去,这样内存应该会轻松些

(最后的代码由于时间紧张,后续再补,其实很简单的说一下思路:

1.在__init__()中改写代码,最后返回index
2.打开image和label存放的txt,读取里面的地址生成list,两个list具有相同的index,最后return index就好,比较简单
3.在__getitem__()改写代码,把返回的index打开相应的地址,把对应的image和label转换成tensor,同时返回
4__len__()不变都行
可以在我的代码基础上,不相关的模块改写进去就好

def readtxt_into_list(address):
    file = open(address)
    addressMat = []
    namelMat = []
    for line in file.readlines():
        curLine = line.strip().split(" ")
        addressMat.append(curLine[0])
        namelMat.append(curLine[2])
    number_of_lines = len(namelMat)
    # 返回值包括图片地址名,文件名,已经这个list的大小
    return  addressMat, namelMat,number_of_lines

def img_tensor(address):
    img = Image.open(address).convert('RGB')
    img_np1 = numpy.transpose(img, (2, 0, 1))
    img3_tensor = torch.Tensor(img_np1)
    four_dims = img3_tensor.unsqueeze(0)
    return four_dims

# 取出lable和img的相关信息
dataset = []  # 用来存放lable 和img 的tensor 四维格式(B x C x H x W)
add_img = 'F:\data\MSRA10K_Imgs_GT\dir.txt'
address_img, name_img,lines = readtxt_into_list(add_img)
add_lable = 'F:\data\MSRA10K_Imgs_GT\dir1.txt'
address_lable, name_lable,lines = readtxt_into_list(add_lable)

for index in range(lines):
    # 取出地址
    img_add = str(address_img[index] + name_img[index])
    address1 = img_add

    lable_add = str(address_lable[index] + name_lable[index])
    address2 = lable_add

    # 读取图片转化成tensor
    input =img_tensor(address1)
    lable = img_tensor(address2)

    dataset.append([input, lable])

有问题,有错误,请指正,大家一起学习一起进步!

你可能感兴趣的:(pytorch框架)