数据介绍
项目来自天池竞赛, 这是项目地址。
项目数据来自Google街景图像中的门牌号数据集(The Street View House Numbers Dataset, SVHN), 对应的Kaggle竞赛地址。
该数据来自真实场景的门牌号。训练集数据包括3W张照片,验证集数据包括1W张照片,每张照片包括颜色图像和对应的编码类别和具体位置;为了保证比赛的公平性,测试集A包括4W张照片,测试集B包括4W张照片。
官方已经帮我们划分好训练集和验证集。
mchar_train.zip: 训练图片
mchar_val.zip: 验证图片
mchar_test_a.zip: 测试图片
mchar_train.json: 训练图片标注
mchar_val.json: 验证图片标注
mchar_sample_submit_A.csv: 提交格式文件
图片标注为json格式, 字段含义如下
Field | Description |
---|---|
top | 左上角坐标X |
height | 字符高度 |
left | 左上角坐标Y |
width | 字符宽度 |
label | 字符编码 |
查看数据
在构建数据集之前,我们先对数据进行一些可视化,对数据有一个大致的了解。
文件路径如下, 将其保存为字典
data_dir = {
'train_data': '/content/data/mchar_train/',
'val_data': '/content/data/mchar_val/',
'test_data': '/content/data/mchar_test_a/',
'train_label': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_train.json',
'val_label': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_val.json',
'submit_file': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_sample_submit_A.csv'
}
-
查看图片数量
def data_summary(): train_list = glob(data_dir['train_data']+'*.png') test_list = glob(data_dir['test_data']+'*.png') val_list = glob(data_dir['val_data']+'*.png') print('train image counts: %d'%len(train_list)) print('val image counts: %d'%len(val_list)) print('test image counts: %d'%len(test_list)) data_summary()
train image counts: 30000 val image counts: 10000 test image counts: 40000
-
查看标注文件信息
def look_train_json(): with open(data_dir['train_label'], 'r', encoding='utf-8') as f: content = f.read() # loads将字符串转为字典 content = json.loads(content) print(content['000000.png']) look_train_json()
{'height': [219, 219], 'label': [1, 9], 'left': [246, 323], 'top': [77, 81], 'width': [81, 96]}
-
查看结果文件提交格式
def look_submit(): df = pd.read_csv(data_dir['submit_file'], sep=',') print(df.head(5)) look_submit()
file_name file_code 0 000000.png 0 1 000001.png 0 2 000002.png 0 3 000003.png 0 4 000004.png 0
-
在图片上查看标注框
def plot_samples(): imgs = glob(data_dir['train_data']+'*.png') fig, ax = plt.subplots(figsize=(12, 8), ncols=2, nrows=2) marks = json.loads(open(data_dir['train_label'], 'r').read()) for i in range(4): img_name = os.path.split(imgs[i])[-1] mark = marks[img_name] img = Image.open(imgs[i]) img = np.array(img) bboxes = np.array( [mark['left'], mark['top'], mark['width'], mark['height']] ) ax[i//2, i%2].imshow(img) for j in range(len(mark['label'])): # 定义一个矩形 rect = patch.Rectangle(bboxes[:, j][:2], bboxes[:, j][2], bboxes[:, j][3], facecolor='none', edgecolor='r') ax[i//2, i%2].text(bboxes[:, j][0], bboxes[:, j][1], mark['label'][j]) # 绘制矩形 ax[i//2, i%2].add_patch(rect) plt.show() plot_samples()
-
查看训练图片的长宽分布
def img_size_summary(): sizes = [] for img in glob(data_dir['train_data']+'*.png'): img = Image.open(img) sizes.append(img.size) sizes = np.array(sizes) plt.figure(figsize=(10, 8)) plt.scatter(sizes[:, 0], sizes[:, 1]) plt.xlabel('Width') plt.ylabel('Height') plt.title('image width-height summary') plt.show() return np.mean(sizes, axis=0), np.median(sizes, axis=0) mean, median = img_size_summary() print('mean: ', mean) print('median: ', median)
可以看到,训练图片之间的尺寸差异非常大,且基本上都是宽要比高大,宽之间的差异大于高之间的差异。后续确定网络输入大小,可以结合中位数或平均值确定网络输入大小。 -
查看边界框大小分布
def bbox_summary(): marks = json.loads(open(data_dir['train_label'], 'r').read()) bboxes = [] for img, mark in marks.items(): for i in range(len(mark['label'])): bboxes.append([mark['left'][i], mark['top'][i], mark['width'][i], mark['height'][i]]) bboxes = np.array(bboxes) fig, ax = plt.subplots(figsize=(12, 8)) ax.scatter(bboxes[:, 2], bboxes[:, 3]) ax.set_title('bbox width-height summary') ax.set_xlabel('width') ax.set_ylabel('height') plt.show() bbox_summary()
-
查看不同字符类别的数目
def label_nums_summary(): marks = json.load(open(data_dir['train_label'], 'r')) dicts = {i: 0 for i in range(10)} for img, mark in marks.items(): for lb in mark['label']: dicts[lb] += 1 xticks = list(range(10)) fig, ax = plt.subplots(figsize=(10, 8)) ax.bar(x=list(dicts.keys()), height=list(dicts.values())) ax.set_xticks(xticks) plt.show() return dicts print(label_nums_summary())
可以看出不同类别之间差别总体差异不大,除了数字1的出现次数较大。没有出现极端不平衡的情况。后期分类可以考虑使用Weighted-CrossEntropy损失。
-
查看每个图片出现的数字个数
def label_summary(): marks = json.load(open(data_dir['train_label'], 'r')) dicts = {} for img, mark in marks.items(): if len(mark['label']) not in dicts: dicts[len(mark['label'])] = 0 dicts[len(mark['label'])] += 1 dicts = sorted(dicts.items(), key=lambda x: x[0]) for k, v in dicts: print('%d个数字的图片数目: %d'%(k, v)) label_summary()
1个数字的图片数目: 4636 2个数字的图片数目: 16262 3个数字的图片数目: 7813 4个数字的图片数目: 1280 5个数字的图片数目: 8 6个数字的图片数目: 1
可以看到,只有一个图片包含数字为6个,可能是异常值,可以不予考虑。几乎全部1~4个数字的图片几乎占了训练图片的全部。
构建数据集
这里,我们借鉴Datawhale提供的Baseline, 由于每个图片最多只包含不到6个数字,为了简化,将字符识别当做一个分类问题来处理。
这里自定义数据集,DigitsDataset继承自torch.utils.data.Dataset,数据增强使用自带的torchvison.transforms。这里只进行了常规的增强操作,比如旋转,随机转灰度,随机调整HSV等。
class DigitsDataset(Dataset):
"""
DigitsDataset
Params:
data_dir(string): data directory
label_path(string): label path
aug(bool): wheather do image augmentation, default: True
"""
def __init__(self, data_dir, label_path, size=(64, 128), aug=True):
super(DigitsDataset, self).__init__()
self.imgs = glob(data_dir+'*.png')
self.aug = aug
self.size = size
if label_path == None:
self.labels = None
else:
self.labels = json.load(open(label_path, 'r'))
self.imgs = [(img, self.labels[os.path.split(img)[-1]]) for img in self.imgs if os.path.split(img)[-1] in self.labels]
def __getitem__(self, idx):
if self.labels:
img, label = self.imgs[idx]
else:
img = self.imgs[idx]
label = None
img = Image.open(img)
trans0 = [
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
min_size = self.size[0] if (img.size[1] / self.size[0]) < ((img.size[0] / self.size[1])) else self.size[1]
trans1 = [
transforms.Resize(min_size),
transforms.CenterCrop(self.size)
]
if self.aug:
trans1.extend([
transforms.ColorJitter(0.1, 0.1, 0.1),
transforms.RandomGrayscale(0.1),
transforms.RandomAffine(10,translate=(0.05, 0.1), shear=5)
])
trans1.extend(trans0)
img = transforms.Compose(trans1)(img)
if self.labels:
return img, t.tensor(label['label'][:5] + (5 - len(label['label']))*[10]).long()
else:
return img, self.imgs[idx]
def __len__(self):
return len(self.imgs)
查看一下数据增强的效果
fig, ax = plt.subplots(figsize=(6, 12), nrows=4, ncols=2)
for i in range(8):
img, label = dataset[i]
# 这些需要进行逆标准化
img = img * t.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + t.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
ax[i//2, i%2].imshow(img.permute(1, 2, 0).numpy())
ax[i//2, i%2].set_xticks([])
ax[i//2, i%2].set_yticks([])
plt.show()
总结
这里主要介绍了数据的准备和数据集的构建,并未使用比较高级复杂的操作,目的是为了搭建一个基础的数据框架,后续可以更加方便的在此基础上增加其他的操作。
下一篇我会专门介绍数据增强,实现一些更复杂的操作。
Reference
[1] 天池项目地址
[2] Kaggle竞赛地址
[3] Pytorch数据集构建Tutorial
[4] Datawhale字符识别Baseline地址