pytorch原版github地址:https://github.com/yunjey/StarGAN
tensorflow版github地址:https://github.com/taki0112/StarGAN-Tensorflow
两个版本实现相差不大,以pytorch版来介绍。
以celebA数据为例,下载后的数据包括label文件,和图像.
文件的第一行为图像的总数,为202599.
第二行为数据处理的类别,包括40种,
5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young
第三行及之后的每行为,图像名,已经对应的40种类别的label,label值为1或-1
000001.jpg -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
all_attr_names表示全部40种任务类别集合
self.selected_attrs表示我们训练选用的任务类别集合,默认的是[‘Black_Hair’, ‘Blond_Hair’, ‘Brown_Hair’, ‘Male’, ‘Young’]
def preprocess(self):
"""Preprocess the CelebA attribute file."""
lines = [line.rstrip() for line in open(self.attr_path, 'r')]
all_attr_names = lines[1].split()
for i, attr_name in enumerate(all_attr_names):
self.attr2idx[attr_name] = i
self.idx2attr[i] = attr_name
lines = lines[2:]
random.seed(1234)
random.shuffle(lines)#打乱图片
for i, line in enumerate(lines):
split = line.split()
filename = split[0]#图片名
values = split[1:]#图片对应的标签
label = []
for attr_name in self.selected_attrs:#创建训练选用的任务类别和索引的一一对应关系
idx = self.attr2idx[attr_name]
label.append(values[idx] == '1')#label如果是1则还是为1,为-1是换成0
if (i+1) < 2000:#取2000张做测试集数据
self.test_dataset.append([filename, label])
else:
self.train_dataset.append([filename, label])
print('Finished preprocessing the CelebA dataset...')
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128,
batch_size=16, dataset='CelebA', mode='train', num_workers=1):
"""Build and return a data loader."""
transform = []
if mode == 'train':
transform.append(T.RandomHorizontalFlip())#数据随机水平翻转
transform.append(T.CenterCrop(crop_size))#从中间裁剪
transform.append(T.Resize(image_size))#更改图片大小
transform.append(T.ToTensor())
transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))#正则化
transform = T.Compose(transform)
if dataset == 'CelebA':
dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
elif dataset == 'RaFD':
dataset = ImageFolder(image_dir, transform)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=(mode=='train'),
num_workers=num_workers)
return data_loader
def build_model(self):
"""Create a generator and a discriminator."""
if self.dataset in ['CelebA', 'RaFD']:
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
elif self.dataset in ['Both']:
self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector.
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
self.print_network(self.G, 'G')#把生成器打印到屏幕上
self.print_network(self.D, 'D')#把判别器打印到屏幕上
self.G.to(self.device)
self.D.to(self.device)
# Fetch fixed inputs for debugging.
data_iter = iter(data_loader)
x_fixed, c_org = next(data_iter)#得到一个batch的图片和
x_fixed = x_fixed.to(self.device)
c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
其中create_labels生成的c_trg_list是一个包含5个[16,5]的list。
生成的list就是后五个图对应的目标标签([‘Black_Hair’, ‘Blond_Hair’, ‘Brown_Hair’, ‘Male’, ‘Young’])
给个看训练效果时看到的图:
def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
"""Generate target domain labels for debugging and testing."""
# Get hair color indices.
if dataset == 'CelebA':
hair_color_indices = []
for i, attr_name in enumerate(selected_attrs):
if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
hair_color_indices.append(i)
c_trg_list = []
for i in range(c_dim):
if dataset == 'CelebA':
c_trg = c_org.clone()
if i in hair_color_indices: # Set one hair color to 1 and the rest to 0.如果目标标签需要改变的是头发颜色,就把想得到的颜色对应的索引置1,其余头发颜色置0。
c_trg[:, i] = 1
for j in hair_color_indices:
if j != i:
c_trg[:, j] = 0
else:
c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value.如果目标标签不是头发颜色,那么就取反,比如男性取反为女性,年老取反为年轻。
elif dataset == 'RaFD':
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
c_trg_list.append(c_trg.to(self.device))
# pdb.set_trace()
return c_trg_list
用来存效果图的代码:
# Translate fixed images for debugging.
if (i+1) % self.sample_step == 0:
with torch.no_grad():
x_fake_list = [x_fixed]
for c_fixed in c_fixed_list:
x_fake_list.append(self.G(x_fixed, c_fixed))
x_concat = torch.cat(x_fake_list, dim=3)
sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(sample_path))
# Fetch real images and labels.
try:
x_real, label_org = next(data_iter)
except:
data_iter = iter(data_loader)
x_real, label_org = next(data_iter)
# Generate target domain labels randomly.
# 打乱了原标签得到训练时的目标标签
rand_idx = torch.randperm(label_org.size(0))
label_trg = label_org[rand_idx]
if self.dataset == 'CelebA':
c_org = label_org.clone()
c_trg = label_trg.clone()
elif self.dataset == 'RaFD':
c_org = self.label2onehot(label_org, self.c_dim)
c_trg = self.label2onehot(label_trg, self.c_dim)
x_real = x_real.to(self.device) # Input images.
c_org = c_org.to(self.device) # Original domain labels.
c_trg = c_trg.to(self.device) # Target domain labels.
label_org = label_org.to(self.device) # Labels for computing classification loss.
label_trg = label_trg.to(self.device) # Labels for computing classification loss.
# Compute loss with real images.
out_src, out_cls = self.D(x_real)#判别器以一个batch(16张)的真实图片为输入,输出out_src[16, 1, 2, 2],用来判断图片真假。out_cls[16, 5],得到图片的标签估计。
d_loss_real = - torch.mean(out_src)#判定越接近为真,损失越小
d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)#衡量真实标签与标签估计的差距
# Compute loss with fake images.
x_fake = self.G(x_real, c_trg)#输入一个batch的真实图片和目标标签,生成假的图
out_src, out_cls = self.D(x_fake.detach())
d_loss_fake = torch.mean(out_src)#判定越接近为假,损失越小
# Compute loss for gradient penalty.
alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
out_src, _ = self.D(x_hat)
d_loss_gp = self.gradient_penalty(out_src, x_hat)
# Backward and optimize.
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
这其中,判别器的结构如下:
class Discriminator(nn.Module):
"""Discriminator network with PatchGAN."""
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
super(Discriminator, self).__init__()
layers = []
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = conv_dim
for i in range(1, repeat_num):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = curr_dim * 2
kernel_size = int(image_size / np.power(2, repeat_num))
self.main = nn.Sequential(*layers)
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)
def forward(self, x):
h = self.main(x)
out_src = self.conv1(h)
out_cls = self.conv2(h)
# pdb.set_trace()
return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
生成器结构如下:
class Generator(nn.Module):
"""Generator network."""
def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
super(Generator, self).__init__()
layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
# Down-sampling layers.
curr_dim = conv_dim
for i in range(2):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim * 2
# Bottleneck layers.
for i in range(repeat_num):
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
# Up-sampling layers.
for i in range(2):
layers.append(nn.ConvTranspose2d(curr_dim, curr_dim/2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim/2, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim / 2
layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.Tanh())
self.main = nn.Sequential(*layers)
def forward(self, x, c):
# Replicate spatially and concatenate domain information.
#输入的x尺寸是[16,3,128,128],c尺寸是[16,5]
c = c.view(c.size(0), c.size(1), 1, 1)#[16,5,1,1]
c = c.repeat(1, 1, x.size(2), x.size(3))#[16, 5, 128, 128],对应标签为0时对应[128,128]全0,
x = torch.cat([x, c], dim=1)#[16,8,128,128]
return self.main(x)
if (i+1) % self.n_critic == 0:#每更新5次判别器再更新一次生成器
# Original-to-target domain.
x_fake = self.G(x_real, c_trg)#输入一个batch的真实图片和目标标签,生成假的图
out_src, out_cls = self.D(x_fake)#得到假图的判别概率和估计标签
g_loss_fake = - torch.mean(out_src)#估计标签越接近为真,损失越小
g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)#估计标签越接近目标标签损失越小
# Target-to-original domain.
x_reconst = self.G(x_fake, c_org)#输入假图和原始标签,重建假图对应的原图
g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))#得到的重建图越像原图损失越小
# Backward and optimize.
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
def test(self):
"""Translate images using StarGAN trained on a single dataset."""
# Load the trained generator.
self.restore_model(self.test_iters)
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
with torch.no_grad():
for i, (x_real, c_org) in enumerate(data_loader):
# Prepare input images and target domain labels.
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
# Translate images.
x_fake_list = [x_real]
for c_trg in c_trg_list:
x_fake_list.append(self.G(x_real, c_trg))
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(result_path))
标签追加一个mask
if dataset == 'CelebA':
c_org = label_org.clone()
c_trg = label_trg.clone()
zero = torch.zeros(x_real.size(0), self.c2_dim)
mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
c_org = torch.cat([c_org, zero, mask], dim=1)
c_trg = torch.cat([c_trg, zero, mask], dim=1)
elif dataset == 'RaFD':
c_org = self.label2onehot(label_org, self.c2_dim)
c_trg = self.label2onehot(label_trg, self.c2_dim)
zero = torch.zeros(x_real.size(0), self.c_dim)
mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
c_org = torch.cat([zero, c_org, mask], dim=1)
c_trg = torch.cat([zero, c_trg, mask], dim=1)