上一节大致了解了一下赛题的相关背景:天池实战-街景字符编码识别-task1赛题理解,从这节开始真正的实操,这节主要介绍数据读取、图像处理以及如何在Pytorch下进行数据的批量加载和处理
图片的位置标签存储在对应的json文件中,所以,针对每一张图片的位置标签需要通过解析json来获取图片的具体位置数据以及标签数据
直接贴json 解析的代码
def parse_json(data):
"""
解析单个图像的json内容,返回二维数组输出
二维数组的行分别对应:top、height、left、width、label,每一列都对应图像中的一个数字
@param data:
@return:
"""
arr = np.array([
data['top'], data['height'], data['left'], data['width'], data['label']
])
arr = arr.astype(int)
return arr
ok,我们需要传入的是每一张图片对应的json数据,这个也直接贴代码
# 加载训练集的json数据
train_json = json.load(open('dataset/train.json'))
# 读取某一张照片并获取位置标签数据
img = cv2.imread('dataset/train/000023.png')
arr = parse_json(train_json['000023.png'])
# 输出
[[ 59 56]
[ 34 34]
[233 248]
[ 19 21]
[ 8 9]]
二维数组代表的意义想必上节已经说的很清楚了,我们直接通过位置数据将图片中的字符截取出来
为了方便看出效果,我们将原图和截取后的图片进行可视化展示
相应的代码是这样的:
# 画出原图片
plt.figure(figsize=(10, 10))
plt.subplot(1, arr.shape[1] + 1, 1)
plt.imshow(img)
# 设置刻度为空
plt.xticks([])
plt.yticks([])
# 画出截取的数字图片
for idx in range(arr.shape[1]):
plt.subplot(1, arr.shape[1] + 1, idx + 2)
plt.imshow(img[arr[0, idx]:arr[0, idx] + arr[1, idx], arr[2, idx]:arr[2, idx] + arr[3, idx]])
plt.title(arr[4, idx])
# 设置刻度为空
plt.xticks([])
plt.yticks([])
plt.show()
对于上面的操作稍微解释一下,通过OpenCV读取图片,然后通过图片中的位置数据,截取到相应的字符,并将json中该字符对应的label显示在其上方。
当然对于图片读取还可以通过Pillow去读取,这里就不多介绍两者的区别,后面我们也会再次用到Pillow
这里需要注意一个问题:图片中的字符个数
我们上面实例中的图片是有两个字符:8和9。但是实际图片中的字符个数是不可控的,有1个字符的图片、2个字符的图片等等。官方给的数据,一张图片中最多不超过6个字符。
那么在进行图片字符处理的时候,我们就需要将图片中的字符进行处理,比如:补齐字符为6位或者通过物体检测法检测字符个数再进行识别
这里我们介绍以下第一种方法,补齐法。
因为官方已经说明了一张图片中字符不会超过6个,那么我们就所有的图片都补齐成6位,这样每张图片都会成为六个字符的多分类问题。
补齐的时候,因为SVHN数据集将10指定为数字0的标签,所以我们可以通过数字10进行空位的填充
例如,下图
字符23填充为23XXXX,字符231填充为231XXX,这里的X我们使用10填充,具体会在下文中体现
在Pytorch中实现图像读取主要基于两个基类:Dataset和DataLoader,Dataset主要是通过索引加载图片并进行相应的处理,而DataLoader则可以进行batch Dataset(想象成图片批处理)
torch.utils.data.DataLoader() :构建可迭代的数据装载器,我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据。
batch_size是什么?这个就涉及到DataLoader的参数了,DataLoader的主要主要参数有5个:
最后一个参数比较有意思,比如说现在又101个样本,batch_size是10,那么drop_last为False的时候,表示不舍弃最后一批数据,那么将会循环取10+1次数据;相反,如果drop_last为True的时候,只会取10次数据
上面还有一个参数是Dataset,也就是我们下面要说的
torch.utils.data.Dataset() :Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()
这个类方法。
__getitem__
方法的是Dataset的核心,作用是接收一个索引, 返回一个样本, 参数里面接收index,然后我们需要编写究竟如何根据这个索引去读取我们的数据部分。
另外在Dataset中还可以定义transforms,可以通过transforms进行图像的预处理
直接看代码吧
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transforms=None):
self.img_path = img_path
self.img_label = img_label
self.transforms = transforms
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transforms is not None:
img = self.transforms(img)
# 设置最长的字符长度为6个
lbl = np.array(self.img_label[index], dtype=np.int)
# 使用10填充剩余的位置
lbl = list(lbl) + (6 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:6]))
def __len__(self):
return len(self.img_path)
SVHNDataset 继承了 Dataset 基类,并且重写了__getitem__
和__len__
方法
其中在__getitem__
方法中,我们通过索引进行图片的读取,并且使用填充方法对不足6位的图像字符进行填充
另外,可以看到在__init__
方法中,我们需要传入单个图片的路径和图片中字符对应的标签,并且我们可以手动传入相对应的transforms
方法
既然说到了transforms,一起来看下在Pytorch中存在哪些图片的预处理操作
transforms 是torchvision计算机视觉工具包最常用的图像预处理方法。 在torchvision中,有三个主要的模块,分别是transforms、datasets和models
其中transforms包括实现图像裁剪、图像的旋转和图像变换,并通过transforms方法实现更多的图像操作,包括但不限于:数据中心化,数据标准化,缩放,裁剪,旋转,翻转,填充,噪声添加,灰度变换,线性变换,仿射变换,亮度、饱和度及对比度变换等等
图像裁剪
图像的翻转和旋转
图像变换(加粗的为常用)
transforms.Compose: 将一系列的transforms方法进行有序的组合包装,具体实现的时候,依次的用包装的方法对图像进行操作。
transforms.Resize: 改变图像大小
transforms.RandomCrop: 对图像进行裁剪
transforms.ToTensor: 将图像转换成张量,同时会进行归一化的一个操作,将张量的值从0-255转到0-1
transforms.Normalize: 将数据进行标准化
transforms.Pad(padding, fill=0, padding_mode=‘constant’): 对图片边缘进行填充
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):调整亮度、对比度、饱和度和色相。brightness是亮度调节因子,contrast对比度参数,saturation饱和度参数,hue是色相因子。
transfor.RandomGrayscale(num_output_channels, p=0.1):依概率将图片转换为灰度图,第一个参数是通道数,只能1或3,p是概率值,转换为灰度图像的概率
transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):对图像进行仿射变换,反射变换是二维的线性变换 由五中基本原子变换构成,分别是旋转,平移,缩放,错切和翻转。 degrees表示旋转角度,translate表示平移区间设置,scale表示缩放比例,fill_color填充颜色设置,shear表示错切
transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): 对图像进行随机遮挡,p概率值,scale遮挡区域的面积,ratio遮挡区域长宽比,value遮挡像素。 随机遮挡有利于模型识别被遮挡的图片。这个是对张量进行操作,所以需要先转成张量才能做
在本次项目中,实际上对图片预处理只用到了部分方法,代码如下:
def train_transform():
"""
针对训练集数据的图像处理
@return:
"""
return transforms.Compose([
# 改变图像大小
transforms.Resize((64, 128)),
# 调整亮度、对比度、饱和度和色相。brightness是亮度调节因子,contrast对比度参数,saturation饱和度参数,hue是色相因子。
transforms.ColorJitter(0.3, 0.3, 0.2),
# 随机旋转图片,degrees表示旋转角度,resample表示重采样方法,expand表示是否扩大图片,以保持原图信息。
transforms.RandomRotation(5),
# 将图像转换成张量,同时会进行归一化的一个操作,将张量的值从0-255转到0-1
transforms.ToTensor(),
# 数据进行标准化
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
到此,关于Pytorch读取数据、进行图像预处理已经搞定了,我们在使用的过程中传入相关参数即可,代码如下:
"""获取训练集相关数据"""
# 获取训练集目录下的所有图片
train_path = glob.glob('dataset/train/*.png')
train_path.sort()
# 获取训练集图片对应的位置标签和label数据
train_json = json.load(open('dataset/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
# 通过Pytorch加载并处理数据
train_loader = DataLoader(SVHNDataset(train_path, train_label, train_transform()),
batch_size=40, shuffle=True, num_workers=10,
)
本节通过对图片数据的读取和可视化展示了相关数据的使用,并通过Pytorch 实现了数据批读取和图像的相应预处理操作。