CTPN的Python实现笔记一

文章目录

  • 一、疑难代码讲解
    • 1. 文本框左上角标注置信度
      • (1) `s = str(round(i[-1] * 100, 2)) + '%'`
      • (2) `cv2.putText()` 函数
      • (3) `cv2.line()`函数
    • 2. 文本框进行扩展操作
    • 3. 文本框进行NMS操作
      • (1) 非极大值抑制函数`def nms(dets, thresh):`
        • a. `order = scores.argsort()[::-1]`
        • b. `xx1 = np.maximum(x1[i], x1[order[1:]])`
        • c. `inds = np.where(ovr <= thresh)[0]`
      • (2) `np.hstack()`函数
    • 4. 显示图片函数
  • 二、附录

一、疑难代码讲解


1. 文本框左上角标注置信度

 for i in text:
                s = str(round(i[-1] * 100, 2)) + '%'
                i = [int(j) for j in i]
                cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2)
                cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2)
                cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2)
                cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2)
                cv2.putText(image_c, s, (i[0]+13, i[1]+13),
                            cv2.FONT_HERSHEY_SIMPLEX,
                            1,
                            (255,0,0),
                            2,
                            cv2.LINE_AA)

(1) s = str(round(i[-1] * 100, 2)) + '%'

这行代码将i的最后一个元素乘以100,保留2位小数并将其转换为字符串,再在这个字符串后面加上百分号,最后将这个字符串赋给变量s。

i[-1]表示i的最后一个元素,round(i[-1] * 100, 2)表示将i的最后一个元素乘以100并保留2位小数,str(round(i[-1] * 100, 2))表示将这个数字转换为字符串,+ '%'表示在字符串后面加上百分号。

这个变量s可能在后面会被用来在图像上显示识别结果的置信度。

round()是Python内置函数,用于将一个数字四舍五入为指定小数位数的值。

语法如下:

round(number[, ndigits])

其中,number是要进行四舍五入的数字,ndigits是可选参数,表示保留小数点后几位。如果ndigits未指定,则默认为0,即返回整数。

x = round(3.1415) # 返回3
y = round(3.1415, 2) # 返回3.14

(2) cv2.putText() 函数

cv2.putText(image_c, s, (i[0]+13, i[1]+13),
                            cv2.FONT_HERSHEY_SIMPLEX,
                            1,
                            (255,0,0),
                            2,
                            cv2.LINE_AA)

使用cv2.putText() 函数在图像上绘制文本。其中,image_c是要绘制文本的图像,s是要绘制的文本,(i[0]+13, i[1]+13)是文本的左下角坐标,cv2.FONT_HERSHEY_SIMPLEX是字体,1是字体大小,(255,0,0)是文本颜色,2是线宽,cv2.LINE_AA是抗锯齿标志。

cv2.putText()是OpenCV库中的一个函数,用于在图像上绘制文本。

语法如下:

cv2.putText(img, text, org, fontFace, fontScale, color[, thickness[, lineType[, bottomLeftOrigin]]])
  • img : 要绘制文本的图像。
  • text : 要绘制的文本。
  • org : 文本左下角的坐标。
  • fontFace : 字体。可以是以下字体之一:cv2.FONT_HERSHEY_SIMPLEX, cv2.FONT_HERSHEY_PLAIN, cv2.FONT_HERSHEY_DUPLEX, cv2.FONT_HERSHEY_COMPLEX, cv2.FONT_HERSHEY_TRIPLEX, cv2.FONT_HERSHEY_COMPLEX_SMALL, cv2.FONT_HERSHEY_SCRIPT_SIMPLEX or cv2.FONT_HERSHEY_SCRIPT_COMPLEX
  • fontScale : 字体大小。
  • color : 文本颜色。
  • thickness : 文本的线宽。
  • lineType : 线条类型(如8邻域连接或抗锯齿)。
  • bottomLeftOrigin : 一个布尔值,如果为真,则坐标是文本的左下角,如果为假,则坐标是文本的左上角。

在这段代码中,cv2.putText(image_c, s, (i[0]+13, i[1]+13), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,0,0), 2, cv2.LINE_AA) 函数在image_c这个图像上绘制文本s, 文本左下角坐标为(i[0]+13, i[1]+13), 字体是 cv2.FONT_HERS。

为什么我(255,0,0)画出来的文字是蓝色的?

