PyTorch - 36 - PyTorch DataLoader源代码 - Debugging Session

PyTorch - 36 - PyTorch DataLoader源代码 - Debugging Session

  • Short Program To Debug PyTorch Source
  • Debugging The PyTorch Source Code
    • The Sampler: To Shuffle Or Not
  • How The Batch Size Is Used
  • Normalizing The Dataset

Short Program To Debug PyTorch Source

在开始调试之前,我们只想简要介绍一下我们编写的程序,使我们可以进入并查看数据集的规范化,并确切地了解如何在引擎盖和PyTorch下完成它。

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn

from torch.utils.data import DataLoader

从torch.utils.data导入DataLoader
正如我们在上一集中讨论的那样,我们具有均值和标准差值。现在,不必计算它们,我们只需要提取它们并将它们硬编码到此处的程序中即可。

mean = 0.2860347330570221
std = 0.3530242443084717

如果要脱机获取这些值,我们将做这种事情。

我们不想麻烦重新计算这些值,因此我们在这里很难对其进行描述。我们有均值和标准差,我们知道我们需要这两个值才能对数据集的每个成员或每个像素进行归一化。

接下来,我们使用FashionMNIST类构造函数初始化训练集。这里要注意或要注意的关键是变换。我们有变换的组成。

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
        , transforms.Normalize(mean, std)
    ])
)

合成的第一个将药丸图像转换为张量,然后第二个为归一化转换,这将对我们的数据进行归一化。我们的目标是在源代码中验证此特定转换的工作方式。

最后,我们创建一个DataLoader并使用它。

loader = DataLoader(train_set, batch_size=1)
image, label = next(iter(loader))

Debugging The PyTorch Source Code

好的,现在我们可以进行实际调试了。要进行调试,我们将继续进行,只需确保选择了我的python运行配置,然后单击开始调试。

使用此链接可以访问PyTorch DataLoader类的当前源代码。本讨论假定PyTorch版本1.5.0。

The Sampler: To Shuffle Or Not

采样器是获取索引值的对象,该索引值将用于从基础数据集中获取实际值。

我们可以看到,有两个相关的特定采样器,随机采样器和顺序采样器。

  1. 随机取样器
  2. 顺序采样器

如果混洗值是true,则采样器将是随机采样器,否则将是连续采样器。

How The Batch Size Is Used

我们发现采样器用于在以下代码中收集索引值:

def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch

在这里,我们可以看到batch_size参数在起作用,因为它限制了所收集索引值的数量。

请注意,此处的yield关键字使此迭代器成为所谓的生成器。

获取索引值后,它们将通过以下方式用于获取数据:

def fetch(self, possibly_batched_index):
    if self.auto_collation:
        data = [self.dataset[idx] for idx in possibly_batched_index]
    else:
        data = self.dataset[possibly_batched_index]
    return self.collate_fn(data)       

从基础数据集中提取每个样本的工作就像这样做。

data = [self.dataset[idx] for idx in possibly_batched_index]

该语法或符号称为列表理解。

这将返回数据元素列表,然后使用collat​​e_fn()方法将其提取并放入单个批处理张量中。

Normalizing The Dataset

最后,我们发现使用功能性api的normalize()函数对返回到批处理中的每个元素进行了规范化。

def normalize(tensor, mean, std, inplace=False):
    """Normalize a tensor image with mean and standard deviation.
    tensor.sub_(mean).div_(std)
    return tensor

请注意,数据集类调用一个转换,然后调用功能性api。我们还遇到了一些错误的设计,需要进行一些修改才能保持一致。

请注意,此处使用“黑客”一词是指我们看到代码在进行不必要的转换。

你可能感兴趣的:(PyTorch,深度学习,pytorch,机器学习,数据挖掘,神经网络)