# 输入图像shape默认为(3,256,256)
class Discriminator(nn.Module): # 定义判别器
def __init__(self): # 初始化方法
super(Discriminator, self).__init__() # 继承初始化方法
self.conv1 = nn.Conv2d(3, 64, 4, 2, 1) # conv
self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # conv
self.isn2 = nn.InstanceNorm2d(128) # instancenorm,实例标准化,在图像风格转化任务中,生成图像依赖于某个图像的实例,所以batchnorm并不适用于风格转化任务
self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # conv
self.isn3 = nn.InstanceNorm2d(256) # in
self.conv4 = nn.Conv2d(256, 512, 4, 2, 1) # conv
self.isn4 = nn.InstanceNorm2d(512) # in
self.conv5 = nn.Conv2d(512, 1, 3, 1, 1) # conv
self.leakyrelu = nn.LeakyReLU(0.2) # leakyrelu
self.sigmoid = nn.Sigmoid() # sigmoid
def forward(self, x): # 前传函数
x = self.conv1(x) # conv,(n,3,256,256)-->(n,64,128,128)
x = self.leakyrelu(x) # leakyrelu
x = self.conv2(x) # conv,(n,64,128,128)-->(n,128,64,64)
x = self.isn2(x) # in
x = self.leakyrelu(x) # leakyrelu
x = self.conv3(x) # conv,(n,128,64,64)-->(n,256,32,32)
x = self.isn3(x) # in
x = self.leakyrelu(x) # leakyrelu
x = self.conv4(x) # conv,(n,256,32,32)-->(n,512,16,16)
x = self.isn4(x) # in
x = self.leakyrelu(x) # leakyrelu
x = self.conv5(x) # conv,(n,512,16,16)-->(n,1,16,16)
x = self.sigmoid(x) # sigmoid,输入映射至(0,1)
return x # 返回图像真假的得分(16,16),相当于对16x16个区域的真假进行评分,而非对整体图片的真假进行评分
class IdentityBlock(nn.Module): # 定义残差块
def __init__(self): # 初始化方法
super(IdentityBlock, self).__init__() # 继承初始化方法
self.conv1 = nn.Conv2d(256, 256, 3, 1, 1) # conv
self.isn1 = nn.InstanceNorm2d(256) # in
self.conv2 = nn.Conv2d(256, 256, 3, 1, 1) # conv
self.isn2 = nn.InstanceNorm2d(256) # in
self.relu = nn.ReLU() # relu
def forward(self, x): # 前传函数
y = self.conv1(x) # conv,(n,256,64,64)-->(n,256,64,64)
y = self.isn1(y) # in
y = self.relu(y) # relu
y = self.conv2(y) # conv,(n,256,64,64)-->(n,256,64,64)
y = self.isn2(y) # in
y += x # F(x) + x,(n,256,64,64)+(n,256,64,64)-->(n,256,64,64)
y = self.relu(y) # relu
return y
class Generator(nn.Module): # 定义生成器
def __init__(self): # 初始化方法
super(Generator, self).__init__() # 继承初始化方法
self.conv1 = nn.Conv2d(3, 64, 7, 1, 3) # conv
self.isn1 = nn.InstanceNorm2d(64) # in
self.conv2 = nn.Conv2d(64, 128, 3, 2, 1) # conv
self.isn2 = nn.InstanceNorm2d(128) # in
self.conv3 = nn.Conv2d(128, 256, 3, 2, 1) # conv
self.isn3 = nn.InstanceNorm2d(256) # in
self.relu = nn.ReLU() # relu
self.layers = [] # 用于存放残差块结构
for i in range(9): # 共9个残差块
self.layers.append(IdentityBlock()) # 向layers中添加残差块结构
self.resnet = nn.Sequential(*self.layers) # 将layers列表转化为模型结构序列
self.ups = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # upsample,上采样
self.conv4 = nn.Conv2d(256, 128, 3, 1, 1) # conv
self.isn4 = nn.InstanceNorm2d(128) # in
self.conv5 = nn.Conv2d(128, 64, 3, 1, 1) # conv
self.isn5 = nn.InstanceNorm2d(64) # in
self.conv6 = nn.Conv2d(64, 3, 7, 1, 3) # conv
self.tanh = nn.Tanh() # tanh
def forward(self, x): # 前传函数
x = self.conv1(x) # conv,(n,3,256,256)-->(n,64,256,256)
x = self.isn1(x) # in
x = self.relu(x) # relu
x = self.conv2(x) # conv,(n,64,256,256)-->(n,128,128,128)
x = self.isn2(x) # in
x = self.relu(x) # relu
x = self.conv3(x) # conv,(n,128,128,128)-->(n,256,64,64)
x = self.isn3(x) # in
x = self.relu(x) # relu
x = self.resnet(x) # 9次残差结构计算,(n,256,64,64)-->(n,256,64,64)
x = self.ups(x) # upsample,(n,256,64,64)-->(n,256,128,128)
x = self.conv4(x) # conv,(n,256,128,128)-->(n,128,128,128)
x = self.isn4(x) # in
x = self.relu(x) # relu
x = self.ups(x) # upsample,(n,128,128,128)-->(n,128,256,256)
x = self.conv5(x) # conv,(n,128,256,256)-->(n,64,256,256)
x = self.isn5(x) # in
x = self.relu(x) # relu
x = self.conv6(x) # conv,(n,64,256,256)-->(n,3,256,256)
x = self.tanh(x) # tanh,输出映射至(-1,1)
return x # 返回风格迁移后的图像