目录
1.定义自己的dataset类
2.Sampler
3.定义自己的dataloader迭代器
4.遍历数据
需要继承torch.utils.data.Dataset类,并重写__init__(),__len__(), __getitem__()方法
数据增强操作(类或者函数)也在该类的__getitem方法中被调用
import torch
import numpy as np
# 继承Dataset方法,并重写__getitem__()和__len__()方法
class my_dataset(torch.utils.data.Dataset):
# 初始化函数,得到数据
def __init__(self, data_root, data_label, transform=None):
self.data = data_root
self.label = data_label
self.transform = transform #数据增强
# inde是索引,最后将data和对应的labels进行一起返回
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
if self.trannform:
data, labels = self.transform(data, labels)
return data, labels
# 该函数返回数据大小长度,目的是DataLoader方便划分
def __len__(self):
return len(self.data)
# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = my_dataset(source_data,
source_label,
transform=tranforms.compose([
Rescale(256),
ToTensOr()]))
class Rescale(object):
"""将图片调整为给定的大小.
Args:
output_size (tuple or int): 期望输出的图片大小. 如果是 tuple 类型,输出图片大小就是给定的 output_size;
如果是 int 类型,则图片最短边将匹配给的大小,然后调整最大边以保持相同的比例。
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, data, labels):
h, w = data.shape[:2]
# 判断给定大小的形式,tuple 还是 int 类型
if isinstance(self.output_size, int):
# int 类型,给定大小作为最短边,最大边长根据原来尺寸比例进行调整
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(data, (new_h, new_w))
# 根据调整前后的尺寸比例,调整关键点的坐标位置,并且 x 对应 w,y 对应 h
labels = labels * [new_w / w, new_h / h]
return img, labels
数据增强可以定义成类,而不是函数,这样就不需要每次都传递参数,为此需要实现__call__方法何__init__方法
PyTorch为我们提供了几种现成的Sampler子类:
dataloader()中的shuffle=True时,默认的是RandomSampler,shuffle=false时默认的是SequentialSampler,一般不需要指定sampler,使用dataloader中默认指定的就行
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
# Cannot statically verify that dataset is Sized
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
sampler = RandomSampler(dataset, generator=generator) # type: ignore
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
通过torch.utils.data.DataLoader实现
dataloader = DataLoader(my_dataset, batch_size=4, shuffle=True, num_workers=4)
直接for循环
for i, data in enumerate(dataloader):
imgs, targets = data