Torchvision中datasets.MNIST设计方法分析

文章目录

  • 前言
  • 逐行分析MNIST代码
  • 设计要点小结

前言

Torchvision包括很多流行的数据集、模型架构和用于计算机视觉的常见图像转换模块,它是PyTorch项目的一部分。

Pytorch官方提供的例子展示了如何使用Torchvision的MNIST数据集。

//构造一个MNIST数据集
data = datasets.MNIST('data', train = True, download = True,
          transform = transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1037,), (0.3081,))
          ]))

本文的重点是分析datasets.MNIST的设计,包含哪些要素,以及实现自己的dataset时,都要注意什么。

逐行分析MNIST代码

首先,MNIST继承了VisionDataset类。

class MNIST(VisionDataset):

# VisionDataset并没有做什么,只是规定要重写两个特殊方法。
class VisionDataset(data.Dataset):
    """
    Base Class For making datasets which are compatible with torchvision.
    It is necessary to override the ``__getitem__`` and ``__len__`` method.

然后,定义了数据集的镜像地址,作用是提供在线下载地址和高可用,当第一个地址无法访问的时候,还可以访问第二个地址。

mirrors = [
        'http://yann.lecun.com/exdb/mnist/',
        'https://ossci-datasets.s3.amazonaws.com/mnist/',
    ]

然后是资源列表,包含了这个数据集包含的所有数据资源,这里面就包括了训练数据、训练标签,测试数据、测试标签。

MNIST官网解释:The MNIST database of handwritten digits, available from this page, has a training set of 60,000 examples, and a test set of 10,000 examples.
来自:http://yann.lecun.com/exdb/mnist/

MNIST的测试数据集包含10k个样本,所以文件名是t10k-开头。

resources = [
        ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
        ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
        ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
        ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
    ]

然后定义了两个文件名,一个是训练数据文件名,另一个是测试数据文件名。

training_file = 'training.pt'
test_file = 'test.pt'

定义图像类别,从0到9。

classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

下面是定义获取训练和测试数据以及标签的方法,同时也可以通过属性方式访问。
这里需要注意的是每个方法中都增加了warning,告诉大家不要再使用train_datatest_data了,要使用data

@property
def train_labels(self):
    warnings.warn("train_labels has been renamed targets")
    return self.targets

@property
def test_labels(self):
    warnings.warn("test_labels has been renamed targets")
    return self.targets

@property
def train_data(self):
    warnings.warn("train_data has been renamed data")
    return self.data

@property
def test_data(self):
    warnings.warn("test_data has been renamed data")
    return self.data

为什么这么设计呢?

因为在领域驱动设计(Domain Driven Design, DDD)中,有一个限界上下文的概念,所谓限界上下文,其实就是一个上下文范围,在这个范围内,使用一套统一语言,不同的范围内,统一语言可以重复,但是意义不同,比如两个范围都用data,但是一个是训练data,一个是测试data。没接触过DDD的同学可能不太好理解这段话,没关系,我们可利用下面的代码来理解。

最开始MNIST数据集这个类设计了train_data和test_data两个方法(属性),但是后来发现,训练(train)和测试(test)其实是两个分开的上下文(Context),完全可以独立使用(也就是发现了坏耦合)。也就是说在使用时,MNIST数据集要么代表训练集,要么代表测试集。于是,就在构造方法加入了train参数,如果在创建对象时,train为True,就代表要创建训练集,否则创建测试集。
在这两个上下文内,都直接叫data就行了,不用重复地说“测试上下文中的测试数据集了”,直接说“测试上下文的数据集”。

然后就是构造函数。

def __init__(
    self,
    root: str,
    train: bool = True,
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    download: bool = False,
) -> None:

root

root (string): Root directory of dataset where MNIST/processed/training.pt
and MNIST/processed/test.pt exist.

当download是True的时候,这个root代表下载的数据存放的目录。

# 就是这几个文件
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz

如果download是False呢,就直接读本地文件。

只有在兼容老的本地文件时,才会去读MNIST/processed/training.ptMNIST/processed/test.pt

train

train (bool, optional): If True, creates dataset from training.pt,
otherwise from test.pt.

代表构建训练数据集还是测试数据集。

download

download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.

是否需要从网上下载数据集,如果本地目录已经有数据集文件了,就不会重复下载。

transform

transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, transforms.RandomCrop

对数据的转换。

target_transform

target_transform (callable, optional): A function/transform that takes in the
target and transforms it.

对标签的转换。

构造函数的内部代码并没有需要特别说明的。只是在创建对象的时候,就调用了download方法,似乎不是特别好的做法。


def __init__(
            self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    ) -> None:
    super(MNIST, self).__init__(root, transform=transform,
                                target_transform=target_transform)
    self.train = train  # training set or test set

    if self._check_legacy_exist():
        self.data, self.targets = self._load_legacy_data()
        return

    if download:
        self.download()

    if not self._check_exists():
        raise RuntimeError('Dataset not found.' +
                           ' You can use download=True to download it')

    self.data, self.targets = self._load_data()

下面是检查遗留文件是否存在的方法。其中check_integrity方法是调用了utils模块内的函数。

def _check_legacy_exist(self):
    # 如果root='data',那么self.processed_folder='data/MNIST/processed'
    processed_folder_exists = os.path.exists(self.processed_folder)
    if not processed_folder_exists:
        return False

    return all(
        check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
    )

这里调用了all函数,如果所有文件都验证通过,就返回True,否则返回False。

all()函数是Python中的一个内置函数,如果给定iterable(列表、字典、元组、集合等)的所有元素都为true,则返回true,否则返回False。如果iterable对象为空,它也会返回True。

然后是加载遗留数据文件。

def _load_legacy_data(self):
    # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
    # directly.
    data_file = self.training_file if self.train else self.test_file
    return torch.load(os.path.join(self.processed_folder, data_file))

加载数据,简单明了,分别调用了read_image_fileread_label_file,返回是torch.Tensor

def _load_data(self):
    image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
    data = read_image_file(os.path.join(self.raw_folder, image_file))

    label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
    targets = read_label_file(os.path.join(self.raw_folder, label_file))

    return data, targets

def read_label_file(path: str) -> torch.Tensor:
    x = read_sn3_pascalvincent_tensor(path, strict=False)
    assert(x.dtype == torch.uint8)
    assert(x.ndimension() == 1)
    return x.long()


def read_image_file(path: str) -> torch.Tensor:
    x = read_sn3_pascalvincent_tensor(path, strict=False)
    assert(x.dtype == torch.uint8)
    assert(x.ndimension() == 3)
    return x

然后是实现python的特殊方法。

def __getitem__(self, index: int) -> Tuple[Any, Any]:
    """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data[index], int(self.targets[index])

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img.numpy(), mode='L')

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

