使用torch.nn.functional.grid_sample对flownet得到的光流结果进行对齐

flownet2 计算视频中前后两帧的光流信息

    def resample(self, image, flow):    
        '''
        image: 上一帧的图片,torch.Size([1, 3, 256, 256])
        flow: 光流, torch.Size([1, 2, 256, 256])
        final_grid:  torch.Size([1, 2, 256, 256])
        '''
        b, c, h, w = image.size()
        grid = get_grid(b, h, w, gpu_id=flow.get_device())    
        flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1)    
        final_grid = (grid + flow).permute(0, 2, 3, 1).cuda(image.get_device())
        output = torch.nn.functional.grid_sample(image, final_grid, mode='bilinear', padding_mode='border')
        return output

Reference:
1.crop pooling
2.What is the equivalent of torch.nn.functional.grid_sample in Tensorflow / Numpy?

你可能感兴趣的:(使用torch.nn.functional.grid_sample对flownet得到的光流结果进行对齐)