PubTabNet数据集介绍(有干货)

前言

PubTabNet是IBM公司公布的基于图像的表格识别数据集。
其包含了568k+表格图片,其标注数据是HTML的表格结构,下载压缩包磁盘存储大小10G+。
GitHub相关地址
IBM的下载地址
相关论文:
Image-based table recognition: data, model, and evaluation
此篇论文的核心在于通过Encoder - double Decoder实现表格结构与单元格的识别

2020-12-24 15_28_04-Start.png

目前暂时没有GitHub复现项目。

PubTabNet数据

表格图片

该数据集的表格都是PDF截图,清晰度不是很高
示例如下:


PMC2838834_005_00.png

训练结构数据

训练结构数据位于.jsonl文件中,该文件每一行都是一条json数据,可以通过jsonlines库逐条读取。
数据结构:

  • imgid: 图像id
  • html: html单元格详细描述
    cells:单元格列表
    cells列表成员:
    tokens:单元格文本字符级别信息,
    bbox:单元格文本范围的bounding box,这个并不是单元格范围,而是单元格内文本范围!数字坐标说明:x_min, y_min, x_max, y_max,即文本范围对角线两个点的坐标。
    该坐标就是图片本身的坐标,可以放心食用
    structure字典
    structure字典成员:
    tokens:对应cells的HTML格式
  • split:表示训练数据或验证数据,分别为train与val
  • filename:图像名称,如PMC2838834_005_00.png

示例如下(只贴出两个单元格以及部分HTML结构数据):

{
  "imgid": 4,
  "html": {
    "cells": [
      {
        "tokens": [
          "",
          "M",
          "a",
          "i",
          "n",
          " ",
          "c",
          "e",
          "l",
          "l",
          "u",
          "l",
          "a",
          "r",
          " ",
          "p",
          "r",
          "o",
          "c",
          "e",
          "s",
          "s",
          ""
        ],
        "bbox": [
          1,
          4,
          76,
          13
        ]
      },
      {
        "tokens": [
          "",
          "M",
          "o",
          "d",
          "u",
          "l",
          "a",
          "t",
          "e",
          "d",
          " ",
          "p",
          "a",
          "t",
          "h",
          "w",
          "a",
          "y",
          "s",
          ""
        ],
        "bbox": [
          92,
          4,
          167,
          13
        ]
      }
    ],
    "structure": {
      "tokens": [
        "",
        "",
        "",
        "",
        "",
        "",
        "",
        "",
        "",
        "",
        "",
        "...",
        "..."
      ]
    }
  },
  "split": "train",
  "filename": "PMC2838834_005_00.png"
}

位于GitHun的PubTabNet相关代码,只有读取数据,将数据转换为HTML的功能,并没有表格识别相关的代码。

PubTabNet转换为SciTSR训练数据格式

因为之前在研究SciTSR数据,所以需要将PubTabNet的训练集数据转换为SciTSR的格式,从而扩大训练集,即我们需要根据.jsonl中的数据,提炼出chunk, structure数据来。
SciTSR的介绍链接在此
这一段是干货,因为这意味着不同数据集可以通用了。

import tqdm
import os
import jsonlines
import json
import logging
import sys
import re

sys.path.insert(0, os.path.abspath('../'))


