论文地址:https://ieeexplore.ieee.org/document/8579014/
论文代码(pytorch):https://github.com/yunjey/stargan
论文翻译与解读:https://blog.csdn.net/m0_61985580/article/details/125766783?spm=1001.2014.3001.5501
1、此文仅作为学习笔记,注释可能会有一些偏差,如有注释错误欢迎留言更正。
2、如有使用此文,请标注出处。
首先是定义一个CelebA的一个类。
在CelebA里面包含def preprocess(self)、def getitem(self, index)、def len(self)、
def get_loader()等函数
def preprocess(self):
"""Preprocess the CelebA attribute file.预处理 CelebA 属性文件"""
# Python rstrip() 删除 string 字符串末尾的指定字符(默认为空格)
lines = [line.rstrip() for line in open(self.attr_path, 'r')] # 去掉路劲中的空格换行等 # txt文件是一行一行读取
all_attr_names = lines[1].split() # splot()通过指定分隔符对字符串进行切片
# str.split(str="", num=string.count(str)). 通过指定分隔符对字符串进行切片,如果参数 num 有指定值,则分隔 num+1 个子字符串
# str -- 分隔符,默认为所有的空字符,包括空格、换行(\n)、制表符(\t)等。num -- 分割次数。默认为 -1, 即分隔所有
# 返回分割后的字符串列表。
for i, attr_name in enumerate(all_attr_names):
self.attr2idx[attr_name] = i # 属性类别
self.idx2attr[i] = attr_name # 类别到属性
lines = lines[2:]
random.seed(1234)
random.shuffle(lines) # 打乱切片
for i, line in enumerate(lines):
split = line.split()
filename = split[0] # 图片名
values = split[1:] # 图片队形的标签
label = []
for attr_name in self.selected_attrs: # 创建训练选用的任务类别和索引的一一对应
idx = self.attr2idx[attr_name] # 得到索引
label.append(values[idx] == '1') # label如果是1则还是1,为-1是换成0
if (i+1) < 2000: # 取2000张作为测试集数据
self.test_dataset.append([filename, label]) # 把名和标签放进test_dataset
else:
self.train_dataset.append([filename, label]) # 把名和标签放进train_dataset
print('Finished preprocessing the CelebA dataset...')
def __getitem__(self, index):
"""Return one image and its corresponding attribute label.返回一张图片及其对应的属性标签"""
dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
filename, label = dataset[index]
image = Image.open(os.path.join(self.image_dir, filename))
return self.transform(image), torch.FloatTensor(label)
def __len__(self):
"""Return the number of images."""
return self.num_images
构建并返回数据加载器,对数据进行水平翻转,裁剪更改图片大小等操作
def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128,
batch_size=16, dataset='CelebA', mode='train', num_workers=1):
"""Build and return a data loader."""
transform = []
if mode == 'train':
transform.append(T.RandomHorizontalFlip()) # 数据随机水平翻转
transform.append(T.CenterCrop(crop_size)) # 从中间裁剪
transform.append(T.Resize(image_size)) # 更改图片大小
transform.append(T.ToTensor())
transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) # 正则化
transform = T.Compose(transform)
if dataset == 'CelebA': # 选择CelebA或者是RaFD
dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
elif dataset == 'RaFD':
dataset = ImageFolder(image_dir, transform)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=(mode=='train'),
num_workers=num_workers)
return data_loader
这个主要是用来加载TensorBord
object为加载对象
def __init__(self, log_dir):
"""Initialize summary writer."""
self.writer = tf.summary.FileWriter(log_dir)
def scalar_summary(self, tag, value, step):
"""Add scalar summary."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer. add_summary(summary, step)
这个主要是用来调参,设置配置参数
def str2bool(v):
return v.lower() in ('true')
def main(config):
# For fast training.用于快速训练。
cudnn.benchmark = True
# Create directories if not exist.如果不存在则创建目录。
if not os.path.exists(config.log_dir):
os.makedirs(config.log_dir)
if not os.path.exists(config.model_save_dir):
os.makedirs(config.model_save_dir)
if not os.path.exists(config.sample_dir):
os.makedirs(config.sample_dir)
if not os.path.exists(config.result_dir):
os.makedirs(config.result_dir)
# Data loader.数据加载器。
celeba_loader = None
rafd_loader = None
if config.dataset in ['CelebA', 'Both']:
celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
config.celeba_crop_size, config.image_size, config.batch_size,
'CelebA', config.mode, config.num_workers)
if config.dataset in ['RaFD', 'Both']:
rafd_loader = get_loader(config.rafd_image_dir, None, None,
config.rafd_crop_size, config.image_size, config.batch_size,
'RaFD', config.mode, config.num_workers)
# Solver for training and testing StarGAN.用于训练和测试 StarGAN 的求解器
solver = Solver(celeba_loader, rafd_loader, config)
if config.mode == 'train':
if config.dataset in ['CelebA', 'RaFD']:
solver.train()
elif config.dataset in ['Both']:
solver.train_multi()
elif config.mode == 'test':
if config.dataset in ['CelebA', 'RaFD']:
solver.test()
elif config.dataset in ['Both']:
solver.test_multi()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Model configuration.模型配置参数
parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)')
parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)')
parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset')
parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset')
parser.add_argument('--image_size', type=int, default=128, help='image resolution')
parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
# Training configuration.训练的配置参数
parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both'])
parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size')
parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
# 表示读取的命令行参数的个数, ‘+’表示读取一个或多个, ‘*’表示0个或多个
parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
# Test configuration.测试配置参数
parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')
# Miscellaneous.
parser.add_argument('--num_workers', type=int, default=1)
parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
parser.add_argument('--use_tensorboard', type=str2bool, default=True)
# Directories.其他的
parser.add_argument('--celeba_image_dir', type=str, default='data/celeba/images')
parser.add_argument('--attr_path', type=str, default='data/celeba/list_attr_celeba.txt')
parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train')
parser.add_argument('--log_dir', type=str, default='stargan/logs')
parser.add_argument('--model_save_dir', type=str, default='stargan/models')
parser.add_argument('--sample_dir', type=str, default='stargan/samples')
parser.add_argument('--result_dir', type=str, default='stargan/results')
# Step size.
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=1000)
parser.add_argument('--model_save_step', type=int, default=10000)
parser.add_argument('--lr_update_step', type=int, default=1000)
config = parser.parse_args()
print(config)
main(config)
这个文件主要是生成器与鉴别器的网络结构以及两者的具体参数
这个是残差块的定义
class ResidualBlock(nn.Module):
"""Residual Block with instance normalization."""
def __init__(self, dim_in, dim_out):
super(ResidualBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))
def forward(self, x):
return x + self.main(x)
生成器用的是cycleGAN里面的生成器参数。
class Generator(nn.Module):
"""Generator network."""
def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
super(Generator, self).__init__()
"""第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度"""
layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
# Down-sampling layers.
curr_dim = conv_dim
for i in range(2):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim * 2
# Bottleneck layers.
for i in range(repeat_num):
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
# Up-sampling layers.
for i in range(2):
layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim // 2
layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.Tanh())
self.main = nn.Sequential(*layers)
def forward(self, x, c):
# Replicate spatially and concatenate domain information.在空间上复制并连接域信息
# Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
# This is because instance normalization ignores the shifting (or bias) effect.
# 在generator的forward时, 把c扩展到四个维度(记为c_expand), 第3 4维度值和x一样
c = c.view(c.size(0), c.size(1), 1, 1) # view 相当于Numpy中的reshape
c = c.repeat(1, 1, x.size(2), x.size(3)) # 沿着指定的维度重复tensor
x = torch.cat([x, c], dim=1) # 将输入图像x,label向量c,串联
return self.main(x)
class Discriminator(nn.Module):
"""Discriminator network with PatchGAN."""
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
super(Discriminator, self).__init__()
layers = []
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = conv_dim
for i in range(1, repeat_num):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01))
curr_dim = curr_dim * 2
kernel_size = int(image_size / np.power(2, repeat_num))
self.main = nn.Sequential(*layers) # 将层加入到神经网络
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) # D判读图像的真假
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False) # 判别输入图像的label.
def forward(self, x):
h = self.main(x) # 这里的X表示训练时的图像,经过main()后生成2048维数据
out_src = self.conv1(h) # out_src 表示图像的真假
out_cls = self.conv2(h) # out_cls 表示图像的标签
return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))
“”“Solver for training and testing StarGAN.”“”
对这些配置参数进行实例化,具体参数可以看main.py
def __init__(self, celeba_loader, rafd_loader, config):
"""Initialize configurations."""
# Data loader.数据加载
self.celeba_loader = celeba_loader
self.rafd_loader = rafd_loader
# Model configurations.模型配置
self.c_dim = config.c_dim
self.c2_dim = config.c2_dim
self.image_size = config.image_size
self.g_conv_dim = config.g_conv_dim
self.d_conv_dim = config.d_conv_dim
self.g_repeat_num = config.g_repeat_num
self.d_repeat_num = config.d_repeat_num
self.lambda_cls = config.lambda_cls
self.lambda_rec = config.lambda_rec
self.lambda_gp = config.lambda_gp
# Training configurations.训练配置
self.dataset = config.dataset
self.batch_size = config.batch_size
self.num_iters = config.num_iters
self.num_iters_decay = config.num_iters_decay
self.g_lr = config.g_lr
self.d_lr = config.d_lr
self.n_critic = config.n_critic
self.beta1 = config.beta1
self.beta2 = config.beta2
self.resume_iters = config.resume_iters
self.selected_attrs = config.selected_attrs
# Test configurations.测试配置
self.test_iters = config.test_iters
# Miscellaneous.其他的
self.use_tensorboard = config.use_tensorboard
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Directories.
self.log_dir = config.log_dir
self.sample_dir = config.sample_dir
self.model_save_dir = config.model_save_dir
self.result_dir = config.result_dir
# Step size.
self.log_step = config.log_step
self.sample_step = config.sample_step
self.model_save_step = config.model_save_step
self.lr_update_step = config.lr_update_step
# Build the model and tensorboard.
self.build_model()
if self.use_tensorboard:
self.build_tensorboard()
def build_model(self):
"""Create a generator and a discriminator."""
if self.dataset in ['CelebA', 'RaFD']:
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
elif self.dataset in ['Both']:
self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector.
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)
self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
self.print_network(self.G, 'G') # 把生成器打印到屏幕上
self.print_network(self.D, 'D') # 把判别器打印到屏幕上
self.G.to(self.device)
self.D.to(self.device)
def print_network(self, model, name):
"""Print out the network information."""
num_params = 0
for p in model.parameters():
num_params += p.numel()
print(model)
print(name)
print("The number of parameters: {}".format(num_params))
def restore_model(self, resume_iters):
"""恢复Restore the trained generator and discriminator."""
print('Loading the trained models from step {}...'.format(resume_iters))
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
# If :attr:`strict` is ``True``,那么state_dict 的键必须与返回的键完全匹配。
# 将参数和缓冲区从 :attr:`state_dict` 复制到这个模块及其后代
def build_tensorboard(self):
"""Build a tensorboard logger."""
from logger import Logger
self.logger = Logger(self.log_dir)
def update_lr(self, g_lr, d_lr):
"""Decay learning rates of the generator and discriminator.生成器和判别器的衰减学习率"""
for param_group in self.g_optimizer.param_groups:
param_group['lr'] = g_lr
for param_group in self.d_optimizer.param_groups:
param_group['lr'] = d_lr
def reset_grad(self):
"""Reset the gradient buffers.重置梯度缓冲区"""
self.g_optimizer.zero_grad()
self.d_optimizer.zero_grad()
def denorm(self, x):
"""Convert the range from [-1, 1] to [0, 1].将范围从 [-1, 1] 转换为 [0, 1]。"""
out = (x + 1) / 2
return out.clamp_(0, 1)
def gradient_penalty(self, y, x):
"""Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
weight = torch.ones(y.size()).to(self.device)
dydx = torch.autograd.grad(outputs=y,
inputs=x,
grad_outputs=weight,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
dydx = dydx.view(dydx.size(0), -1)
dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
return torch.mean((dydx_l2norm-1)**2)
def label2onehot(self, labels, dim):
"""Convert label indices to one-hot vectors.将标签索引转换为one-hot向量"""
batch_size = labels.size(0)
out = torch.zeros(batch_size, dim)
out[np.arange(batch_size), labels.long()] = 1 # long() 函数将数字或字符串转换为一个长整型。
return out
def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
"""Generate target domain labels for debugging and testing."""
# Get hair color indices.获取头发颜色指数
if dataset == 'CelebA':
hair_color_indices = []
for i, attr_name in enumerate(selected_attrs):
if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
hair_color_indices.append(i)
c_trg_list = []
for i in range(c_dim):
if dataset == 'CelebA':
c_trg = c_org.clone()
if i in hair_color_indices: # 将一种头发颜色设置为 1,其余设置为 0。
c_trg[:, i] = 1
for j in hair_color_indices:
if j != i:
c_trg[:, j] = 0
else:
c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value.反转属性值
elif dataset == 'RaFD':
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
c_trg_list.append(c_trg.to(self.device))
return c_trg_list
def classification_loss(self, logit, target, dataset='CelebA'):
"""Compute binary or softmax cross entropy loss.""" # 分类loss并不都是交叉熵损失
if dataset == 'CelebA': # CelebA的标签是多属性的,不是一个onehot,所以使用了一个多个二分类的形式
return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
elif dataset == 'RaFD': # RaFD则是一个onehot
return F.cross_entropy(logit, target)
def train(self):
"""Train StarGAN within a single dataset.在单个数据集中训练 StarGAN。"""
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
# Fetch fixed inputs for debugging.
data_iter = iter(data_loader)
x_fixed, c_org = next(data_iter) # x_fixed表示图像像素值 c_org表示真实标签值tensor([[ 1., 0., 0., 1., 1.]])
x_fixed = x_fixed.to(self.device)
c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
# print(c_fixed_list)
# [tensor([[ 1., 0., 0., 1., 1.]]), tensor([[ 0., 1., 0., 1., 1.]]), tensor([[ 0., 0., 1., 1., 1.]]),
# tensor([[ 1., 0., 0., 0., 1.]]), tensor([[ 1., 0., 0., 1., 0.]])]
# Learning rate cache for decaying.
# Learning rate cache for decaying.
g_lr = self.g_lr # 生成器的学习率
d_lr = self.d_lr # 鉴别器的学习率
# Start training from scratch or resume training.从头开始训练或恢复训练
start_iters = 0
if self.resume_iters: # 参数resume_iters设置为none
start_iters = self.resume_iters # 可以不连续训练,从之前训练好后的结果处开始
self.restore_model(self.resume_iters)
# Start training.
print('Start training...')
start_time = time.time()
for i in range(start_iters, self.num_iters):
# =================================================================================== #
# 1. Preprocess input data预处理输入数据 #
# =================================================================================== #
# Fetch real images and labels.获取真实图像和标签
try:
x_real, label_org = next(data_iter)
except:
data_iter = iter(data_loader)
x_real, label_org = next(data_iter)
# Generate target domain labels randomly.随机生成目标域标签
rand_idx = torch.randperm(label_org.size(0)) # tensor([ 0])
label_trg = label_org[rand_idx] # tensor([[ 1., 0., 0., 1., 1.]]) 真实label,从数据中取出
if self.dataset == 'CelebA':
c_org = label_org.clone()
c_trg = label_trg.clone()
elif self.dataset == 'RaFD':
c_org = self.label2onehot(label_org, self.c_dim)
c_trg = self.label2onehot(label_trg, self.c_dim)
x_real = x_real.to(self.device) # Input images.输入图像
c_org = c_org.to(self.device) # Original domain labels.原始域标签
# print(c_org) tensor([[ 1., 0., 0., 1., 1.]]
c_trg = c_trg.to(self.device) # Target domain labels.目标域标签
# print(c_trg) tensor([[ 1., 0., 0., 1., 1.]]
label_org = label_org.to(self.device) # Labels for computing classification loss.计算分类损失的标签
label_trg = label_trg.to(self.device) # Labels for computing classification loss.计算分类损失的标签
# =================================================================================== #
# 2. Train the discriminator训练判别器 #
# =================================================================================== #
# 判别器以一个batch(16张)的真实图片为输入,输出out_src[16, 1, 2, 2]和用来判断图片真假的out_cls[16, 5],得到图片的标签估计。
# Compute loss with real images.用真实图像计算损失
out_src, out_cls = self.D(x_real) # out_src 表示图像的真假 # out_cls 表示图像的标签
d_loss_real = - torch.mean(out_src) # 判定越接近为真,损失越小 # d_loss_real最小,那么 out_src 最大==1 (针对图像)
# d_loss_real = tensor(1.00000e-04 * 3.8965)
d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset) # 衡量真实标签与标签估计
# d_loss_cls = tensor(3.4666)
# Compute loss with fake images.用假图像计算损失
# 将真实图像输入x_real和假的标签c_trg输入生成网络,得到生成图像x_fake,
x_fake = self.G(x_real, c_trg) # 输入一个batch的真实图片和目标标签,生成假的图
out_src, out_cls = self.D(x_fake.detach()) # 梯度截断//
d_loss_fake = torch.mean(out_src) # 判定越接近为假,损失越小 # tensor(1.00000e-05 *-1.0045)
"""
out_src
tensor(1.00000e-03 *
[[[[-1.5289, 0.8110],
[ 0.2153, 0.4624]]]])
out_cls
tensor(1.00000e-03 *
[[ 1.4681, 1.9497, 1.2743, -1.1915, 0.7609]])
"""
# Compute loss for gradient penalty.计算梯度惩罚的损失
# 计算梯度惩罚因子alpha,根据alpha结合x_real,x_fake,输入判别网络,计算梯度,得到梯度损失函数,
# alpha是一个随机数 tensor([[[[ 0.7610]]]])
alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
# x_hat是一个图像大小的张量数据,随着alpha的改变而变化
out_src, _ = self.D(x_hat) # x_hat 表示梯度惩罚因子
d_loss_gp = self.gradient_penalty(out_src, x_hat) # 最终d_loss_gp 在0.9954~ 0.9956 波动
# Backward and optimize.向后并优化
# 损失包含4项:
# 1.真实图像判定为真
# 2.真实图像+错误标签记过G网络生成的图像判定为假
# 3.真实图像经过D网络的生成的标签与真实标签之间的差异损失
# 4.真实图像和 真实图像+错误标签记过G网络生成的图像 融合的梯度惩罚因子
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
# Logging.记录
loss = {}
loss['D/loss_real'] = d_loss_real.item()
loss['D/loss_fake'] = d_loss_fake.item()
loss['D/loss_cls'] = d_loss_cls.item()
loss['D/loss_gp'] = d_loss_gp.item()
# =================================================================================== #
# 3. Train the generator训练生成器 #
# =================================================================================== #
# 生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像(重建)
if (i+1) % self.n_critic == 0: # 每更新5次判别器再更新一次生成器
# Original-to-target domain.原始到目标域
x_fake = self.G(x_real, c_trg) # 输入一个batch的真实图片和目标标签,生成假的图片
out_src, out_cls = self.D(x_fake) # 得到假图的判别概率和估计标签
g_loss_fake = - torch.mean(out_src) # 估计标签越接近为真,损失越小。#这里是对抗损失,希望生成的假图像为1
g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset) # 估计越标签越接近目标标签,损失越小
# Target-to-original domain.目标到原始域
x_reconst = self.G(x_fake, c_org) # 输入假图和原始标签,重建假图对应的原图
g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) # 重建损失--得到的重建图越像原图,损失越小
# Backward and optimize.向后并优化
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls # 计算生成器的损失值
self.reset_grad() # 梯度清零
g_loss.backward() # 将损失值返回
self.g_optimizer.step() # 优化
# Logging.记录
loss['G/loss_fake'] = g_loss_fake.item()
loss['G/loss_rec'] = g_loss_rec.item()
loss['G/loss_cls'] = g_loss_cls.item()
# =================================================================================== #
# 4. Miscellaneous其他的 #
# =================================================================================== #
# Print out training information.打印训练信息
if (i+1) % self.log_step == 0: # 每10次更新一次
et = time.time() - start_time # 所需的时间
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
for tag, value in loss.items():
log += ", {}: {:.4f}".format(tag, value)
print(log)
if self.use_tensorboard:
for tag, value in loss.items():
self.logger.scalar_summary(tag, value, i+1)
# Translate fixed images for debugging.翻译固定图像以进行调试,用来存效果图的代码
if (i+1) % self.sample_step == 0: # 1000
with torch.no_grad(): # x_fixed表示图像像素值
x_fake_list = [x_fixed] # x_fixed放到x_fake_list里面
for c_fixed in c_fixed_list: # 遍历c_fixed_list
x_fake_list.append(self.G(x_fixed, c_fixed)) # c_fixed标签
x_concat = torch.cat(x_fake_list, dim=3)
sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(sample_path))
# Save model checkpoints.保存模型checkpoints
if (i+1) % self.model_save_step == 0: # 迭代10000保存1次
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) # 迭代10000保存一次G_path权重
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) # 迭代10000保存一次D_path权重
torch.save(self.G.state_dict(), G_path) # 保存G_path
torch.save(self.D.state_dict(), D_path) # 保存D_path
print('Saved model checkpoints into {}...'.format(self.model_save_dir))
# Decay learning rates.衰减学习率
if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
g_lr -= (self.g_lr / float(self.num_iters_decay)) # num_iters_decay——衰减 lr 的迭代次数
d_lr -= (self.d_lr / float(self.num_iters_decay)) # num_iters___训练D的总迭代次数
self.update_lr(g_lr, d_lr) # 更新学习率
print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) # 将学习率打印
“”“Train StarGAN with multiple datasets.使用多个数据集训练 StarGAN”“”
def train_multi(self):
"""Train StarGAN with multiple datasets.使用多个数据集训练 StarGAN"""
# Data iterators.数据迭代器。
celeba_iter = iter(self.celeba_loader) # celeba数据集迭代
rafd_iter = iter(self.rafd_loader) # rafd数据集迭代
# Fetch fixed inputs for debugging.获取固定输入以进行调试
x_fixed, c_org = next(celeba_iter) # next() 返回迭代器的下一个项目。
x_fixed = x_fixed.to(self.device)
c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) # celeba的标签列表
c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') # rafd的标签列表
zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. CelebA 的零向量
zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. RaFD 的零向量
mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0].
mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1].
# Learning rate cache for decaying.用于衰减的学习率缓存
g_lr = self.g_lr
d_lr = self.d_lr
# Start training from scratch or resume training.从头开始训练或恢复训练
start_iters = 0
if self.resume_iters:
start_iters = self.resume_iters
self.restore_model(self.resume_iters)
# Start training.开始训练
print('Start training...')
start_time = time.time()
for i in range(start_iters, self.num_iters):
for dataset in ['CelebA', 'RaFD']:
# =================================================================================== #
# 1. Preprocess input data预处理输入数据 #
# =================================================================================== #
# Fetch real images and labels.获取真实图像和标签。
data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter
try:
x_real, label_org = next(data_iter)
except:
if dataset == 'CelebA':
celeba_iter = iter(self.celeba_loader)
x_real, label_org = next(celeba_iter) # celeba 数据迭代 x_real是真实图像,label_org是对应的标签
elif dataset == 'RaFD':
rafd_iter = iter(self.rafd_loader)
x_real, label_org = next(rafd_iter) # rafd 数据迭代 x_real是真实图像,label_org是对应的标签
# Generate target domain labels randomly.随机生成目标域标签。
rand_idx = torch.randperm(label_org.size(0)) # torch.randperm(n):将0~n-1(包括0和n-1)随机打乱后获得的数字序列
label_trg = label_org[rand_idx] # 目标域标签
# 标签追加一个mask
# 在多数据集训练时,我们需要mask向量,mask向量的形成按如下形式进行拼接,前面是celebA的label后面是RaFD的label,最后是onehot,代表了哪个数据集的标签是已知的。
if dataset == 'CelebA':
c_org = label_org.clone() # 将label_org复制一份给c_org
c_trg = label_trg.clone() # 将目标域标签label_trg复制一份给c_trg
zero = torch.zeros(x_real.size(0), self.c2_dim)
mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
c_org = torch.cat([c_org, zero, mask], dim=1)
c_trg = torch.cat([c_trg, zero, mask], dim=1)
elif dataset == 'RaFD':
c_org = self.label2onehot(label_org, self.c2_dim)
c_trg = self.label2onehot(label_trg, self.c2_dim)
zero = torch.zeros(x_real.size(0), self.c_dim)
mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
c_org = torch.cat([zero, c_org, mask], dim=1)
c_trg = torch.cat([zero, c_trg, mask], dim=1)
x_real = x_real.to(self.device) # Input images.输入图像
c_org = c_org.to(self.device) # Original domain labels.原始域标签
c_trg = c_trg.to(self.device) # Target domain labels.目标域标签
label_org = label_org.to(self.device) # Labels for computing classification loss.计算分类损失的标签
label_trg = label_trg.to(self.device) # Labels for computing classification loss.计算分类损失的标签
# =================================================================================== #
# 2. Train the discriminator #
# =================================================================================== #
# Compute loss with real images.用真图像计算损失
out_src, out_cls = self.D(x_real)
out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] # 这行不太懂
d_loss_real = - torch.mean(out_src)
d_loss_cls = self.classification_loss(out_cls, label_org, dataset)
# Compute loss with fake images.用假图像计算损失
x_fake = self.G(x_real, c_trg) # 将真实图像和目标域标签传入G生成X_fake
out_src, _ = self.D(x_fake.detach()) # 梯度截断
d_loss_fake = torch.mean(out_src) # 均值
# Compute loss for gradient penalty.计算梯度惩罚的损失。
alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) # alpha是一个随机数
x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
# x_hat是一个图像大小的张量数据,随着alpha的改变而变化
out_src, _ = self.D(x_hat) # x_hat 表示梯度惩罚因子
d_loss_gp = self.gradient_penalty(out_src, x_hat)
# Backward and optimize.
# 损失包含4项:
# 1.真实图像判定为真
# 2.真实图像+错误标签记过G网络生成的图像判定为假
# 3.真实图像经过D网络的生成的标签与真实标签之间的差异损失
# 4.真实图像和 真实图像+错误标签记过G网络生成的图像 融合的梯度惩罚因子
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
# Logging.
loss = {}
loss['D/loss_real'] = d_loss_real.item()
loss['D/loss_fake'] = d_loss_fake.item()
loss['D/loss_cls'] = d_loss_cls.item()
loss['D/loss_gp'] = d_loss_gp.item()
# =================================================================================== #
# 3. Train the generator #
# =================================================================================== #
# 生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像(重建)
if (i+1) % self.n_critic == 0: # 每更新5次判别器再更新一次生成器
# Original-to-target domain.原始到目标域
x_fake = self.G(x_real, c_trg) # 输入一个batch的真实图片和目标标签,生成假的图片
out_src, out_cls = self.D(x_fake) # 得到假图的判别概率和估计标签
out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
g_loss_fake = - torch.mean(out_src) # 估计标签越接近为真,损失越小。#这里是对抗损失,希望生成的假图像为1
g_loss_cls = self.classification_loss(out_cls, label_trg, dataset)
# Target-to-original domain.
x_reconst = self.G(x_fake, c_org) # 输入假图和原始标签,重建假图对应的原图
g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) # 重建损失--得到的重建图越像原图,损失越小
# Backward and optimize.
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
# Logging.
loss['G/loss_fake'] = g_loss_fake.item()
loss['G/loss_rec'] = g_loss_rec.item()
loss['G/loss_cls'] = g_loss_cls.item()
# =================================================================================== #
# 4. Miscellaneous #
# =================================================================================== #
# Print out training info.打印训练信息
if (i+1) % self.log_step == 0:
et = time.time() - start_time
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset)
for tag, value in loss.items():
log += ", {}: {:.4f}".format(tag, value)
print(log)
if self.use_tensorboard:
for tag, value in loss.items():
self.logger.scalar_summary(tag, value, i+1)
# Translate fixed images for debugging.翻译固定图像以进行调试
if (i+1) % self.sample_step == 0:
with torch.no_grad():
x_fake_list = [x_fixed]
for c_fixed in c_celeba_list: # 遍历celeba标签列表
c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1)
x_fake_list.append(self.G(x_fixed, c_trg))
for c_fixed in c_rafd_list:
c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1)
x_fake_list.append(self.G(x_fixed, c_trg))
x_concat = torch.cat(x_fake_list, dim=3)
sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(sample_path))
# Save model checkpoints. 保存模型权重
if (i+1) % self.model_save_step == 0:
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
torch.save(self.G.state_dict(), G_path)
torch.save(self.D.state_dict(), D_path)
print('Saved model checkpoints into {}...'.format(self.model_save_dir))
# Decay learning rates.学习率的衰减
if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
g_lr -= (self.g_lr / float(self.num_iters_decay))
d_lr -= (self.d_lr / float(self.num_iters_decay))
self.update_lr(g_lr, d_lr)
print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
def test(self):
"""Translate images using StarGAN trained on a single dataset.""" # 使用在单个数据集上训练的 StarGAN 翻译图像。
# Load the trained generator.加载训练好的生成器
self.restore_model(self.test_iters)
# Set data loader.数据加载
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
with torch.no_grad():
for i, (x_real, c_org) in enumerate(data_loader):
# Prepare input images and target domain labels.准备输入图像和目标域标签
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) # 目标标签
# Translate images.翻译图像
x_fake_list = [x_real]
for c_trg in c_trg_list:
x_fake_list.append(self.G(x_real, c_trg)) # 生成的假图像存在x_fake_list
# Save the translated images.保存翻译的图像
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(result_path))
def test_multi(self):
"""Translate images using StarGAN trained on multiple datasets.使用在多个数据集上训练的 StarGAN 翻译图像"""
# Load the trained generator.加载训练好的生成器
self.restore_model(self.test_iters)
with torch.no_grad():
for i, (x_real, c_org) in enumerate(self.celeba_loader):
# Prepare input images and target domain labels.准备输入图像和目标域标签
x_real = x_real.to(self.device)
c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA
zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0].
mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1].
# Translate images.
x_fake_list = [x_real]
for c_celeba in c_celeba_list:
c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
x_fake_list.append(self.G(x_real, c_trg))
for c_rafd in c_rafd_list:
c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
x_fake_list.append(self.G(x_real, c_trg))
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(result_path))