一般来说,在OpenCV中如果使用的是 BGR 颜色空间,那么 (255,0,0) 代表的颜色是蓝色,而不是红色。

图1 轮胎图片的文本框检索

这段代码可能用于在图像中标记识别出的文本框,并在每个文本框的左上角显示识别结果的置信度,方便用户了解识别结果的可信度。


(3) cv2.line()函数

cv2.line()是OpenCV库中的一个函数,用于在图像上绘制直线。语法如下:

cv2.line(img, start_point, end_point, color, thickness)
  • img : 要绘制直线的图像。
  • start_point : 直线的起始坐标。
  • end_point : 直线的终止坐标。
  • color : 直线的颜色。
  • thickness : 直线的线宽。

在这段代码中,

cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2)
cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2)
cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2)
cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2)

这几行代码在图像image_c中绘制了四条直线, 这四条直线构成了一个矩形,线条颜色为(0, 0, 255),线宽为2, i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7]是四个顶点坐标。

这个函数非常简单,可以在图像上绘制一条或多条直线,在图像处理中经常使用。


2. 文本框进行扩展操作

        if expand:
            for idx in range(len(text)):
                text[idx][0] = max(text[idx][0] - 10, 0)
                text[idx][2] = min(text[idx][2] + 10, w - 1)
                text[idx][4] = max(text[idx][4] - 10, 0)
                text[idx][6] = min(text[idx][6] + 10, w - 1)

这段代码中,如果 expand 变量为真,那么将对文本框进行扩展。具体来说,对于每一个文本框,通过修改文本框的四个顶点坐标来实现扩展。

这段代码中, 就是将每个文本框的X轴方向的边框扩大10个像素,以便更好地识别文本。

为什么只调整0,2,4,6?


这里只调整文本框的0,2,4,6点,是因为文本框的顶点坐标是按照顺时针顺序来存储的。通常情况下,一个文本框的顶点坐标会存储为(x1,y1,x2,y2,x3,y3,x4,y4)

  • (x1,y1) 为左上角坐标
  • (x2,y2) 为右上角坐标
  • (x3,y3) 为右下角坐标
  • (x4,y4) 为左下角坐标

这样,0,2,4,6点就对应了文本框四个顶点的x坐标。
所以当expand为真时,通过更改文本框四个顶点的x坐标来扩大文本框,使文本框更容易被识别。


3. 文本框进行NMS操作

# nms
        select_anchor = select_anchor[keep_index]
        select_score = select_score[keep_index]
        select_score = np.reshape(select_score, (select_score.shape[0], 1))
        nmsbox = np.hstack((select_anchor, select_score))
        keep = nms(nmsbox, 0.3)
        # print(keep)
        select_anchor = select_anchor[keep]
        select_score = select_score[keep]

这段代码是进行NMS处理的一段,用了两个numpy的函数。

首先是筛选掉不符合条件的边界框:select_anchor = select_anchor[keep_index]select_score = select_score[keep_index]

接着是数组的形状变换:select_score = np.reshape(select_score, (select_score.shape[0], 1)),即将一维数组转为二维数组,使得数组每一行只有一个元素,方便后面的合并。

最后是数组的合并:nmsbox = np.hstack((select_anchor, select_score)),通过hstack函数,把选择的边界框和对应的分数合并到一个数组中,作为NMS处理的输入。

最后一步是调用NMS处理函数:keep = nms(nmsbox, 0.3),将NMS处理的结果保留下来。


(1) 非极大值抑制函数def nms(dets, thresh):

def nms(dets, thresh):
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]
    return keep

a. order = scores.argsort()[::-1]

order = scores.argsort()[::-1] 这一行的作用是将 scores 数组从大到小排序,并得到排序后元素的索引

例如,如果 scores = [0.8, 0.7, 0.9, 0.6],那么 scores.argsort() 就是 [3, 1, 0, 2],也就是第一个元素的索引是3,第二个元素的索引是1,以此类推。

scores.argsort()[::-1] 就是将上述索引数组倒序,得到 [2, 0, 1, 3]。

[::-1]是什么意思?

[::-1]是列表的切片操作,代表从后往前逆向取整个列表中的元素,即将列表逆序。在这里,scores.argsort()表示将scores列表中的元素从小到大排序后得到的下标列表,[::-1]则将这个下标列表逆序,也就是将scores中从大到小排序后得到的下标。argsort是argument sort的缩写。


