利用pytorch的affine_grid和grid_sample实现rroi_align

原始图片:
利用pytorch的affine_grid和grid_sample实现rroi_align_第1张图片

import random
import math
import torch
import numpy as np
import torch.nn.functional as F
import cv2
import matplotlib.pyplot as plt
from data_gen import draw_box_points

path = './test/timg.jpeg'
im_data = cv2.imread(path)
img = im_data.copy()
# plt.imshow(im_data)
# plt.show()

# 参数设置
debug = True
norm_height = 44
gt = np.asarray([[205,150],[202,126],[365,93],[372,111]])
im_data = torch.from_numpy(im_data).unsqueeze(0)

im_data = im_data.permute(0,3,1,2).to(torch.float)

center = (gt[0, :] + gt[1, :] + gt[2, :] + gt[3, :]) / 4
dw = gt[2, :] - gt[1, :]
dh =  gt[1, :] - gt[0, :] 

w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])
h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1])  + random.randint(-2, 2)


angle_gt = ( math.atan2((gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2((gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0]) ) / 2
        
input_W = im_data.size(3)
input_H = im_data.size(2)
target_h = norm_height  

scale = target_h / h 
target_gw = int(w * scale) + random.randint(0, int(target_h)) 
target_gw = max(8, int(round(target_gw / 4)) * 4) 

xc = center[0] 
yc = center[1] 
w2 = w 
h2 = h 


scalex =  (w2 + random.randint(0, int(h2))) / input_W 
scaley = h2 / input_H 

th11 =  scalex * math.cos(angle_gt)
th12 = -math.sin(angle_gt) * scaley
th13 =  (2 * xc - input_W - 1) / (input_W - 1) #* torch.cos(angle_var) - (2 * yc - input_H - 1) / (input_H - 1) * torch.sin(angle_var)

th21 = math.sin(angle_gt) * scalex 
th22 =  scaley * math.cos(angle_gt)  
th23 =  (2 * yc - input_H - 1) / (input_H - 1) #* torch.cos(angle_var) + (2 * xc - input_W - 1) / (input_W - 1) * torch.sin(angle_var)

t = np.asarray([th11, th12, th13, th21, th22, th23], dtype=np.float)
t = torch.from_numpy(t).type(torch.FloatTensor)
theta = t.view(-1, 2, 3)

grid = F.affine_grid(theta, torch.Size((1, 3, int(target_h ), int(target_gw))))
x = F.grid_sample(im_data, grid)


if debug:
    x_c = x.data.cpu().numpy()[0]
    x_data_draw = x_c.swapaxes(0, 2)
    x_data_draw = x_data_draw.swapaxes(0, 1)

    x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
    x_data_draw = x_data_draw[:, :, ::-1]

    cv2.circle(img, (int(center[0]), int(center[1])), 5, (0, 255, 0))      
    cv2.imshow('im_data', x_data_draw)

    # draw_box_points(img, pts)
    draw_box_points(img, gt, color=(0, 0, 255))

    cv2.imshow('img', img)
    cv2.waitKey(100)

裁剪出来的图片:
效果还是有的,但是采用了pytorch的affine_grid和grid_sample,并不知道theta矩阵的计算方式。
在这里插入图片描述
在这里插入图片描述
将文字区域调整到同样的高度,不同的长度,但是字会出现左右(最左,最右的字)会超出文字区域。

第二种方案

采用rroi_align的方式进行旋转矫正和crop操作,并使用cuda进行运算加速

实现细节

结果

可以看到效果比上面的结果好多了。
在这里插入图片描述
在这里插入图片描述

下一步工作

实现批量的rroi_align操作

你可能感兴趣的:(PYTHON,深度学习,ocr)