ICPR MTWI 2018 挑战赛二:网络图像的文本检测

目录

1、dataset

1.1、展示多张图像

1.2、显示图像对应的标注框

1.3、解决cv2显示中文乱码的情况

2、构建数据处理模块

3、SSD模型构建

4、构建训练模块

1、dataset

        在互联网世界中,图片是传递信息的重要媒介。特别是电子商务,社交,搜索等领域,每天都有数以亿兆级别的图像在传播。图片文字识别(OCR)在商业领域有重要的应用价值,是数据信息化和线上线下打通的基础,也是学术界的研究热点。然而,研究领域尚没有基于网络图片的、以中文为主的OCR数据集。本竞赛将公开基于网络图片的中英混合数据集,该数据集数据量充分,涵盖几十种字体,几个到几百像素字号,多种版式,较多干扰背景。期待学术界可以在本数据集上作深入的研究,工业界可以藉此发展基于OCR的图片管控,搜索,信息录入等AI领域的工作。

        我们提供20000张图像作为本次比赛的数据集。其中50%用来作为训练集,50%用来作为测试集。该数据集全部来源于网络图像,主要由合成图像,产品描述,网络广告构成。典型的图片如图1所示:

展示一张图像和对应的标注的文件:

ICPR MTWI 2018 挑战赛二:网络图像的文本检测_第1张图片

ICPR MTWI 2018 挑战赛二:网络图像的文本检测_第2张图片

1.1、展示多张图像

def show_multi_img(imgpath, num):
    """
    :param imgpath: 图像地址
    :param num: 输出图像的数量:eg:6*6,一幅图展示36张
    :return:
    """
    img_path = glob.glob(imgpath + "/*")
    plt.figure()
    for i in range(1, num * num + 1):
        img = cv2.imread(img_path[i])
        title = img_path[i].split("\\")[1]
        plt.subplot(num, num, i)
        plt.imshow(img)
        plt.title(title, fontsize=6)
        plt.xticks([])
        plt.yticks([])
        plt.axis("on")
        plt.savefig("./final.png")
    plt.show()

ICPR MTWI 2018 挑战赛二:网络图像的文本检测_第3张图片

1.2、显示图像对应的标注框

def show_box(input_img:str):
    """
    :param input_img: 输入图像路径
    :return: 显示所有文本框
    """
    tmp = input_img.split("/")[2][:-4] # image_name
    img = cv2.imread(input_img)
    img_bbox_dir = "./txt_train" # txt文件路径
    with open(os.path.join(img_bbox_dir,tmp)+".txt","r",encoding="utf-8") as f:
        lines = f.readlines()
        for line in lines:
            print(line)
            x1 = float(line.split(',')[0])
            y1 = float(line.split(',')[1])
            x2 = float(line.split(',')[2])
            y2 = float(line.split(',')[3])
            x3 = float(line.split(',')[4])
            y3 = float(line.split(',')[5])
            x4 = float(line.split(',')[6])
            y4 = float(line.split(',')[7])
            text = line.split(',')[8]
            # 采用左上角坐标和右下角坐标
            cv2.rectangle(img,(int(x1),int(y1)),(int(x3),int(y3)),(255,0,255),2)
            # putText只能显示英文字符,中文会出现乱码情况,中文会显示问号?????
            cv2.putText(img, str(text), (int(x1), int(y1) - 5),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 2)
        cv2.imwrite("plot_bbox.png",img)
        cv2.imshow("show", img)
        cv2.waitKey(0)

ICPR MTWI 2018 挑战赛二:网络图像的文本检测_第4张图片

1.3、解决cv2显示中文乱码的情况

