Pytorch学习笔记-第七章

Pytorch学习笔记-第七章GAN生成动漫头像

  • model文件
    • 生成器
    • 判别器
  • visualize文件
  • main文件
    • 数据处理
    • 训练
      • 判别器
      • 生成器
    • 生成结果

记录一下个人学习和使用Pytorch中的一些问题。强烈推荐 《深度学习框架PyTorch:入门与实战》.写的非常好而且作者也十分用心,大家都可以看一看,本文为学习第七章GAN生成动漫头像的学习笔记。

主要分析实现代码里面main,model,visualize这3个代码文件完成整个项目模型结构定义,训练及生成,还有输出展示的整个过程。

model文件

整个模型结构是经典的生成器-判别器架构,model文件也只有这两个类,分别用于生成和判别图片。

生成器

生成器是从无到有,从一个噪声扩充数据到指定的大小,所以其中的网络层是反卷积层。除了第一层的输入通道数为参数设置之外,其他各层与判别器对称,可以从1x1x opt.nz的噪音,生成一个3x96x96的图片

class NetG(nn.Module):
    """
    生成器定义
    """

    def __init__(self, opt):
        super(NetG, self).__init__()
        ngf = opt.ngf  # 生成器feature map数

        self.main = nn.Sequential(
            # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
            nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf*8) x 4 x 4

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*4) x 8 x 8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*2) x 16 x 16

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf) x 32 x 32

            nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
            nn.Tanh()  # 输出范围 -1~1 故而采用Tanh
            # 输出形状:3 x 96 x 96
        )

判别器

正常的一个卷积网络,从3x96x96的图片得最后一个为真实样本的概率值。

class NetD(nn.Module):
    """
    判别器定义
    """

    def __init__(self, opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 输入 3 x 96 x 96
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf) x 32 x 32

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*2) x 16 x 16

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*4) x 8 x 8

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*8) x 4 x 4

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            #约束到(0,1)
            nn.Sigmoid()  # 输出一个数(概率)
        )

visualize文件

整个文件就一个用于可视化的类,在Visdom的基础上,加工出几个更方便使用的函数。
类内包含了一个visdom对象,借助__getattr__函数使得想要调用visdom原生接口的函数也可以obj.func而不用obj.visdom.func

def __getattr__(self, name):
       '''
       self.function 等价于self.vis.function
       自定义的plot,image,log,plot_many等除外
       '''
       return getattr(self.vis, name)

文件自定义的接口为:一次性画多个点,以及一次性网格展示整个batch的图片。但是一次性展示整个batch图片这个函数好像没有使用,而且有点疑惑它这部分代码为什么输入的batch数据没有RGB通道。

def img_grid(self, name, input_3d):
        """
        一个batch的图片转成一个网格图,i.e. input(36,64,64)
        会变成 6*6 的网格图,每个格子大小64*64
        """
        self.img(name, tv.utils.make_grid(
            input_3d.cpu()[0].unsqueeze(1).clamp(max=1, min=0)))
            #疑惑为什么要假设输入的数据是BxHxW而没有颜色通道

main文件

该文件为主体文件,模型的各种参数也放在该文件的config类中。

数据处理

需要准备一些现成的动漫头像最为材料,供网络去学习模仿,就按一般图片的处理方式即可。因为现成数据的标签都是1,可以用默认的dataset而不进行额外处理,之后计算损失直接把target设为全1。

# 数据
    transforms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=opt.num_workers,
                                         drop_last=True
                                         )

训练

定义好生成器,判别器以及对应的优化器和误差计算函数之后开始模型的训练。

判别器

对于判别器而言需要分别出真实图像和生成的假图像,所以它的训练需要让两者的误分类损失最小,两者的损失都需要优化。

生成器

对于生成器而言,目标是让判别器误分类生成的假图像,所以它的损失函数是假图像的预测概率和真实图像标签(非我们定义假图像标签,而是我们定义的真实图像的标签)的误差。

生成结果

当生成器和判别器之间达到我们预先设定的条件,训练轮数足够或者平衡时,就可以说模型训练完成,可以查看输出的生成结果了。
为了得到好的结果,我们还是要同时用到生成器和判别器了,用一批随机噪声送入训练好的生成器,把生成的结果用判别器打个分,得分最高的(预测最接近真实标签的)就可以作为好结果输出了。

def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    device=t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1, 1))

你可能感兴趣的:(DL,深度学习)