【PyTorch图像语义分割】1. PyTorch数据准备与预处理

PyTorch数据准备与预处理

  • .py源文件的结构
  • dataset.py的功能
  • 用来加载数据的类`torch.utils.data.DataLoader()`
  • 导入自己的dataset
    • 数据子类的基础框架
    • 使用UAV图像的dataset .py
  • 参考资料

.py源文件的结构

  1. 数据准备与预处理: dataset.py
  2. 模型:model.py
  3. 训练规则:train.py
  4. 测试(benchmark + predict):test.py

dataset.py的功能

统一将图像返回成torch能处理的[original_iamges.tensor,label.tensor]

用来加载数据的类torch.utils.data.DataLoader()

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

重点关注四个参数:
batch_size: 批处理数目
shuffle: 是否每个epoch都打乱
workers: 载入数据的线程数
dataset: 是经过变换的自己的数据集(即:一个继承了torch.utils.data.Dataset类的子类的实例),[original_iamges.tensor,label.tensor]之类的,定义的“dataset.py”就是产生这个dataset的。然后在train.py中调用。

导入自己的dataset

class UAVDataSet(torch.ultis.data.Dataset)
  • 继承了torch.utils.data.Dataset这个(抽象)类,我们看看这个类在中文文档中介绍:

  • 所有其他数据集都应该进行子类化。所有子类应该重载__len____getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0len(self)。当然还有个初始化__init__()

  • 类 = 属性+方法(变量 + 函数),__init__()就是定义自己的属性。

数据子类的基础框架

如上述,必须要重载的是__getitem__()__len__()

  • __len__()len(dataset)返回数据集的大小。
  • __getitem__():实现数据集的下标索引,使用dataset[i]来得到第i个样本(图像和标记)。
import torch.utils.data as data
import torch
from torchvision import transforms
 
class MyTrainData(torch.utils.data.Dataset) #子类化
  def __init__(self, root, transform=None, train=True): #第一步初始化各个变量
 
    self.root = root   
    self.train = train
 
  def __getitem__(self, idx): #第二步装载数据,返回[img,label],idx就是一张一张地读取
      # get item  获取  数据 
 
      img = imread(img_path) #img_path根据自己的数据自定义,灵活性很高
      img = torch.from_numpy(img).float() #需要转成float
 
      gt = imread(gt_path)  #读取gt,如果是分类问题,可以根据文件夹或命名赋值 0 1  
      gt = torch.from_numpy(gt).float()
 
      return img, gt #返回  一一对应
 
  def __len__(self):
    return len(self.imagenumber) #这个是必须返回的长度

在框架里面填写具体的东西:

  1. 是否transform如裁剪、归一化、旋转等?如果要transform则还需要区分test和train。比如train需要随机翻转,但是test则不需要操作.
  2. 如何做到一张一张对应读取图片? 可以自定义这些函数。

使用UAV图像的dataset .py

在根目录下创建:

  • 文件夹UAVSegImages:里面放所有的图像,“7.jpg” …;
  • 文件夹UAVSegLabels:里面放所有的标签,“7.png” …;(与图像的名字一样)
  • UAVImages_id.txt:包含图像的名字,每一行是一个名字。

【PyTorch图像语义分割】1. PyTorch数据准备与预处理_第1张图片

【PyTorch图像语义分割】1. PyTorch数据准备与预处理_第2张图片
【PyTorch图像语义分割】1. PyTorch数据准备与预处理_第3张图片

下面是完整的代码:

# -*- coding: utf-8 -*-
"""
Created on Sat Dec  1 09:50:08 2018

@author: dspslzbw
"""
#%%

import os
import numpy as np
#import random
import matplotlib.pyplot as plt
#import collections
import torch
import torchvision
#import cv2
from PIL import Image
#import torchvision.transforms as transforms

class UAVDataSet(torch.utils.data.Dataset):
    def __init__(self, root, list_path, ignore_label=255):
        super(UAVDataSet,self).__init__()
        self.root = root
        self.list_path = list_path
	self.img_ids = [i_id.strip() for i_id in open(list_path)]
 
        self.files = []
        for name in self.img_ids:
            img_file = os.path.join(self.root, "UAVSegImages/%s.jpg" % name)
            label_file = os.path.join(self.root, "UAVSegLabels/%s.png" % name)
            self.files.append({
                "img": img_file,
                "label": label_file,
                "name": name
            })
 
    def __len__(self):
        return len(self.files)
 
 
    def __getitem__(self, index):
        datafiles = self.files[index]
 
        '''load the datas'''
        name = datafiles["name"]
        image = Image.open(datafiles["img"]).convert('RGB')
        label = Image.open(datafiles["label"]).convert('L')
        size_origin = image.size # W * H

		I = np.asarray(image,np.float32) 
        I = I.transpose((2,0,1))#transpose the  H*W*C to C*H*W
        L = np.asarray(np.array(label), np.int64)
        #print(I.shape,L.shape)
        return I.copy(), L.copy(), np.array(size_origin), name
  • 待改进:加入各种变换,写成transforms.Compose()插入到代码中。

  • 下面是一个测试函数,加在上面的代码后面,即代码写好后直接python运行当前py文件,就会执行以下代码的内容,以检测上面的代码是否有问题, 这其实就是方便调试,而不是每次都去run整个网络再看哪里报错。


if __name__ == '__main__':
    DATA_DIRECTORY = './'
    DATA_LIST_PATH = './images_id.txt'
    Batch_size = 2
    MEAN = (104.008, 116.669, 122.675)
    dst = UAVDataSet(DATA_DIRECTORY,DATA_LIST_PATH, mean=(0,0,0))
    # just for test,  so the mean is (0,0,0) to show the original images.
    # But when we are training a model, the mean should have another value
    trainloader = torch.utils.data.DataLoader(dst, batch_size = Batch_size)
    plt.ion()
    for i, data in enumerate(trainloader):
        imgs, labels,_,_= data
        if i % 1 == 0:
            img = torchvision.utils.make_grid(imgs).numpy()
            img = img.astype(np.uint8) # change the dtype from float32 to uint8, 
                                       # because the plt.imshow() need the uint8
            img = np.transpose(img, (1, 2, 0)) # transpose the C*H*W to H*W*C
            #img = img[:, :, ::-1]
            plt.imshow(img)
            plt.show()
            plt.pause(0.5)
 
#            label = torchvision.utils.make_grid(labels).numpy()
            labels = labels.astype(np.uint8) # change the dtype from float32 to uint8, 
#                                       # because the plt.imshow() need the uint8
            for i in range(labels.shape[0]):
                plt.imshow(labels[i],cmap='gray')
                plt.show()
                plt.pause(0.5)
 
            #input()
  • 输出:(只截了第一个batch_size,后面还有输出,上面的代码是遍历整个数据集。)
    【PyTorch图像语义分割】1. PyTorch数据准备与预处理_第4张图片

参考资料

[1] https://blog.csdn.net/woshicao11/article/details/78318156
[2] https://blog.csdn.net/Teeyohuang/article/details/82108203
[3] PyTorch文档中文版:https://pytorch-cn.readthedocs.io/zh/latest/
[4] https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

你可能感兴趣的:(数据集,PyTorch,数据加载与预处理)