Pytorch中文文档已出(http://pytorch-cn.readthedocs.io/zh/latest/)。第一篇博客献给了pytorch,主要是为了整理自己的思路。
原来使用caffe,总是要编译,经历了无数的坑。当开始接触pytorch时,果断拔草caffe。
学习Pytorch最好有一些深度学习理论基础才更好开,废话不多说,进入主题。
当训练一个神经网络的时候,我们需要有数据,有模型,并且需要设置训练的参数。为了不乱,我们最好分别定义三个文件,分别是:数据准备和预处理traindataset.py+编写模型model.py+如何训练main.py(xx.py,xx自己可任意取名)。
今天我们只讲数据准备与预处理阶段:traindataset.py(怎样命名无所谓,as you like)。这个文件的作用是什么呢?
统一将图像(或矩阵)返回成torch能处理的[original_iamges.tensor,label.tensor]
我们先跳跃一下看中文介绍是如何导入数据:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False)
我们一般关注DataLoader四个参数:
dataset, batch_size, shuffle, num_workers=0
batch_size是你批处理数目,shuffle是否每个epoch都打乱,workers是载入数据的线程数(请查看中文文档对每个参数的解释)
这个dataset是 [original_iamges.tensor,label.tensor] 之类的,我们定义的“traindataset.py”就是产生这个dataset的。然后只需在main.py 文件import就可调用!
from traindataset import *
这个py文件一定要
1:能输入自己的数据路径 2:还得预处理吧,比如的裁剪啊~
step 1:先导入你肯定需要的库路径
import torch.utils.data
import torch
from tochvision import transforms
torch.utils.data模块是子类化你的数据
transforms库对数据预处理
step 2:自定义dataset类(子类化你的数据)
class MyTrainData(torch.utils.data.Dataset)
这里继承了torch.utils.data.Dataset这个类,我们看看这个类在中文文档中介绍:
所有其他数据集都应该进行子类化。所有子类应该override__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。当然还有个初始化__init__()
类:属性+方法,__init__()就是定义自己的属性
我们脸谱化py文件,再往里面加东西(以下为基础框架):
#encoding:utf-8
import torch.utils.data as data
import torch
from torchvision import transforms
class MyTrainData(torch.utils.data.Dataset) #子类化
def __init__(self, root, transform=None, train=True): #第一步初始化各个变量
self.root = root
self.train = train
def __getitem__(self, idx): #第二步装载数据,返回[img,label],idx就是一张一张地读取
# get item 获取 数据
img = imread(img_path) #img_path根据自己的数据自定义,灵活性很高
img = torch.from_numpy(img).float() #需要转成float
gt = imread(gt_path) #读取gt,如果是分类问题,可以根据文件夹或命名赋值 0 1
gt = torch.from_numpy(gt).float()
return img, gt #返回 一一对应
def __len__(self):
return len(self.imagenumber) #这个是必须返回的长度
现在往框框里面填
(1)是否transform如裁剪、归一化、旋转等?如果要transform则还需要区分test和train。比如我train需要 随机翻转,但是test则不需要操作
(2)如何做到一张一张对应读取图片? 可以自定义这些函数
以下贴出完整代码:
#encoding:utf-8
import torch.utils.data as data
import torch
from scipy.ndimage import imread
import os
import os.path
import glob
from torchvision import transforms
def make_dataset(root, train=True): #读取自己的数据的函数
dataset = []
if train:
dirgt = os.path.join(root, 'train_data/groundtruth')
dirimg = os.path.join(root, 'train_data/imgs')
for fGT in glob.glob(os.path.join(dirgt, '*.jpg')):
# for k in range(45)
fName = os.path.basename(fGT)
fImg = 'train_ori'+fName[8:]
dataset.append( [os.path.join(dirimg, fImg), os.path.join(dirgt, fName)] )
return dataset
#自定義dataset的框架
class MyTrainData(data.Dataset): #需要繼承data.Dataset
def __init__(self, root, transform=None, train=True): #初始化文件路進或文件名
self.train = train
if self.train:
self.train_set_path = make_dataset(root, train)
def __getitem__(self, idx):
if self.train:
img_path, gt_path = self.train_set_path[idx]
img = imread(img_path)
img = np.atleast_3d(img).transpose(2, 0, 1).astype(np.float32)
img = (img - img.min()) / (img.max() - img.min())
img = torch.from_numpy(img).float()
gt = imread(gt_path)
gt = np.atleast_3d(gt).transpose(2, 0, 1)
gt = gt / 255.0
gt = torch.from_numpy(gt).float()
return img, gt
def __len__(self):
return len(self.train_set_path)
这里的py文件需要在最后main.py文件中调用,所以root我并没有赋值,我会在main,py中赋值。
这里我并没有用到“transform”进行预处理,如果你想用的话,在__getitem__()下面,return img,gt前重新赋值
img = transforms.ToTensor(img)以及gt = transforms.ToTensor(gt)
这需要注意的是,查看中文文档transforms库有哪些变换,如果有需要涉及参数的如CenterCrop(size),需要先实参化,如
crop = transforms.CenterCrop(10);再使用:img = crop(img)