b. xx1 = np.maximum(x1[i], x1[order[1:]])

np.maximum() 函数是 Numpy 中的元素对比函数,它的作用是逐个位置判断两个数组中的元素,输出两个数组中该位置元素的最大值。

语法:

numpy.maximum(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj]) = <ufunc 'maximum'>
  • x1, x2 : array_like
    输入数组,形状必须相同。
  • out : ndarray, None, or tuple of ndarray and None, optional
    如果提供了,将结果存储在此数组中。
  • where : array_like, optional
    可选,当计算结果时需要替换的位置。
  • casting : {‘no’, ‘equiv’, ‘safe’, ‘same_kind’, ‘unsafe’}, optional
    可选,类型转换的方式。
  • order : {‘K’, ‘A’, ‘C’, ‘F’}, optional
    可选,输出数组的存储方式。
  • dtype : data-type, optional
    可选,数组元素的数据类型。
import numpy as np

a = np.array([1, 4, 2, 5, 3])
b = np.array([2, 5, 3, 1, 4])
c = np.maximum(a, b)
print(c)

# Output:
# [2 5 3 5 4]

注意两个数组有必要array_like


c. inds = np.where(ovr <= thresh)[0]

Tuple 是 Python 中的一种数据结构,它是用圆括号括起来的一系列元素的集合,是不可变的,支持多种数据类型(整数、字符串、列表等)。Tuple 元素之间以逗号分隔。可以用索引访问 tuple 中的元素。

Tuple是一种不可变的序列类型,它可以保存任意多个不同类型的数据项。下面是一个关于Tuple的示例:

a = (1, 'hello', [3, 4], 5.6)

这是一个元素为整数、字符串、列表和浮点数的Tuple。

我们举例说明一下np.where()的用法

import numpy as np

a = np.array([3, 4, 5, 6, 7, 8])
indices = np.where(a > 5)

print(indices)
'''
(array([3, 4, 5], dtype=int64),)
'''

print(type(indices))
'''

'''

print(indices[0])
'''
[3 4 5]
'''

print(type(indices[1]))
#IndexError: tuple index out of range

print(type(indices[0]))
#

inds = np.where(ovr <= thresh)[0]

上述NMS代码段中,使用了np.where()[0],是因为该函数返回的是一个Tuple类型的数据,为了能够让index是标准的numpy.ndarray类型,必须要选定Tuple当中的第一个元素。


(2) np.hstack()函数

np.hstack()是numpy中用于水平堆叠数组的函数。它将多个数组沿着**水平方向(列方向)**堆叠在一起,形成一个新的数组。如果想在垂直方向堆叠数组,可以使用 np.vstack() 函数。

假设有两个二维数组A和B,其中A为 3 × 4 3\times4 3×4的数组,数组元素如下:

[[1, 2, 3, 4],
 [5, 6, 7, 8],
 [9,10,11,12]]

B为 3 × 4 3\times4 3×4的数组,数组元素如下:

[[13, 14, 15, 16],
 [17, 18, 19, 20],
 [21, 22, 23, 24]]

调用 np.vstack((A, B)) 后会得到一个新数组C,C为 6 × 4 6\times4 6×4的数组,数组元素如下:

[[ 1,  2,  3,  4],
 [ 5,  6,  7,  8],
 [ 9, 10, 11, 12],
 [13, 14, 15, 16],
 [17, 18, 19, 20],
 [21, 22, 23, 24]]

可以看出,在垂直方向上将A和B两个数组堆叠在一起得到了新数组C。

如果你使用 np.hstack((A,B)) 则会将A和B在水平方向上堆叠起来。数组C为 3 × 8 3\times8 3×8的数组。

数组A和B行列不对齐可以拼接吗?

A和B两个数组如果不是同一类型或者不具有相同的行列,那么np.vstack((A, B))会报错,因为在这种情况下两个数组不能够直接拼接在一起。

这里使用 np.hstack((select_anchor, select_score)) 将候选框坐标数组select_anchor和得分数组select_score沿着列方向堆叠在一起,形成一个新的数组nmsbox。nmsbox的每一行都是一个完整的框,包含了坐标和得分。


4. 显示图片函数

