DataWhale-天池街景数字识别竞赛-task2-数据载入

背景

2020年5月的DW组队学习选择了天池的街景字符编码识别,在这个入门竞赛中,数据集来自Google街景图像中的门牌号数据集(The Street View House Numbers Dataset, SVHN),并根据一定方式采样得到比赛数据集。评测标准为测试集预测结果的准确率,即编码识别正确的数量测试集图片数量的比率。

组队学习的第二个任务是学习PyTorch的自定义数据集制作方法,并利用torchvision.transforms中的数据扩增函数对样本进行变换,以增强模型的泛化能力。

本章学习手册内容由 王程伟 编写,而本篇博客则是这章内容的笔记,在这里对作者表示感谢,受益匪浅!

自定义数据集

在PyTorch中,我们可以自定义数据集,即建立一个类,该类继承 torch.utils.data.dataset 的 Dataset类,并需要重载__getitem__()方法,同时可选择重载__len__()方法。

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

                                                                                                                         —————— pytorch.org/docs

 于是,我们可以如下定义一个名为SVHNDataset的类:

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None
    
    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        
        # 假设最长字符为5个, 提供的数据集中0为0(原数据集中0为类别10),所以我们可以用10来表示空字符串
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl) + (5 - len(lbl)) * [10]

        return img, torch.from_numpy(np.array(lbl[:5]))
    
    def __len__(self):
        return len(self.img_path)

这段代码有几个值得注意的地方:

  1. 类初始化传入的参数为图像所在路径(方便用索引打开图片),图像的标签(即正确的数字)以及数据扩增中用到的图像转换方法。
  2. getitem中Image打开图片后转化为RGB格式,我想是为了保险起见,因为按道理PIL库打开图片默认是按RGB来排列通道。但是另一个图像处理库OpenCV则并非如此,OpenCV打开图片默认是BGR通道,这主要是与其起源时的规则有关,一直没有改变沿用至今,参考这篇博文。
  3. lbl = list(lbl) + (5 - len(lbl)) * [10]  一句的意思是在原标签基础上补足字符至5个,也就是空字符设为标签10,然后使得所有标签长度相等,化为定长字符识别问题。在原数据集中,数字0的标签是10,但是天池这次提供的数据集中0的标签是0,所以我们可以用10来标记空字符。另外,上一章中提到化为6个定长字符,这里因为考虑到6个字符的数据只有一个,所以把它忽略了,以后统一按5个定长字符处理

类初始化

在上一小节中提到类初始化中有3个参数,传入参数时可以按下列方法传入:

