DOA_GAN的近似复现

论文:《DOA-GAN: Dual-Order Attentive Generative Adversarial Network for Image Copy-move Forgery Detection and Localization》

code:170744039/doagan_clean: DOA-GAN: Dual-Order Attentive Generative Adversarial Network for Image Copy-move Forgery Detection and Localization (github.com)icon-default.png?t=M276https://github.com/170744039/doagan_clean​​​​​​​看了之后才知道上面的源码并没有包括判别器与训练代码,但想着自己也有数据集,就是buster的那个uscisi_cmsd10k张那个,训练代码也不就是一个接口的事情,但做了之后才发现事情有点不对劲,这也是本文叫做近似复现的原因。

问题一:由于没有判别器的代码,但可以从论文中对应描述的参考文献中找到,在按照论文中的设置进行小小的修改之后,输出还有那么一点点不对。

class Discriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=32, n_layers=5, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(Discriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), norm_layer(ndf), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            # gradually increase the number of filters

            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 16)
            # print(nf_mult)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        # nf_mult_prev = nf_mult
        # nf_mult = min(2 ** n_layers, 8)
        # sequence += [
        #     nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
        #     norm_layer(ndf * nf_mult),
        #     nn.LeakyReLU(0.2, True)
        # ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

输入是320*320的size,此时输出为(1,1,9,9),并不是论文中给出的(1,1,10,10),但也不想仔细算这个1的差异了。

问题二:论文中对于loss的描述有点难懂,很生气,对抗生成网络必须有判别器loss与生成器loss,但原文只给了三个不知道说什么的,分别对应都好实现,但加一起到底怎么个分配法,我也只能按照我的想法结合他的参考论文改改。

if __name__ == "__main__":
    # device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    args = Doa_gan_config.config_USC()

    if args.dataset == 'usc':
        args = Doa_gan_config.config_USC()
        args.out_channel = 3
    elif args.dataset == 'casia':
        args = Doa_gan_config.config_CASIA()
        args.out_channel = 1
    elif args.dataset == 'como':
        args = Doa_gan_config.config_COMO()
        args.out_channel = 1

    args.size = tuple(int(i) for i in args.size.split("x"))

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # model name
    model_name = args.model + "_" + \
        args.dataset + args.suffix

    print(f"Model Name: {model_name}")

    # model
    model_g = Doa_GAN_model.DOA(out_channel=args.out_channel)
    model_d = Doa_GAN_model.Discriminator(input_nc=6)
    model_g.to(device)
    model_d.to(device)

    if args.dataset == 'usc':
        data = Doa_gan_dataset.USCISI_CMD_Dataset(lmdb_dir=args.lmdb_dir, args=args,
                                                 sample_file=args.train_key)
    elif args.dataset == 'casia':
        data = Doa_gan_dataset.Dataset_CASIA(args)
    elif args.dataset == 'como':
        data = Doa_gan_dataset.Dataset_COMO(args)

    trian_data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, num_workers=0)

    if args.ckpt is not None:
        checkpoint = torch.load(args.ckpt)
        model_g.load_state_dict(checkpoint["model_state"], strict=True)

    iter_num = 0
    max_epoch = args.max_epoch
    max_iterations = max_epoch * len(trian_data_loader)

    model_g.train()
    model_d.train()

    criterionGAN_loss = GANLoss(gan_mode="vanilla").to(device)
    bce_loss  = BCEWithLogitsLoss()
    ce_loss = CrossEntropyLoss()


    base_params = list(map(id, model_g.encoder.parameters()))
    logits_params = filter(lambda p: id(p) not in base_params, model_g.parameters())
    params = [{'params': model_g.encoder.parameters(), 'lr': 0.0001},
              {'params': logits_params, 'lr': 0.001},]

    optimizer_G = torch.optim.Adam(params, betas=(0.5, 0.999))


    optimizer_D = torch.optim.Adam(model_d.parameters(), lr=0.0001, betas=(0.5, 0.999))

    # 论文中说是在5个epoch后损失平稳后将学习率调整为原来一半,
    scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, mode="min", factor=0.5, patience=3)
    scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_D, mode="min", factor=0.5, patience=3)



    last_step = 0
    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    num_iter_per_epoch = len(trian_data_loader)

    snapshot_path = "./DOA/"

    for epoch in range(max_epoch):

        epoch_loss = []
        progress_bar = tqdm(trian_data_loader)
        for iter, data in enumerate(progress_bar):
            imgs, gts, = data


                # if only one gpu, just send it to cuda:0
                # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
            imgs = imgs.cuda()
            gts = gts.cuda()
            pred_gts, pred_detc = model_g(imgs.detach())



            # # 首先只训练生成器3个epoch
            # if epoch < 3:
            #     for name, p in model_d.named_parameters():
            #         p.requires_grad = False







            # we use conditional GANs; we need to feed both input and output to the discriminator
            # 判别器的损失函数
            fake_AB = torch.cat((imgs, pred_gts),1)
            pred_d_fake = model_d(fake_AB.detach())
            loss_D_fake = criterionGAN_loss(pred_d_fake, False)
            # Real
            real_AB = torch.cat((imgs, gts), 1)
            pred_real = model_d(real_AB)
            loss_D_real = criterionGAN_loss(pred_real, True)
            # combine loss and calculate gradients
            loss_D = (loss_D_fake + loss_D_real) * 0.5


            if epoch >= 3 and loss_D > 0.3:
                optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
                scheduler_d.step(loss_D)
            # print(model_d.parameters())
            # for name, parameters in model_d.state_dict().items():
            #     print(parameters)




            # 生成器的损失函数

            fake_AB = torch.cat((imgs, pred_gts), 1)
            pred_d_fake = model_d(fake_AB.detach())
            loss_g_fake = criterionGAN_loss(pred_d_fake, True)

            gts  = gts.type(torch.LongTensor).cuda()
            _ , gts = gts.max(dim=1)
            loss_g_ce = ce_loss(pred_gts, gts)
            nb_batch = pred_detc.size(0)
            real_labels = torch.full((nb_batch,1), 1.)
            real_labels = real_labels.cuda()

            loss_g_bce = bce_loss(pred_detc, real_labels)

            # 无法理解论文中关于生成器的loss计算,主要是对抗生成loss:L = Ladv + αLce + βLdet. (
            # Adversarial Loss Ladv is defined as:
            # Ladv(G, D) =E(I,M)[log(D(I,M)) + log(1 − D(I,G(I))],
            # 而且不知道这两个超参数是否可学习调整
            loss_G = 0.6*loss_g_fake + 0.3*loss_g_ce + 0.1*loss_g_bce

            optimizer_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()
            scheduler_g.step(loss_G)


            # # 当鉴别器损失减少到 0.3 时,我们冻结鉴别器直到损失增加
            # if loss_D < 0.3:
            #     for name, p in model_d.named_parameters():
            #         p.requires_grad = False




            progress_bar.set_description(
                'Step: {}. Epoch: {}/{}. Iteration: {}/{}.  loss_D: {:.5f} . loss_G: {:.5f}'.format(
                    step, epoch, max_epoch, iter + 1, num_iter_per_epoch, loss_D.item(), loss_G.item()))



            #
            # # log learning_rate
            # current_lr_G = optimizer_G.param_groups[0]
            # current_lr_D = optimizer_D.param_groups[0]['lr']
        if epoch > 40:
            save_mode_path_g = os.path.join(snapshot_path, 'g_epoch_' + str(epoch) + '.pkl')
            save_mode_path_d = os.path.join(snapshot_path, 'd_epoch_' + str(epoch) + '.pkl')
            torch.save(model_g.state_dict(), save_mode_path_g)
            torch.save(model_d.state_dict(), save_mode_path_d)
            progress_bar.close()

