天池实战-街景字符编码识别-task2数据预处理

写在前面的话

上一节大致了解了一下赛题的相关背景:天池实战-街景字符编码识别-task1赛题理解,从这节开始真正的实操,这节主要介绍数据读取、图像处理以及如何在Pytorch下进行数据的批量加载和处理


数据读取

根据位置标签获取单个字符

图片的位置标签存储在对应的json文件中,所以,针对每一张图片的位置标签需要通过解析json来获取图片的具体位置数据以及标签数据

直接贴json 解析的代码

def parse_json(data):
    """
    解析单个图像的json内容,返回二维数组输出
    二维数组的行分别对应:top、height、left、width、label,每一列都对应图像中的一个数字
    @param data:
    @return:
    """
    arr = np.array([
        data['top'], data['height'], data['left'], data['width'], data['label']
    ])
    arr = arr.astype(int)

    return arr

ok,我们需要传入的是每一张图片对应的json数据,这个也直接贴代码

# 加载训练集的json数据
train_json = json.load(open('dataset/train.json'))
# 读取某一张照片并获取位置标签数据
img = cv2.imread('dataset/train/000023.png')
arr = parse_json(train_json['000023.png'])

# 输出
[[ 59  56]
 [ 34  34]
 [233 248]
 [ 19  21]
 [  8   9]]

二维数组代表的意义想必上节已经说的很清楚了,我们直接通过位置数据将图片中的字符截取出来

为了方便看出效果,我们将原图和截取后的图片进行可视化展示

天池实战-街景字符编码识别-task2数据预处理_第1张图片

相应的代码是这样的:

 # 画出原图片
 plt.figure(figsize=(10, 10))
 plt.subplot(1, arr.shape[1] + 1, 1)
 plt.imshow(img)
 # 设置刻度为空
 plt.xticks([])
 plt.yticks([])

# 画出截取的数字图片
for idx in range(arr.shape[1]):
	plt.subplot(1, arr.shape[1] + 1, idx + 2)
    plt.imshow(img[arr[0, idx]:arr[0, idx] + arr[1, idx], arr[2, idx]:arr[2, idx] + arr[3, idx]])
    plt.title(arr[4, idx])
    # 设置刻度为空
    plt.xticks([])
    plt.yticks([])

plt.show()

对于上面的操作稍微解释一下,通过OpenCV读取图片,然后通过图片中的位置数据,截取到相应的字符,并将json中该字符对应的label显示在其上方。

当然对于图片读取还可以通过Pillow去读取,这里就不多介绍两者的区别,后面我们也会再次用到Pillow

图片中的字符个数

这里需要注意一个问题:图片中的字符个数

我们上面实例中的图片是有两个字符:8和9。但是实际图片中的字符个数是不可控的,有1个字符的图片、2个字符的图片等等。官方给的数据,一张图片中最多不超过6个字符。

那么在进行图片字符处理的时候,我们就需要将图片中的字符进行处理,比如:补齐字符为6位或者通过物体检测法检测字符个数再进行识别

这里我们介绍以下第一种方法,补齐法。

因为官方已经说明了一张图片中字符不会超过6个,那么我们就所有的图片都补齐成6位,这样每张图片都会成为六个字符的多分类问题。

补齐的时候,因为SVHN数据集将10指定为数字0的标签,所以我们可以通过数字10进行空位的填充

例如,下图

天池实战-街景字符编码识别-task2数据预处理_第2张图片

字符23填充为23XXXX,字符231填充为231XXX,这里的X我们使用10填充,具体会在下文中体现


基于Pytorch实现图像读取

在Pytorch中实现图像读取主要基于两个基类:Dataset和DataLoader,Dataset主要是通过索引加载图片并进行相应的处理,而DataLoader则可以进行batch Dataset(想象成图片批处理)

先来看DataLoader

torch.utils.data.DataLoader() :构建可迭代的数据装载器,我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据。

batch_size是什么?这个就涉及到DataLoader的参数了,DataLoader的主要主要参数有5个:

  • dataset:Dataset类, 数据的读取和图片预处理
  • batch_size:批大小,默认为1
  • num_workers:是否多进程读取机制,默认为0
  • shuffle:每个epoch是否乱序,默认为False
  • drop_last:当样本数不能被batch_size整除时, 是否舍弃最后一批数据。默认为False

最后一个参数比较有意思,比如说现在又101个样本,batch_size是10,那么drop_last为False的时候,表示不舍弃最后一批数据,那么将会循环取10+1次数据;相反,如果drop_last为True的时候,只会取10次数据

上面还有一个参数是Dataset,也就是我们下面要说的


再来看Dataset

torch.utils.data.Dataset() :Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法。

__getitem__方法的是Dataset的核心,作用是接收一个索引, 返回一个样本, 参数里面接收index,然后我们需要编写究竟如何根据这个索引去读取我们的数据部分。

另外在Dataset中还可以定义transforms,可以通过transforms进行图像的预处理

直接看代码吧

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transforms=None):
        self.img_path = img_path
        self.img_label = img_label
        self.transforms = transforms

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')
        if self.transforms is not None:
            img = self.transforms(img)
        # 设置最长的字符长度为6个
        lbl = np.array(self.img_label[index], dtype=np.int)
        # 使用10填充剩余的位置
        lbl = list(lbl) + (6 - len(lbl)) * [10]
        return img, torch.from_numpy(np.array(lbl[:6]))

    def __len__(self):
        return len(self.img_path)