train_path = glob.glob(r'data\mchar_train\*.png')
# glob库读取所有文件
train_path.sort()
train_json = json.load(open(r'data\mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]

data = SVHNDataset(train_path, train_label, transforms.Compose([
    # 缩放到固定尺寸 PIL的resize函数(可选算法)
    transforms.Resize((64, 128)),
    # 随机颜色变换(4个参数 亮度 对比度 饱和度 色相)
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    # 加入随机旋转(输入的是旋转的度数,从-5度到5度)
    transforms.RandomRotation(5),
    # 将图片转换为pytorch tensor
    transforms.ToTensor(),
    # 将图像像素归一化(对每个通道做z-score)做了之后效果不好
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))

这段代码有几个值得注意的地方:

  1. 第一行用glob库一次读取所有路径,这个操作值得学习。
  2. 第3个参数时数据扩增的一些方法,可以用Compose函数进行集成,作为一个参数传入。关于数据扩增的一些介绍将在下文描述

制作批量读取类DataLoader

SVHNDataset类对数据集进行了封装,按照索引可获取样本与标签,但训练时往往是批量训练,因此还需要构建Dataloader,以实现对样本的批量读取。具体代码如下:

train_loader = torch.utils.data.DataLoader(data, batch_size=32, shuffle=False)

其中,batch_size表示一批中的样本数,可自定义;shuffle表示是否将样本打乱次序(类似于洗牌);原教程中还有一个num_workers来设置读取的线程个数,但是在windows下运行会报错,只需要使用默认值0即可。

可以用下面的语句来查看train_loader的元素的格式以及某种样本图片:

for data in train_loader:
    print(data[0].shape)
    # batch_size*channels*height*width
    img = data[0][3].numpy()
    img = np.transpose(img, (1, 2, 0))
    plt.imshow(img)
    break

data[0].shape的输出为:

torch.Size([32, 3, 64, 128])

可见第一个参数为batch_size,第二个为通道数RGB),第三个为图像高 height,第四个为图像宽 width

数据扩增

数据扩增是对原始图像作出一系列的变换,以增强模型的泛化能力。一般来说,对于图像分类,数据扩增不会改变标签;对于物体检测,数据扩增会改变物体坐标位置;对于图像分割,数据扩增会改变像素标签。

torchvision.transforms中定义了多种数据扩增方法:

  • CenterCrop :对图片中心进行裁剪
  • FiveCrop :对图片四个角和中心进行裁剪得到五分图像
  • ColorJitter :对图像颜色的亮度、对比度、饱和度和色相进行变换
  • Grayscale :对图像进行灰度变换
  • Pad :使用固定值进行像素填充
  • RandomAffine :随机仿射变换
  • RandomCrop :随机区域裁剪
  • RandomHorizontalFlip :随机水平翻转
  • RandomRotation :随机旋转
  • RandomVerticalFlip :随机垂直翻转

对于这个赛题,不能进行翻转操作,否则会使某些数字发生变化,如6变成9等等。另外,除了使用torchvision对图片进行数据扩增,还可以使用imgaug、albumentations库进行数据扩增。

在上面的代码中已经包含了数据扩增的操作,我们再来回顾一下:

data = SVHNDataset(train_path, train_label, transforms.Compose([
    # 缩放到固定尺寸 PIL的resize函数(可选算法)
    transforms.Resize((64, 128)),
    # 随机颜色变换(4个参数 亮度 对比度 饱和度 色相)
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    # 加入随机旋转(输入的是旋转的度数,从-5度到5度)
    transforms.RandomRotation(5),
    # 将图片转换为pytorch tensor
    transforms.ToTensor(),
    # 将图像像素归一化(对每个通道做z-score)做了之后效果不好
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))

学习手册中提供了五个方法,可以简要了解一下:

  • transforms.Resize :根据官方文档,这个函数调用了PIL库中的resize方法,可以选择各种算法进行插值取样,默认是最邻近插值。

This can be one of PIL.Image.NEAREST (use nearest neighbour), PIL.Image.BILINEAR (linear interpolation in a 2x2 environment), PIL.Image.BICUBIC (cubic spline interpolation in a 4x4 environment), or PIL.Image.ANTIALIAS (a high-quality downsampling filter).

If omitted, or if the image has mode “1” or “P”, it is set PIL.Image.NEAREST.

  • transforms.ColorJitter:这个函数有4个参数,分别代表亮度 对比度 饱和度 色相。其具体值设置规则比较复杂,可参考官方文档。这里直接使用学习手册提供的参考值。

  • transforms.RandomRotation:对图像进行随机旋转,传入的是度数,如传入5代表随机旋转-5到5度。

  • transforms.ToTensor:转化为PyTorch的tensor,以方便后续训练。

  • transforms.Normalize:对每个通道进行z-score归一化,第一个列表是均值,列表元素个数对应通道数,第二个参数是方差。正则化的效果如图(训练集第44张,索引43):

DataWhale-天池街景数字识别竞赛-task2-数据载入_第1张图片

感觉什么都看不到了,去掉正则化后,效果如图:

DataWhale-天池街景数字识别竞赛-task2-数据载入_第2张图片

可以看出是有一个轻微的旋转。至此,task2的任务已经全部完成!

最后

此次学习的教程由Datawhale提供,学习手册的链接为:点这里。

你可能感兴趣的:(Datawhale)