RealBasicVSR 实现超分辨率

文章目录

  • MMEditing 安装
    • 1. 创建并激活 conda 虚拟环境
    • 2. 安装 PyTorch 和 torchvision
    • 3. 安装mmcv-full
    • 4. 克隆 MMEditing 仓库
    • 5. 安装相关依赖和 MMEditing
    • 6. 验证安装
  • 使用RealBasicVSR
    • 克隆RealBasicVSR仓库
    • 下载需要的模型文件
    • 修改源码限制测试的视频长度
    • 测试
  • 效果
    • 原视频:
    • 超分后视频:
    • 效果对比:
    • 图片
    • 速度

MMEditing 安装

没有conda环境的需要先安装下,推荐anaconda。

1. 创建并激活 conda 虚拟环境

conda create -n mmedit python=3.8 -y
conda activate mmedit

2. 安装 PyTorch 和 torchvision

可以参照官网安装适合自己环境的版本

conda install pytorch==1.7.1 torchvision cudatoolkit=10.1 -c pytorch

3. 安装mmcv-full

在这里下载适合自己环境的MMCV

安装命令如下:

pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html

将{mmcv_version}换成要安装的对应版本号即可例如

pip install mmcv-full==1.4.6 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html

4. 克隆 MMEditing 仓库

git clone https://github.com/open-mmlab/mmediting.git
cd mmediting

5. 安装相关依赖和 MMEditing

pip install -r requirements.txt
pip install -v -e .  # or "python setup.py develop"

**注意:**这里有时候会出现不同的包版本冲突的问题,可以打开requirement.txt文件看看总共安装了些什么,将版本调至不冲突即可。可以pip install 包名==版本号或者 conda install 包名=版本号来重新安装冲突的包

6. 验证安装

python
import mmedit
mmedit.__version__

显示出版本号。

使用RealBasicVSR

克隆RealBasicVSR仓库

git clone https://github.com/ckkelvinchan/RealBasicVSR.git
cd RealBasicVSR

下载需要的模型文件

新建checkpoints文件夹

mkdir checkpoints

在这里下载模型并存放到checkpoints文件夹下

修改源码限制测试的视频长度

由于直接传视频会导致GPU的显存不够用,这里取视频的前25帧。我们设置fps=25,则刚好是1秒钟。

     # read images
     file_extension = os.path.splitext(args.input_dir)[1]
     if file_extension in VIDEO_EXTENSIONS:  # input is a video file
         video_reader = mmcv.VideoReader(args.input_dir)
         inputs = []
         i = 0
         for frame in video_reader:
             inputs.append(np.flip(frame, axis=2))
             i = i + 1
             if(i > 25):
                 break
     elif file_extension == '':  # input is a directory
         inputs = []
         input_paths = sorted(glob.glob(f'{args.input_dir}/*'))
         for input_path in input_paths:
             img = mmcv.imread(input_path, channel_order='rgb')
             inputs.append(img)
     else:
         raise ValueError('"input_dir" can only be a video or a directory.')

测试

python inference_realbasicvsr.py configs/realbasicvsr_x4.py checkpoints/RealBasicVSR_x4.pth input/demo_001.mp4 output/demo_001out.mp4 --fps=25

效果

原视频:

demo_001

超分后视频:

demo_001out

效果对比:

对比

图片

原图:
RealBasicVSR 实现超分辨率_第1张图片
超分后:
RealBasicVSR 实现超分辨率_第2张图片

速度

我这里RTX4000显卡,每秒能跑8帧左右

你可能感兴趣的:(人工智能,pytorch,超分辨率,BasicVSR,RealBasicVSR)