薄板样条插值(Thin plate splines)的实现与使用

最近项目用到了tps算法,opencv2封装的tps实现起来比较慢,于是用pytorch实现了一下,可以支持gpu加速,就很nice了,在这里记录一下!

1. 简介

薄板样条函数(TPS)是一种很常见的插值方法。因为它一般都是基于2D插值,所以经常用在在图像配准中。在两张图像中找出N个匹配点,应用TPS可以将这N个点形变到对应位置,同时给出了整个空间的形变(插值)。
薄板样条插值(Thin plate splines)的实现与使用_第1张图片

2. 实现

1. opencv的tps使用

import cv2
import numpy as np
import random
import torch
from torchvision.transforms import ToTensor, ToPILImage

DEVICE = torch.device("cpu")

def choice3(img):
    '''
    产生波浪型文字
    :param img:
    :return:
    '''
    h, w = img.shape[0:2]
    N = 5
    pad_pix = 50
    points = []
    dx = int(w/ (N - 1))
    for i in range( N):
        points.append((dx * i,  pad_pix))
        points.append((dx * i, pad_pix + h))

    #加边框
    img = cv2.copyMakeBorder(img, pad_pix, pad_pix, 0, 0, cv2.BORDER_CONSTANT,
                             value=(int(img[0][0][0]), int(img[0][0][1]), int(img[0][0][2])))

    #原点
    source = np.array(points, np.int32)
    source = source.reshape(1, -1, 2)

    #随机扰动幅度
    rand_num_pos = random.uniform(20, 30)
    rand_num_neg = -1 * rand_num_pos

    newpoints = []
    for i in range(N):
        rand = np.random.choice([rand_num_neg, rand_num_pos], p=[0.5, 0.5])
        if(i == 1):
            nx_up = points[2 * i][0]
            ny_up = points[2 * i][1] + rand
            nx_down = points[2 * i + 1][0]
            ny_down = points[2 * i + 1][1] + rand
        elif (i == 4):
            rand = rand_num_neg if rand > 1 else rand_num_pos
            nx_up = points[2 * i][0]
            ny_up = points[2 * i][1] + rand
            nx_down = points[2 * i + 1][0]
            ny_down = points[2 * i + 1][1] + rand
        else:
            nx_up = points[2 * i][0]
            ny_up = points[2 * i][1]
            nx_down = points[2 * i + 1][0]
            ny_down = points[2 * i + 1][1]

        newpoints.append((nx_up, ny_up))
        newpoints.append((nx_down, ny_down))

    #target点
    target = np.array(newpoints, np.int32)
    target = target.reshape(1, -1, 2)

    #计算matches
    matches = []
    for i in range(1, 2*N + 1):
        matches.append(cv2.DMatch(i, i, 0))

    return source, target, matches, img

def norm(points_int, width, height):
	"""
	将像素点坐标归一化至 -1 ~ 1
    """
	points_int_clone = torch.from_numpy(points_int).detach().float().to(DEVICE)
	x = ((points_int_clone * 2)[..., 0] / (width - 1) - 1)
	y = ((points_int_clone * 2)[..., 1] / (height - 1) - 1)
	return torch.stack([x, y], dim=-1).contiguous().view(-1, 2)


class TPS(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, X, Y, w, h, device):

        """ 计算grid"""
        grid = torch.ones(1, h, w, 2, device=device)
        grid[:, :, :, 0] = torch.linspace(-1, 1, w)
        grid[:, :, :, 1] = torch.linspace(-1, 1, h)[..., None]
        grid = grid.view(-1, h * w, 2)

        """ 计算W, A"""
        n, k = X.shape[:2]
        device = X.device

        Z = torch.zeros(1, k + 3, 2, device=device)
        P = torch.ones(n, k, 3, device=device)
        L = torch.zeros(n, k + 3, k + 3, device=device)

        eps = 1e-9
        D2 = torch.pow(X[:, :, None, :] - X[:, None, :, :], 2).sum(-1)
        K = D2 * torch.log(D2 + eps)

        P[:, :, 1:] = X
        Z[:, :k, :] = Y
        L[:, :k, :k] = K
        L[:, :k, k:] = P
        L[:, k:, :k] = P.permute(0, 2, 1)

        Q = torch.solve(Z, L)[0]
        W, A = Q[:, :k], Q[:, k:]

        """ 计算U """
        eps = 1e-9
        D2 = torch.pow(grid[:, :, None, :] - X[:, None, :, :], 2).sum(-1)
        U = D2 * torch.log(D2 + eps)

        """ 计算P """
        n, k = grid.shape[:2]
        device = grid.device
        P = torch.ones(n, k, 3, device=device)
        P[:, :, 1:] = grid

        # grid = P @ A + U @ W
        grid = torch.matmul(P, A) + torch.matmul(U, W)
        return grid.view(-1, h, w, 2)

if __name__=='__main__':
    # 弯曲水平文本
    img = cv2.imread('data/test.jpg', cv2.IMREAD_COLOR)
    source, target, matches, img = choice3(img)
    # #opencv版tps
    # tps = cv2.createThinPlateSplineShapeTransformer()
    # tps.estimateTransformation(source, target, matches)
    # img = tps.warpImage(img)
    # cv2.imshow('test.png', img)
    # cv2.imwrite('test.png', img)
    # cv2.waitKey(0)

    #torch实现tps
    ten_img = ToTensor()(img).to(DEVICE)
    h, w = ten_img.shape[1], ten_img.shape[2]
    ten_source = norm(source, w, h)
    ten_target = norm(target, w, h)

    tps = TPS()
    warped_grid = tps(ten_target[None, ...], ten_source[None, ...], w, h, DEVICE)   #这个输入的位置需要归一化,所以用norm
    ten_wrp = torch.grid_sampler_2d(ten_img[None, ...], warped_grid, 0, 0)
    new_img_torch = np.array(ToPILImage()(ten_wrp[0].cpu()))

    cv2.imshow('test.png', new_img_torch)
    cv2.imwrite('test.png', new_img_torch)
    cv2.waitKey(0)

3. 效果

  • 贴个效果图对比:
    薄板样条插值(Thin plate splines)的实现与使用_第2张图片
    薄板样条插值(Thin plate splines)的实现与使用_第3张图片

上图可以看出,pytorch实现与cv2的tps的效果完全对齐,所以重点看耗时,接下来贴耗时的对比图(差距还是蛮大的,图片越大差距越大)
薄板样条插值(Thin plate splines)的实现与使用_第4张图片
如果对你有帮助的话,希望给个赞,谢谢~


参考1:TPS 薄板样条插值 python的opencv实现
注,这个参考可以初步了解使用cv2的tps使用,但是具体细节上还存在错误

参考2:薄板样条函数(Thin plate splines)的讨论与分析
参考3:数值方法——薄板样条插值(Thin-Plate Spline)

你可能感兴趣的:(项目)