PyTorch implementation of SRGAN
156服务器,2080Ti , cuda10.0
conda create -n torch11 python=3.6.9
conda activate torch11
conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch
pip install pillow==5.2.0
pip install opencv-python
pip install scipy
pip install thop
pip install matplotlib
pip install pandas
pip install tqdm
我使用的训练数据:DIV2K
测试数据:点击上图 gitHub 提供的下载链接下载
我的存放位置如下:
这个代码 训练数据只需要 HR
图片即可:
在 train.py
中设置 训练数据集、和 评估集 路径;
python train.py
optional arguments:
--crop_size training images crop size [default value is 88]
--upscale_factor super resolution upscale factor [default value is 4](choices:[2, 4, 8])
--num_epochs train epoch number [default value is 100]
The output val super resolution images are on training_results
directory.
python test_benchmark.py --upscale_factor 2 --model_name netG_epoch_2_10.pth
optional arguments:
--upscale_factor super resolution upscale factor [default value is 4]
--model_name generator model epoch name [default value is netG_epoch_4_100.pth]
The output super resolution images are on benchmark_results directory.
初次训练 会自动下载 pytorch版本的 vgg16 model 用来 计算 loss ,考验网速哈:
CUDA out of memory 报错如下:
RuntimeError: CUDA out of memory. Tried to allocate 1018.00 MiB (GPU 0; 7.79 GiB total capacity; 4.72 GiB already allocated; 853.50 MiB free; 1.52 GiB cached)
test_benchmark.py 测试 ssim 计算 报错,处理方法如下:
156服务器,显卡 2080Ti , cuda10.0 ,8G
训练数据: DIV2K
训练命令: python train.py
train.py 中 参数设置如下:
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
parser.add_argument('--crop_size', default=96, type=int, help='training images crop size')
parser.add_argument('--upscale_factor', default=2, type=int, choices=[2, 4, 8],
help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')
训练时长: 2小时20分钟
测试时长: 57s
效果如下: