为了使用对抗(adversarial)loss,第一步当然是创建一个判别器了,这里就先放一个最常规最简单的判别器:
class Discriminator(nn.Module):
def __init__(self, input_nc=3, ngf=32):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(input_nc, ngf, normalize=False),
*discriminator_block(ngf, ngf*2),
*discriminator_block(ngf*2, ngf*4),
# *discriminator_block(ngf*4, ngf*8), # 根据需求看是否需要增加判别器的复杂性
# nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(ngf*4, 1, 4, padding=1)
)
def forward(self, img):
x = self.model(img)
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
class ReplayBuffer():
def __init__(self, max_size=50):
assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
self.max_size = max_size
self.data = []
def push_and_pop(self, data):
to_return = []
for element in data.data:
element = torch.unsqueeze(element, 0)
if len(self.data) < self.max_size:
self.data.append(element)
to_return.append(element)
else:
if random.uniform(0,1) > 0.5:
i = random.randint(0, self.max_size-1)
to_return.append(self.data[i].clone())
self.data[i] = element
else:
to_return.append(element)
return Variable(torch.cat(to_return))
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
target_real = Variable(Tensor(cfg.TRAIN.batch_size, 1).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(cfg.TRAIN.batch_size, 1).fill_(0.0), requires_grad=False)
fake_buffer = ReplayBuffer()
net_D = Discriminator(input_nc=1, ngf=32).to(torch.device('cuda:0'))
optimizer_G = optim.Adam(net_D.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = optim.Adam(net_D.parameters(), lr=0.0001, betas=(0.5, 0.999))
criterion_GAN = torch.nn.MSELoss()
for iters in range(1, total_iterations):
input, target = train_provider.next()
# 更新生成器
optimizer_G.zero_grad()
pred = model(input)
# 生成器loss
D_pred = net_D(pred)
loss_G = criterion_GAN(D_pred, target_real)
loss_G .backward()
optimizer_G .step()
# 更新判别器
optimizer_D.zero_grad()
# Real loss
pred_real = net_D(target)
loss_D_real = criterion_GAN(pred_real, target_real)
# fake loss
pred = fake_buffer.push_and_pop(pred)
pred_fake = net_D(pred.detach())
loss_D_fake = criterion_GAN(pred_fake, target_fake)
# Total loss
loss_ad = (loss_D_real + loss_D_fake)*0.5
loss_ad.backward()
optimizer_D.step()