pytorch---仿射变换

一、仿射变换

图片的旋转、平移、缩放等可以看做一个像素的重采样过程。将原图的像素映射到目标图像的对应位置上,可以

                                                           \begin{bmatrix} x\\ y\\ 1 \end{bmatrix} = \begin{bmatrix} {x}^{s} & {y}^{s} & 1 \end{bmatrix} * \begin{bmatrix} a & b &0 \\ c & d & 0\\ e &f & 1 \end{bmatrix}

   其中{x}^{s} ,{y}^{s}为原图的坐标,x,y为目标图的坐标,该变换称为前向变换,遍历原图像素,求出改像素在目标图像的对应位置。

   前向变换虽然符合逻辑,但是却使得目标图像上很多位置没有对应的像素。因此一种更合理的方式是使用后向变换,即从目标图像出发,遍历目标图像的每个位置,求出每个位置在原图中的对应像素。此时,公式变为:

                                                            \begin{bmatrix} x^{s}\\ y^{s}\\ 1 \end{bmatrix} = \begin{bmatrix} {x} & {y} & 1 \end{bmatrix} *{ \begin{bmatrix} a & b &0 \\ c & d & 0\\ e &f & 1 \end{bmatrix}}^{-1}

  二、pytorch中的仿射变换 

pytorch中就使用的为后向变换。主要涉及两个函数

  • F.affine_grid(theta,size)
  • F.grid_sample(input, grid, mode='bilinear', padding_mode='zeros')

     

1.F.affine_grid根据输入的变换矩阵theta和尺寸利用后向变换求出目标图像每个像素在原图像的位置。

    theta是一个[N,2,3]的tensor,N为batchsize大小;2行3列共六个参数,为affine的变换矩阵,第一行为x坐标,即横坐标的变换参数,前两个为权重,最后一个为偏移,值得注意的是偏移值是一个相对于图像宽归一化的参数a,c,e(并非像素值),例如0.5表示左移半个图像的宽度。第二行表示y坐标的变换参数(b,d,f)。

   size是一个tuple,为(N,C,H,W)

   output为[N,h,w,2]的Tensor,表示在原图中的对应位置。

2. F.grid_sample()为重采样函数,根据输入的原图和位置对应关系矩阵(F.affine_grid的输出)对原图像素进行重采样,构成变换后的图像。由于重采样过程中,在原图中的位置会出现小数,因此需要对原图进行插值,插值方式为可选参数,默认双线性插值。

下面我们来看一个例子:

将图像顺时针旋转45度,注意pytorch使用的为后向变换。

对于前向变换来说,顺时针旋转45度的变换矩阵为\begin{bmatrix} cos\theta &sin \theta & 0\\ -sin \theta& cos\theta & 0 \\ 0& 0 & 1\end{bmatrix},后向变换应该对其求逆。但是我们可以换一个角度理解,原图到目标图需要顺时针旋转45度,那么目标图到原图不就是逆时针旋转45度吗,因此直接取\theta = -45带入原公式计算即可

 

代码如下:

import torch
import cv2
import torch.nn.functional as F
import matplotlib.pyplot as plt


theta = torch.Tensor([[0.707,0.707,0],[-0.707,0.707,0]]).unsqueeze(dim=0)
img = cv2.imread('achor.png',cv2.IMREAD_GRAYSCALE)
plt.subplot(2,1,1)
plt.imshow(img,cmap='gray')
plt.axis('off')
img = torch.Tensor(img).unsqueeze(0).unsqueeze(0)
grid =  F.affine_grid(theta,size=img.shape)
output = F.grid_sample(img,grid)[0].numpy().transpose(1,2,0).squeeze()
plt.subplot(2,1,2)
plt.imshow(output,cmap='gray')
plt.axis('off')
plt.show()

 

结果如下(pytorch中以图像中心点为原点,与一般的左上角为原点不太一样):

pytorch---仿射变换_第1张图片

 

你可能感兴趣的:(pytorch)