论文地址:Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network | IEEE Conference Publication | IEEE Xplore
或者:[1609.05158] Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network (arxiv.org)
ESPCN是2016年提出的,是一篇经典的超分辨率重建算法文章,虽然它的效果和现在的文章相比不算好,但是它所提出的Efficient Sub-pixel Convolution,也叫亚像素卷积/子像素卷积为后面网络PSNR的提升做出了很大贡献,关键这个Sub-pixel Convolution比插值,反卷积,反池化这些上采样方法计算量要更少,因此网络的运行速度会有很大提升,如下图所示。
那么接下来看看这个Sub-pixel Convolution的结构,正常情况下,卷积操作会使feature map的高和宽变小,但当时,可以让卷积后的feature map的高和宽变大,就实现了分辨率的提升也就是超分辨重建,这个操作叫做sub-pixel convolution。
对于sub-pixel convolution,作者将一个H × W的低分辨率输入图像(Low Resolution)作为输入,低分辨率图像特征提取完毕后,生成n1个特征图,然后经过中间一堆操作等,不管有多少,只要到该上采样的时候,在最后一个卷积调整成就可以通过Sub-pixel操作将其变为rH x rW的高分辨率图像(High Resolution)。但是其实现过程不是直接通过插值等方式产生这个高分辨率图像,而是通过卷积先得到个通道的特征图(特征图大小和输入低分辨率图像一致),然后通过周期筛选(periodic shuffing)的方法得到这个高分辨率的图像,其中r rr为上采样因子(upscaling factor),也就是图像的扩大倍率。
sub-pixel convolution这个操作在pytorch里面提供的有接口,只需要调用就可以了。
输入:
输出:
比如:
sub= nn.PixelShuffle(4)
input = torch.tensor(1, 4**2, 4, 4)
output = sub(input)
torch.Size为[1, 1, 16, 16]
主干网络:
import torch.nn as nn
import torch.nn.init as init
class Net(nn.Module):
def __init__(self, upscale_factor):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, kernel_size=3, stride=1, padding=1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv4.weight)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.pixel_shuffle(x)
return x
训练:
from __future__ import print_function
from math import log10
import torch
import torch.backends.cudnn as cudnn
from SubPixelCNN.model import Net
from progress_bar import progress_bar
class SubPixelTrainer(object):
def __init__(self, config, training_loader, testing_loader):
super(SubPixelTrainer, self).__init__()
self.CUDA = torch.cuda.is_available()
self.device = torch.device('cuda' if self.CUDA else 'cpu')
self.model = None
self.lr = config.lr
self.nEpochs = config.nEpochs
self.criterion = None
self.optimizer = None
self.scheduler = None
self.seed = config.seed
self.upscale_factor = config.upscale_factor
self.training_loader = training_loader
self.testing_loader = testing_loader
def build_model(self):
self.model = Net(upscale_factor=self.upscale_factor).to(self.device)
self.criterion = torch.nn.MSELoss()
torch.manual_seed(self.seed)
if self.CUDA:
torch.cuda.manual_seed(self.seed)
cudnn.benchmark = True
self.criterion.cuda()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[50, 75, 100], gamma=0.5) # lr decay
def save(self):
model_out_path = "model_path.pth"
torch.save(self.model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def train(self):
self.model.train()
train_loss = 0
for batch_num, (data, target) in enumerate(self.training_loader):
data, target = data.to(self.device), target.to(self.device)
self.optimizer.zero_grad()
loss = self.criterion(self.model(data), target)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
progress_bar(batch_num, len(self.training_loader), 'Loss: %.4f' % (train_loss / (batch_num + 1)))
print(" Average Loss: {:.4f}".format(train_loss / len(self.training_loader)))
def test(self):
self.model.eval()
avg_psnr = 0
with torch.no_grad():
for batch_num, (data, target) in enumerate(self.testing_loader):
data, target = data.to(self.device), target.to(self.device)
prediction = self.model(data)
mse = self.criterion(prediction, target)
psnr = 10 * log10(1 / mse.item())
avg_psnr += psnr
progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))
print(" Average PSNR: {:.4f} dB".format(avg_psnr / len(self.testing_loader)))
def run(self):
self.build_model()
for epoch in range(1, self.nEpochs + 1):
print("\n===> Epoch {} starts:".format(epoch))
self.train()
self.test()
self.scheduler.step(epoch)
if epoch == self.nEpochs:
self.save()
import math
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self, num_channels, base_channel, upscale_factor, num_residuals):
super(Net, self).__init__()
self.input_conv = nn.Conv2d(num_channels, base_channel, kernel_size=3, stride=1, padding=1)
resnet_blocks = []
for _ in range(num_residuals):
resnet_blocks.append(ResnetBlock(base_channel, kernel=3, stride=1, padding=1))
self.residual_layers = nn.Sequential(*resnet_blocks)
self.mid_conv = nn.Conv2d(base_channel, base_channel, kernel_size=3, stride=1, padding=1)
upscale = []
for _ in range(int(math.log2(upscale_factor))):
upscale.append(PixelShuffleBlock(base_channel, base_channel, upscale_factor=2))
self.upscale_layers = nn.Sequential(*upscale)
self.output_conv = nn.Conv2d(base_channel, num_channels, kernel_size=3, stride=1, padding=1)
def weight_init(self, mean=0.0, std=0.02):
for m in self._modules:
normal_init(self._modules[m], mean, std)
def forward(self, x):
x = self.input_conv(x)
residual = x
x = self.residual_layers(x)
x = self.mid_conv(x)
x = torch.add(x, residual)
x = self.upscale_layers(x)
x = self.output_conv(x)
return x
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
if m.bias is not None:
m.bias.data.zero_()
class ResnetBlock(nn.Module):
def __init__(self, num_channel, kernel=3, stride=1, padding=1):
super(ResnetBlock, self).__init__()
self.conv1 = nn.Conv2d(num_channel, num_channel, kernel, stride, padding)
self.conv2 = nn.Conv2d(num_channel, num_channel, kernel, stride, padding)
self.bn = nn.BatchNorm2d(num_channel)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
x = self.bn(self.conv1(x))
x = self.activation(x)
x = self.bn(self.conv2(x))
x = torch.add(x, residual)
return x
class PixelShuffleBlock(nn.Module):
def __init__(self, in_channel, out_channel, upscale_factor, kernel=3, stride=1, padding=1):
super(PixelShuffleBlock, self).__init__()
self.conv = nn.Conv2d(in_channel, out_channel * upscale_factor ** 2, kernel, stride, padding)
self.ps = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = self.ps(self.conv(x))
return x
EDSR这个网络后面得单独写一篇。
论文中给出的结果如下表,我实际跑出来比原文要略低,应该是因为训练不到位和训练数据集不一样的原因,论文的训练数据集是Image,我是用BSD300训练的。
从左到右分别是原图,Bicubic,ESPCN,效果还说的过去。