前言
PubTabNet是IBM公司公布的基于图像的表格识别数据集。
其包含了568k+表格图片,其标注数据是HTML的表格结构,下载压缩包磁盘存储大小10G+。
GitHub相关地址
IBM的下载地址
相关论文:
Image-based table recognition: data, model, and evaluation
此篇论文的核心在于通过Encoder - double Decoder实现表格结构与单元格的识别
目前暂时没有GitHub复现项目。
PubTabNet数据
表格图片
该数据集的表格都是PDF截图,清晰度不是很高
示例如下:
训练结构数据
训练结构数据位于.jsonl文件中,该文件每一行都是一条json数据,可以通过jsonlines库逐条读取。
数据结构:
- imgid: 图像id
- html: html单元格详细描述
cells:单元格列表
cells列表成员:
tokens:单元格文本字符级别信息,
bbox:单元格文本范围的bounding box,这个并不是单元格范围,而是单元格内文本范围!数字坐标说明:x_min, y_min, x_max, y_max,即文本范围对角线两个点的坐标。
该坐标就是图片本身的坐标,可以放心食用
structure字典
structure字典成员:
tokens:对应cells的HTML格式 - split:表示训练数据或验证数据,分别为train与val
- filename:图像名称,如PMC2838834_005_00.png
示例如下(只贴出两个单元格以及部分HTML结构数据):
{
"imgid": 4,
"html": {
"cells": [
{
"tokens": [
"",
"M",
"a",
"i",
"n",
" ",
"c",
"e",
"l",
"l",
"u",
"l",
"a",
"r",
" ",
"p",
"r",
"o",
"c",
"e",
"s",
"s",
""
],
"bbox": [
1,
4,
76,
13
]
},
{
"tokens": [
"",
"M",
"o",
"d",
"u",
"l",
"a",
"t",
"e",
"d",
" ",
"p",
"a",
"t",
"h",
"w",
"a",
"y",
"s",
""
],
"bbox": [
92,
4,
167,
13
]
}
],
"structure": {
"tokens": [
"",
"",
"",
" ",
"",
" ",
"",
" ",
"",
" ",
" ",
"...",
"..."
]
}
},
"split": "train",
"filename": "PMC2838834_005_00.png"
}
位于GitHun的PubTabNet相关代码,只有读取数据,将数据转换为HTML的功能,并没有表格识别相关的代码。
PubTabNet转换为SciTSR训练数据格式
因为之前在研究SciTSR数据,所以需要将PubTabNet的训练集数据转换为SciTSR的格式,从而扩大训练集,即我们需要根据.jsonl中的数据,提炼出chunk, structure数据来。
SciTSR的介绍链接在此
这一段是干货,因为这意味着不同数据集可以通用了。
import tqdm
import os
import jsonlines
import json
import logging
import sys
import re
sys.path.insert(0, os.path.abspath('../'))
class Transform:
def transform(self, pubtabnet_data_file: str):
if pubtabnet_data_file is None or not os.path.exists(pubtabnet_data_file):
logging.error('No PubTabNet data file!')
return
root_path = os.path.dirname(pubtabnet_data_file)
with open(pubtabnet_data_file, encoding='utf-8') as reader:
for img in jsonlines.Reader(reader):
# if img['filename'] == 'PMC5577841_001_00.png':
# print(img)
img_filename = img['filename']
img_type = img['split']
gfte_structure_folder = '{0}/{1}/structure/'.format(root_path, img_type)
if not os.path.exists(gfte_structure_folder):
os.makedirs(gfte_structure_folder)
gfte_chunk_folder = '{0}/{1}/chunk/'.format(root_path, img_type)
if not os.path.exists(gfte_chunk_folder):
os.makedirs(gfte_chunk_folder)
print('Handle image: {0}'.format(img_filename))
cells = img['html']['cells']
structure_list = img['html']['structure']['tokens']
gfte_structure = self.get_row_col_position(structure_list)
gfte_structure_file = '{0}.json'.format(img_filename.replace('.png', ''))
gfte_chunk_file = '{0}.chunk'.format(img_filename.replace('.png', ''))
structure_data_dict, chunk_data_dict = self.transfer_GFTE_data(cells, gfte_structure)
if isinstance(structure_data_dict, dict) and \
isinstance(chunk_data_dict, dict):
with open(os.path.join(gfte_structure_folder,
gfte_structure_file), "w", encoding='utf-8') as write_file:
json.dump(structure_data_dict, write_file)
with open(os.path.join(gfte_chunk_folder,
gfte_chunk_file), "w", encoding='utf-8') as write_file:
json.dump(chunk_data_dict, write_file)
def get_row_col_position(self, structure_list: list):
row_index = 0
col_index = 0
gfte_structure = []
for index, table_element in enumerate(structure_list):
if table_element == '':
col_index = 0
if table_element == ' ':
row_index += 1
if table_element == '':
structure_dict = {'start_row': row_index,
'end_row': row_index,
'start_col': col_index,
'end_col': col_index}
col_index += 1
gfte_structure.append(structure_dict)
continue
if table_element == ' ', '', raw_text).strip()
structure_data['content'] = pure_text.split()
structure_data.update(structure)
structure_data_dict['cells'].append(structure_data)
chunk_data = {}
if cell.get('bbox', None) is not None:
chunk_data['pos'] = [cell['bbox'][0],
cell['bbox'][2],
cell['bbox'][1],
cell['bbox'][3]]
else:
chunk_data['pos'] = [0, 0, 0, 0]
chunk_data['text'] = raw_text
chunk_data_dict['chunks'].append(chunk_data)
for index, (structure_data, chunk_data) in enumerate(zip(structure_data_dict['cells'],
chunk_data_dict['chunks'])):
# x0, x1, y0, y1
if chunk_data['pos'] == [0, 0, 0, 0]:
id = structure_data['id']
start_row = structure_data['start_row']
end_row = structure_data['end_row']
start_col = structure_data['start_col']
end_col = structure_data['end_col']
x0 = None
x1 = None
y0 = None
y1 = None
for structure, chunk in zip(structure_data_dict['cells'],
chunk_data_dict['chunks']):
if structure['id'] != id and chunk['pos'] != [0, 0, 0, 0]:
if y0 is None and structure['start_row'] == start_row:
y0 = chunk['pos'][2]
if y1 is None and structure['end_row'] == end_row:
y1 = chunk['pos'][3]
if x0 is None and structure['start_col'] == start_col:
x0 = chunk['pos'][0]
if x1 is None and structure['end_col'] == end_col:
x1 = chunk['pos'][1]
if x0 is not None and x1 is not None and y0 is not None and y1 is not None:
chunk_data['pos'] = [x0, x1, y0, y1]
break
return structure_data_dict, chunk_data_dict
if __name__ == "__main__":
transform = Transform()
transform.transform('/data/pubtabnet/PubTabNet_2.0.0.jsonl')
# transform.transform('/data/scitsr/examples/PubTabNet_Examples.jsonl')
UNET所需数据准备
UNET非常适合图像分割,对于表格来说,如果能够通过UNET将表格中行与列进行标注,则能够方便根据表格结构提取各个单元格信息。
事实上,UNET预测的是“线”,训练数据类似于如下格式,即两点一线,标签为0代表横线,标签为1代表竖线
{
"label": "0",
"line_color": [
0,
0,
128
],
"fill_color": [
0,
0,
128
],
"points": [
[
0,
0
],
[
503,
0
]
]
},
{
"label": "0",
"line_color": [
0,
0,
128
],
"fill_color": [
0,
0,
128
],
"points": [
[
0,
275
],
[
503,
275
]
]
},
{
"label": "1",
"line_color": [
0,
0,
128
],
"fill_color": [
0,
0,
128
],
"points": [
[
0,
0
],
[
0,
276
]
]
}
训练数据转换代码,则是根据SciTSR的训练数据:chunk与structure得到。
如下代码实现了:
- 根据每行每列的最大与最小坐标,以及邻接单元格信息,尝试绘制表格中的横线与竖线。而这仅仅是根据原数据中的单元格中文本框的坐标计算得到的
- 将得到的线绘制到单独一张图片,构成训练数据示意图片
import os
import numpy as np
from PIL import Image
import cv2
import json
from glob import glob
import traceback
# Returns if columns belong to same table or not
def sameTable(ymin_1, ymin_2, ymax_1, ymax_2):
min_diff = abs(ymin_1 - ymin_2)
max_diff = abs(ymax_1 - ymax_2)
if min_diff <= 5 and max_diff <= 5:
return True
elif min_diff <= 4 and max_diff <= 7:
return True
elif min_diff <= 7 and max_diff <= 4:
return True
return False
def draw_cell_line(root_path: str = r'/data/scitsr/train',
img_path: str = 'img',
is_scitsr: bool = True,
output_unet: bool = False,
draw_img_amount: int = None,
file_list: list = None,
is_middle: bool = False,
draw_thickness: bool = False,
draw_line_img: bool = True):
"""
依赖PubTabNet的图像数据,以及SciTSR的数据格式,转换为table_net需要的数据结构
SciTSR的数据格式,通过pubtabnet_format_transform.py的transform方法进行转换
即绘制横线与竖线,横线标签为0,竖线标签为1
并将数据保存为json文件
:param root_path:
:param is_scitsr:
:param output_unet:
:param draw_img_amount:
:param file_list:
:param is_middle: 是否从两个文本区域中间划线,这种绘线方式与实际表格线段更贴近
:return:
"""
directory = os.path.join(root_path, img_path)
chunk_directory = os.path.join(root_path, 'chunk')
structure_directory = os.path.join(root_path, 'structure')
if is_middle:
final_cell_directory = os.path.join(root_path, 'cell_mask_img_middle')
final_table_cell_directory = os.path.join(root_path, 'table_cell_mask_img_middle')
table_net_data_directory = os.path.join(root_path, 'table_net_data_middle')
else:
final_cell_directory = os.path.join(root_path, 'cell_mask_img')
final_table_cell_directory = os.path.join(root_path, 'table_cell_mask_img')
table_net_data_directory = os.path.join(root_path, 'table_net_data')
if not os.path.exists(final_cell_directory):
os.makedirs(final_cell_directory)
if not os.path.exists(final_table_cell_directory):
os.makedirs(final_table_cell_directory)
if not os.path.exists(table_net_data_directory):
os.makedirs(table_net_data_directory)
final_table_directory = os.path.join(root_path, 'table_mask_img')
if not os.path.exists(final_table_directory):
os.makedirs(final_table_directory)
unet_img_directory = os.path.join(root_path, 'img_unet')
if output_unet:
if not os.path.exists(unet_img_directory):
os.makedirs(unet_img_directory)
final_cell_directory = os.path.join(root_path, 'mask_unet')
if not os.path.exists(final_cell_directory):
os.makedirs(final_cell_directory)
if file_list is None:
files = os.listdir(directory)
else:
files = file_list
count = 1
table_net_data_file_list = []
for index, file in enumerate(files):
if file_list is None and \
draw_img_amount is not None and \
isinstance(draw_img_amount, int) and \
draw_img_amount > 0 and \
count == draw_img_amount:
break
print('Handle the {0} file: {1}'.format(count, file))
try:
file_path = os.path.join(directory, file)
chunk_file = os.path.join(chunk_directory,
file.replace('.png', '.chunk').replace('.jpg', '.chunk'))
structure_file = os.path.join(structure_directory,
file.replace('.png', '.json').replace('.jpg', '.json'))
table_net_data_file = os.path.join(table_net_data_directory,
file.replace('.png', '.json').replace('.jpg', '.json'))
if not os.path.exists(chunk_file) or not os.path.exists(structure_file):
print('No structure data for {0}'.format(file))
continue
count += 1
img = cv2.imread(file_path)
height, width = img.shape[:2]
# Create grayscale image array
col_mask = np.ones((height, width), dtype=np.int32) * 255
with open(chunk_file, encoding='utf-8') as f:
chunk_data_list = json.load(f)['chunks']
with open(structure_file, encoding='utf-8') as f:
structure_data = json.load(f)
structure_data_list = structure_data['cells']
# is_changed = fix_structure_data(structure_data_list)
# if is_changed:
# with open(structure_file,"w", encoding='utf-8') as json_file:
# json.dump(structure_data, json_file, ensure_ascii=False)
table_xmin = 0
table_xmax = width
table_ymin = 0
table_ymax = height
# width_ratio = (table_xmax - table_xmin) / width
# height_ration = (table_ymax - table_ymin) / height
row_col_bound_dict = {'row': {}, 'col': {}}
# 先得到每行的ymin与ymax,每列的xmin与xmax
for s_index, structure in enumerate(structure_data_list):
index = structure['id']
bndbox = chunk_data_list[index]['pos']
xmin = int(int(bndbox[0]) - table_xmin)
xmax = int(int(bndbox[1]) - table_xmin)
# SciTSR的坐标,x0, x1是顺序的,但是y0, y1是倒序的,即坐标系y轴是从下往上的。
# 所以相对坐标应该是,cell_ymin = -(y2 - ymax) cell_ymax = -(y1 - ymax)
if is_scitsr:
ymin = int(abs(int(bndbox[3]) - table_ymax))
ymax = int(abs(int(bndbox[2]) - table_ymax))
else:
ymin = int((int(bndbox[2]) - table_ymin))
ymax = int((int(bndbox[3]) - table_ymin))
start_row = structure['start_row']
if row_col_bound_dict['row'].get(start_row, None) is None:
row_col_bound_dict['row'][start_row] = {'ymin': ymin}
else:
if row_col_bound_dict['row'][start_row].get('ymin', None) is None:
row_col_bound_dict['row'][start_row]['ymin'] = ymin
else:
if row_col_bound_dict['row'][start_row]['ymin'] > ymin:
row_col_bound_dict['row'][start_row]['ymin'] = ymin
end_row = structure['end_row']
if row_col_bound_dict['row'].get(end_row, None) is None:
row_col_bound_dict['row'][end_row] = {'ymax': ymax}
else:
if row_col_bound_dict['row'][end_row].get('ymax', None) is None:
row_col_bound_dict['row'][end_row]['ymax'] = ymax
else:
if row_col_bound_dict['row'][end_row]['ymax'] < ymax:
row_col_bound_dict['row'][end_row]['ymax'] = ymax
start_col = structure['start_col']
if row_col_bound_dict['col'].get(start_col, None) is None:
row_col_bound_dict['col'][start_col] = {'xmin': xmin}
else:
if row_col_bound_dict['col'][start_col].get('xmin', None) is None:
row_col_bound_dict['col'][start_col]['xmin'] = xmin
else:
if row_col_bound_dict['col'][start_col]['xmin'] > xmin:
if start_col - 1 >= 0:
if row_col_bound_dict['col'].get(start_col - 1, None) is not None:
if row_col_bound_dict['col'][start_col - 1].get('max', None) is not None and \
row_col_bound_dict['col'][start_col - 1]['xmax'] < xmin:
row_col_bound_dict['col'][start_col]['xmin'] = xmin
else:
row_col_bound_dict['col'][start_col]['xmin'] = xmin
else:
row_col_bound_dict['col'][start_col]['xmin'] = xmin
end_col = structure['end_col']
if row_col_bound_dict['col'].get(end_col, None) is None:
row_col_bound_dict['col'][end_col] = {'xmax': xmax}
else:
if row_col_bound_dict['col'][end_col].get('xmax', None) is None:
row_col_bound_dict['col'][end_col]['xmax'] = xmax
else:
if row_col_bound_dict['col'][end_col]['xmax'] < xmax:
row_col_bound_dict['col'][end_col]['xmax'] = xmax
for row_index, location in row_col_bound_dict['row'].items():
if location.get('ymin', None) is not None and \
location.get('ymax', None) is None and \
row_col_bound_dict['row'].get(row_index + 1, None) is not None and \
row_col_bound_dict['row'].get(row_index + 1, {}).get('ymin', None) is not None:
next_ymin = row_col_bound_dict['row'].get(row_index + 1, {}).get('ymin', None)
current_ymin = location.get('ymin', None)
if next_ymin > current_ymin + 5:
location['ymax'] = next_ymin - 5
else:
location['ymax'] = next_ymin
for col_index, location in row_col_bound_dict['col'].items():
if location.get('xmin', None) is not None and \
location.get('xmax', None) is None and \
row_col_bound_dict['col'].get(col_index + 1, None) is not None and \
row_col_bound_dict['col'].get(col_index + 1, {}).get('xmin', None) is not None:
next_xmin = row_col_bound_dict['col'].get(col_index + 1, {}).get('xmin', None)
current_xmin = location.get('xmin', None)
if next_xmin > current_xmin + 5:
location['xmax'] = next_xmin - 5
else:
location['xmax'] = next_xmin
if location.get('xmin', None) is not None and \
location.get('xmax', None) is not None and \
row_col_bound_dict['col'].get(col_index + 1, None) is not None and \
row_col_bound_dict['col'].get(col_index + 1, {}).get('xmin', None) is not None:
current_xmax = location.get('xmax', None)
next_xmin = row_col_bound_dict['col'].get(col_index + 1, {}).get('xmin', None)
if next_xmin < current_xmax + 5:
location['xmax'] = next_xmin - 5
table_net_data = {'version': '3.16.7',
'flags': {},
'lineColor': [0, 255, 0, 128],
'fillColor': [255, 0, 0, 128],
'imagePath': file_path,
'shapes': []}
# 首先添加表边框四条线
table_net_data['shapes'].append({'label': '0',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': [[0, 0], [width, 0]],
'thickness': 2})
table_net_data['shapes'].append({'label': '0',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': [[0, height - 1],
[width, height - 1]],
'thickness': 2})
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': [[0, 0],
[0, height]],
'thickness': 2})
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': [[width - 1, 0],
[width - 1, height]],
'thickness': 2})
row_list = list(row_col_bound_dict.get('row', {}).keys())
row_list.sort()
col_list = list(row_col_bound_dict.get('col', {}).keys())
col_list.sort()
# 只添加结束行横线
for row in row_list:
if row == row_list[-1]:
continue
same_row_list = [structure for structure
in structure_data_list
if structure['start_row'] == row]
line_points = None
same_row_list = sorted(same_row_list, key=lambda keys: keys['start_col'])
thickness = None
for cell in same_row_list:
end_row = cell['end_row']
if end_row == row_list[-1]:
continue
start_col = cell['start_col']
end_col = cell['end_col']
if start_col == 0:
xmin = 0
else:
try:
if is_middle:
xmin = round((row_col_bound_dict['col'][start_col - 1]['xmax'] +
row_col_bound_dict['col'][start_col]['xmin']) / 2)
if xmin < row_col_bound_dict['col'][start_col - 1]['xmax']:
xmin = row_col_bound_dict['col'][start_col - 1]['xmax']
else:
xmin = row_col_bound_dict['col'][start_col - 1]['xmax']
except:
xmin = row_col_bound_dict['col'][start_col - 1]['xmax']
if end_col == len(col_list) - 1:
xmax = table_xmax
else:
try:
if is_middle:
xmax = round((row_col_bound_dict['col'][end_col]['xmax'] +
row_col_bound_dict['col'][end_col + 1]['xmin']) / 2)
if xmax < row_col_bound_dict['col'][end_col]['xmax']:
xmax = row_col_bound_dict['col'][end_col]['xmax']
else:
xmax = row_col_bound_dict['col'][end_col]['xmax']
except:
xmax = row_col_bound_dict['col'][end_col]['xmax']
try:
if is_middle:
y = round((row_col_bound_dict['row'][end_row]['ymax'] +
row_col_bound_dict['row'][end_row + 1]['ymin']) / 2)
temp = abs(int(row_col_bound_dict['row'][end_row + 1]['ymin'] -
row_col_bound_dict['row'][end_row]['ymax']))
if temp < 2:
temp = 2
if thickness is None:
thickness = temp
elif thickness > temp:
thickness = temp
if y < row_col_bound_dict['row'][end_row]['ymax']:
y = row_col_bound_dict['row'][end_row]['ymax']
else:
y = row_col_bound_dict['row'][end_row]['ymax']
thickness = int(row_col_bound_dict['row'][end_row]['ymax'] -
row_col_bound_dict['row'][end_row + 1]['ymin'])
except:
y = row_col_bound_dict['row'][end_row]['ymax']
if end_row == row:
if line_points is None:
line_points = [[xmin, y], [xmax, y]]
else:
line_points[1][0] = xmax
else:
if line_points is not None:
table_net_data['shapes'].append({'label': '0',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points,
'thickness': thickness})
line_points = [[xmin, y], [xmax, y]]
table_net_data['shapes'].append({'label': '0',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points,
'thickness': thickness})
line_points = None
if line_points is not None:
if thickness is None or thickness < 2:
thickness = 2
table_net_data['shapes'].append({'label': '0',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points,
'thickness': thickness})
# 只添加结束列竖线
for col in col_list:
if col == col_list[-1]:
continue
same_col_list = [structure for structure
in structure_data_list
if structure['start_col'] == col]
same_col_list = sorted(same_col_list, key=lambda keys: keys['start_row'])
line_points = None
thickness = None
for index, cell in enumerate(same_col_list):
if index > 0:
last_cell_end_row = same_col_list[index - 1]['end_row']
cell_start_row = cell['start_row']
if cell_start_row - last_cell_end_row != 1 and line_points is not None:
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points})
line_points = None
end_col = cell['end_col']
if end_col == col_list[-1]:
if line_points is not None:
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points})
line_points = None
continue
start_row = cell['start_row']
end_row = cell['end_row']
if start_row - 1 >= 0:
try:
if is_middle:
ymin = round((row_col_bound_dict['row'][start_row - 1]['ymax'] +
row_col_bound_dict['row'][start_row]['ymin']) / 2)
if ymin < row_col_bound_dict['row'][start_row - 1]['ymax']:
ymin = row_col_bound_dict['row'][start_row - 1]['ymax']
else:
ymin = row_col_bound_dict['row'][start_row - 1]['ymax']
except:
ymin = row_col_bound_dict['row'][start_row - 1]['ymax']
else:
ymin = 0
if end_row == len(row_list) - 1:
ymax = table_ymax
else:
try:
if is_middle:
ymax = round((row_col_bound_dict['row'][end_row]['ymax'] +
row_col_bound_dict['row'][end_row + 1]['ymin']) / 2)
if ymax < row_col_bound_dict['row'][end_row]['ymax']:
ymax = row_col_bound_dict['row'][end_row]['ymax']
else:
ymax = row_col_bound_dict['row'][end_row]['ymax']
except:
ymax = row_col_bound_dict['row'][end_row]['ymax']
try:
if is_middle:
x = round((row_col_bound_dict['col'][end_col]['xmax'] +
row_col_bound_dict['col'][end_col + 1]['xmin']) / 2)
temp = abs(int(row_col_bound_dict['col'][end_col + 1]['xmin'] -
row_col_bound_dict['col'][end_col]['xmax']))
if temp < 2:
temp = 2
if thickness is None:
thickness = temp
elif thickness < temp:
thickness = temp
if x < row_col_bound_dict['col'][end_col]['xmax']:
x = row_col_bound_dict['col'][end_col]['xmax']
else:
x = row_col_bound_dict['col'][end_col]['xmax']
thickness = int(row_col_bound_dict['col'][end_col]['xmax'] -
row_col_bound_dict['col'][end_col + 1]['xmin'])
except:
x = row_col_bound_dict['col'][end_col]['xmax']
if end_col == col:
if line_points is None:
line_points = [[x, ymin], [x, ymax]]
else:
line_points[1][1] = ymax
else:
if line_points is not None:
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points,
'thickness': thickness})
line_points = [[x, ymin], [x, ymax]]
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points,
'thickness': thickness})
line_points = None
if line_points is not None:
if thickness is None or thickness < 2:
thickness = 2
table_net_data['shapes'].append({'label': '1',
'line_color': [0, 0, 128],
'fill_color': [0, 0, 128],
'points': line_points,
'thickness': thickness})
fix_table_net_data(table_net_data)
if draw_line_img:
draw_line(col_mask, table_net_data['shapes'], draw_thickness=draw_thickness)
cv2.imwrite(os.path.join(final_cell_directory, file), col_mask)
# cv2.imwrite(os.path.join(unet_img_directory, file), img)
draw_line(img, table_net_data['shapes'], draw_thickness=draw_thickness)
updated = file.replace('.png', '').replace('.jpg', '') + '_line.png'
cv2.imwrite(os.path.join(final_table_cell_directory, updated), img)
with open(table_net_data_file, "w", encoding='utf-8') as f:
json.dump(table_net_data, f, indent=True, ensure_ascii=False)
table_net_data_file_list.append(table_net_data_file)
except Exception as e:
traceback.print_exc()
# print(e)
return table_net_data_file_list
def fix_table_net_data(table_net_data):
lines = table_net_data['shapes']
row_x1_list = sorted(list(set([line['points'][0][0] for line in lines if line['label'] == '0'])))
row_x2_list = sorted(list(set([line['points'][1][0] for line in lines if line['label'] == '0'])))
col_y1_list = sorted(list(set([line['points'][0][1] for line in lines if line['label'] == '1'])))
col_y2_list = sorted(list(set([line['points'][1][1] for line in lines if line['label'] == '1'])))
# 修正横线与表格左边框有空隙
if len(row_x1_list) > 1:
second_min_x = row_x1_list[1]
x1_hori_lines = [line
for line in
table_net_data['shapes']
if line['points'][0][0] == second_min_x and
line['label'] == '0']
for h_i, h_line in enumerate(x1_hori_lines):
h_y = h_line['points'][0][1]
if not [line for line in lines
if line['label'] == '1' and
line['points'][0][1] < h_y and
line['points'][1][1] > h_y and
line['points'][0][0] == second_min_x]:
h_line['points'][0][0] = 0
# 修正横线与表格右边框有空隙
if len(row_x2_list) > 1:
second_max_x = row_x2_list[-2]
x2_hori_lines = [line
for line in
table_net_data['shapes']
if line['points'][1][0] == second_max_x and
line['label'] == '0']
for h_i, h_line in enumerate(x2_hori_lines):
h_y = h_line['points'][0][1]
if not [line for line in lines
if line['label'] == '1' and
line['points'][0][1] < h_y and
line['points'][1][1] > h_y and
line['points'][1][0] == second_max_x]:
h_line['points'][1][0] = row_x1_list[-1]
# 修正竖线与表格上边框有空隙
if len(col_y1_list) > 1:
second_min_y = col_y1_list[1]
y1_verti_lines = [line
for line in
table_net_data['shapes']
if line['points'][0][1] == second_min_y and
line['label'] == '1']
for v_i, v_line in enumerate(y1_verti_lines):
v_x = v_line['points'][0][0]
if not [line for line in lines
if line['label'] == '0' and
line['points'][0][0] < v_x and
line['points'][1][0] > v_x and
line['points'][0][1] == second_min_y]:
v_line['points'][0][1] = 0
# 修正竖线与表格下边框有空隙
if len(col_y2_list) > 1:
second_max_y = col_y2_list[-2]
y2_verti_lines = [line
for line in
table_net_data['shapes']
if line['points'][1][1] == second_max_y and
line['label'] == '1']
for v_i, v_line in enumerate(y2_verti_lines):
v_x = v_line['points'][0][0]
if not [line for line in lines
if line['label'] == '0' and
line['points'][0][0] < v_x and
line['points'][1][0] > v_x and
line['points'][1][1] == second_max_y]:
v_line['points'][1][1] = col_y2_list[-1]
def fix_structure_data(structrue_list: list):
if structrue_list is None or len(structrue_list) == 0:
return []
is_changed = False
span_row_start = None
span_row_end = None
span_row_cur_col = None
handled_cells = []
for index, cell in enumerate(structrue_list):
if index in handled_cells:
continue
start_row = cell['start_row']
end_row = cell['end_row']
start_col = cell['start_col']
end_col = cell['end_col']
if end_row - start_row > 0:
if span_row_end is None or span_row_end < end_row:
span_row_start = start_row
span_row_end = end_row
span_row_cur_col = cell['end_col']
continue
if span_row_start is not None and span_row_end is not None and \
(end_row >= span_row_start and end_row <= span_row_end):
if start_col <= span_row_cur_col:
count_round = 1
current_end_row = None
for idx in range(index, len(structrue_list), 1):
current_cell = structrue_list[idx]
if current_end_row is None:
current_end_row = current_cell['end_row']
if current_cell['end_row'] != current_end_row:
count_round = 1
current_end_row = current_cell['end_row']
if current_end_row > span_row_start and current_end_row <= span_row_end:
origin_end_start_sub = current_cell['end_col'] - current_cell['start_col']
current_cell['start_col'] = span_row_cur_col + count_round
current_cell['end_col'] = current_cell['start_col'] + origin_end_start_sub
is_changed = True
count_round += 1
handled_cells.append(idx)
else:
span_row_start = None
span_row_end = None
span_row_cur_col = None
break
else:
span_row_start = None
span_row_end = None
span_row_cur_col = None
return is_changed
def get_span_row_col_img():
root_path = r'/data/pubtabnet/train/'
structure_directory = os.path.join(root_path, 'structure/*.json')
structure_files = glob(structure_directory)
span_row_img_file = os.path.join(root_path, 'span_row_img.txt')
span_col_img_file = os.path.join(root_path, 'span_col_img.txt')
span_row_list = []
span_col_list = []
for file in structure_files:
with open(file, encoding='utf-8') as f:
structure_data_list = json.load(f)['cells']
if len(span_row_list) < 100:
with_span_row = False
for cell in structure_data_list:
if cell['end_row'] - cell['start_row'] > 2:
with_span_row = True
break
if with_span_row:
span_row_list.append(os.path.basename(file).replace('.json', '.png') + '\n')
print('Find the {0} span row file: {1}'.format(len(span_row_list),
os.path.basename(file).replace('.json', '.png')))
if len(span_col_list) < 100:
with_span_col = False
for cell in structure_data_list:
if cell['end_col'] - cell['start_col'] > 2:
with_span_col = True
break
if with_span_col:
span_col_list.append(os.path.basename(file).replace('.json', '.png') + '\n')
print('Find the {0} span col file: {1}'.format(len(span_col_list),
os.path.basename(file).replace('.json', '.png')))
if len(span_row_list) >= 100 and len(span_col_list) >= 100:
break
with open(span_row_img_file, mode='w', encoding='utf-8') as write:
write.writelines(span_row_list)
with open(span_col_img_file, mode='w', encoding='utf-8') as write:
write.writelines(span_col_list)
def draw_line(col_mask, line_list: list, draw_thickness: bool = False):
for line in line_list:
points = line['points']
thickness = line.get('thickness', 2)
xmin = points[0][0]
ymin = points[0][1]
xmax = points[1][0]
ymax = points[1][1]
if thickness <= 2 or not draw_thickness:
if thickness > 2:
thickness = 2
cv2.line(col_mask,
(xmin, ymin),
(xmax, ymax),
(0, 0, 128),
thickness=thickness,
lineType=0)
else:
if thickness > 4:
thickness = 4
if xmax == xmin:
# 处理竖线
xmin = int(xmin - (thickness / 2))
xmax = int(xmax + (thickness / 2))
elif ymax == ymin:
# 处理横线
ymin = int(ymin - (thickness / 2))
ymax = int(ymax + (thickness / 2))
cv2.rectangle(col_mask, (xmin, ymin), (xmax, ymax), (0, 0, 128), -1)
def draw_cell(col_mask,
structure_data_list: list,
row_col_bound_dict: dict,
is_scitsr: bool,
table_xmin,
table_xmax,
table_ymin,
table_ymax):
for structure in structure_data_list:
start_row = structure['start_row']
end_row = structure['end_row']
start_col = structure['start_col']
end_col = structure['end_col']
xmin = row_col_bound_dict['col'].get(start_col, {}).get('xmin', None)
xmax = row_col_bound_dict['col'].get(end_col, {}).get('xmax', None)
ymin = row_col_bound_dict['row'].get(start_row, {}).get('ymin', None)
ymax = row_col_bound_dict['row'].get(end_row, {}).get('ymax', None)
if xmin is not None and xmax is not None and ymin is not None and ymax is not None:
draw_rectangle(col_mask, xmin, xmax, ymin, ymax)
if is_scitsr:
draw_rectangle(col_mask, 0, table_xmax - table_xmin, 0, abs(table_ymin - table_ymax))
else:
draw_rectangle(col_mask, 0, table_xmax - table_xmin, 0, table_ymax - table_ymin)
def draw_rectangle(img, xmin, xmax, ymin, ymax):
thickness = 2
# 上横线
cv2.line(img, (xmin, ymin), (xmax, ymin), (0, 0, 128), lineType=thickness)
# 左竖线
cv2.line(img, (xmin, ymin), (xmin, ymax), (0, 0, 128), lineType=thickness)
# 右竖线
cv2.line(img, (xmax, ymin), (xmax, ymax), (0, 0, 128), lineType=thickness)
# 下横线
cv2.line(img, (xmin, ymax), (xmax, ymax), (0, 0, 128), lineType=thickness)
if __name__ == '__main__':
draw_cell_line(root_path=r'/data/pubtabnet/val/',
is_scitsr=False,
output_unet=False,
draw_img_amount=9000)
示例
原图如下:
纯线段图如下:
将纯线段与原图叠加如下:
效果不是完美,是因为数据源提供的坐标仅仅是各个单元格的文本区域bounding box,之后会进一步完善
你可能感兴趣的:(PubTabNet数据集介绍(有干货))