在使用自己数据集训练网络时,往往需要定义自己的dataloader。
一般将dataloader封装为一个类,这个类继承自 torch.utils.data.dataset
from torch.utils.data import dataset
class LoadData(Dataset): # 注意父类的名称,不能写dataset
pass
需要注意的是dataset是模块名,而Dataset是类名,在python中模块名和类名是完全独立的命名空间,因此这里的父类需要写成 dataset.Dataset。
在我们定义的LoadData中,至少需要有三个方法:
整体大致架构:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class LoadData(dDataset):
def __init__(self):
pass
def __getitem__(self,index):
pass
def __len__(self):
pass
dataset = Loaddata()
train_loader = DataLoader(dataset = dataset,batch_size = 32,shuffle = Ture,num_workers=2)
__init__方法需要传入至少两个参数:
def __init__(self, txt_path, train=True):
super(LoadData, self).__init__()
self.img_info = self.get_img(txt_path)
self.train = train
# train预处理
self.train_transforms = transforms.Compose([
transforms.Resize(20),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# test预处理
self.test_transforms = transforms.Compose([
transforms.Resize(20),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# 这个函数是用来读txt文档的
def get_img(self, txt_path):
with open(txt_path, 'r', encoding='utf-8') as f:
imgs_info = f.readlines()
imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
return imgs_info
__getitem__方法只需要根据index返回数据的item和label。
def __getitem__(self, index):
img_path, label = self.img_info[index]
img = Image.open(img_path)
label = int(label)
# 注意区分预处理
if self.train:
img = self.train_transforms(img)
else:
img = self.test_transforms(img)
return img, label
__len__方法最简单,仅返回数据项个数。
def __len__(self):
return len(self.img_info)
以训练数据为例,调用dataloader需要两步:
from torch.utils.data import Dataloader
train_dataset = LoadData(txt_path='XXXX', train=True)
train_loader = dataloader.Dataloader(
dataset=train_dataset,
batch_size=8,
shuffle=True
)
至此,一个最简单的dataloader就完成了!
可以用以下代码测试:
for image, label in train_loader:
print(image.shape)
print(label)
https://zhuanlan.zhihu.com/p/399447239