def __len__(self) -> int:
    return len(self.data)

这里面需要注意的是图像数据转换。在代码注释中说,和其他数据集保持一致,返回PIL Image。

PIL是Python Image Library的首字母缩写,实际使用的是pillow分支。

前面的load_data和read_image_file已经完成了原始图像的读取并转换成torch.Tensor。这里又把数据转换成了PIL Image

从后面的代码可以看到,transform方法处理的都是PIL Image。也就是说,接口使用了Python通用的格式,而没有采用pytorch框架自己的特殊格式,这是不错的。

img = Image.fromarray(img.numpy(), mode='L')

# mode L 表示8比特灰度值,范围0-255
# (8-bit pixels, black and white)

自定义数据转换,包括数据转换和标签转换。标准接口,很好的设计模式。

if self.transform is not None:
        img = self.transform(img)

if self.target_transform is not None:
    target = self.target_transform(target)

设计要点小结

  1. 数据集分为训练集和测试集,实现上相同,但概念上分开。
  2. 一个数据集包含了下载地址(镜像地址)、资源列表、数据标签,是数据集的元数据。
  3. 要实现数据集的存储和加载。
    最简单的就是保存到本地文件系统。如果是上云的环境,比如私有云,还可以保存或者读取对象存储、分布式存储等。这就需要同步资源的地址和下载逻辑。
  4. 要进行基本的数据转换,并使用Python通用的数据格式(PIL)。
  5. 要提供自定义数据变换的接口(Transform)。
  6. 实现必要的特殊方法(getitem__和__len)。

你可能感兴趣的:(Python,深度学习,深度学习,pytorch,python)