class Transform:
    def transform(self, pubtabnet_data_file: str):
        if pubtabnet_data_file is None or not os.path.exists(pubtabnet_data_file):
            logging.error('No PubTabNet data file!')
            return
        root_path = os.path.dirname(pubtabnet_data_file)

        with open(pubtabnet_data_file, encoding='utf-8') as reader:
            for img in jsonlines.Reader(reader):
                # if img['filename'] == 'PMC5577841_001_00.png':
                # print(img)
                img_filename = img['filename']

                img_type = img['split']
                gfte_structure_folder = '{0}/{1}/structure/'.format(root_path, img_type)
                if not os.path.exists(gfte_structure_folder):
                    os.makedirs(gfte_structure_folder)
                gfte_chunk_folder = '{0}/{1}/chunk/'.format(root_path, img_type)
                if not os.path.exists(gfte_chunk_folder):
                    os.makedirs(gfte_chunk_folder)

                print('Handle image: {0}'.format(img_filename))
                cells = img['html']['cells']
                structure_list = img['html']['structure']['tokens']
                gfte_structure = self.get_row_col_position(structure_list)

                gfte_structure_file = '{0}.json'.format(img_filename.replace('.png', ''))
                gfte_chunk_file = '{0}.chunk'.format(img_filename.replace('.png', ''))
                structure_data_dict, chunk_data_dict = self.transfer_GFTE_data(cells, gfte_structure)
                if isinstance(structure_data_dict, dict) and \
                        isinstance(chunk_data_dict, dict):
                    with open(os.path.join(gfte_structure_folder,
                                           gfte_structure_file), "w", encoding='utf-8') as write_file:
                        json.dump(structure_data_dict, write_file)

                    with open(os.path.join(gfte_chunk_folder,
                                           gfte_chunk_file), "w", encoding='utf-8') as write_file:
                        json.dump(chunk_data_dict, write_file)

    def get_row_col_position(self, structure_list: list):
        row_index = 0
        col_index = 0
        gfte_structure = []
        for index, table_element in enumerate(structure_list):
            if table_element == '':
                col_index = 0
            if table_element == '':
                row_index += 1
            if table_element == '':
                structure_dict = {'start_row': row_index,
                                  'end_row': row_index,
                                  'start_col': col_index,
                                  'end_col': col_index}
                col_index += 1
                gfte_structure.append(structure_dict)
                continue
            if table_element == '', '', raw_text).strip()
            structure_data['content'] = pure_text.split()
            structure_data.update(structure)
            structure_data_dict['cells'].append(structure_data)

            chunk_data = {}
            if cell.get('bbox', None) is not None:
                chunk_data['pos'] = [cell['bbox'][0],
                                     cell['bbox'][2],
                                     cell['bbox'][1],
                                     cell['bbox'][3]]
            else:
                chunk_data['pos'] = [0, 0, 0, 0]
            chunk_data['text'] = raw_text
            chunk_data_dict['chunks'].append(chunk_data)
        for index, (structure_data, chunk_data) in enumerate(zip(structure_data_dict['cells'],
                                                                 chunk_data_dict['chunks'])):
            # x0, x1, y0, y1
            if chunk_data['pos'] == [0, 0, 0, 0]:
                id = structure_data['id']
                start_row = structure_data['start_row']
                end_row = structure_data['end_row']
                start_col = structure_data['start_col']
                end_col = structure_data['end_col']
                x0 = None
                x1 = None
                y0 = None
                y1 = None
                for structure, chunk in zip(structure_data_dict['cells'],
                                            chunk_data_dict['chunks']):
                    if structure['id'] != id and chunk['pos'] != [0, 0, 0, 0]:
                        if y0 is None and structure['start_row'] == start_row:
                            y0 = chunk['pos'][2]
                        if y1 is None and structure['end_row'] == end_row:
                            y1 = chunk['pos'][3]
                        if x0 is None and structure['start_col'] == start_col:
                            x0 = chunk['pos'][0]
                        if x1 is None and structure['end_col'] == end_col:
                            x1 = chunk['pos'][1]
                    if x0 is not None and x1 is not None and y0 is not None and y1 is not None:
                        chunk_data['pos'] = [x0, x1, y0, y1]
                        break
        return structure_data_dict, chunk_data_dict


if __name__ == "__main__":
    transform = Transform()
    transform.transform('/data/pubtabnet/PubTabNet_2.0.0.jsonl')
    # transform.transform('/data/scitsr/examples/PubTabNet_Examples.jsonl')

UNET所需数据准备

UNET非常适合图像分割,对于表格来说,如果能够通过UNET将表格中行与列进行标注,则能够方便根据表格结构提取各个单元格信息。
事实上,UNET预测的是“线”,训练数据类似于如下格式,即两点一线,标签为0代表横线,标签为1代表竖线

{
   "label": "0",
   "line_color": [
    0,
    0,
    128
   ],
   "fill_color": [
    0,
    0,
    128
   ],
   "points": [
    [
     0,
     0
    ],
    [
     503,
     0
    ]
   ]
  },
  {
   "label": "0",
   "line_color": [
    0,
    0,
    128
   ],
   "fill_color": [
    0,
    0,
    128
   ],
   "points": [
    [
     0,
     275
    ],
    [
     503,
     275
    ]
   ]
  },
  {
   "label": "1",
   "line_color": [
    0,
    0,
    128
   ],
   "fill_color": [
    0,
    0,
    128
   ],
   "points": [
    [
     0,
     0
    ],
    [
     0,
     276
    ]
   ]
  }

训练数据转换代码,则是根据SciTSR的训练数据:chunk与structure得到。
如下代码实现了:

  1. 根据每行每列的最大与最小坐标,以及邻接单元格信息,尝试绘制表格中的横线与竖线。而这仅仅是根据原数据中的单元格中文本框的坐标计算得到的
  2. 将得到的线绘制到单独一张图片,构成训练数据示意图片
import os
import numpy as np
from PIL import Image
import cv2
import json
from glob import glob
import traceback


# Returns if columns belong to same table or not
def sameTable(ymin_1, ymin_2, ymax_1, ymax_2):
    min_diff = abs(ymin_1 - ymin_2)
    max_diff = abs(ymax_1 - ymax_2)

    if min_diff <= 5 and max_diff <= 5:
        return True
    elif min_diff <= 4 and max_diff <= 7:
        return True
    elif min_diff <= 7 and max_diff <= 4:
        return True
    return False