def dis(image):
    cv2.imshow('image', image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

cv2.waitKey(0) 是 OpenCV 中的函数,它可以等待特定的按键事件。参数 0 表示等待键盘输入,直到按下任意按键才继续执行程序。可以通过更改参数的值来控制等待的时长(单位为毫秒)。

cv2.destroyAllWindows()是OpenCV中的函数,用于关闭所有OpenCV窗口。一般用于在显示图像后,手动关闭窗口时使用。


二、附录

#-*- coding:utf-8 -*-
#'''
# Created on 18-12-11 上午10:03
#
# @Author: Greg Gao(laygin)
#'''
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import numpy as np

import torch
import torch.nn.functional as F
from detect.ctpn_model import CTPN_Model
from detect.ctpn_utils import gen_anchor, bbox_transfor_inv, clip_box, filter_bbox,nms, TextProposalConnectorOriented
from detect.ctpn_utils import resize
from detect import config

prob_thresh = 0.5
height = 720
gpu = True
if not torch.cuda.is_available():
    gpu = False
device = torch.device('cuda:0' if gpu else 'cpu')
weights = os.path.join(config.checkpoints_dir, 'CTPN.pth')
model = CTPN_Model()
model.load_state_dict(torch.load(weights, map_location=device)['model_state_dict'])
model.to(device)
model.eval()


def dis(image):
    cv2.imshow('image', image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


def get_det_boxes(image,display = True, expand = True):
    image = resize(image, height=height)
    image_r = image.copy()
    image_c = image.copy()
    h, w = image.shape[:2]
    image = image.astype(np.float32) - config.IMAGE_MEAN
    image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float()

    with torch.no_grad():
        image = image.to(device)
        cls, regr = model(image)
        cls_prob = F.softmax(cls, dim=-1).cpu().numpy()
        regr = regr.cpu().numpy()
        anchor = gen_anchor((int(h / 16), int(w / 16)), 16)
        bbox = bbox_transfor_inv(anchor, regr)
        bbox = clip_box(bbox, [h, w])
        # print(bbox.shape)

        fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0]
        # print(np.max(cls_prob[0, :, 1]))
        select_anchor = bbox[fg, :]
        select_score = cls_prob[0, fg, 1]
        select_anchor = select_anchor.astype(np.int32)
        # print(select_anchor.shape)
        keep_index = filter_bbox(select_anchor, 16)

        # nms
        select_anchor = select_anchor[keep_index]
        select_score = select_score[keep_index]
        select_score = np.reshape(select_score, (select_score.shape[0], 1))
        nmsbox = np.hstack((select_anchor, select_score))
        keep = nms(nmsbox, 0.3)

        # print(keep)
        select_anchor = select_anchor[keep]
        select_score = select_score[keep]

        # text line-
        textConn = TextProposalConnectorOriented()#定义了一个类,用于将Proposal转化为文本框
        text = textConn.get_text_lines(select_anchor, select_score, [h, w])

        # expand text
        if expand:
            for idx in range(len(text)):
                text[idx][0] = max(text[idx][0] - 10, 0)
                text[idx][2] = min(text[idx][2] + 10, w - 1)
                text[idx][4] = max(text[idx][4] - 10, 0)
                text[idx][6] = min(text[idx][6] + 10, w - 1)


        # print(text)
        if display:
            blank = np.zeros(image_c.shape,dtype=np.uint8)
            for box in select_anchor:
                pt1 = (box[0], box[1])
                pt2 = (box[2], box[3])
                blank = cv2.rectangle(blank, pt1, pt2, (50, 0, 0), -1)
            image_c = image_c+blank
            image_c[image_c>255] = 255
            for i in text:
                s = str(round(i[-1] * 100, 2)) + '%'
                i = [int(j) for j in i]
                cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2)
                cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2)
                cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2)
                cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2)
                cv2.putText(image_c, s, (i[0]+13, i[1]+13),
                            cv2.FONT_HERSHEY_SIMPLEX,
                            1,
                            (255,0,0),
                            2,
                            cv2.LINE_AA)
            # dis(image_c)
        # print(text)
        return text,image_c,image_r

if __name__ == '__main__':
    img_path = 'images/t1.png'
    image = cv2.imread(img_path)
    text,image,_ = get_det_boxes(image)
    print(text)
    dis(image)

你可能感兴趣的:(机器学习,python,opencv,计算机视觉)