grid的shape为[B,H,W,2],首先不用去考虑batch,那就是[H,W,2],这个2存储的是input的坐标值,取值为-1~1。
实际上grid就是一个map,这个map告诉你output的某个点 ( x o , y o ) (x_o,y_o) (xo,yo)来自于input的某个点 ( x i , y i ) (x_i,y_i) (xi,yi)
生成一个grid的代码如下
B, C, H, W = x.size()
# mesh grid
xx = torch.arange(0, W).view(1,-1).repeat(H,1)
yy = torch.arange(0, H).view(-1,1).repeat(1,W)
xx = xx.view(1,1,H,W).repeat(B,1,1,1)
yy = yy.view(1,1,H,W).repeat(B,1,1,1)
grid = torch.cat((xx,yy),1).float()
vgrid = grid
vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone()/max(W-1,1)-1.0
#取出光流v这个维度,原来范围是0~W-1,再除以W-1,范围是0~1,再乘以2,范围是0~2,再-1,范围是-1~1
vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone()/max(H-1,1)-1.0 #取出光流u这个维度,同上
vgrid = vgrid.permute(0,2,3,1)
举个简单的例子
比如 grid的shape是 224 × 224 × 2 224\times224\times2 224×224×2,input要从 200 × 200 200\times200 200×200插值到 224 × 224 224\times224 224×224
grid[0,0,:]=[-1,-1],那就表示output在坐标(0,0)位置点的值来源于input坐标为(-1,-1)的点(左上角)
grid[223,223,:]=[1,1],那就表示output在坐标(223,223)位置点的值来源于input坐标为(1,1)的点(右下角)
得到input和output的对应关系后,即可选择插值方式,双线性插值or最邻近插值
torch.nn.functional.grid_sample(input,
grid,
mode='bilinear',
padding_mode='zeros',
align_corners=None)
光流法的亮度恒定假设为同一目标在不同帧间运动时,其亮度不会发生改变,因此可以通过找到亮度不变的点来对相邻帧图片的点进行对应,所谓warp操作就是在原来的对应关系上加上偏移量使得亮度不变的点得到对应。
比如原来[112,112]的点插值到[224,224]的点,用的是grid,现在我得到了一个光流flow,我就可以进行warp操作,然后再进行grid_sample
grid = grid + flow
output = nn.functional.grid_sample(x, vgrid,align_corners=True)