统一将图像返回成torch能处理的[original_iamges.tensor,label.tensor]
torch.utils.data.DataLoader()
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
重点关注四个参数:
batch_size: 批处理数目
shuffle: 是否每个epoch都打乱
workers: 载入数据的线程数
dataset: 是经过变换的自己的数据集(即:一个继承了torch.utils.data.Dataset
类的子类的实例),[original_iamges.tensor,label.tensor]
之类的,定义的“dataset.py”就是产生这个dataset的。然后在train.py中调用。
class UAVDataSet(torch.ultis.data.Dataset)
继承了torch.utils.data.Dataset
这个(抽象)类,我们看看这个类在中文文档中介绍:
所有其他数据集都应该进行子类化。所有子类应该重载__len__
和__getitem__
,前者提供了数据集的大小,后者支持整数索引,范围从0
到len(self)
。当然还有个初始化__init__()
。
类 = 属性+方法(变量 + 函数),__init__()
就是定义自己的属性。
如上述,必须要重载的是__getitem__()
和__len__()
。
__len__()
:len(dataset)
返回数据集的大小。__getitem__()
:实现数据集的下标索引,使用dataset[i]
来得到第i个样本(图像和标记)。import torch.utils.data as data
import torch
from torchvision import transforms
class MyTrainData(torch.utils.data.Dataset) #子类化
def __init__(self, root, transform=None, train=True): #第一步初始化各个变量
self.root = root
self.train = train
def __getitem__(self, idx): #第二步装载数据,返回[img,label],idx就是一张一张地读取
# get item 获取 数据
img = imread(img_path) #img_path根据自己的数据自定义,灵活性很高
img = torch.from_numpy(img).float() #需要转成float
gt = imread(gt_path) #读取gt,如果是分类问题,可以根据文件夹或命名赋值 0 1
gt = torch.from_numpy(gt).float()
return img, gt #返回 一一对应
def __len__(self):
return len(self.imagenumber) #这个是必须返回的长度
在框架里面填写具体的东西:
在根目录下创建:
下面是完整的代码:
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 1 09:50:08 2018
@author: dspslzbw
"""
#%%
import os
import numpy as np
#import random
import matplotlib.pyplot as plt
#import collections
import torch
import torchvision
#import cv2
from PIL import Image
#import torchvision.transforms as transforms
class UAVDataSet(torch.utils.data.Dataset):
def __init__(self, root, list_path, ignore_label=255):
super(UAVDataSet,self).__init__()
self.root = root
self.list_path = list_path
self.img_ids = [i_id.strip() for i_id in open(list_path)]
self.files = []
for name in self.img_ids:
img_file = os.path.join(self.root, "UAVSegImages/%s.jpg" % name)
label_file = os.path.join(self.root, "UAVSegLabels/%s.png" % name)
self.files.append({
"img": img_file,
"label": label_file,
"name": name
})
def __len__(self):
return len(self.files)
def __getitem__(self, index):
datafiles = self.files[index]
'''load the datas'''
name = datafiles["name"]
image = Image.open(datafiles["img"]).convert('RGB')
label = Image.open(datafiles["label"]).convert('L')
size_origin = image.size # W * H
I = np.asarray(image,np.float32)
I = I.transpose((2,0,1))#transpose the H*W*C to C*H*W
L = np.asarray(np.array(label), np.int64)
#print(I.shape,L.shape)
return I.copy(), L.copy(), np.array(size_origin), name
待改进:加入各种变换,写成transforms.Compose()
插入到代码中。
下面是一个测试函数,加在上面的代码后面,即代码写好后直接python运行当前py文件,就会执行以下代码的内容,以检测上面的代码是否有问题, 这其实就是方便调试,而不是每次都去run整个网络再看哪里报错。
if __name__ == '__main__':
DATA_DIRECTORY = './'
DATA_LIST_PATH = './images_id.txt'
Batch_size = 2
MEAN = (104.008, 116.669, 122.675)
dst = UAVDataSet(DATA_DIRECTORY,DATA_LIST_PATH, mean=(0,0,0))
# just for test, so the mean is (0,0,0) to show the original images.
# But when we are training a model, the mean should have another value
trainloader = torch.utils.data.DataLoader(dst, batch_size = Batch_size)
plt.ion()
for i, data in enumerate(trainloader):
imgs, labels,_,_= data
if i % 1 == 0:
img = torchvision.utils.make_grid(imgs).numpy()
img = img.astype(np.uint8) # change the dtype from float32 to uint8,
# because the plt.imshow() need the uint8
img = np.transpose(img, (1, 2, 0)) # transpose the C*H*W to H*W*C
#img = img[:, :, ::-1]
plt.imshow(img)
plt.show()
plt.pause(0.5)
# label = torchvision.utils.make_grid(labels).numpy()
labels = labels.astype(np.uint8) # change the dtype from float32 to uint8,
# # because the plt.imshow() need the uint8
for i in range(labels.shape[0]):
plt.imshow(labels[i],cmap='gray')
plt.show()
plt.pause(0.5)
#input()
[1] https://blog.csdn.net/woshicao11/article/details/78318156
[2] https://blog.csdn.net/Teeyohuang/article/details/82108203
[3] PyTorch文档中文版:https://pytorch-cn.readthedocs.io/zh/latest/
[4] https://pytorch.org/tutorials/beginner/data_loading_tutorial.html