官方链接:
https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html#torch.nn.functional.grid_sample
https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html#torch-nn-functional-grid-sample
更多相关请看我的另一篇文章。
在PyTorch中,流场(flow field)通常用来表示图像中像素的运动或位移信息。它是一个二维矢量场,每个像素都对应一个二维矢量,表示该像素从一个图像到另一个图像的位移。流场通常用于计算光流(optical flow)等计算机视觉任务,用于追踪物体的运动、分析视频序列等。
简单来说,就是在把图像a上的点a_{i, j} 变换到图像b的点b_{x, y}上。
它是一个4维张量,形状为(N, Hout, Wout, 2)。其中N表示批次大小,Hout和Wout表示输出的高度和宽度,2表示每个像素在新图像上的(x, y)坐标。
TORCH.MESHGRID 生成grid
https://pytorch.org/docs/stable/generated/torch.meshgrid.html
MAKE_GRID 用于一次显示多张图
https://pytorch.org/vision/stable/generated/torchvision.utils.make_grid.html
**affine_grid用于仿射变换**
https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html#torch.nn.functional.affine_grid
For each output location
output[n, :, h, w]
, the size-2 vectorgrid[n, h, w]
specifiesinput
pixel locationsx
andy
, which are used to interpolate the output valueoutput[n, :, h, w]
.
根据输入值和映射网格(flow-field grid)计算输出。它主要用于在图像处理和计算机视觉任务中,根据给定的网格对输入数据进行采样和插值。
提供一个input的Tensor以及一个对应的flow-field网格(比如光流,体素流等),然后根据grid中每个位置提供的坐标信息(这里指input中pixel的坐标),将input中对应位置的像素值填充到grid指定的位置,得到最终的输出。
grid_sample底层是应用双线性插值,把输入的tensor转换为指定大小。那它和interpolate有啥区别呢?
interpolate是规则采样(uniform),但是grid_sample的转换方式,内部采点的方式并不是规则的,是一种更为灵活的方式。
torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)
align_corners
: 一个可选参数,通常为None。如果为True,则网格中的像素坐标(-1, -1)和(1, 1)将对准输入图像的四个角。如果为False或None,则(-1, -1)对应于输入图像的左上角,(1, 1)对应于右下角。
例子中,我们将一个大小为4x4的tensor 转换为了一个20x20的。grid的大小指定了输出大小,每个grid的位置是一个(x,y)坐标,其值来自于:输入input的(x,y)中 的四邻域插值得到的。
import torch
from torch.nn import functional as F
inp = torch.ones(1, 1, 4, 4)
# 目的是得到一个 长宽为20的tensor
out_h = 20
out_w = 20
# grid的生成方式等价于用mesh_grid
new_h = torch.linspace(-1, 1, out_h).view(-1, 1).repeat(1, out_w)
new_w = torch.linspace(-1, 1, out_w).repeat(out_h, 1)
grid = torch.cat((new_h.unsqueeze(2), new_w.unsqueeze(2)), dim=2)
grid = grid.unsqueeze(0)
outp = F.grid_sample(inp, grid=grid, mode='bilinear')
print(outp.shape) #torch.Size([1, 1, 20, 20])
图片来自于SFnet(eccv2020)。flow field是grid, low_resolution是input, high resolution是output。
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
def visualize_affine_transformation(input_image, affine_matrix, padding_mode='zeros'):
# 使用 grid_sample 进行仿射变换
output_image = F.grid_sample(input_image, affine_matrix.unsqueeze(0).expand(input_image.size(0), -1, -1), padding_mode=padding_mode)
# 可视化输入和输出图像
input_image = input_image.squeeze().numpy()
output_image = output_image.squeeze().numpy()
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(input_image, cmap='gray')
plt.title('Input Image')
plt.subplot(1, 2, 2)
plt.imshow(output_image, cmap='gray')
plt.title('Output Image after Affine Transformation')
plt.show()
# 创建一个示例输入图像(单通道)
input_image = torch.zeros(1, 1, 5, 5)
input_image[0, 0, :, 2] = 1 # 在中心放置一个白色像素
# 定义不同的仿射变换矩阵和 padding_mode
affine_matrix1 = torch.tensor([[1, 0, 2], [0, 1, 2], [0, 0, 1]], dtype=torch.float32)
affine_matrix2 = torch.tensor([[0.5, 0, 0], [0, 2, 0], [0, 0, 1]], dtype=torch.float32)
padding_mode1 = 'zeros'
padding_mode2 = 'border'
# 调用可视化函数并尝试不同的参数组合
visualize_affine_transformation(input_image, affine_matrix1, padding_mode1)
visualize_affine_transformation(input_image, affine_matrix2, padding_mode1)
visualize_affine_transformation(input_image, affine_matrix1, padding_mode2)
visualize_affine_transformation(input_image, affine_matrix2, padding_mode2)
AFFINE_GRID用于生成仿射变换所需的矩阵。也就是映射所需的流场。
Generates a 2D or 3D flow field (sampling grid), given a batch of affine matrices
theta
.
其中的具体参数
torch.nn.functional.affine_grid(theta, size, align_corners=None)
theta
:一个4x2的张量,表示仿射变换的参数矩阵。这个矩阵通常由用户指定,它包含了仿射变换的缩放、旋转、平移和错切等信息。矩阵的形状应为 (N, 2, 3)
,其中 N
是批次大小。通常情况下,你可以使用PyTorch的 torch.tensor
创建这个矩阵。
size
:一个包含两个整数的元组 (H, W)
,指定生成的仿射变换网格的大小。H
表示输出的高度,W
表示输出的宽度。
align_corners
:一个布尔值或None,通常影响网格生成的坐标点的精确位置。
当它为True时,生成的网格的坐标点会与输入图像的四个角对齐,这意味着生成的网格将确切地覆盖输入图像的所有四个角。这种情况下,坐标点的精确位置与输入图像的四个角相吻合,对于某些精确的几何变换可能更合适。
如果为False或None(默认值),则生成的网格的坐标点会与输入图像的左上角对齐。这种情况下,坐标点的精确位置可能会在输入图像的像素之间,对于一般的仿射变换通常更常见。
grid_sample()函数及双线性采样