def draw_cell_line(root_path: str = r'/data/scitsr/train',
                   img_path: str = 'img',
                   is_scitsr: bool = True,
                   output_unet: bool = False,
                   draw_img_amount: int = None,
                   file_list: list = None,
                   is_middle: bool = False,
                   draw_thickness: bool = False,
                   draw_line_img: bool = True):
    """
    依赖PubTabNet的图像数据,以及SciTSR的数据格式,转换为table_net需要的数据结构
    SciTSR的数据格式,通过pubtabnet_format_transform.py的transform方法进行转换
    即绘制横线与竖线,横线标签为0,竖线标签为1
    并将数据保存为json文件
    :param root_path:
    :param is_scitsr:
    :param output_unet:
    :param draw_img_amount:
    :param file_list:
    :param is_middle: 是否从两个文本区域中间划线,这种绘线方式与实际表格线段更贴近
    :return:
    """
    directory = os.path.join(root_path, img_path)
    chunk_directory = os.path.join(root_path, 'chunk')
    structure_directory = os.path.join(root_path, 'structure')

    if is_middle:
        final_cell_directory = os.path.join(root_path, 'cell_mask_img_middle')
        final_table_cell_directory = os.path.join(root_path, 'table_cell_mask_img_middle')
        table_net_data_directory = os.path.join(root_path, 'table_net_data_middle')
    else:
        final_cell_directory = os.path.join(root_path, 'cell_mask_img')
        final_table_cell_directory = os.path.join(root_path, 'table_cell_mask_img')
        table_net_data_directory = os.path.join(root_path, 'table_net_data')

    if not os.path.exists(final_cell_directory):
        os.makedirs(final_cell_directory)
    if not os.path.exists(final_table_cell_directory):
        os.makedirs(final_table_cell_directory)
    if not os.path.exists(table_net_data_directory):
        os.makedirs(table_net_data_directory)
    final_table_directory = os.path.join(root_path, 'table_mask_img')
    if not os.path.exists(final_table_directory):
        os.makedirs(final_table_directory)

    unet_img_directory = os.path.join(root_path, 'img_unet')
    if output_unet:
        if not os.path.exists(unet_img_directory):
            os.makedirs(unet_img_directory)
        final_cell_directory = os.path.join(root_path, 'mask_unet')
        if not os.path.exists(final_cell_directory):
            os.makedirs(final_cell_directory)
    if file_list is None:
        files = os.listdir(directory)
    else:
        files = file_list
    count = 1
    table_net_data_file_list = []
    for index, file in enumerate(files):
        if file_list is None and \
                draw_img_amount is not None and \
                isinstance(draw_img_amount, int) and \
                draw_img_amount > 0 and \
                count == draw_img_amount:
            break
        print('Handle the {0} file: {1}'.format(count, file))
        try:
            file_path = os.path.join(directory, file)
            chunk_file = os.path.join(chunk_directory,
                                      file.replace('.png', '.chunk').replace('.jpg', '.chunk'))
            structure_file = os.path.join(structure_directory,
                                          file.replace('.png', '.json').replace('.jpg', '.json'))
            table_net_data_file = os.path.join(table_net_data_directory,
                                               file.replace('.png', '.json').replace('.jpg', '.json'))
            if not os.path.exists(chunk_file) or not os.path.exists(structure_file):
                print('No structure data for {0}'.format(file))
                continue
            count += 1
            img = cv2.imread(file_path)
            height, width = img.shape[:2]
            # Create grayscale image array
            col_mask = np.ones((height, width), dtype=np.int32) * 255

            with open(chunk_file, encoding='utf-8') as f:
                chunk_data_list = json.load(f)['chunks']
            with open(structure_file, encoding='utf-8') as f:
                structure_data = json.load(f)
                structure_data_list = structure_data['cells']

            # is_changed = fix_structure_data(structure_data_list)
            # if is_changed:
            #     with open(structure_file,"w", encoding='utf-8') as json_file:
            #         json.dump(structure_data, json_file, ensure_ascii=False)

            table_xmin = 0
            table_xmax = width
            table_ymin = 0
            table_ymax = height

            # width_ratio = (table_xmax - table_xmin) / width
            # height_ration = (table_ymax - table_ymin) / height

            row_col_bound_dict = {'row': {}, 'col': {}}
            # 先得到每行的ymin与ymax,每列的xmin与xmax
            for s_index, structure in enumerate(structure_data_list):
                index = structure['id']
                bndbox = chunk_data_list[index]['pos']
                xmin = int(int(bndbox[0]) - table_xmin)
                xmax = int(int(bndbox[1]) - table_xmin)
                # SciTSR的坐标,x0, x1是顺序的,但是y0, y1是倒序的,即坐标系y轴是从下往上的。
                # 所以相对坐标应该是,cell_ymin = -(y2 - ymax) cell_ymax = -(y1 - ymax)
                if is_scitsr:
                    ymin = int(abs(int(bndbox[3]) - table_ymax))
                    ymax = int(abs(int(bndbox[2]) - table_ymax))
                else:
                    ymin = int((int(bndbox[2]) - table_ymin))
                    ymax = int((int(bndbox[3]) - table_ymin))

                start_row = structure['start_row']
                if row_col_bound_dict['row'].get(start_row, None) is None:
                    row_col_bound_dict['row'][start_row] = {'ymin': ymin}
                else:
                    if row_col_bound_dict['row'][start_row].get('ymin', None) is None:
                        row_col_bound_dict['row'][start_row]['ymin'] = ymin
                    else:
                        if row_col_bound_dict['row'][start_row]['ymin'] > ymin:
                            row_col_bound_dict['row'][start_row]['ymin'] = ymin

                end_row = structure['end_row']
                if row_col_bound_dict['row'].get(end_row, None) is None:
                    row_col_bound_dict['row'][end_row] = {'ymax': ymax}
                else:
                    if row_col_bound_dict['row'][end_row].get('ymax', None) is None:
                        row_col_bound_dict['row'][end_row]['ymax'] = ymax
                    else:
                        if row_col_bound_dict['row'][end_row]['ymax'] < ymax:
                            row_col_bound_dict['row'][end_row]['ymax'] = ymax

                start_col = structure['start_col']
                if row_col_bound_dict['col'].get(start_col, None) is None:
                    row_col_bound_dict['col'][start_col] = {'xmin': xmin}
                else:
                    if row_col_bound_dict['col'][start_col].get('xmin', None) is None:
                        row_col_bound_dict['col'][start_col]['xmin'] = xmin
                    else:
                        if row_col_bound_dict['col'][start_col]['xmin'] > xmin:
                            if start_col - 1 >= 0:
                                if row_col_bound_dict['col'].get(start_col - 1, None) is not None:
                                    if row_col_bound_dict['col'][start_col - 1].get('max', None) is not None and \
                                            row_col_bound_dict['col'][start_col - 1]['xmax'] < xmin:
                                        row_col_bound_dict['col'][start_col]['xmin'] = xmin
                                    else:
                                        row_col_bound_dict['col'][start_col]['xmin'] = xmin
                                else:
                                    row_col_bound_dict['col'][start_col]['xmin'] = xmin

                end_col = structure['end_col']
                if row_col_bound_dict['col'].get(end_col, None) is None:
                    row_col_bound_dict['col'][end_col] = {'xmax': xmax}
                else:
                    if row_col_bound_dict['col'][end_col].get('xmax', None) is None:
                        row_col_bound_dict['col'][end_col]['xmax'] = xmax
                    else:
                        if row_col_bound_dict['col'][end_col]['xmax'] < xmax:
                            row_col_bound_dict['col'][end_col]['xmax'] = xmax

            for row_index, location in row_col_bound_dict['row'].items():
                if location.get('ymin', None) is not None and \
                        location.get('ymax', None) is None and \
                        row_col_bound_dict['row'].get(row_index + 1, None) is not None and \
                        row_col_bound_dict['row'].get(row_index + 1, {}).get('ymin', None) is not None:
                    next_ymin = row_col_bound_dict['row'].get(row_index + 1, {}).get('ymin', None)
                    current_ymin = location.get('ymin', None)
                    if next_ymin > current_ymin + 5:
                        location['ymax'] = next_ymin - 5
                    else:
                        location['ymax'] = next_ymin

            for col_index, location in row_col_bound_dict['col'].items():
                if location.get('xmin', None) is not None and \
                        location.get('xmax', None) is None and \
                        row_col_bound_dict['col'].get(col_index + 1, None) is not None and \
                        row_col_bound_dict['col'].get(col_index + 1, {}).get('xmin', None) is not None:
                    next_xmin = row_col_bound_dict['col'].get(col_index + 1, {}).get('xmin', None)
                    current_xmin = location.get('xmin', None)
                    if next_xmin > current_xmin + 5:
                        location['xmax'] = next_xmin - 5
                    else:
                        location['xmax'] = next_xmin
                if location.get('xmin', None) is not None and \
                        location.get('xmax', None) is not None and \
                        row_col_bound_dict['col'].get(col_index + 1, None) is not None and \
                        row_col_bound_dict['col'].get(col_index + 1, {}).get('xmin', None) is not None:
                    current_xmax = location.get('xmax', None)
                    next_xmin = row_col_bound_dict['col'].get(col_index + 1, {}).get('xmin', None)
                    if next_xmin < current_xmax + 5:
                        location['xmax'] = next_xmin - 5

            table_net_data = {'version': '3.16.7',
                              'flags': {},
                              'lineColor': [0, 255, 0, 128],
                              'fillColor': [255, 0, 0, 128],
                              'imagePath': file_path,
                              'shapes': []}
            # 首先添加表边框四条线
            table_net_data['shapes'].append({'label': '0',
                                             'line_color': [0, 0, 128],
                                             'fill_color': [0, 0, 128],
                                             'points': [[0, 0], [width, 0]],
                                             'thickness': 2})
            table_net_data['shapes'].append({'label': '0',
                                             'line_color': [0, 0, 128],
                                             'fill_color': [0, 0, 128],
                                             'points': [[0, height - 1],
                                                        [width, height - 1]],
                                             'thickness': 2})
            table_net_data['shapes'].append({'label': '1',
                                             'line_color': [0, 0, 128],
                                             'fill_color': [0, 0, 128],
                                             'points': [[0, 0],
                                                        [0, height]],
                                             'thickness': 2})
            table_net_data['shapes'].append({'label': '1',
                                             'line_color': [0, 0, 128],
                                             'fill_color': [0, 0, 128],
                                             'points': [[width - 1, 0],
                                                        [width - 1, height]],
                                             'thickness': 2})

            row_list = list(row_col_bound_dict.get('row', {}).keys())
            row_list.sort()
            col_list = list(row_col_bound_dict.get('col', {}).keys())
            col_list.sort()
            # 只添加结束行横线
            for row in row_list:
                if row == row_list[-1]:
                    continue
                same_row_list = [structure for structure
                                 in structure_data_list
                                 if structure['start_row'] == row]

                line_points = None
                same_row_list = sorted(same_row_list, key=lambda keys: keys['start_col'])
                thickness = None
                for cell in same_row_list:
                    end_row = cell['end_row']
                    if end_row == row_list[-1]:
                        continue

                    start_col = cell['start_col']
                    end_col = cell['end_col']

                    if start_col == 0:
                        xmin = 0
                    else:
                        try:
                            if is_middle:
                                xmin = round((row_col_bound_dict['col'][start_col - 1]['xmax'] +
                                              row_col_bound_dict['col'][start_col]['xmin']) / 2)
                                if xmin < row_col_bound_dict['col'][start_col - 1]['xmax']:
                                    xmin = row_col_bound_dict['col'][start_col - 1]['xmax']
                            else:
                                xmin = row_col_bound_dict['col'][start_col - 1]['xmax']
                        except:
                            xmin = row_col_bound_dict['col'][start_col - 1]['xmax']

                    if end_col == len(col_list) - 1:
                        xmax = table_xmax
                    else:
                        try:
                            if is_middle:
                                xmax = round((row_col_bound_dict['col'][end_col]['xmax'] +
                                              row_col_bound_dict['col'][end_col + 1]['xmin']) / 2)
                                if xmax < row_col_bound_dict['col'][end_col]['xmax']:
                                    xmax = row_col_bound_dict['col'][end_col]['xmax']
                            else:
                                xmax = row_col_bound_dict['col'][end_col]['xmax']
                        except:
                            xmax = row_col_bound_dict['col'][end_col]['xmax']

                    try:
                        if is_middle:
                            y = round((row_col_bound_dict['row'][end_row]['ymax'] +
                                       row_col_bound_dict['row'][end_row + 1]['ymin']) / 2)
                            temp = abs(int(row_col_bound_dict['row'][end_row + 1]['ymin'] -
                                           row_col_bound_dict['row'][end_row]['ymax']))
                            if temp < 2:
                                temp = 2
                            if thickness is None:
                                thickness = temp
                            elif thickness > temp:
                                thickness = temp
                            if y < row_col_bound_dict['row'][end_row]['ymax']:
                                y = row_col_bound_dict['row'][end_row]['ymax']
                        else:
                            y = row_col_bound_dict['row'][end_row]['ymax']
                            thickness = int(row_col_bound_dict['row'][end_row]['ymax'] -
                                            row_col_bound_dict['row'][end_row + 1]['ymin'])
                    except:
                        y = row_col_bound_dict['row'][end_row]['ymax']

                    if end_row == row:
                        if line_points is None:
                            line_points = [[xmin, y], [xmax, y]]
                        else:
                            line_points[1][0] = xmax
                    else:
                        if line_points is not None:
                            table_net_data['shapes'].append({'label': '0',
                                                             'line_color': [0, 0, 128],
                                                             'fill_color': [0, 0, 128],
                                                             'points': line_points,
                                                             'thickness': thickness})
                        line_points = [[xmin, y], [xmax, y]]
                        table_net_data['shapes'].append({'label': '0',
                                                         'line_color': [0, 0, 128],
                                                         'fill_color': [0, 0, 128],
                                                         'points': line_points,
                                                         'thickness': thickness})
                        line_points = None
                if line_points is not None:
                    if thickness is None or thickness < 2:
                        thickness = 2
                    table_net_data['shapes'].append({'label': '0',
                                                     'line_color': [0, 0, 128],
                                                     'fill_color': [0, 0, 128],
                                                     'points': line_points,
                                                     'thickness': thickness})

            # 只添加结束列竖线
            for col in col_list:
                if col == col_list[-1]:
                    continue
                same_col_list = [structure for structure
                                 in structure_data_list
                                 if structure['start_col'] == col]
                same_col_list = sorted(same_col_list, key=lambda keys: keys['start_row'])
                line_points = None
                thickness = None
                for index, cell in enumerate(same_col_list):
                    if index > 0:
                        last_cell_end_row = same_col_list[index - 1]['end_row']
                        cell_start_row = cell['start_row']
                        if cell_start_row - last_cell_end_row != 1 and line_points is not None:
                            table_net_data['shapes'].append({'label': '1',
                                                             'line_color': [0, 0, 128],
                                                             'fill_color': [0, 0, 128],
                                                             'points': line_points})
                            line_points = None
                    end_col = cell['end_col']
                    if end_col == col_list[-1]:
                        if line_points is not None:
                            table_net_data['shapes'].append({'label': '1',
                                                             'line_color': [0, 0, 128],
                                                             'fill_color': [0, 0, 128],
                                                             'points': line_points})
                            line_points = None
                        continue
                    start_row = cell['start_row']
                    end_row = cell['end_row']
                    if start_row - 1 >= 0:
                        try:
                            if is_middle:
                                ymin = round((row_col_bound_dict['row'][start_row - 1]['ymax'] +
                                              row_col_bound_dict['row'][start_row]['ymin']) / 2)
                                if ymin < row_col_bound_dict['row'][start_row - 1]['ymax']:
                                    ymin = row_col_bound_dict['row'][start_row - 1]['ymax']
                            else:
                                ymin = row_col_bound_dict['row'][start_row - 1]['ymax']
                        except:
                            ymin = row_col_bound_dict['row'][start_row - 1]['ymax']
                    else:
                        ymin = 0
                    if end_row == len(row_list) - 1:
                        ymax = table_ymax
                    else:
                        try:
                            if is_middle:
                                ymax = round((row_col_bound_dict['row'][end_row]['ymax'] +
                                              row_col_bound_dict['row'][end_row + 1]['ymin']) / 2)
                                if ymax < row_col_bound_dict['row'][end_row]['ymax']:
                                    ymax = row_col_bound_dict['row'][end_row]['ymax']
                            else:
                                ymax = row_col_bound_dict['row'][end_row]['ymax']
                        except:
                            ymax = row_col_bound_dict['row'][end_row]['ymax']

                    try:
                        if is_middle:
                            x = round((row_col_bound_dict['col'][end_col]['xmax'] +
                                       row_col_bound_dict['col'][end_col + 1]['xmin']) / 2)
                            temp = abs(int(row_col_bound_dict['col'][end_col + 1]['xmin'] -
                                           row_col_bound_dict['col'][end_col]['xmax']))
                            if temp < 2:
                                temp = 2
                            if thickness is None:
                                thickness = temp
                            elif thickness < temp:
                                thickness = temp
                            if x < row_col_bound_dict['col'][end_col]['xmax']:
                                x = row_col_bound_dict['col'][end_col]['xmax']
                        else:
                            x = row_col_bound_dict['col'][end_col]['xmax']
                            thickness = int(row_col_bound_dict['col'][end_col]['xmax'] -
                                            row_col_bound_dict['col'][end_col + 1]['xmin'])
                    except:
                        x = row_col_bound_dict['col'][end_col]['xmax']

                    if end_col == col:
                        if line_points is None:
                            line_points = [[x, ymin], [x, ymax]]
                        else:
                            line_points[1][1] = ymax
                    else:
                        if line_points is not None:
                            table_net_data['shapes'].append({'label': '1',
                                                             'line_color': [0, 0, 128],
                                                             'fill_color': [0, 0, 128],
                                                             'points': line_points,
                                                             'thickness': thickness})
                        line_points = [[x, ymin], [x, ymax]]
                        table_net_data['shapes'].append({'label': '1',
                                                         'line_color': [0, 0, 128],
                                                         'fill_color': [0, 0, 128],
                                                         'points': line_points,
                                                         'thickness': thickness})
                        line_points = None
                if line_points is not None:
                    if thickness is None or thickness < 2:
                        thickness = 2
                    table_net_data['shapes'].append({'label': '1',
                                                     'line_color': [0, 0, 128],
                                                     'fill_color': [0, 0, 128],
                                                     'points': line_points,
                                                     'thickness': thickness})

            fix_table_net_data(table_net_data)
            if draw_line_img:
                draw_line(col_mask, table_net_data['shapes'], draw_thickness=draw_thickness)
                cv2.imwrite(os.path.join(final_cell_directory, file), col_mask)

                # cv2.imwrite(os.path.join(unet_img_directory, file), img)

                draw_line(img, table_net_data['shapes'], draw_thickness=draw_thickness)
                updated = file.replace('.png', '').replace('.jpg', '') + '_line.png'
                cv2.imwrite(os.path.join(final_table_cell_directory, updated), img)
            with open(table_net_data_file, "w", encoding='utf-8') as f:
                json.dump(table_net_data, f, indent=True, ensure_ascii=False)
            table_net_data_file_list.append(table_net_data_file)
        except Exception as e:
            traceback.print_exc()
            # print(e)
    return table_net_data_file_list


