使用对抗loss来提升图像的视觉效果

判别器

为了使用对抗(adversarial)loss,第一步当然是创建一个判别器了,这里就先放一个最常规最简单的判别器:

class Discriminator(nn.Module):
    def __init__(self, input_nc=3, ngf=32):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(input_nc, ngf, normalize=False),
            *discriminator_block(ngf, ngf*2),
            *discriminator_block(ngf*2, ngf*4),
            # *discriminator_block(ngf*4, ngf*8), # 根据需求看是否需要增加判别器的复杂性
            # nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(ngf*4, 1, 4, padding=1)
        )
    
    def forward(self, img):
        x = self.model(img)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

ReplayBuffer

class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

训练过程

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
target_real = Variable(Tensor(cfg.TRAIN.batch_size, 1).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(cfg.TRAIN.batch_size, 1).fill_(0.0), requires_grad=False)
fake_buffer = ReplayBuffer()

net_D = Discriminator(input_nc=1, ngf=32).to(torch.device('cuda:0'))

optimizer_G = optim.Adam(net_D.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = optim.Adam(net_D.parameters(), lr=0.0001, betas=(0.5, 0.999))
criterion_GAN = torch.nn.MSELoss()

for iters in range(1, total_iterations):
	input, target = train_provider.next()
	# 更新生成器
	optimizer_G.zero_grad()
	pred = model(input)
	
	# 生成器loss
	D_pred = net_D(pred)
	loss_G = criterion_GAN(D_pred, target_real)
	loss_G .backward()
	optimizer_G .step()
	
	# 更新判别器
	optimizer_D.zero_grad()
	# Real loss
    pred_real = net_D(target)
    loss_D_real = criterion_GAN(pred_real, target_real)
    # fake loss
    pred = fake_buffer.push_and_pop(pred)
    pred_fake = net_D(pred.detach())
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    # Total loss
    loss_ad = (loss_D_real + loss_D_fake)*0.5
    loss_ad.backward()
    optimizer_D.step()

你可能感兴趣的:(日用小技能,gan,对抗损失)