CycleGAN模型——pytorch实现

# 输入图像shape默认为(3,256,256)
class Discriminator(nn.Module):  # 定义判别器
    def __init__(self):  # 初始化方法
        super(Discriminator, self).__init__()  # 继承初始化方法
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)  # conv
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)  # conv
        self.isn2 = nn.InstanceNorm2d(128)  # instancenorm,实例标准化,在图像风格转化任务中,生成图像依赖于某个图像的实例,所以batchnorm并不适用于风格转化任务
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)  # conv
        self.isn3 = nn.InstanceNorm2d(256)  # in
        self.conv4 = nn.Conv2d(256, 512, 4, 2, 1)  # conv
        self.isn4 = nn.InstanceNorm2d(512)  # in
        self.conv5 = nn.Conv2d(512, 1, 3, 1, 1)  # conv

        self.leakyrelu = nn.LeakyReLU(0.2)  # leakyrelu
        self.sigmoid = nn.Sigmoid()  # sigmoid

    def forward(self, x):  # 前传函数
        x = self.conv1(x)  # conv,(n,3,256,256)-->(n,64,128,128)
        x = self.leakyrelu(x)  # leakyrelu
        x = self.conv2(x)  # conv,(n,64,128,128)-->(n,128,64,64)
        x = self.isn2(x)  # in
        x = self.leakyrelu(x)  # leakyrelu
        x = self.conv3(x)  # conv,(n,128,64,64)-->(n,256,32,32)
        x = self.isn3(x)  # in
        x = self.leakyrelu(x)  # leakyrelu
        x = self.conv4(x)  # conv,(n,256,32,32)-->(n,512,16,16)
        x = self.isn4(x)  # in
        x = self.leakyrelu(x)  # leakyrelu
        x = self.conv5(x)  # conv,(n,512,16,16)-->(n,1,16,16)
        x = self.sigmoid(x)  # sigmoid,输入映射至(0,1)

        return x  # 返回图像真假的得分(16,16),相当于对16x16个区域的真假进行评分,而非对整体图片的真假进行评分


class IdentityBlock(nn.Module):  # 定义残差块
    def __init__(self):  # 初始化方法
        super(IdentityBlock, self).__init__()  # 继承初始化方法
        self.conv1 = nn.Conv2d(256, 256, 3, 1, 1)  # conv
        self.isn1 = nn.InstanceNorm2d(256)  # in
        self.conv2 = nn.Conv2d(256, 256, 3, 1, 1)  # conv
        self.isn2 = nn.InstanceNorm2d(256)  # in
        self.relu = nn.ReLU()  # relu

    def forward(self, x):  # 前传函数
        y = self.conv1(x)  # conv,(n,256,64,64)-->(n,256,64,64)
        y = self.isn1(y)  # in
        y = self.relu(y)  # relu
        y = self.conv2(y)  # conv,(n,256,64,64)-->(n,256,64,64)
        y = self.isn2(y)  # in
        y += x  # F(x) + x,(n,256,64,64)+(n,256,64,64)-->(n,256,64,64)
        y = self.relu(y)  # relu
        return y


class Generator(nn.Module):  # 定义生成器
    def __init__(self):  # 初始化方法
        super(Generator, self).__init__()  # 继承初始化方法
        self.conv1 = nn.Conv2d(3, 64, 7, 1, 3)  # conv
        self.isn1 = nn.InstanceNorm2d(64)  # in
        self.conv2 = nn.Conv2d(64, 128, 3, 2, 1)  # conv
        self.isn2 = nn.InstanceNorm2d(128)  # in
        self.conv3 = nn.Conv2d(128, 256, 3, 2, 1)  # conv
        self.isn3 = nn.InstanceNorm2d(256)  # in
        self.relu = nn.ReLU()  # relu
        self.layers = []  # 用于存放残差块结构
        for i in range(9):  # 共9个残差块
            self.layers.append(IdentityBlock())  # 向layers中添加残差块结构
        self.resnet = nn.Sequential(*self.layers)  # 将layers列表转化为模型结构序列
        self.ups = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # upsample,上采样
        self.conv4 = nn.Conv2d(256, 128, 3, 1, 1)  # conv
        self.isn4 = nn.InstanceNorm2d(128)  # in
        self.conv5 = nn.Conv2d(128, 64, 3, 1, 1)  # conv
        self.isn5 = nn.InstanceNorm2d(64)  # in
        self.conv6 = nn.Conv2d(64, 3, 7, 1, 3)  # conv
        self.tanh = nn.Tanh()  # tanh

    def forward(self, x):  # 前传函数
        x = self.conv1(x)  # conv,(n,3,256,256)-->(n,64,256,256)
        x = self.isn1(x)  # in
        x = self.relu(x)  # relu
        x = self.conv2(x)  # conv,(n,64,256,256)-->(n,128,128,128)
        x = self.isn2(x)  # in
        x = self.relu(x)  # relu
        x = self.conv3(x)  # conv,(n,128,128,128)-->(n,256,64,64)
        x = self.isn3(x)  # in
        x = self.relu(x)  # relu
        x = self.resnet(x)  # 9次残差结构计算,(n,256,64,64)-->(n,256,64,64)
        x = self.ups(x)  # upsample,(n,256,64,64)-->(n,256,128,128)
        x = self.conv4(x)  # conv,(n,256,128,128)-->(n,128,128,128)
        x = self.isn4(x)  # in
        x = self.relu(x)  # relu
        x = self.ups(x)  # upsample,(n,128,128,128)-->(n,128,256,256)
        x = self.conv5(x)  # conv,(n,128,256,256)-->(n,64,256,256)
        x = self.isn5(x)  # in
        x = self.relu(x)  # relu
        x = self.conv6(x)  # conv,(n,64,256,256)-->(n,3,256,256)
        x = self.tanh(x)  # tanh,输出映射至(-1,1)

        return x  # 返回风格迁移后的图像

你可能感兴趣的:(pytorch,python,深度学习)