def fix_table_net_data(table_net_data):
    lines = table_net_data['shapes']
    row_x1_list = sorted(list(set([line['points'][0][0] for line in lines if line['label'] == '0'])))
    row_x2_list = sorted(list(set([line['points'][1][0] for line in lines if line['label'] == '0'])))
    col_y1_list = sorted(list(set([line['points'][0][1] for line in lines if line['label'] == '1'])))
    col_y2_list = sorted(list(set([line['points'][1][1] for line in lines if line['label'] == '1'])))
    # 修正横线与表格左边框有空隙
    if len(row_x1_list) > 1:
        second_min_x = row_x1_list[1]
        x1_hori_lines = [line
                         for line in
                         table_net_data['shapes']
                         if line['points'][0][0] == second_min_x and
                         line['label'] == '0']
        for h_i, h_line in enumerate(x1_hori_lines):
            h_y = h_line['points'][0][1]
            if not [line for line in lines
                    if line['label'] == '1' and
                       line['points'][0][1] < h_y and
                       line['points'][1][1] > h_y and
                       line['points'][0][0] == second_min_x]:
                h_line['points'][0][0] = 0
    # 修正横线与表格右边框有空隙
    if len(row_x2_list) > 1:
        second_max_x = row_x2_list[-2]
        x2_hori_lines = [line
                         for line in
                         table_net_data['shapes']
                         if line['points'][1][0] == second_max_x and
                         line['label'] == '0']
        for h_i, h_line in enumerate(x2_hori_lines):
            h_y = h_line['points'][0][1]
            if not [line for line in lines
                    if line['label'] == '1' and
                       line['points'][0][1] < h_y and
                       line['points'][1][1] > h_y and
                       line['points'][1][0] == second_max_x]:
                h_line['points'][1][0] = row_x1_list[-1]

    # 修正竖线与表格上边框有空隙
    if len(col_y1_list) > 1:
        second_min_y = col_y1_list[1]
        y1_verti_lines = [line
                         for line in
                         table_net_data['shapes']
                         if line['points'][0][1] == second_min_y and
                         line['label'] == '1']
        for v_i, v_line in enumerate(y1_verti_lines):
            v_x = v_line['points'][0][0]
            if not [line for line in lines
                    if line['label'] == '0' and
                       line['points'][0][0] < v_x and
                       line['points'][1][0] > v_x and
                       line['points'][0][1] == second_min_y]:
                v_line['points'][0][1] = 0

    # 修正竖线与表格下边框有空隙
    if len(col_y2_list) > 1:
        second_max_y = col_y2_list[-2]
        y2_verti_lines = [line
                         for line in
                         table_net_data['shapes']
                         if line['points'][1][1] == second_max_y and
                         line['label'] == '1']
        for v_i, v_line in enumerate(y2_verti_lines):
            v_x = v_line['points'][0][0]
            if not [line for line in lines
                    if line['label'] == '0' and
                       line['points'][0][0] < v_x and
                       line['points'][1][0] > v_x and
                       line['points'][1][1] == second_max_y]:
                v_line['points'][1][1] = col_y2_list[-1]


