pytorch中DataLoader浅析

1.这里不解释源码(因为我也看不懂)

2. 简单解释

dataloader本质是一个可迭代(iterable)对象,使用iter()访问,不能使用next()访问;

( iter和next:迭代器和生成器)
下面的代码为DataLoader的源码:

    def __iter__(self) -> '_BaseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

3.DataLoader()实例化对象的访问

既然叫做可迭代的对象,其实就可以当做 迭代器 处理

方法一:

使用iter(dataloader)返回迭代器,
再使用next()访问;

方法二:

 for inputs, labels in dataloaders :

进行可迭代对象的访问;

一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;

3.构建DataLoader的pipeline

  1. 构建一个 Dataset 对象
  2. 构建一个 DataLoader 对象
  3. 循环这个 DataLoader 对象,将img, label加载到模型中进行训练
class Diabetes(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

# 生成一个Diabetes对象,自动调用构造函数,数据就写入了self.x_data,self.y_data中



dataset = Diabetes('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

...
...
# 下面的 i 的作用就是看一下每个epoch分成了多少个batch_size
for epoch in range(10):
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        y_pred = model(inputs)
        loss = criterion(y_pred, labels)
        print('epoch:', epoch, 'i:', i, 'loss:', loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

4.DataLoader作用总结

将自定义的Dataset封装成一个Batch Size大小的Tensor,用于后面的训练。

参考来源:
pytorch之dataloader深入剖析

你可能感兴趣的:(pytorch系列)