光流方法Flownet的简单调用

文章目录

  • 1 概述
  • 2 代码下载
  • 3 数据下载
  • 4 预训练模型下载
  • 5 代码讲解
  • 6 输出示意

1 概述

如果是自己训练,30G的FLyingChairs数据集还是很吃设备,这里只介绍如何使用该算法。

TIps:假设已经安装好了所有库。

2 代码下载

Torch: https://github.com/ClementPinard/FlowNetPytorch

这里主要使用的是run_inference.py文件:
光流方法Flownet的简单调用_第1张图片

3 数据下载

链接https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs

点进去是这样的:
光流方法Flownet的简单调用_第2张图片
下载这个看网速了,这里我放几个测试文件,方便大家调用:
https://download.csdn.net/download/weixin_44575152/86863644

4 预训练模型下载

推荐下载pytorch的,不然torch转有点麻烦:
https://drive.google.com/drive/folders/1dTpSyc7rIYYG19p1uiDfilcsmSPNy-_3
光流方法Flownet的简单调用_第3张图片

5 代码讲解

如果只想直接用,按照以上步骤运行run_inference.py即可,否则可以阅读以下带有注释的代码。需要修改一下parser.add_argument中的存储位置:

  1. 数据集位置:–data
  2. 预训练模型位置:–pretrained
  3. 生成图像存储位置:–output
    尽量不要把output放在data里哇~
import argparse
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import models
import torchvision.transforms as transforms
import flow_transforms
import numpy as np

from path import Path
from tqdm import tqdm
from imageio import imread, imwrite
from util import flow2rgb
import warnings
warnings.filterwarnings("ignore")


model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__"))

parser = argparse.ArgumentParser(description='PyTorch FlowNet inference on a folder of img pairs',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# 图像存储路径
parser.add_argument('--data', metavar='DIR', default=r"D:\Data\Flow\FlyingChairs",
                    help='path to images folder, image names must match \'[name]0.[ext]\' and \'[name]1.[ext]\'')
# 预训练模型
parser.add_argument('--pretrained', metavar='PTH', default=r"D:\Data\Flow\FlowNet\model\flownets.pth.tar", help='path to pre-trained model')
# 文件存储位置
parser.add_argument('--output', '-o', metavar='DIR', default=r"D:\Data\Flow\FlowNet\data",
                    help='path to output folder. If not set, will be created in data folder')
# 存储值的类型
parser.add_argument('--output-value', '-v', choices=['raw', 'vis', 'both'], default='both',
                    help='which value to output, between raw input (as a npy file) and color vizualisation (as an image file).'
                         ' If not set, will output both')
#
parser.add_argument('--div-flow', default=20, type=float,
                    help='value by which flow will be divided. overwritten if stored in pretrained file')
# 图像类型
parser.add_argument("--img-exts", metavar='EXT', default=['png', 'jpg', 'bmp', 'ppm'], nargs='*', type=str,
                    help="images extensions to glob")
# 最大流值
parser.add_argument('--max_flow', default=None, type=float,
                    help='max flow value. Flow map color is saturated above this value. If not set, will use flow map\'s max value')
# 未设置输出原始输入,即4次下采样;如果选择,则输出指定上采样下的完整分辨率流图
parser.add_argument('--upsampling', '-u', choices=['nearest', 'bilinear'], default=None,
                    help='if not set, will output FlowNet raw input,'
                         'which is 4 times downsampled. If set, will output full resolution flow map, with selected upsampling')
# 设置,则输出反转流和常规流
parser.add_argument('--bidirectional', action='store_true',
                    help='if set, will output invert flow (from 1 to 0) along with regular flow')

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


@torch.no_grad()
def main():
    global args, save_path
    args = parser.parse_args()

    # 输出方式
    output_string = ""
    if args.output_value == 'both':
        output_string = "raw output and RGB visualization"
    elif args.output_value == 'raw':
        output_string = "raw output"
    elif args.output_value == 'vis':
        output_string = "RGB visualization"
    print("=> will save " + output_string)
    data_dir = Path(args.data)
    print("=> fetching img pairs in '{}'".format(args.data))
    if args.output is None:
        save_path = data_dir / 'flow'
    else:
        save_path = Path(args.output)
    print('=> will save everything to {}'.format(save_path))
    save_path.makedirs_p()
    # Data loading code
    input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
        transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
    ])

    img_pairs = []
    for ext in args.img_exts:
        # 读取与当前格式匹配的图像,后缀为1.ppm
        test_files = data_dir.files('*1.{}'.format(ext))
        for file in test_files:
            # 单个图像,后缀为2.ppm
            img_pair = file.parent / (file.stem[:-1] + '2.{}'.format(ext))
            if img_pair.isfile():
                # 存储图像对
                img_pairs.append([file, img_pair])

    print('{} samples found'.format(len(img_pairs)))
    # create model
    network_data = torch.load(args.pretrained, map_location=torch.device('cpu'))
    print("=> using pre-trained model '{}'".format(network_data['arch']))
    # 读取模型
    model = models.__dict__[network_data['arch']](network_data).to(device)
    model.eval()
    cudnn.benchmark = True
    if 'div_flow' in network_data.keys():
        args.div_flow = network_data['div_flow']

    # 遍历图像对
    for (img1_file, img2_file) in tqdm(img_pairs):
        # 以下均以飞行椅子为例
        # (3, 384, 512)
        img1 = input_transform(imread(img1_file))
        img2 = input_transform(imread(img2_file))
        # (1, 6, 384, 515)
        input_var = torch.cat([img1, img2]).unsqueeze(0)

        if args.bidirectional:
            # feed inverted pair along with normal pair
            inverted_input_var = torch.cat([img2, img1]).unsqueeze(0)
            input_var = torch.cat([input_var, inverted_input_var])

        input_var = input_var.to(device)
        # compute output
        output = model(input_var)
        if args.upsampling is not None:
            # 采样
            output = F.interpolate(output, size=img1.size()[-2:], mode=args.upsampling, align_corners=False)
        for suffix, flow_output in zip(['flow', 'inv_flow'], output):
            filename = save_path / '{}{}'.format(img1_file.stem[:-1], suffix)
            if args.output_value in ['vis', 'both']:
                rgb_flow = flow2rgb(args.div_flow * flow_output, max_value=args.max_flow)
                to_save = (rgb_flow * 255).astype(np.uint8).transpose(1, 2, 0)
                imwrite(filename + '.png', to_save)
            # if args.output_value in ['raw', 'both']:
            #     # Make the flow map a HxWx2 array as in .flo files
            #     to_save = (args.div_flow * flow_output).cpu().numpy().transpose(1, 2, 0)
            #     np.save(filename + '.npy', to_save)
        break


if __name__ == '__main__':
    main()

6 输出示意

你可能感兴趣的:(#,深度学习,编程实战之Python,深度学习,pytorch,flownet,光流)