DataWhale街景字符编码识别项目-数据准备

数据介绍

项目来自天池竞赛, 这是项目地址
项目数据来自Google街景图像中的门牌号数据集(The Street View House Numbers Dataset, SVHN), 对应的Kaggle竞赛地址

该数据来自真实场景的门牌号。训练集数据包括3W张照片,验证集数据包括1W张照片,每张照片包括颜色图像和对应的编码类别和具体位置;为了保证比赛的公平性,测试集A包括4W张照片,测试集B包括4W张照片。
官方已经帮我们划分好训练集和验证集。

DataWhale街景字符编码识别项目-数据准备_第1张图片
如下图所示,所示的文件分别为

mchar_train.zip: 训练图片
mchar_val.zip: 验证图片
mchar_test_a.zip: 测试图片
mchar_train.json: 训练图片标注
mchar_val.json: 验证图片标注
mchar_sample_submit_A.csv: 提交格式文件

DataWhale街景字符编码识别项目-数据准备_第2张图片

图片标注为json格式, 字段含义如下

Field Description
top 左上角坐标X
height 字符高度
left 左上角坐标Y
width 字符宽度
label 字符编码

查看数据

在构建数据集之前,我们先对数据进行一些可视化,对数据有一个大致的了解。

文件路径如下, 将其保存为字典

data_dir = {
    'train_data': '/content/data/mchar_train/',
    'val_data': '/content/data/mchar_val/',
    'test_data': '/content/data/mchar_test_a/',
    'train_label': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_train.json',
    'val_label': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_val.json',
    'submit_file': '/content/drive/My Drive/Data/Datawhale-DigitsRecognition/mchar_sample_submit_A.csv'
}
  • 查看图片数量

    def data_summary():
        train_list = glob(data_dir['train_data']+'*.png')
        test_list = glob(data_dir['test_data']+'*.png')
        val_list = glob(data_dir['val_data']+'*.png')
        print('train image counts: %d'%len(train_list))
        print('val image counts: %d'%len(val_list))
        print('test image counts: %d'%len(test_list))
    
    data_summary()
    train image counts: 30000
    val image counts: 10000
    test image counts: 40000
    
  • 查看标注文件信息

    def look_train_json():
        with open(data_dir['train_label'], 'r', encoding='utf-8') as f:
            content = f.read()
        # loads将字符串转为字典
        content = json.loads(content)
    
        print(content['000000.png'])
    
    look_train_json()
    {'height': [219, 219], 'label': [1, 9], 'left': [246, 323], 'top': [77, 81], 'width': [81, 96]}
    
  • 查看结果文件提交格式

    def look_submit():
        df = pd.read_csv(data_dir['submit_file'], sep=',')
        print(df.head(5))
    
    look_submit()
        file_name  file_code
    0  000000.png          0
    1  000001.png          0
    2  000002.png          0
    3  000003.png          0
    4  000004.png          0
    
  • 在图片上查看标注框

    
    def plot_samples():
        imgs = glob(data_dir['train_data']+'*.png')
        fig, ax = plt.subplots(figsize=(12, 8), ncols=2, nrows=2)
        marks = json.loads(open(data_dir['train_label'], 'r').read())
        for i in range(4):
    
            img_name = os.path.split(imgs[i])[-1]
            mark = marks[img_name]
            img = Image.open(imgs[i])
            img = np.array(img)
    
            bboxes = np.array(
                [mark['left'],
                mark['top'],
                mark['width'],
                mark['height']]
            )
            ax[i//2, i%2].imshow(img)
            for j in range(len(mark['label'])):
            
            # 定义一个矩形
            rect = patch.Rectangle(bboxes[:, j][:2], bboxes[:, j][2], bboxes[:, j][3], facecolor='none', edgecolor='r')
            ax[i//2, i%2].text(bboxes[:, j][0], bboxes[:, j][1], mark['label'][j])
            # 绘制矩形
            ax[i//2, i%2].add_patch(rect)
        plt.show()
    
    plot_samples()

    DataWhale街景字符编码识别项目-数据准备_第3张图片

  • 查看训练图片的长宽分布

    def img_size_summary():
        sizes = []
    
        for img in glob(data_dir['train_data']+'*.png'):
            img = Image.open(img)
    
            sizes.append(img.size)
    
        sizes = np.array(sizes)
    
        plt.figure(figsize=(10, 8))
        plt.scatter(sizes[:, 0], sizes[:, 1])
        plt.xlabel('Width')
        plt.ylabel('Height')
    
        plt.title('image width-height summary')
        plt.show()
        return np.mean(sizes, axis=0), np.median(sizes, axis=0)
    
    mean, median = img_size_summary()
    print('mean: ', mean)
    print('median: ', median)
    

    DataWhale街景字符编码识别项目-数据准备_第4张图片
    可以看到,训练图片之间的尺寸差异非常大,且基本上都是宽要比高大,宽之间的差异大于高之间的差异。后续确定网络输入大小,可以结合中位数或平均值确定网络输入大小。

  • 查看边界框大小分布

    def bbox_summary():
        marks = json.loads(open(data_dir['train_label'], 'r').read())
        bboxes = []
    
        for img, mark in marks.items():
            for i in range(len(mark['label'])):
            bboxes.append([mark['left'][i], mark['top'][i], mark['width'][i], mark['height'][i]])
    
        bboxes = np.array(bboxes)
    
        fig, ax = plt.subplots(figsize=(12, 8))
        ax.scatter(bboxes[:, 2], bboxes[:, 3])
        ax.set_title('bbox width-height summary')
        ax.set_xlabel('width')
        ax.set_ylabel('height')
        plt.show()
    
    bbox_summary()

    DataWhale街景字符编码识别项目-数据准备_第5张图片
    如果采用目标检测的思路实现字符识别,可以使用Kmeans聚类的方式来对边界框来确定anchor尺寸。

  • 查看不同字符类别的数目

    def label_nums_summary():
        marks = json.load(open(data_dir['train_label'], 'r'))
    
        dicts = {i: 0 for i in range(10)}
        for img, mark in marks.items():
            for lb in mark['label']:
            dicts[lb] += 1
    
        xticks = list(range(10))
        fig, ax = plt.subplots(figsize=(10, 8))
        ax.bar(x=list(dicts.keys()), height=list(dicts.values()))
        ax.set_xticks(xticks)
        plt.show()
        return dicts
    
    print(label_nums_summary())

    DataWhale街景字符编码识别项目-数据准备_第6张图片

    可以看出不同类别之间差别总体差异不大,除了数字1的出现次数较大。没有出现极端不平衡的情况。后期分类可以考虑使用Weighted-CrossEntropy损失。

  • 查看每个图片出现的数字个数

    def label_summary():
        marks = json.load(open(data_dir['train_label'], 'r'))
    
        dicts = {}
        for img, mark in marks.items():
            if len(mark['label']) not in dicts:
            dicts[len(mark['label'])] = 0
            dicts[len(mark['label'])] += 1
        dicts = sorted(dicts.items(), key=lambda x: x[0])
        for k, v in dicts:
            print('%d个数字的图片数目: %d'%(k, v))
    
    label_summary()
    1个数字的图片数目: 4636
    2个数字的图片数目: 16262
    3个数字的图片数目: 7813
    4个数字的图片数目: 1280
    5个数字的图片数目: 8
    6个数字的图片数目: 1
    

    可以看到,只有一个图片包含数字为6个,可能是异常值,可以不予考虑。几乎全部1~4个数字的图片几乎占了训练图片的全部。

构建数据集

这里,我们借鉴Datawhale提供的Baseline, 由于每个图片最多只包含不到6个数字,为了简化,将字符识别当做一个分类问题来处理。

这里自定义数据集,DigitsDataset继承自torch.utils.data.Dataset,数据增强使用自带的torchvison.transforms。这里只进行了常规的增强操作,比如旋转,随机转灰度,随机调整HSV等。

class DigitsDataset(Dataset):
    """
    
    DigitsDataset
    
    Params:
        data_dir(string): data directory
    
        label_path(string): label path
    
        aug(bool): wheather do image augmentation, default: True
    """
    def __init__(self, data_dir, label_path, size=(64, 128), aug=True):
        super(DigitsDataset, self).__init__()
        self.imgs = glob(data_dir+'*.png')
    
        self.aug = aug
    
        self.size = size
        if label_path == None:
            self.labels = None
        else:
            self.labels = json.load(open(label_path, 'r'))
            self.imgs = [(img, self.labels[os.path.split(img)[-1]]) for img in self.imgs if os.path.split(img)[-1] in self.labels]
        
    def __getitem__(self, idx):
        if self.labels:
            img, label = self.imgs[idx]
        else:
            img = self.imgs[idx]
            label = None
        
            img = Image.open(img)
        
            trans0 = [                
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
            
            min_size = self.size[0] if (img.size[1] / self.size[0]) < ((img.size[0] / self.size[1])) else self.size[1]
            trans1 = [
                transforms.Resize(min_size),    
                transforms.CenterCrop(self.size)
                ]
    
        if self.aug:
            trans1.extend([
                    transforms.ColorJitter(0.1, 0.1, 0.1),
                    transforms.RandomGrayscale(0.1),
                    transforms.RandomAffine(10,translate=(0.05, 0.1), shear=5)
            ])
    
        trans1.extend(trans0)
        
        img = transforms.Compose(trans1)(img)
    
        if self.labels:
            return img, t.tensor(label['label'][:5] + (5 - len(label['label']))*[10]).long()
        else:
            return img, self.imgs[idx]
    
    
    def __len__(self):
        return len(self.imgs)

查看一下数据增强的效果

fig, ax = plt.subplots(figsize=(6, 12), nrows=4, ncols=2)
for i in range(8):
    img, label = dataset[i]
    # 这些需要进行逆标准化
    img = img * t.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + t.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    ax[i//2, i%2].imshow(img.permute(1, 2, 0).numpy())
    
    ax[i//2, i%2].set_xticks([])
    ax[i//2, i%2].set_yticks([])

plt.show()

DataWhale街景字符编码识别项目-数据准备_第7张图片

总结

这里主要介绍了数据的准备和数据集的构建,并未使用比较高级复杂的操作,目的是为了搭建一个基础的数据框架,后续可以更加方便的在此基础上增加其他的操作。

下一篇我会专门介绍数据增强,实现一些更复杂的操作。

Reference

[1] 天池项目地址
[2] Kaggle竞赛地址
[3] Pytorch数据集构建Tutorial
[4] Datawhale字符识别Baseline地址

你可能感兴趣的:(python,图像,神经网络)