2020年5月的DW组队学习选择了天池的街景字符编码识别,在这个入门竞赛中,数据集来自Google街景图像中的门牌号数据集(The Street View House Numbers Dataset, SVHN),并根据一定方式采样得到比赛数据集。评测标准为测试集预测结果的准确率,即编码识别正确的数量占测试集图片数量的比率。
组队学习的第二个任务是学习PyTorch的自定义数据集制作方法,并利用torchvision.transforms中的数据扩增函数对样本进行变换,以增强模型的泛化能力。
本章学习手册内容由 王程伟 编写,而本篇博客则是这章内容的笔记,在这里对作者表示感谢,受益匪浅!
在PyTorch中,我们可以自定义数据集,即建立一个类,该类继承 torch.utils.data.dataset 的 Dataset类,并需要重载__getitem__()方法,同时可选择重载__len__()方法。
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite
__getitem__()
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__()
, which is expected to return the size of the dataset by manySampler
implementations and the default options ofDataLoader
.—————— pytorch.org/docs
于是,我们可以如下定义一个名为SVHNDataset的类:
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 假设最长字符为5个, 提供的数据集中0为0(原数据集中0为类别10),所以我们可以用10来表示空字符串
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
这段代码有几个值得注意的地方:
在上一小节中提到类初始化中有3个参数,传入参数时可以按下列方法传入:
train_path = glob.glob(r'data\mchar_train\*.png')
# glob库读取所有文件
train_path.sort()
train_json = json.load(open(r'data\mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]
data = SVHNDataset(train_path, train_label, transforms.Compose([
# 缩放到固定尺寸 PIL的resize函数(可选算法)
transforms.Resize((64, 128)),
# 随机颜色变换(4个参数 亮度 对比度 饱和度 色相)
transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
# 加入随机旋转(输入的是旋转的度数,从-5度到5度)
transforms.RandomRotation(5),
# 将图片转换为pytorch tensor
transforms.ToTensor(),
# 将图像像素归一化(对每个通道做z-score)做了之后效果不好
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
这段代码有几个值得注意的地方:
SVHNDataset类对数据集进行了封装,按照索引可获取样本与标签,但训练时往往是批量训练,因此还需要构建Dataloader,以实现对样本的批量读取。具体代码如下:
train_loader = torch.utils.data.DataLoader(data, batch_size=32, shuffle=False)
其中,batch_size表示一批中的样本数,可自定义;shuffle表示是否将样本打乱次序(类似于洗牌);原教程中还有一个num_workers来设置读取的线程个数,但是在windows下运行会报错,只需要使用默认值0即可。
可以用下面的语句来查看train_loader的元素的格式以及某种样本图片:
for data in train_loader:
print(data[0].shape)
# batch_size*channels*height*width
img = data[0][3].numpy()
img = np.transpose(img, (1, 2, 0))
plt.imshow(img)
break
data[0].shape的输出为:
torch.Size([32, 3, 64, 128])
可见第一个参数为batch_size,第二个为通道数(RGB),第三个为图像高 height,第四个为图像宽 width。
数据扩增是对原始图像作出一系列的变换,以增强模型的泛化能力。一般来说,对于图像分类,数据扩增不会改变标签;对于物体检测,数据扩增会改变物体坐标位置;对于图像分割,数据扩增会改变像素标签。
在torchvision.transforms中定义了多种数据扩增方法:
对于这个赛题,不能进行翻转操作,否则会使某些数字发生变化,如6变成9等等。另外,除了使用torchvision对图片进行数据扩增,还可以使用imgaug、albumentations库进行数据扩增。
在上面的代码中已经包含了数据扩增的操作,我们再来回顾一下:
data = SVHNDataset(train_path, train_label, transforms.Compose([
# 缩放到固定尺寸 PIL的resize函数(可选算法)
transforms.Resize((64, 128)),
# 随机颜色变换(4个参数 亮度 对比度 饱和度 色相)
transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
# 加入随机旋转(输入的是旋转的度数,从-5度到5度)
transforms.RandomRotation(5),
# 将图片转换为pytorch tensor
transforms.ToTensor(),
# 将图像像素归一化(对每个通道做z-score)做了之后效果不好
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
学习手册中提供了五个方法,可以简要了解一下:
transforms.Resize :根据官方文档,这个函数调用了PIL库中的resize方法,可以选择各种算法进行插值取样,默认是最邻近插值。
This can be one of PIL.Image.NEAREST (use nearest neighbour), PIL.Image.BILINEAR (linear interpolation in a 2x2 environment), PIL.Image.BICUBIC (cubic spline interpolation in a 4x4 environment), or PIL.Image.ANTIALIAS (a high-quality downsampling filter).
If omitted, or if the image has mode “1” or “P”, it is set PIL.Image.NEAREST.
transforms.ColorJitter:这个函数有4个参数,分别代表亮度 对比度 饱和度 色相。其具体值设置规则比较复杂,可参考官方文档。这里直接使用学习手册提供的参考值。
transforms.RandomRotation:对图像进行随机旋转,传入的是度数,如传入5代表随机旋转-5到5度。
transforms.ToTensor:转化为PyTorch的tensor,以方便后续训练。
transforms.Normalize:对每个通道进行z-score归一化,第一个列表是均值,列表元素个数对应通道数,第二个参数是方差。正则化的效果如图(训练集第44张,索引43):
感觉什么都看不到了,去掉正则化后,效果如图:
可以看出是有一个轻微的旋转。至此,task2的任务已经全部完成!
此次学习的教程由Datawhale提供,学习手册的链接为:点这里。