下载:
git clone https://github.com/yunjey/StarGAN.git
cd StarGAN/
下载celebA训练数据:
bash download.sh
训练:
python main.py --mode='train' --dataset='CelebA' --c_dim=5 --image_size=128 \
--sample_path='stargan_celebA/samples' --log_path='stargan_celebA/logs' \
--model_save_path='stargan_celebA/models' --result_path='stargan_celebA/results'
第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度,
layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True))
layers.append(nn.ReLU(inplace=True))
2个卷积层,stride=2,即下采样,
# Down-Sampling
curr_dim = conv_dim
for i in range(2):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim * 2
残差层,
# Bottleneck
for i in range(repeat_num):
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
残差网络结构,
class ResidualBlock(nn.Module):
"""Residual Block."""
def __init__(self, dim_in, dim_out):
super(ResidualBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True))
def forward(self, x):
return x + self.main(x)
上采样,
# Up-Sampling
for i in range(2):
layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim // 2
最后一层,得到输出维度为3,
layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.Tanh())
self.main = nn.Sequential(*layers)
对于输入图像x,label向量c,串联如下,
def forward(self, x, c):
# replicate spatially and concatenate domain information
c = c.unsqueeze(2).unsqueeze(3)
c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3))
x = torch.cat([x, c], dim=1)
return self.main(x)
判别网络输入为图像,用于判别输入图像真假,已经输入图像的类别,
class Discriminator(nn.Module):
"""Discriminator. PatchGAN."""
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
super(Discriminator, self).__init__()
layers = []
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = conv_dim
for i in range(1, repeat_num):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = curr_dim * 2
k_size = int(image_size / np.power(2, repeat_num))
self.main = nn.Sequential(*layers)
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=k_size, bias=False)
def forward(self, x):
h = self.main(x)
out_real = self.conv1(h)
out_aux = self.conv2(h)
return out_real.squeeze(), out_aux.squeeze()
conv1输出维度为1,即判别输入的真假,conv2输出维度为c_dim,即判别输入图像的label.
输入包括
real_x,real_c,fake_c
fake_c为随机生成的,
# Generat fake labels randomly (target domain labels)
rand_idx = torch.randperm(real_label.size(0))
fake_label = real_label[rand_idx]
if self.dataset == 'CelebA':
real_c = real_label.clone()
fake_c = fake_label.clone()
将真实图像输入判别网络,
# Compute loss with real images
out_src, out_cls = self.D(real_x)
d_loss_real = - torch.mean(out_src)
判别网络的输入为真实图像,输出out_cls为真实图像对应的标签的概率,则可以计算交叉损失熵,
if self.dataset == 'CelebA':
d_loss_cls = F.binary_cross_entropy_with_logits(
out_cls, real_label, size_average=False) / real_x.size(0)
将真实图像输入real_x和假的标签fake_c输入生成网络,得到生成图像fake_x,
fake_x = self.G(real_x, fake_c)
将生成图像输入判别网络,
fake_x = Variable(fake_x.data)
out_src, out_cls = self.D(fake_x)
d_loss_fake = torch.mean(out_src)
总的损失函数为,
# Backward + Optimize
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
根据d_loss更新判别网络参数,
# Backward + Optimize
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
计算梯度惩罚因子alpha,根据alpha结合real_x,fake_x,输入判别网络,计算梯度,得到梯度损失函数,
# Compute gradient penalty
alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
out, out_cls = self.D(interpolated)
grad = torch.autograd.grad(outputs=out,
inputs=interpolated,
grad_outputs=torch.ones(out.size()).cuda(),
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
grad = grad.view(grad.size(0), -1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = torch.mean((grad_l2norm - 1)**2)
根据梯度损失函数d_loss_gp优化判别网路,
# Backward + Optimize
d_loss = self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像.
将原图像输入生成网络,得到生成图像fake_x,同时将fake_x图像输入生成网络,希望生成的图像与真实图像尽量相似,
# Original-to-target and target-to-original domain
fake_x = self.G(real_x, fake_c)
rec_x = self.G(fake_x, real_c)
# Compute losses
g_loss_rec = torch.mean(torch.abs(real_x - rec_x))
将fake_x输入判别网路,
out_src, out_cls = self.D(fake_x)
g_loss_fake = - torch.mean(out_src)
计算损失函数,
g_loss_fake = - torch.mean(out_src)
对于fake_x,对应的label为fake_label,将fake_x输入判别网络,判别网络预测label概率为out_cls,因此可以计算交叉损失熵,
g_loss_cls = F.binary_cross_entropy_with_logits(
out_cls, fake_label, size_average=False) / fake_x.size(0)
生成网络参数更新,
# Backward + Optimize
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
以celebA数据为例,下载后的数据包括label文件,和图像.
文件的第一行为图像的总数,为202599.
第二行为数据处理的类别,包括40种,
5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young
第三行及之后的每行为,图像名,已经对应的40种类别的label,label值为1或-1,之后提取为值1为1,-1为0.
000001.jpg -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
list_attr_celeba.txt文件提取函数为,
def preprocess(self):
attrs = self.lines[1].split()
for i, attr in enumerate(attrs):
self.attr2idx[attr] = i
self.idx2attr[i] = attr
self.selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
self.train_filenames = []
self.train_labels = []
self.test_filenames = []
self.test_labels = []
lines = self.lines[2:]#the image and labels
random.shuffle(lines) # random shuffling
for i, line in enumerate(lines):
splits = line.split()
filename = splits[0]#image name
values = splits[1:]# labels
label = []
for idx, value in enumerate(values):
attr = self.idx2attr[idx]# there are 40 classes,find the idx th class name
if attr in self.selected_attrs:#check if the attr in the selected classes
if value == '1':#if the ckss label is 1 then label equal 2,otherwise,0
label.append(1)
else:
label.append(0)
if (i+1) < 2000:
self.test_filenames.append(filename)
self.test_labels.append(label)
else:
self.train_filenames.append(filename)
self.train_labels.append(label)
self.selected_attrs表示我们训练选用的任务类别集合.最后得到图像名数组self.train_filenames,及其对应的label数组 self.train_labels.
之后采用from torch.utils.data import DataLoader加载训练数据,
data_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle)
fixed_x = []
real_c = []
for i, (images, labels) in enumerate(self.data_loader):
fixed_x.append(images)
real_c.append(labels)
if i == 3:
break
读取后的图像数组为fixed_x,lable为real_c.图像为(bath_size,c_dim,imagesize,imagesize),label为(batch_size,len(self.selected_attrs)).
得到固定的输入图像数组,label,labelist,用于sample.
# Fixed inputs and target domain labels for debugging
fixed_x = torch.cat(fixed_x, dim=0)#4*batch_szie,(64,3,128,128)
fixed_x = self.to_var(fixed_x, volatile=True)
real_c = torch.cat(real_c, dim=0)
if self.dataset == 'CelebA':
fixed_c_list = self.make_celeb_labels(real_c)
labellist生成函数为,
def make_celeb_labels(self, real_c):
"""Generate domain labels for CelebA for debugging/testing.
if dataset == 'CelebA':
return single and multiple attribute changes
elif dataset == 'Both':
return single attribute changes
"""
y = [torch.FloatTensor([1, 0, 0]), # black hair
torch.FloatTensor([0, 1, 0]), # blond hair
torch.FloatTensor([0, 0, 1])] # brown hair
fixed_c_list = []
# single attribute transfer
for i in range(self.c_dim):
fixed_c = real_c.clone()
for c in fixed_c:
if i < 3:
c[:3] = y[i]
else:
c[i] = 0 if c[i] == 1 else 1 # opposite value
fixed_c_list.append(self.to_var(fixed_c, volatile=True))
# multi-attribute transfer (H+G, H+A, G+A, H+G+A)
if self.dataset == 'CelebA':
for i in range(4):
fixed_c = real_c.clone()
for c in fixed_c:
if i in [0, 1, 3]: # Hair color to brown
c[:3] = y[2]
if i in [0, 2, 3]: # Gender
c[3] = 0 if c[3] == 1 else 1
if i in [1, 2, 3]: # Aged
c[4] = 0 if c[4] == 1 else 1
fixed_c_list.append(self.to_var(fixed_c, volatile=True))
return fixed_c_list
fixed_c_list长度为c_dim+4=5+4=9,
训练的时候,fake_label为随机产生0-batch_size的索引,并由索引,从real_label取值,
# Start training
start_time = time.time()
for e in range(start, self.num_epochs):
for i, (real_x, real_label) in enumerate(self.data_loader):
# Generat fake labels randomly (target domain labels)
rand_idx = torch.randperm(real_label.size(0))
fake_label = real_label[rand_idx]
if self.dataset == 'CelebA':
real_c = real_label.clone()
fake_c = fake_label.clone()
else:
real_c = self.one_hot(real_label, self.c_dim)
fake_c = self.one_hot(fake_label, self.c_dim)
# Convert tensor to variable
real_x = self.to_var(real_x)#(16,3,128,128)
real_c = self.to_var(real_c) #(16,5) # input for the generator
fake_c = self.to_var(fake_c)#(16,5)
real_label = self.to_var(real_label) # this is same as real_c if dataset == 'CelebA'
fake_label = self.to_var(fake_label)