def fix_structure_data(structrue_list: list):
    if structrue_list is None or len(structrue_list) == 0:
        return []
    is_changed = False
    span_row_start = None
    span_row_end = None
    span_row_cur_col = None
    handled_cells = []
    for index, cell in enumerate(structrue_list):
        if index in handled_cells:
            continue
        start_row = cell['start_row']
        end_row = cell['end_row']
        start_col = cell['start_col']
        end_col = cell['end_col']

        if end_row - start_row > 0:
            if span_row_end is None or span_row_end < end_row:
                span_row_start = start_row
                span_row_end = end_row
                span_row_cur_col = cell['end_col']
                continue
        if span_row_start is not None and span_row_end is not None and \
                (end_row >= span_row_start and end_row <= span_row_end):
            if start_col <= span_row_cur_col:
                count_round = 1
                current_end_row = None
                for idx in range(index, len(structrue_list), 1):
                    current_cell = structrue_list[idx]
                    if current_end_row is None:
                        current_end_row = current_cell['end_row']
                    if current_cell['end_row'] != current_end_row:
                        count_round = 1
                        current_end_row = current_cell['end_row']
                    if current_end_row > span_row_start and current_end_row <= span_row_end:
                        origin_end_start_sub = current_cell['end_col'] - current_cell['start_col']
                        current_cell['start_col'] = span_row_cur_col + count_round
                        current_cell['end_col'] = current_cell['start_col'] + origin_end_start_sub
                        is_changed = True
                        count_round += 1
                        handled_cells.append(idx)
                    else:
                        span_row_start = None
                        span_row_end = None
                        span_row_cur_col = None
                        break
        else:
            span_row_start = None
            span_row_end = None
            span_row_cur_col = None
    return is_changed