def cv2AddChineseText(input_path:str,textColor=(0, 255, 0), textSize:str=15):
    """
    解决中文显示错误的问题
    :param input_path:输入图像路径
    :param textColor:文本的颜色
    :param textSize:文本字的大小
    :return:
    """
    tmp = input_img.split("/")[2][:-4]  # image_name
    img = cv2.imread(input_img)
    if (isinstance(img, np.ndarray)):  # 判断是否OpenCV图片类型
        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    img_bbox_dir = "./txt_train"  # txt文件路径
    # 创建一个可以在给定图像上绘图的对象
    draw = ImageDraw.Draw(img)
    # 字体的格式
    fontStyle = ImageFont.truetype("simsun.ttc", textSize, encoding="utf-8")
    with open(os.path.join(img_bbox_dir, tmp) + ".txt", "r", encoding="utf-8") as f:
        lines = f.readlines()
        for line in lines:
            # print(line)
            x1 = float(line.split(',')[0])
            y1 = float(line.split(',')[1])
            x2 = float(line.split(',')[2])
            y2 = float(line.split(',')[3])
            x3 = float(line.split(',')[4])
            y3 = float(line.split(',')[5])
            x4 = float(line.split(',')[6])
            y4 = float(line.split(',')[7])
            text = line.split(',')[8]
            # print(text)
            shape = [(x1,y1),(x3,y3)]
            draw.rectangle(shape,fill=None,outline="red")
            # 绘制文本
            draw.text((int(x1)-5,int(y1)),str(text), textColor, font=fontStyle)
        img = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR)
        cv2.imshow('show_chinese', img)
        cv2.waitKey(0)

ICPR MTWI 2018 挑战赛二:网络图像的文本检测_第5张图片

2、构建数据处理模块

import os
import torch
import torch.utils.data as data
import cv2
import numpy as np
from PIL import Image
num_classes = ("text",)
from data import *

class AnnoTransform(object):
    def __init__(self,class_to_ind=None,keep_difficult=False):
        """
        Args:
            class_to_ind: 类别索引
            keep_difficult: 是否保留difficult=1的物体
        """
        # {“text”:0,...}
        self.class_to_ind = class_to_ind or dict(zip(num_classes,range(len(num_classes))))
        self.keep_difficult = keep_difficult

    def __call__(self,target,width,height):
        """
        Args:
            target: 每幅图像的标签
            width: 图像的宽度
            height: 图像的高度
        Returns:[[xmin, ymin, xmax, ymax, label_ind], ... ]
        """
        list_target = []
        for tar in target:
            # 对标签中的每一行进行读取,这里的target是列表,通过,进行分割获取每个坐标点的值
            tar_list = tar.strip().split(",")
            # 这里的#表示图像标签的中的类别名,是特殊符号,遇到就跳过
            if tar_list[8] == "###":
                continue
            name = "text"
            bndbox = []
            # 获取每一行的坐标点
            for i,points in enumerate(tar_list[:8]):
                pt = float(points)
                # 对坐标进行尺度变换
                if i % 2 == 0:
                    pt/width
                else:
                    pt/height
                bndbox.append(pt)
            p12 = (bndbox[0]-bndbox[2],bndbox[1] - bndbox[3])
            p23 = (bndbox[2]-bndbox[4],bndbox[3] - bndbox[5])
            if p12[0]*p23[1] - p12[1]*p23[0] < 0:
                bndbox[0:7:2] = bndbox[6::-2]
                bndbox[1:8:2] = bndbox[7::-2]
            label_idx = self.class_to_ind[name]
            bndbox.append(label_idx)
            list_target += [bndbox] # [x1, y1, x2, y2, x3, y3, x4, y4, label_ind]
        return list_target

# 继承torch.utils.data.Dataset
class MyDataset(data.Dataset):
    def __init__(self,root,transform=None,target_transform=AnnoTransform()):
        self.root = root # 数据集的路径
        self.transform = transform
        self.target_transform = target_transform
        self.annotation_path = os.path.join(self.root,'txt_train','{}.txt') # 图像标注文件的地址
        self.imgpath = os.path.join(self.root,"image_train", "{}.jpg") # 图像路径
        self.idx = list()
        for txt in os.listdir(os.path.join(self.root,'txt_train')):
            self.idx.append(txt.replace('.txt',''))

    def __getitem__(self,index):
        img_id = self.idx[index]
        with open(self.annotation_path.format(img_id),'r',encoding='utf-8') as f:
            target = f.readlines()
        """
        注意:cv2读取的图像是BGR格式,Image.open()默认的是RGB
        可以通过isinstance(img,np.ndarray)
        """
        img = cv2.imread(self.imgpath.format(img_id))
        isinstance(img,np.ndarray) #判断是否为opencv格式
        #img = Image.open(self.imgpath.format(img_id))
        try:
            height,width,channels = img.shape
        except AttributeError as e:
            print(img_id)
            height,width,channels = img.shape

        if self.target_transform is not None:
            target = self.target_transform(target,width,height)

        if self.transform is not None:
            target = np.array(target)
            img,boxes,labels = self.transform(img,target[:,:8],target[:,8])
            img = img[:,:,(2,1,0)]
            # img = img.transpose(2,0,1) # 将图像转换为RGB格式
            # np.hstack水平堆叠数组,np.expand_dims扩张维度
            target = np.hstack((boxes,np.expand_dims(labels,axis=1)))
        return torch.from_numpy(img).permute(2,0,1),target,height,width

    def __len__(self):
        return len(self.idx)


