论文传送门:https://arxiv.org/pdf/1609.04802.pdf
SRGAN模型目的:输入低分辨率图像,生成高分辨率图像。
生成网络由三部分构成:
①卷积+PReLU激活函数;
②(卷积+BN+PReLU+卷积+BN,连接残差边)x16+卷积+BN,连接残差边;
③(卷积+像素重组+PReLU)x2+卷积;
①②用于提取图像特征,③用于图像上采样,实现超分。
生成网络的目的:输入低分辨率图像,输出高分辨率图像。
鉴别网络类似VGG结构,由(卷积+BN+LeakyReLU)组成。
鉴别网络目的:输入高分辨图像,判断输入图像是真实图像还是生成图像。
class D_Block(nn.Module): # 定义判别器中结构块(卷积+标准化+激活函数)
def __init__(self, in_channel, out_channle, strid): # 初始化方法,参数:输入通道数,输出通道数,卷积步长
super(D_Block, self).__init__() # 继承初始化方法
self.block = nn.Sequential( # 结构块
nn.Conv2d(in_channel, out_channle, 3, strid, 1), # conv
nn.BatchNorm2d(out_channle), # bn
nn.LeakyReLU(0.2) # leakyrelu
)
def forward(self, x): # 前传函数
return self.block(x)
class Discriminator(nn.Module): # 定义判别器
def __init__(self): # 初始化方法
super(Discriminator, self).__init__() # 继承初始化方法
self.conv1 = nn.Conv2d(3, 64, 3, 1, 1) # conv
self.leakyrelu = nn.LeakyReLU(0.2) # leakyrelu
self.downsample = nn.Sequential( # 下采样结构块,与VGG相同
D_Block(64, 64, 2), # 卷积+标准化+激活函数
D_Block(64, 128, 1), # 卷积+标准化+激活函数
D_Block(128, 128, 2), # 卷积+标准化+激活函数
D_Block(128, 256, 1), # 卷积+标准化+激活函数
D_Block(256, 256, 2), # 卷积+标准化+激活函数
D_Block(256, 512, 1), # 卷积+标准化+激活函数
D_Block(512, 512, 2) # 卷积+标准化+激活函数
)
self.linear = nn.Sequential( # 线性映射结构块
nn.AdaptiveAvgPool2d(1), # 平均自适应池化
nn.Conv2d(512, 1024, 1, 1, 0), # conv,使用1x1卷积代替全连接
nn.LeakyReLU(0.2), # leakyrelu
nn.Conv2d(1024, 1, 1, 1, 0), # conv,使用1x1卷积代替全连接
nn.Sigmoid() # sigmoid
)
def forward(self, x): # 前传函数,输入高分辨率图像
x = self.leakyrelu(self.conv1(x)) # conv+leakyrelu,(n,3,256,256)-->(n,64,256,256)
x = self.downsample(
x) # 下采样,(n,64,256,256)-->(n,64,128,128)-->(n,128,128,128)-->(n,128,64,64)-->(n,256,64,64)-->(n,256,32,32)-->(n,512,32,32)-->(n,512,16,16)
x = self.linear(x) # 线性映射,(n,512,16,16)-->(n,512,1,1)-->(n,1024,1,1)-->(n,1,1,1)
x = x.squeeze() # 删除多余的维度,(n,1,1,1)-->(n)
return x # 返回图片真假的得分
class G_Block(nn.Module): # 定义生成器中结构块(残差结构)
def __init__(self, channel): # 初始化方法,参数:通道数,残差结构前后通道数不变
super(G_Block, self).__init__() # 继承初始化方法
self.block = nn.Sequential( # 结构块
nn.Conv2d(channel, channel, 3, 1, 1), # conv
nn.BatchNorm2d(channel), # bn
nn.PReLU(channel), # prelu,带参数的relu激活函数
nn.Conv2d(channel, channel, 3, 1, 1), # conv
nn.BatchNorm2d(channel) # bn
)
def forward(self, x): # 前传函数
return x + self.block(x) # F(x) + x
class Generator(nn.Module): # 定义生成器
def __init__(self): # 初始化方法
super(Generator, self).__init__() # 继承初始化方法
self.conv1 = nn.Conv2d(3, 64, 9, 1, 4) # conv
self.prelu1 = nn.PReLU(64) # prelu
self.blocks = [] # 存放残差块的列表
for _ in range(16): # 共16个残差块
self.blocks.append(G_Block(64)) # 添加残差块
self.blocks = nn.Sequential(*self.blocks) # 列表转化为模型结构序列
self.conv2 = nn.Conv2d(64, 64, 3, 1, 1) # conv
self.bn2 = nn.BatchNorm2d(64) # bn
self.upsample = nn.Sequential( # 上采样块
nn.Conv2d(64, 256, 3, 1, 1), # conv
nn.PixelShuffle(2), # pixelshuffle,像素重组,将通道拆分重组至(H,W)
nn.PReLU(64), # prelu
nn.Conv2d(64, 256, 3, 1, 1), # conv
nn.PixelShuffle(2), # pixelshuffle
nn.PReLU(64), # prelu
nn.Conv2d(64, 3, 9, 1, 4) # conv
)
def forward(self, x): # 前传函数,输入低分辨率图像
x = self.prelu1(self.conv1(x)) # conv+prelu,(n,3,64,64)-->(n,64,64,64)
x += self.bn2(self.conv2(self.blocks(x))) # F(x)+x,F(x):16层残差结构+conv+bn,(n,64,64,64)-->(n,64,64,64)
x = self.upsample(
x) # 上采样,(n,64,64,64)-->(n,256,64,64)-->(n,64,128,128)-->(n,256,128,128)-->(n,64,256,256)-->(n,3,256,256)
return x # 返回高分辨率图像