def get_span_row_col_img():
    root_path = r'/data/pubtabnet/train/'
    structure_directory = os.path.join(root_path, 'structure/*.json')
    structure_files = glob(structure_directory)
    span_row_img_file = os.path.join(root_path, 'span_row_img.txt')
    span_col_img_file = os.path.join(root_path, 'span_col_img.txt')

    span_row_list = []
    span_col_list = []
    for file in structure_files:
        with open(file, encoding='utf-8') as f:
            structure_data_list = json.load(f)['cells']
            if len(span_row_list) < 100:
                with_span_row = False
                for cell in structure_data_list:
                    if cell['end_row'] - cell['start_row'] > 2:
                        with_span_row = True
                        break
                if with_span_row:
                    span_row_list.append(os.path.basename(file).replace('.json', '.png') + '\n')
                    print('Find the {0} span row file: {1}'.format(len(span_row_list),
                                                                   os.path.basename(file).replace('.json', '.png')))

            if len(span_col_list) < 100:
                with_span_col = False
                for cell in structure_data_list:
                    if cell['end_col'] - cell['start_col'] > 2:
                        with_span_col = True
                        break
                if with_span_col:
                    span_col_list.append(os.path.basename(file).replace('.json', '.png') + '\n')
                    print('Find the {0} span col file: {1}'.format(len(span_col_list),
                                                                   os.path.basename(file).replace('.json', '.png')))
            if len(span_row_list) >= 100 and len(span_col_list) >= 100:
                break
    with open(span_row_img_file, mode='w', encoding='utf-8') as write:
        write.writelines(span_row_list)

    with open(span_col_img_file, mode='w', encoding='utf-8') as write:
        write.writelines(span_col_list)