if __name__ == "__main__":
    root = "./data/datasets"
    data_1 = MyDataset(root)
    data_loader = data.DataLoader(data_1, 1, num_workers=0, shuffle=False,pin_memory=True)
    print(len(data_loader))
    for tp in data_loader:
        img,target,h,w = tp
        print(img)
    # batch_iterator = iter(data_loader)
    # images, targets = next(batch_iterator)

3、SSD模型构建

 此处参考博客:目标检测系列——SSD

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from layers import *
import os

from torch.autograd import Function
from ..box_utils import decode, nms

class Detect(Function):
    """At test time, Detect is the final layer of SSD.  Decode location preds,
    apply non-maximum suppression to location predictions based on conf
    scores and threshold to a top_k number of output predictions for both
    confidence score and locations.
    """
    def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh):
        self.num_classes = num_classes
        self.background_label = bkg_label
        self.top_k = top_k
        # Parameters used in nms.
        self.nms_thresh = nms_thresh
        if nms_thresh[0] <= 0:
            raise ValueError('nms_threshold must be non negative.')
        self.conf_thresh = conf_thresh
        self.variance = cfg['variance']

    def forward(self, loc_data, conf_data, prior_data):
        """
        Args:
            loc_data: (tensor) Loc preds from loc layers
                Shape: [batch,num_priors*4]
            conf_data: (tensor) Shape: Conf preds from conf layers
                Shape: [batch*num_priors,num_classes]
            prior_data: (tensor) Prior boxes and variances from priorbox layers
                Shape: [1,num_priors,4]
        """
        num = loc_data.size(0)  # batch size
        num_priors = prior_data.size(0)
        output = torch.zeros(num, self.num_classes, self.top_k, 13)
        conf_preds = conf_data.view(num, num_priors,
                                    self.num_classes).transpose(2, 1)

        # Decode predictions into bboxes.
        for i in range(num):
            decoded_boxes = decode(loc_data[i], prior_data, self.variance)
            # For each class, perform nms
            conf_scores = conf_preds[i].clone()

            for cl in range(1, self.num_classes):
                # 大于阈值的设置为1,并把预测分数付给scores
                c_mask = conf_scores[cl].gt(self.conf_thresh)
                scores = conf_scores[cl][c_mask]
                if scores.size(0) == 0:
                    continue
                l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
                boxes = decoded_boxes[l_mask].view(-1, 12)
                # idx of highest scoring and non-overlapping boxes per class
                ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
                output[i, cl, :count] = \
                    torch.cat((scores[ids[:count]].unsqueeze(1),
                               boxes[ids[:count]]), 1)
        flt = output.contiguous().view(num, -1, 13)
        _, idx = flt[:, :, 0].sort(1, descending=True)
        _, rank = idx.sort(1)
        flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
        return output

# M表示最大池化,C表示带有ceil_mode的最大池化
base = {
    '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
            512, 512, 512],
    '384': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
            512, 512, 512],
    '768': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
            512, 512, 512],
}
# 进行尺度变换
extras = {
    '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],
    '384': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],
    '768': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],
}
# 对应上图中def.boxes中的数量
mbox = {
    '384': [6, 8, 8, 8, 6, 6],  # number of boxes per feature map location
    '768': [6, 8, 8, 8, 6, 6],
}


