pytorch中DataLoader生成器的使用记录

在深度学习训练数据集中,采用批量训练时候基本都要使用生成器一批次一批次地把数据送入网络,节省内存。在keras中有ImageDataGenerator,使用很方便。所以pytorch也有对应的生成器,这里记录一下学习笔记。个人感觉pytorch的生成器并没有keras的使用方便。

keras中有ImageDataGenerator使用:https://blog.csdn.net/qq_35054151/article/details/101178662

pytorch中数据提取模块主要有Dataset和DataLoader两个部分:

1. DataLoader的函数定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

dataset:加载的数据集(Dataset对象) 
      batch_size:batch size 
      shuffle::是否将数据打乱 
      sampler: 样本抽样,后续会详细介绍 
      num_workers:使用多进程加载的进程数,0代表不使用多进程 
      collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可 
      pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些 
      drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

2. dataset

PyTorch读取图片,主要是通过Dataset类

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
	raise NotImplementedError
def __len__(self):
	raise NotImplementedError
def __add__(self, other):
	return ConcatDataset([self, other])

这里重点是getitem字典函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。一般的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。(这设计也不嫌麻烦!!!!)

搬运别人的代码:

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
	fh = open(txt_path, 'r')
	imgs = []
	for line in fh:
		line = line.rstrip()
		words = line.split()
		imgs.append((words[0], int(words[1])))
		self.imgs = imgs 
		self.transform = transform
		self.target_transform = target_transform
def __getitem__(self, index):
	fn, label = self.imgs[index]
	img = Image.open(fn).convert('RGB') 
	if self.transform is not None:
		img = self.transform(img) 
	return img, label
def __len__(self):
	return len(self.imgs)

第一行:self.imgs 是一个list,也就是一开始提到的list,self.imgs的一个元素是一个str,包含图片路径,图片标签,这些信息是从txt文件中读取

第二行:利用Image.open对图片进行读取,img类型为 Image ,mode=‘RGB’

第三行与第四行: 对图片进行处理,这个transform里边可以实现 减均值,除标准差,随机裁剪,旋转,翻转,放射变换,等等操作,这个放在后面会详细讲解。

当Mydataset构建好,剩下的操作就交给DataLoder,在DataLoder中,会触发Mydataset中的getiterm函数读取一张图片的数据和标签,并拼接成一个batch返回,作为模型真正的输入。下一小节将会通过一个小例子,介绍DataLoder是如何获取一个batch,以及一张图片是如何被PyTorch读取,最终变为模型的输入的。

参考链接:https://blog.csdn.net/weixin_40766438/article/details/100750633

               https://blog.csdn.net/u011995719/article/details/85102770

             https://blog.csdn.net/wwwww_bw/article/details/102911957

https://www.cnblogs.com/leokale-zz/p/11275800.html

你可能感兴趣的:(pytorch)