收缩分割多边形(PSENet中有使用)

目的:为了解决密集文本的分割问题

代码:

# -*- coding=utf-8 -*-
import os
import cv2
import Polygon as plg
import pyclipper
import numpy as np


def dist(a, b):
    return np.sqrt(np.sum((a - b) ** 2))

#计算周长
def perimeter(bbox):
    peri = 0.0
    for i in range(bbox.shape[0]):
        # print('==bbox[i], bbox[i + 1]:', bbox[i], bbox[(i + 1) % bbox.shape[0]])
        peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]])
    return peri

def shrink(bboxes, rate, max_shr=20):
    rate = rate * rate
    shrinked_bboxes = []
    for bbox in bboxes:
        area = plg.Polygon(bbox).area()
        print('===bbox:', bbox)
        peri = perimeter(bbox)

        pco = pyclipper.PyclipperOffset()
        pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
        offset = min((int)(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)

        shrinked_bbox = pco.Execute(-offset)
        if len(shrinked_bbox) == 0:
            shrinked_bboxes.append(bbox)
            continue

        shrinked_bbox = np.array(shrinked_bbox)[0]
        if shrinked_bbox.shape[0] <= 2:
            shrinked_bboxes.append(bbox)
            continue

        shrinked_bboxes.append(shrinked_bbox)

    return shrinked_bboxes  # np.array(shrinked_bboxes)


def main(shrink_threshold=0.9):

    img_path = './src_imgs/img_43.jpg'
    label_path = './labels_txt/img_43.txt'

    img = cv2.imread(img_path)
    H, W, _ = img.shape

    with open(label_path, 'r') as f:
        label_lines = f.readlines()
    f.close()

    gt_boxes = []
    for line in label_lines:
        line = line.strip().strip('\ufeff').strip('\xef\xbb\xbf')
        # print('===line:', line)
        box_points = [int(float(item)) for item in line.split(',')[:-1]]
        box_info = np.array(box_points).reshape((-1, 2))
        # print('===box_info:', box_info)
        gt_boxes.append(box_info)

    ori_mask_img = np.zeros((H, W))
    for box in gt_boxes:
        cv2.fillPoly(ori_mask_img, [box], (255))
    cv2.imwrite('./ori_mask_img.jpg', ori_mask_img)

    shrink_mask_img = np.zeros((H, W))
    new_gt_boxes = shrink(gt_boxes, shrink_threshold)
    for box in new_gt_boxes:
        cv2.fillPoly(shrink_mask_img, [box], (100))

    cv2.imwrite('./shrink_mask_img.jpg', shrink_mask_img)

if __name__ == "__main__":
    main(shrink_threshold=0.6)

收缩分割多边形(PSENet中有使用)_第1张图片

                               1.原图

收缩分割多边形(PSENet中有使用)_第2张图片    收缩分割多边形(PSENet中有使用)_第3张图片

                               2.原先分割图                                                                         3.收缩后分割图

你可能感兴趣的:(OCR,numpy)