class SSD(nn.Module):
    def __init__(self, phase, size, base, extras, head, num_classes):
        super(SSD, self).__init__()
        self.phase = phase
        self.num_classes = num_classes
        if size == 384:
            self.cfg = mtwi384
        elif size == 768:
            self.cfg = mtwi768
        self.priorbox = PriorBox(self.cfg)
        with torch.no_grad():
            self.priors = Variable(self.priorbox.forward())
        self.size = size

        # SSD network
        self.vgg = nn.ModuleList(base)
        # Layer learns to scale the l2 normalized features from conv4_3
        self.L2Norm = L2Norm(512, 20)
        self.extras = nn.ModuleList(extras)

        self.loc = nn.ModuleList(head[0])
        self.conf = nn.ModuleList(head[1])

        if phase == 'test':
            self.softmax = nn.Softmax(dim=-1)
            self.detect = Detect(num_classes, 0, 200, 0.01, (0.8, 0.1))

    def forward(self, x):
        """Applies network layers and ops on input image(s) x.

        Args:
            x: input image or batch of images. Shape: [batch,3,300,300].

        Return:
            Depending on phase:
            test:
                Variable(tensor) of output class label predictions,
                confidence score, and corresponding location predictions for
                each object detected. Shape: [batch,topk,7]

            train:
                list of concat outputs from:
                    1: confidence layers, Shape: [batch*num_priors,num_classes]
                    2: localization layers, Shape: [batch,num_priors*4]
                    localization layers, Shape: [batch,num_priors*8]
                    3: priorbox layers, Shape: [2,num_priors*4], center-offset form
        """
        sources = list()
        loc = list()
        conf = list()

        # apply vgg up to conv4_3 relu
        # 获得conv4_3的内容
        for k in range(23):
            x = self.vgg[k](x)

        s = self.L2Norm(x)
        sources.append(s)

        # apply vgg up to fc7
        # 获得fc7的内容
        for k in range(23, len(self.vgg)):
            x = self.vgg[k](x)
        sources.append(x)

        # apply extra layers and cache source layer outputs
        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True)
            if k % 2 == 1:
                sources.append(x)

        # apply multibox head to source layers
        # 添加回归层和分类层
        for (x, l, c) in zip(sources, self.loc, self.conf):
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf.append(c(x).permute(0, 2, 3, 1).contiguous())

        # 进行resize
        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
        if self.phase == "test":
            # loc会resize到batch_size,num_anchors,4
            # conf会resize到batch_size,num_anchors,num_classes
            output = self.detect.forward(
                loc.view(loc.size(0), -1, 12),                   # loc preds
                self.softmax(conf.view(conf.size(0), -1,
                             self.num_classes)),                # conf preds
                self.priors.type(type(x.data))                  # default boxes
            )
        else:
            output = (
                loc.view(loc.size(0), -1, 12),
                conf.view(conf.size(0), -1, self.num_classes),
                self.priors
            )
        return output

    def load_weights(self, base_file):
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            self.load_state_dict(torch.load(base_file,
                                 map_location=lambda storage, loc: storage))
            print('Finished!')
        else:
            print('Sorry only .pth and .pkl files supported.')


