Torchvision包括很多流行的数据集、模型架构和用于计算机视觉的常见图像转换模块,它是PyTorch项目的一部分。
Pytorch官方提供的例子展示了如何使用Torchvision的MNIST数据集。
//构造一个MNIST数据集
data = datasets.MNIST('data', train = True, download = True,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1037,), (0.3081,))
]))
本文的重点是分析datasets.MNIST
的设计,包含哪些要素,以及实现自己的dataset时,都要注意什么。
首先,MNIST继承了VisionDataset类。
class MNIST(VisionDataset):
# VisionDataset并没有做什么,只是规定要重写两个特殊方法。
class VisionDataset(data.Dataset):
"""
Base Class For making datasets which are compatible with torchvision.
It is necessary to override the ``__getitem__`` and ``__len__`` method.
然后,定义了数据集的镜像地址
,作用是提供在线下载地址和高可用,当第一个地址无法访问的时候,还可以访问第二个地址。
mirrors = [
'http://yann.lecun.com/exdb/mnist/',
'https://ossci-datasets.s3.amazonaws.com/mnist/',
]
然后是资源列表
,包含了这个数据集包含的所有数据资源,这里面就包括了训练数据、训练标签,测试数据、测试标签。
MNIST官网解释:The MNIST database of handwritten digits, available from this page, has a training set of 60,000 examples, and a test set of 10,000 examples.
来自:http://yann.lecun.com/exdb/mnist/
MNIST的测试数据集包含10k个样本,所以文件名是t10k-
开头。
resources = [
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
]
然后定义了两个文件名,一个是训练数据文件名,另一个是测试数据文件名。
training_file = 'training.pt'
test_file = 'test.pt'
定义图像类别
,从0到9。
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
下面是定义获取训练和测试数据以及标签的方法,同时也可以通过属性方式访问。
这里需要注意的是每个方法中都增加了warning
,告诉大家不要再使用train_data
和test_data
了,要使用data
。
@property
def train_labels(self):
warnings.warn("train_labels has been renamed targets")
return self.targets
@property
def test_labels(self):
warnings.warn("test_labels has been renamed targets")
return self.targets
@property
def train_data(self):
warnings.warn("train_data has been renamed data")
return self.data
@property
def test_data(self):
warnings.warn("test_data has been renamed data")
return self.data
为什么这么设计呢?
因为在领域驱动设计(Domain Driven Design, DDD)中,有一个限界上下文的概念,所谓限界上下文,其实就是一个上下文范围,在这个范围内,使用一套统一语言,不同的范围内,统一语言可以重复,但是意义不同,比如两个范围都用data,但是一个是训练data,一个是测试data。没接触过DDD的同学可能不太好理解这段话,没关系,我们可利用下面的代码来理解。
最开始MNIST数据集这个类设计了train_data和test_data两个方法(属性),但是后来发现,训练(train)和测试(test)其实是两个分开的上下文(Context),完全可以独立使用(也就是发现了坏耦合)。也就是说在使用时,MNIST数据集要么代表训练集,要么代表测试集。于是,就在构造方法加入了train
参数,如果在创建对象时,train为True,就代表要创建训练集,否则创建测试集。
在这两个上下文内,都直接叫data
就行了,不用重复地说“测试上下文中的测试数据集了”,直接说“测试上下文的数据集”。
然后就是构造函数。
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
root
root (string): Root directory of dataset where
MNIST/processed/training.pt
andMNIST/processed/test.pt
exist.
当download是True的时候,这个root代表下载的数据存放的目录。
# 就是这几个文件
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
如果download是False呢,就直接读本地文件。
只有在兼容老的本地文件时,才会去读MNIST/processed/training.pt
和MNIST/processed/test.pt
。
train
train (bool, optional): If True, creates dataset from
training.pt
,
otherwise fromtest.pt
.
代表构建训练数据集还是测试数据集。
download
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
是否需要从网上下载数据集,如果本地目录已经有数据集文件了,就不会重复下载。
transform
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g,transforms.RandomCrop
对数据的转换。
target_transform
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
对标签的转换。
构造函数的内部代码并没有需要特别说明的。只是在创建对象的时候,就调用了download方法,似乎不是特别好的做法。
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(MNIST, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
self.data, self.targets = self._load_data()
下面是检查遗留文件是否存在的方法。其中check_integrity
方法是调用了utils
模块内的函数。
def _check_legacy_exist(self):
# 如果root='data',那么self.processed_folder='data/MNIST/processed'
processed_folder_exists = os.path.exists(self.processed_folder)
if not processed_folder_exists:
return False
return all(
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
)
这里调用了all函数,如果所有文件都验证通过,就返回True,否则返回False。
all()函数是Python中的一个内置函数,如果给定iterable(列表、字典、元组、集合等)的所有元素都为true,则返回true,否则返回False。如果iterable对象为空,它也会返回True。
然后是加载遗留数据文件。
def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file))
加载数据,简单明了,分别调用了read_image_file
和read_label_file
,返回是torch.Tensor
。
def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))
return data, targets
def read_label_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert(x.dtype == torch.uint8)
assert(x.ndimension() == 1)
return x.long()
def read_image_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert(x.dtype == torch.uint8)
assert(x.ndimension() == 3)
return x
然后是实现python的特殊方法。
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.data)
这里面需要注意的是图像数据转换。在代码注释中说,和其他数据集保持一致,返回PIL Image。
PIL是Python Image Library的首字母缩写,实际使用的是pillow分支。
前面的load_data和read_image_file已经完成了原始图像的读取并转换成torch.Tensor
。这里又把数据转换成了PIL Image
。
从后面的代码可以看到,transform方法处理的都是PIL Image
。也就是说,接口使用了Python通用的格式,而没有采用pytorch框架自己的特殊格式,这是不错的。
img = Image.fromarray(img.numpy(), mode='L')
# mode L 表示8比特灰度值,范围0-255
# (8-bit pixels, black and white)
自定义数据转换,包括数据转换和标签转换。标准接口,很好的设计模式。
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)