有出入的地方有好几个,原文中说是5个epoch后loss会趋于稳定,然后将所有lr统一减小一半,但我哪里知道我写的对不对,就没敢这么做,只用了个下面想着动态调整一下就行了

scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, mode="min", factor=0.5, patience=3)

论文中的检测分支输出用来判别图像是否包含复制篡改,是的话标签为一,否则为零,可怜我整个训练集全都只有复制篡改啊,连一个拼接的都没有,只能在损失计算前把标签全置为一,我知道这不可能是作者的想法,但我真的不知道这句话还能怎么翻译。

where yim is set to 1 if the image contains copy-move forgery, otherwise it is set to 0, and yˆim is the output from the detection branch.

还有论文中对于两个gts的cross entropyloss,和检测分支的bceloss前面还有两个参数,但论文中和源码里也没说设为几,能不能学习,所以我的代码里就随便给了两个数,这也肯定不对,但原谅我还是不知道该怎么办,凑合办吧,能跑就行。

看这篇文章能不能抛砖引玉,来个哪位大佬教教我们这篇论文的正确复现姿势,大家有什么问题都可以在评论里讲,但我不一定会理。我也就改了上面这两大段,剩下一点也就是改改路径的问题,拜了拜了。

顺便提一下,跑起来特别慢,因为直接从lmdb格式加载,buster的那篇复现也是这个问题,但出于尊重源码,我也一直没用自己已经解析为tif图片的数据集,而且我不熟悉pytorch,但keras甚至不能读取lmdb数据,所以也不知道怎么调快点,哪怕是我的3090一晚上也只能走5个epoch,但源码里要走50个啊,近段时间不考虑跑出结果,等电脑闲了再跑出结果更新这篇。

 
 

你可能感兴趣的:(python)