def vgg(cfg, i, batch_norm=False):
    layers = []
    in_channels = i
    for v in cfg:
        # M表示采用最大池化
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        # 这个也是最大池化,只不过是采用ceil_mode的最大池化
        elif v == 'C':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    # 最大池化,不会进行高和宽的压缩
    # 19*19*512-》19*19*512
    pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
    # 19*19*512-》19*19*1024
    conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
    # 19*19*1024
    conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
    # 附加激活函数
    layers += [pool5, conv6,
               nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
    return layers


def multibox(vgg, extra_layers, cfg, num_classes):
    loc_layers = []
    conf_layers = []
    vgg_source = [21, -2]
    for k, v in enumerate(vgg_source):
        loc_layers += [nn.Conv2d(vgg[v].out_channels,
                                 cfg[k] * 12, kernel_size=(3, 5), padding=(1, 2))]
        conf_layers += [nn.Conv2d(vgg[v].out_channels,
                        cfg[k] * num_classes, kernel_size=(3, 5), padding=(1, 2))]
    for k, v in enumerate(extra_layers[1::2], 2):
        loc_layers += [nn.Conv2d(v.out_channels, cfg[k]
                                 * 12, kernel_size=(3, 5), padding=(1, 2))]
        conf_layers += [nn.Conv2d(v.out_channels, cfg[k]
                                  * num_classes, kernel_size=(3, 5), padding=(1, 2))]
    return vgg, extra_layers, (loc_layers, conf_layers)

def add_extras(cfg, i, batch_norm=False):
    # Extra layers added to VGG for feature scaling
    layers = []
    in_channels = i
    flag = False
    for k, v in enumerate(cfg):
        if in_channels != 'S':
            if v == 'S':
                layers += [nn.Conv2d(in_channels, cfg[k + 1],
                           kernel_size=(1, 3)[flag], stride=2, padding=1)]
            else:
                layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
            flag = not flag
        in_channels = v
    return layers


def build_ssd(phase, size=384, num_classes=21):
    if phase != "test" and phase != "train":
        print("ERROR: Phase: " + phase + " not recognized")
        return
    base_, extras_, head_ = multibox(vgg(base[str(size)], 3),
                                     add_extras(extras[str(size)], 1024),
                                     mbox[str(size)], num_classes)
    return SSD(phase, size, base_, extras_, head_, num_classes)

4、构建训练模块

 

from data import mtwi2018,MTWIDetection
from data import *
from utils.augmentations import SSDAugmentation
from layers.modules import MultiBoxLoss
from ssd import build_ssd
import os
import sys
import time
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torch.utils.data as data
import numpy as np

save_folder = "weights/vgg16_reducedfc.pth"
# 使用cuda向量
if torch.cuda.is_available():
    print("使用cuda!!!")
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

if not os.path.exists(save_folder):
    os.mkdir(save_folder)

def xavier(param):
    init.xavier_uniform_(param)

def weights_init(m):
    # 权重初始化操作
    if isinstance(m, nn.Conv2d):
        xavier(m.weight.data)
        m.bias.data.zero_()

def train():
    print("starting training......")
    dataset_root = './data/datasets'
    MEANS = (104, 117, 123)
    dataset = MTWIDetection(dataset_root,transform=SSDAugmentation(384,MEANS))

    ssd_net = build_ssd('train',384,2)
    #print(ssd_net) #打印网络

    net = torch.nn.DataParallel(ssd_net).cuda()
    cudnn.benchmark = True

    vgg_weights = torch.load(save_folder)
    print('Loading base network...')
    ssd_net.vgg.load_state_dict(vgg_weights)
    start_iter = 0

    print('Initializing weights...')
    # xavier method
    ssd_net.extras.apply(weights_init)
    ssd_net.loc.apply(weights_init)
    ssd_net.conf.apply(weights_init)

    optimizer = optim.SGD(net.parameters(),lr=1e-4, momentum=0.9,weight_decay=5e-4)
    criterion = MultiBoxLoss(2,0.5,True,0,True,3,0.5,False,True)

    net.train()

    loc_loss = 0
    conf_loss = 0
    batch_size = 16
    epoch = 0
    epoch_size = len(dataset) // batch_size
    step_index = 0

    data_loader = data.DataLoader(dataset, batch_size,num_workers=0,shuffle=False, collate_fn=detection_collate,pin_memory=True)

    # create batch iterator
    batch_iterator = iter(data_loader)
    for iteration in range(start_iter,100):
        if iteration != 0 and (iteration % epoch_size == 0):
            #reset epoch loss counters
            loc_loss = 0
            conf_loss = 0
            epoch += 1

        try:
            images, targets = next(batch_iterator)
        except StopIteration:
            batch_iterator = iter(data_loader)
            images, targets = next(batch_iterator)

        images = Variable(images.cuda())
        with torch.no_grad():
            targets = [Variable(ann.cuda()) for ann in targets]
        # forward
        t0 = time.time()
        out = net(images)
        # backprop
        optimizer.zero_grad()
        loss_l, loss_c = criterion(out, targets)
        loss = loss_l + loss_c
        loss.backward()
        optimizer.step()
        t1 = time.time()
        loc_loss += loss_l.item()
        conf_loss += loss_c.item()

        if iteration % 10 == 0:
            print('timer: %.4f sec.' % (t1 - t0))
            print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.item()), end=' ')

        if iteration % 10 == 0:
            # net保存的是模型参数,而optimizer保存的是优化器
            checkpoint = {
                'iteration': iteration,
                'model_state_dict': net.state_dict()
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(checkpoint, 'checkpoint/ssd_mtwi.pth')
            print('save checkpoint')


        if iteration != 0 and iteration % 100 == 0:
            print('Saving state, iter:', iteration)
            torch.save(ssd_net.state_dict(), 'weights/ssd_mtwi_' +repr(iteration) + '.pth')

    torch.save(ssd_net.state_dict(),save_folder + 'mtwi.pth')


if __name__ == '__main__':
    train()

你可能感兴趣的:(Pytorch,Pytorch深度学习,人工智能,计算机视觉,python,MTWI2018,opencv)