超分之SRGAN原文解读链接
import argparse # 用于解析命令行参数,主要有四个步骤,这是步骤一:首先导入该模块
import os
from math import log10
import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import pytorch_ssim
from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
from loss import GeneratorLoss
from model import Generator, Discriminator
# 用于解析命令行参数,主要有四个步骤,这是步骤二:然后创建一个解析对象
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
# 用于解析命令行参数,主要有四个步骤,这是步骤三:然后向该对象中添加要关注的命令行参数和选项,每一个add_argument方法对应一个要关注的参数或选项
parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],
help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=2, type=int, help='train epoch number')
if __name__ == '__main__':
# 用于解析命令行参数,主要有四个步骤,这是步骤四:最后调用parse_args()方法进行解析;解析成功之后即可使用。
opt = parser.parse_args()
CROP_SIZE = opt.crop_size
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
# 实例化创建的训练数据集
train_set = TrainDatasetFromFolder('E:\\Datasets\\SR\\DIV2K\\DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder('E:\\Datasets\\SR\\DIV2K\\DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=16, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
# 定义网络模型
netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
# 定义生成器损失函数
generator_criterion = GeneratorLoss()
# 模型、损失函数放在GPU
if torch.cuda.is_available():
netG.cuda()
netD.cuda()
generator_criterion.cuda()
# 定义优化器
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())
# d_loss: 判别器损失
# g_loss: 生成器损失
# d_score: 判别器得分
# g_scpre: 生成器得分
# psnr: 峰值信噪比
# ssim: 结构相似性
results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
# 训练
for epoch in range(1, NUM_EPOCHS + 1):
# 训练集的dataloader进度条显示
train_bar = tqdm(train_loader)
running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
# 训练模型
netG.train()
netD.train()
for data, target in train_bar:
g_update_first = True
batch_size = data.size(0)
running_results['batch_sizes'] += batch_size
############################
# (1) Update D network: maximize D(x)-1-D(G(z))
###########################
# HR图像
real_img = Variable(target)
if torch.cuda.is_available():
real_img = real_img.cuda()
# LR图像
z = Variable(data)
if torch.cuda.is_available():
z = z.cuda()
# 前向传播:生成SR图像
fake_img = netG(z)
# 判别器梯度清空
netD.zero_grad()
# 判别器判别HR图像的概率
real_out = netD(real_img).mean()
# 判别器前向传播:判别器判断SR图像的概率
fake_out = netD(fake_img).mean()
# 计算判别器损失:1 - HR + SR
# 判别器损失 ---> 1: 判别能力强(1 - 0.9 + 0.9 = 0.9)
# 判别器损失 ---> 0: 判别能力弱(1 - 0.9 + 0.2 = 0.3)
d_loss = 1 - real_out + fake_out
# 反向传播
d_loss.backward(retain_graph=True)
# 判别器梯度更新
optimizerD.step()
############################
# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
###########################
# 生成器梯度清零
netG.zero_grad()
# The two lines below are added to prevent runetime error in Google Colab
# 生成器器前向传播:生成SR图像(前面已经有了,)
fake_img = netG(z)
# 判别器前向传播, 计算SR的概率(前面已经有了)
fake_out = netD(fake_img).mean()
# 计算生成器损失:图像损失 + 0.001*对抗损失 + 0.006*感知损失 +2*(10^-8)
g_loss = generator_criterion(fake_out, fake_img, real_img)
# 反向传播
g_loss.backward()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
# 生成器梯度更新
optimizerG.step()
# loss for current batch before optimization
running_results['g_loss'] += g_loss.item() * batch_size
running_results['d_loss'] += d_loss.item() * batch_size
running_results['d_score'] += real_out.item() * batch_size
running_results['g_score'] += fake_out.item() * batch_size
# 更新并显示训练过程中的进度条描述信息
train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
running_results['g_loss'] / running_results['batch_sizes'],
running_results['d_score'] / running_results['batch_sizes'],
running_results['g_score'] / running_results['batch_sizes']))
# 验证模型
netG.eval()
out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
if not os.path.exists(out_path):
os.makedirs(out_path)
with torch.no_grad():
val_bar = tqdm(val_loader)
valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
val_images = []
for val_lr, val_hr_restore, val_hr in val_bar:
batch_size = val_lr.size(0)
valing_results['batch_sizes'] += batch_size
lr = val_lr
hr = val_hr
if torch.cuda.is_available():
lr = lr.cuda()
hr = hr.cuda()
sr = netG(lr)
batch_mse = ((sr - hr) ** 2).data.mean()
valing_results['mse'] += batch_mse * batch_size
batch_ssim = pytorch_ssim.ssim(sr, hr).item()
valing_results['ssims'] += batch_ssim * batch_size # 总的SSIM
valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes'] # 每个batch的ssim
# 进度条显示
val_bar.set_description(
desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
valing_results['psnr'], valing_results['ssim']))
# 保存验证图像
val_images.extend(
[display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
display_transform()(sr.data.cpu().squeeze(0))])
# torch.stack(): 在维度上连接(concatenate)若干个张量。(这些张量形状相同)(默认dim=0)
val_images = torch.stack(val_images)
# torch.chunk(): 将数组拆分为特定数量的块
val_images = torch.chunk(val_images, val_images.size(0) // 15)
# 进度条显示:训练结果
val_save_bar = tqdm(val_images, desc='[saving training results]')
index = 1
for image in val_save_bar:
# 网格化显示数据: 3行
image = utils.make_grid(image, nrow=3, padding=5)
utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
index += 1
# save model parameters
torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
# save loss\scores\psnr\ssim
results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
results['psnr'].append(valing_results['psnr'])
results['ssim'].append(valing_results['ssim'])
if epoch % 10 == 0 and epoch != 0:
out_path = 'statistics/'
data_frame = pd.DataFrame(
data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
index=range(1, epoch + 1))
data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
from os import listdir
from os.path import join
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize, InterpolationMode # 根据提示使用 use InterpolationMode enum
def is_image_file(filename):
"""用于判断filename是否是png、jpg、jpeg等格式"""
# any函数用于检查生成器表达式的结果序列,如果其中任何一个结果为True(文件名以任何一个图像文件扩展名结尾),则any函数返回True,否则返回False。
# 用endswith()判断字符串是否以指定字符串结尾
return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
def calculate_valid_crop_size(crop_size, upscale_factor):
"""将图片剪裁成缩放因子的整数倍"""
# crop_size=25, upscale_factor=4
# return 256 - (256 % 4) = 256
# return 255 - (255 % 4) = 252
return crop_size - (crop_size % upscale_factor)
def train_hr_transform(crop_size):
return Compose([
RandomCrop(crop_size),
ToTensor(),
])
def train_lr_transform(crop_size, upscale_factor):
return Compose([
ToPILImage(),
Resize(crop_size // upscale_factor, interpolation=InterpolationMode.BICUBIC), # 把Image.BICUBIC改成InterpolationMode.BICUBIC
ToTensor()
])
def display_transform():
return Compose([
ToPILImage(),
Resize(400),
CenterCrop(400),
ToTensor()
])
# 构建自己的训练数据集
class TrainDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, crop_size, upscale_factor):
super(TrainDatasetFromFolder, self).__init__()
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
crop_size = calculate_valid_crop_size(crop_size, upscale_factor) # 将图片剪裁成缩放因子的整数倍
self.hr_transform = train_hr_transform(crop_size)
self.lr_transform = train_lr_transform(crop_size, upscale_factor) # 使用双三次插值下采样得到LR
def __getitem__(self, index):
hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
lr_image = self.lr_transform(hr_image)
return lr_image, hr_image
def __len__(self):
return len(self.image_filenames)
class ValDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(ValDatasetFromFolder, self).__init__()
self.upscale_factor = upscale_factor
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
def __getitem__(self, index):
hr_image = Image.open(self.image_filenames[index])
w, h = hr_image.size
crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor) # 将图片的最小边长,剪裁成缩放因子的整数倍
lr_scale = Resize(crop_size // self.upscale_factor, interpolation=InterpolationMode.BICUBIC) # 把Image.BICUBIC改成InterpolationMode.BICUBIC
hr_scale = Resize(crop_size, interpolation=InterpolationMode.BICUBIC) # 把Image.BICUBIC改成InterpolationMode.BICUBIC
hr_image = CenterCrop(crop_size)(hr_image)
lr_image = lr_scale(hr_image)
hr_restore_img = hr_scale(lr_image)
return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
def __len__(self):
return len(self.image_filenames)
class TestDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(TestDatasetFromFolder, self).__init__()
self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
self.upscale_factor = upscale_factor
self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]
def __getitem__(self, index):
image_name = self.lr_filenames[index].split('/')[-1]
lr_image = Image.open(self.lr_filenames[index])
w, h = lr_image.size
hr_image = Image.open(self.hr_filenames[index])
hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=InterpolationMode.BICUBIC) # 把Image.BICUBIC改成InterpolationMode.BICUBIC
hr_restore_img = hr_scale(lr_image)
return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
def __len__(self):
return len(self.lr_filenames)
import math
import torch
from torch import nn
class Generator(nn.Module):
def __init__(self, scale_factor):
# 进行上采样时用多少个上采样块
# 如果scale_factor = 4, 则upsample_block_num=2,如果缩放因子为4, 则需要2个上采样块
upsample_block_num = int(math.log(scale_factor, 2))
super(Generator, self).__init__()
# 浅层特征提取层
self.block1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=9, padding=4),
nn.PReLU()
)
# 深层特征提取层
self.block2 = ResidualBlock(64)
self.block3 = ResidualBlock(64)
self.block4 = ResidualBlock(64)
self.block5 = ResidualBlock(64)
self.block6 = ResidualBlock(64)
self.block7 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64)
)
# 上采样层
block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
self.block8 = nn.Sequential(*block8)
def forward(self, x):
block1 = self.block1(x)
block2 = self.block2(block1)
block3 = self.block3(block2)
block4 = self.block4(block3)
block5 = self.block5(block4)
block6 = self.block6(block5)
block7 = self.block7(block6)
block8 = self.block8(block1 + block7)
return (torch.tanh(block8) + 1) / 2
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(512, 1024, kernel_size=1),
nn.LeakyReLU(0.2),
nn.Conv2d(1024, 1, kernel_size=1)
)
def forward(self, x):
batch_size = x.size(0)
return torch.sigmoid(self.net(x).view(batch_size))
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
return x + residual
class UpsampleBLock(nn.Module):
"""上采样块设计"""
def __init__(self, in_channels, up_scale):
super(UpsampleBLock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
self.pixel_shuffle = nn.PixelShuffle(up_scale)
self.prelu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.prelu(x)
return x
import torch
from torch import nn
from torchvision.models.vgg import vgg16
class GeneratorLoss(nn.Module):
def __init__(self):
super(GeneratorLoss, self).__init__()
vgg = vgg16(pretrained=True)
# 使用VGG的前31层作为损失网络
loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
for param in loss_network.parameters():
param.requires_grad = False
self.loss_network = loss_network
self.mse_loss = nn.MSELoss()
self.tv_loss = TVLoss()
def forward(self, out_labels, out_images, target_images):
# Adversarial Loss:对抗损失 1 - 目标概率
adversarial_loss = torch.mean(1 - out_labels)
# Perception Loss:感知损失 MSE(VGG(HR), VGG(SR))
perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
# Image Loss: 图像损失 MSE(HR, SR)
image_loss = self.mse_loss(out_images, target_images)
# TV Loss: 内容损失 1/(r^2WH)(HR - SR)^2
tv_loss = self.tv_loss(out_images)
# 生成器总损失 = 图像损失 + 0.001*对抗损失 + 0.006*感知损失 +2*(10^-8)
return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
super(TVLoss, self).__init__()
self.tv_loss_weight = tv_loss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self.tensor_size(x[:, :, 1:, :]) # 获取x垂直方向的元素个数
count_w = self.tensor_size(x[:, :, :, 1:])
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() # 计算输入张量 x 在第二个维度上相邻元素之间差异的平方和
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
@staticmethod
def tensor_size(t):
return t.size()[1] * t.size()[2] * t.size()[3]
if __name__ == "__main__":
g_loss = GeneratorLoss()
print(g_loss)
from math import exp
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def gaussian(window_size, sigma):
"""生成一维高斯滤波函数"""
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()
def create_window(window_size, channel):
"""创建二维窗口"""
# 生成一个一维的高斯滤波器
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
# 将 _1D_window 与其转置相乘,生成一个二维的高斯滤波器
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
# 将 _2D_window 在第一个维度上进行扩展,以适应输入数据的通道数
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
"""结构相似度:用于比较两幅图像的相似度"""
# 均值
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
# 平方
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
# 方差图像:计算输入图像的平方与均值图像平方的差异
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
# 协方差图像:计算输入图像乘积与均值图像乘积的差异
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
# ssim_map = ((2 * a*b + c1) * (2 * 协方差 + C2)) / (a^2 + b^2 + C1) * (a方差 + b方差 + C2 )
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
# 如果图像的通道数与保存的通道数相同,并且窗口数据类型与图像的数据类型相同,就直接使用保存的窗口;
# 否则,重新创建窗口,并根据图像是否在 GPU 上进行相应的处理。
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
import argparse
import time
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from model import Generator
parser = argparse.ArgumentParser(description='Test Single Image')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--test_mode', default='GPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU')
parser.add_argument('--image_name', default='SUT1.jpg', type=str, help='test low resolution image name')
parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
opt = parser.parse_args()
UPSCALE_FACTOR = opt.upscale_factor
TEST_MODE = True if opt.test_mode == 'GPU' else False
IMAGE_NAME = opt.image_name
IMAGE_PATH = 'test_photo/'
MODEL_NAME = opt.model_name
model = Generator(UPSCALE_FACTOR).eval()
if TEST_MODE:
model.cuda()
model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
else:
model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
image = Image.open(IMAGE_PATH + IMAGE_NAME)
image = Variable(ToTensor()(image)).unsqueeze(0)
print(image.shape)
if TEST_MODE:
image = image.cuda()
start = time.process_time()
out = model(image)
elapsed = (time.process_time() - start)
print('cost ' + str(elapsed) + ' s')
out_img = ToPILImage()(out[0].data.cpu())
out_img.save('test_photo/out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)
import argparse
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from tqdm import tqdm
from model import Generator
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test Single Video')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--video_name', type=str, help='test low resolution video name')
parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
opt = parser.parse_args()
UPSCALE_FACTOR = opt.upscale_factor
VIDEO_NAME = opt.video_name
MODEL_NAME = opt.model_name
model = Generator(UPSCALE_FACTOR).eval()
if torch.cuda.is_available():
model = model.cuda()
# for cpu
# model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
videoCapture = cv2.VideoCapture(VIDEO_NAME)
fps = videoCapture.get(cv2.CAP_PROP_FPS)
frame_numbers = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
sr_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR),
int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR)
compared_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10),
int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR + 10 + int(
int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10) / int(
10 * int(int(
videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 + 1)) * int(
int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 - 9)))
output_sr_name = 'out_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi'
output_compared_name = 'compare_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi'
sr_video_writer = cv2.VideoWriter(output_sr_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps, sr_video_size)
compared_video_writer = cv2.VideoWriter(output_compared_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps,
compared_video_size)
# read frame
success, frame = videoCapture.read()
test_bar = tqdm(range(int(frame_numbers)), desc='[processing video and saving result videos]')
for index in test_bar:
if success:
image = Variable(ToTensor()(frame), volatile=True).unsqueeze(0)
if torch.cuda.is_available():
image = image.cuda()
out = model(image)
out = out.cpu()
out_img = out.data[0].numpy()
out_img *= 255.0
out_img = (np.uint8(out_img)).transpose((1, 2, 0))
# save sr video
sr_video_writer.write(out_img)
# make compared video and crop shot of left top\right top\center\left bottom\right bottom
out_img = ToPILImage()(out_img)
crop_out_imgs = transforms.FiveCrop(size=out_img.width // 5 - 9)(out_img)
crop_out_imgs = [np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(img)) for img in crop_out_imgs]
out_img = transforms.Pad(padding=(5, 0, 0, 5))(out_img)
compared_img = transforms.Resize(size=(sr_video_size[1], sr_video_size[0]), interpolation=Image.BICUBIC)(
ToPILImage()(frame))
crop_compared_imgs = transforms.FiveCrop(size=compared_img.width // 5 - 9)(compared_img)
crop_compared_imgs = [np.asarray(transforms.Pad(padding=(0, 5, 10, 0))(img)) for img in crop_compared_imgs]
compared_img = transforms.Pad(padding=(0, 0, 5, 5))(compared_img)
# concatenate all the pictures to one single picture
top_image = np.concatenate((np.asarray(compared_img), np.asarray(out_img)), axis=1)
bottom_image = np.concatenate(crop_compared_imgs + crop_out_imgs, axis=1)
bottom_image = np.asarray(transforms.Resize(
size=(int(top_image.shape[1] / bottom_image.shape[1] * bottom_image.shape[0]), top_image.shape[1]))(
ToPILImage()(bottom_image)))
final_image = np.concatenate((top_image, bottom_image))
# save compared video
compared_video_writer.write(final_image)
# next frame
success, frame = videoCapture.read()