def draw_line(col_mask, line_list: list, draw_thickness: bool = False):
    for line in line_list:
        points = line['points']
        thickness = line.get('thickness', 2)
        xmin = points[0][0]
        ymin = points[0][1]
        xmax = points[1][0]
        ymax = points[1][1]
        if thickness <= 2 or not draw_thickness:
            if thickness > 2:
                thickness = 2
            cv2.line(col_mask,
                     (xmin, ymin),
                     (xmax, ymax),
                     (0, 0, 128),
                     thickness=thickness,
                     lineType=0)
        else:
            if thickness > 4:
                thickness = 4
            if xmax == xmin:
                # 处理竖线
                xmin = int(xmin - (thickness / 2))
                xmax = int(xmax + (thickness / 2))
            elif ymax == ymin:
                # 处理横线
                ymin = int(ymin - (thickness / 2))
                ymax = int(ymax + (thickness / 2))
            cv2.rectangle(col_mask, (xmin, ymin), (xmax, ymax), (0, 0, 128), -1)


def draw_cell(col_mask,
              structure_data_list: list,
              row_col_bound_dict: dict,
              is_scitsr: bool,
              table_xmin,
              table_xmax,
              table_ymin,
              table_ymax):
    for structure in structure_data_list:
        start_row = structure['start_row']
        end_row = structure['end_row']
        start_col = structure['start_col']
        end_col = structure['end_col']

        xmin = row_col_bound_dict['col'].get(start_col, {}).get('xmin', None)
        xmax = row_col_bound_dict['col'].get(end_col, {}).get('xmax', None)
        ymin = row_col_bound_dict['row'].get(start_row, {}).get('ymin', None)
        ymax = row_col_bound_dict['row'].get(end_row, {}).get('ymax', None)

        if xmin is not None and xmax is not None and ymin is not None and ymax is not None:
            draw_rectangle(col_mask, xmin, xmax, ymin, ymax)
    if is_scitsr:
        draw_rectangle(col_mask, 0, table_xmax - table_xmin, 0, abs(table_ymin - table_ymax))
    else:
        draw_rectangle(col_mask, 0, table_xmax - table_xmin, 0, table_ymax - table_ymin)


def draw_rectangle(img, xmin, xmax, ymin, ymax):
    thickness = 2
    # 上横线
    cv2.line(img, (xmin, ymin), (xmax, ymin), (0, 0, 128), lineType=thickness)
    # 左竖线
    cv2.line(img, (xmin, ymin), (xmin, ymax), (0, 0, 128), lineType=thickness)
    # 右竖线
    cv2.line(img, (xmax, ymin), (xmax, ymax), (0, 0, 128), lineType=thickness)
    # 下横线
    cv2.line(img, (xmin, ymax), (xmax, ymax), (0, 0, 128), lineType=thickness)


if __name__ == '__main__':
    draw_cell_line(root_path=r'/data/pubtabnet/val/',
                   is_scitsr=False,
                   output_unet=False,
                   draw_img_amount=9000)

示例

原图如下:


PMC1181812_008_00.png

纯线段图如下:


PMC1181812_008_00.png

将纯线段与原图叠加如下:
PMC1181812_008_00_line.png

效果不是完美,是因为数据源提供的坐标仅仅是各个单元格的文本区域bounding box,之后会进一步完善

你可能感兴趣的:(PubTabNet数据集介绍(有干货))