SVHNDataset 继承了 Dataset 基类,并且重写了__getitem____len__方法

其中在__getitem__方法中,我们通过索引进行图片的读取,并且使用填充方法对不足6位的图像字符进行填充

另外,可以看到在__init__方法中,我们需要传入单个图片的路径和图片中字符对应的标签,并且我们可以手动传入相对应的transforms方法

既然说到了transforms,一起来看下在Pytorch中存在哪些图片的预处理操作


最后看transforms

transforms 是torchvision计算机视觉工具包最常用的图像预处理方法。 在torchvision中,有三个主要的模块,分别是transforms、datasets和models

其中transforms包括实现图像裁剪、图像的旋转和图像变换,并通过transforms方法实现更多的图像操作,包括但不限于:数据中心化,数据标准化,缩放,裁剪,旋转,翻转,填充,噪声添加,灰度变换,线性变换,仿射变换,亮度、饱和度及对比度变换等等

图像裁剪

  • transforms.CenterCrop(size):图像中心裁剪图片, size是所需裁剪的图片尺寸,如果比原始图像大了, 会默认填充0。
  • transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant):从图片中位置随机裁剪出尺寸为size的图片,size是尺寸大小,padding设置填充大小,pad_if_need: 若图像小于设定的size, 则填充。 padding_mode表示填充模型, 有4种:constant像素值由fill设定,edge像素值由图像边缘像素设定,reflect镜像填充,symmetric也是镜像填充。
  • transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3/4, 4/3), interpolation):随机大小,长宽比裁剪图片。 scale表示随机裁剪面积比例,ratio随机长宽比,interpolation表示插值方法。
  • FiveCrop, TenCrop:在图像的上下左右及中心裁剪出尺寸为size的5张图片,后者还在这5张图片的基础上再水平或者垂直镜像得到10张图片。

图像的翻转和旋转

  • RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5):依概率水平或者垂直翻转图片,p表示翻转概率
  • RandomRotation(degrees, resample=False, expand=False, center=None):随机旋转图片,degrees表示旋转角度,resample表示重采样方法,expand表示是否扩大图片,以保持原图信息。

图像变换(加粗的为常用)

  • transforms.Compose: 将一系列的transforms方法进行有序的组合包装,具体实现的时候,依次的用包装的方法对图像进行操作。

  • transforms.Resize: 改变图像大小

  • transforms.RandomCrop: 对图像进行裁剪

  • transforms.ToTensor: 将图像转换成张量,同时会进行归一化的一个操作,将张量的值从0-255转到0-1

  • transforms.Normalize: 将数据进行标准化

  • transforms.Pad(padding, fill=0, padding_mode=‘constant’): 对图片边缘进行填充

  • transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):调整亮度、对比度、饱和度和色相。brightness是亮度调节因子,contrast对比度参数,saturation饱和度参数,hue是色相因子。

  • transfor.RandomGrayscale(num_output_channels, p=0.1):依概率将图片转换为灰度图,第一个参数是通道数,只能1或3,p是概率值,转换为灰度图像的概率

  • transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):对图像进行仿射变换,反射变换是二维的线性变换 由五中基本原子变换构成,分别是旋转,平移,缩放,错切和翻转。 degrees表示旋转角度,translate表示平移区间设置,scale表示缩放比例,fill_color填充颜色设置,shear表示错切

  • transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): 对图像进行随机遮挡,p概率值,scale遮挡区域的面积,ratio遮挡区域长宽比,value遮挡像素。 随机遮挡有利于模型识别被遮挡的图片。这个是对张量进行操作,所以需要先转成张量才能做


在本次项目中,实际上对图片预处理只用到了部分方法,代码如下:

def train_transform():
    """
    针对训练集数据的图像处理
    @return:
    """
    return transforms.Compose([
        # 改变图像大小
        transforms.Resize((64, 128)),
        # 调整亮度、对比度、饱和度和色相。brightness是亮度调节因子,contrast对比度参数,saturation饱和度参数,hue是色相因子。
        transforms.ColorJitter(0.3, 0.3, 0.2),
        # 随机旋转图片,degrees表示旋转角度,resample表示重采样方法,expand表示是否扩大图片,以保持原图信息。
        transforms.RandomRotation(5),
        # 将图像转换成张量,同时会进行归一化的一个操作,将张量的值从0-255转到0-1
        transforms.ToTensor(),
        # 数据进行标准化
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

到此,关于Pytorch读取数据、进行图像预处理已经搞定了,我们在使用的过程中传入相关参数即可,代码如下:

"""获取训练集相关数据"""
# 获取训练集目录下的所有图片
train_path = glob.glob('dataset/train/*.png')
train_path.sort()
# 获取训练集图片对应的位置标签和label数据
train_json = json.load(open('dataset/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
# 通过Pytorch加载并处理数据
train_loader = DataLoader(SVHNDataset(train_path, train_label, train_transform()),
						  batch_size=40, shuffle=True, num_workers=10,
                          )

总结

本节通过对图片数据的读取和可视化展示了相关数据的使用,并通过Pytorch 实现了数据批读取和图像的相应预处理操作。

你可能感兴趣的:(天池项目实战)