本文以pytorch1.10进行解读:torch — PyTorch 1.10 documentation
文本的操作在github上都有Shirley-Xie/pytorch_exercise · GitHub,且有运行结果。
torch.utils.data.
Dataset
(*args, **kwds)
所有表示从键到数据样本映射的数据集都应该将其子类化。所有子类都应该覆盖__getitem__(),支持为给定的键获取数据样本。子类还可以选择性地覆盖__len__(),许多Sampler实现和DataLoader的默认选项都期望它返回数据集的大小。
Dataset定义数据集的内容,类似于列表的数据结构,长度确定,能够用索引获取数据集中的元素。
Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()
这个类方法,作用是接收一个索引, 返回一个样本。
DataLoader定义了按batch加载数据集的方法,它是一个实现了`__iter__`方法的可迭代对象,每次迭代输出一个batch的数据。Dataset数据集相当于一种列表结构不同,IterableDataset相当于一种迭代器结构。 它更加复杂,一般较少使用。
能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。
在绝大部分情况下,用户只需实现Dataset的`__len__`方法和`__getitem__`方法,就可以轻松构建自己的数据集,并用默认数据管道进行加载。
函数签名如下:
torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
)
常用dataset, batch_size, shuffle, num_workers,pin_memory, drop_last这六个参数。
一般实现的代码如下:
ds = TensorDataset(torch.randn(1000,3),
torch.randint(low=0,high=2,size=(1000,)).float())
dl = DataLoader(ds,batch_size=4,drop_last = False)
features,labels = next(iter(dl))
print("features = ",features )
print("labels = ",labels )
结果:
features = tensor([[-0.3192, -1.7329, -1.7346],
[-0.7792, 1.2145, -0.5208],
[ 0.5105, -1.4158, 1.0757],
[-1.3785, -1.3909, -0.7086]])
labels = tensor([0., 0., 0., 1.])
获取一个batch数据的步骤
假定数据集的特征和标签分别表示为张量X和Y,数据集可以表示为(X,Y), 假定batch大小为m 。
总而言之,在一个确定数据集中,按照batch的大小确定索引,然后根据索引取出对应的数据。最后整理成特征和标签在一起的样子。
具体内部方法拆解如下:
# step1: 确定数据集长度 (Dataset的 __len__ 方法实现)
ds = TensorDataset(torch.randn(1000,3),
torch.randint(low=0,high=2,size=(1000,)).float())
print("n = ", len(ds)) # len(ds)等价于 ds.__len__()
# step2: 确定抽样indices (DataLoader中的 Sampler和BatchSampler实现)
sampler = RandomSampler(data_source = ds)
batch_sampler = BatchSampler(sampler = sampler,
batch_size = 4, drop_last = False)
for idxs in batch_sampler:
indices = idxs
break
print("indices = ",indices)
# step3: 取出一批样本batch (Dataset的 __getitem__ 方法实现)
batch = [ds[i] for i in indices] # ds[i] 等价于 ds.__getitem__(i)
print("batch = ", batch)
# step4: 整理成features和labels (DataLoader 的 collate_fn 方法实现)
def collate_fn(batch):
features = torch.stack([sample[0] for sample in batch])
labels = torch.stack([sample[1] for sample in batch])
return features,labels
features,labels = collate_fn(batch)
print("features = ",features)
print("labels = ",labels)
结果:
n = 1000
indices = [426, 137, 471, 292]
batch = [(tensor([1.5614, 0.6875, 1.7250]), tensor(1.)), (tensor([ 0.2853, -1.4416, -0.5672]), tensor(1.)), (tensor([ 0.1800, 0.2652, -0.5301]), tensor(0.)), (tensor([-0.9303, 0.7461, 0.2575]), tensor(1.))]
features = tensor([[ 1.5614, 0.6875, 1.7250],
[ 0.2853, -1.4416, -0.5672],
[ 0.1800, 0.2652, -0.5301],
[-0.9303, 0.7461, 0.2575]])
labels = tensor([1., 1., 0., 1.])
Dataset创建数据集常用的方法有:
此外,还可以通过
此处代码是常见的自定义方法:
class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
rmb面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = rmb_label[sub_dir]
data_info.append((path_img, int(label)))
return data_info
其中get_img_info做的是拿到数据的位置和标签,也就是元祖列表格式。 有了这个list,然后又给了data_info一个index, data_info[index] 就取出了某个(样本i_loc, label_i)。
__getitem__()
这个方法, 是不是很容易理解了, 第一行我们拿到了一个样本的图片路径和标签。然后第二行就是去找到图片,然后转成RGB数值。 第三行就是做了图片的数据预处理,最后返回了这张图片的张量形式和它的标签。
参考文章:
torch — PyTorch 1.10 documentation
GitHub - lyhue1991/eat_pytorch_in_20_days: Pytorch is delicious, just eat it! 系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)_dataloader 输入两个变量-CSDN博客