一、数据类型
训练集的数据都是带标注的图片形式,本质是一个分类问题,预测图片中的数字。
二、赛题难点
赛题的目的是识别图片中的字符,即input = 图片,output = 字符。但是给定的数据字符的长度是不一致的。解决的方法,暂时可以考虑三种办法:
1. 标准化一致的长度
也就是说取所有数据中字符长度最长为目标长度字符,比如,最长的字符长度为6,表示为123456,那么字符12则需要表示成12XXXX,这样相当于每个字符都有11种可能。模型训练的时候相当于对每张图片都以11个字符对待。
这种处理方法比较简单,但我觉得缺陷可能在于训练的模型不能很好的处理更长的字符。
2.不定长字符识别
CRNN模型暂时理解不了
3.先检测再识别
三、理解数据
1.先导库
import os, sys, glob, shutil, json
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm, tqdm_notebook
import torch
import matplotlib.pyplot as plt
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
import json
2.查看一下json文件
json文件中存放的是训练集所有图片的标注
train_json = json.load(open('C:\\Users\\pc\\Desktop\\tianchi\\mchar_train.json'))
train_json["000000.png"]
{'000000.png': {'height': [219, 219],
'label': [1, 9],
'left': [246, 323],
'top': [77, 81],
'width': [81, 96]},
}
可以看到每张图片都会存放5个标注的内容,包括“height”,“label”,“left”,“top”,“width”
“label”很好理解,就是每张图片要识别的字符是什么,也就是数据的标签
剩下的4个就是字符的坐标,可以用一张图很好的表示
3.理解一下读取图片
img = cv2.imread('C:\\Users\\pc\\Desktop\\tianchi\\000000.png')
图片的读取这里使用的是openCV的imread函数进行读取
imread这个函数有两个参数,即filename和flag
filename就是你的图片的对应的路径
flag指的是你读取图片采用的模式,可以有很多选择
再看看imread返回的是什么
array([[[ 98, 112, 108],
[ 97, 112, 108],
[ 98, 114, 107],
...,
[255, 255, 255],
[255, 255, 255],
[255, 255, 255]],
[[100, 114, 110],
[ 99, 114, 110],
[ 99, 115, 108],
...,
[255, 255, 255],
[255, 255, 255],
[255, 255, 255]],
[[101, 116, 112],
[101, 116, 112],
[101, 117, 110],
...,
[255, 255, 255],
[255, 255, 255],
[255, 255, 255]],
...,
[[ 25, 21, 20],
[ 24, 22, 21],
[ 26, 24, 23],
...,
[255, 255, 255],
[255, 255, 255],
[255, 255, 255]],
[[ 24, 23, 19],
[ 23, 24, 20],
[ 22, 23, 19],
...,
[255, 255, 255],
[255, 255, 255],
[255, 255, 255]],
[[ 21, 22, 18],
[ 22, 23, 19],
[ 21, 22, 18],
...,
[255, 255, 255],
[255, 255, 255],
[255, 255, 255]]], dtype=uint8)
它返回的是一个三维数组
这里就需要补充一点图像的相关知识了,一张图片有许多的像素点构成。我们常见的图片一般是三通道图片,即由红、绿、蓝三个通道构成,每个像素点的颜色就可以用三个数值表示了,范围在0~255。imread返回的通道顺序是BGR,因此很好理解了,imread返回的就是一个个像素点,每个像素点由三个值代表该点的颜色,这三个值可以看成是蓝、绿、红的深度的数值。
img.shape
(350, 741, 3)
同样可以知道的是,该图片是一个长350,宽741,三通道的图片。
我们对图片的读取核心也就是获得这些信息,这些信息进一步处理就可以转换成对应的样本的特征了。
另外,img既然可以用一个三维数组表示,这也就意味着我们可以通过切片的方式对图片进行截取,同时根据图片对应的标注的坐标信息,我们也就可以将对应的字符截取下来。
plt.subplot(1, 1, 1)
plt.imshow(img[77:77+219,246:246+81])
4.Task01总结
pytorch安装以及单个数据读取都没什么问题,数据理解缺了一点基础知识,已经查阅资料理解了。第一阶段先这样,接下来进行数据的读取和扩增。