文章中提到用psp可以实现多种应用。 参考【论文解析】Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation.
项目中,包scriptes
包含了这些应用对应的运行文件。
上一次,我们在【源码解析】Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation 中已经走通了inference.py
的整个流程,为新的测试图像进行了风格转换。
这次主要了解train.py
。
运行的所有参数通过TrainOptions
类进行传递。这个在train_options.py
中实现。
具体内容如下: 整体来说,还是比较好理解。
self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str, help='Type of dataset/experiment to run')
self.parser.add_argument('--encoder_type', default='GradualStyleEncoder', type=str, help='Which encoder to use') # 编码类型
self.parser.add_argument('--input_nc', default=3, type=int, help='Number of input image channels to the psp encoder') # 编码器对应的输入图像通道。
self.parser.add_argument('--label_nc', default=0, type=int, help='Number of input label channels to the psp encoder') # 输入标签通道。 标签是什么?
self.parser.add_argument('--output_size', default=1024, type=int, help='Output size of generator') # 生成器的输出规模
self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')
self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate')
self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') # 优化起选择
self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model') # 是否训练解码器
self.parser.add_argument('--start_from_latent_avg', action='store_true', help='Whether to add average latent vector to generate codes from encoder.')
# 是否加入平均潜在向量从编码器生成编码。
self.parser.add_argument('--learn_in_w', action='store_true', help='Whether to learn in w space instead of w+') # 选择学习空间
# 关于损失的设定
self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor')
self.parser.add_argument('--id_lambda', default=0, type=float, help='ID loss multiplier factor')
self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
self.parser.add_argument('--w_norm_lambda', default=0, type=float, help='W-norm loss multiplier factor')
self.parser.add_argument('--lpips_lambda_crop', default=0, type=float, help='LPIPS loss multiplier factor for inner image region')
self.parser.add_argument('--l2_lambda_crop', default=0, type=float, help='L2 loss multiplier factor for inner image region')
self.parser.add_argument('--moco_lambda', default=0, type=float, help='Moco-based feature similarity loss multiplier factor')
self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str, help='Path to StyleGAN model weights') # stylegan网络的权重路劲
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') # 模型路径
self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps') # 最多训练多少步
self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training') # 记录图像的间隔
self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard')
self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval') # 验证间隔
self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval') # 保存checkpoint的间隔。
# arguments for weights & biases support
self.parser.add_argument('--use_wandb', action="store_true", help='Whether to use Weights & Biases to track experiment.')
# 是否使用权重和偏差来跟踪实验
# arguments for super-resolution
self.parser.add_argument('--resize_factors', type=str, default=None, help='For super-res, comma-separated resize factors to use for inference.')
train.py 的主要内容
def main():
opts = TrainOptions().parse()
if os.path.exists(opts.exp_dir):
raise Exception('Oops... {} already exists'.format(opts.exp_dir))
os.makedirs(opts.exp_dir) # 创建实验目录
opts_dict = vars(opts)
pprint.pprint(opts_dict)
with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
json.dump(opts_dict, f, indent=4, sort_keys=True) #把参数信息保存下来
coach = Coach(opts)
coach.train()
init函数
def __init__(self, opts):
self.opts = opts
self.global_step = 0
self.device = 'cuda:0' # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES
self.opts.device = self.device
if self.opts.use_wandb:
from utils.wandb_utils import WBLogger
self.wb_logger = WBLogger(self.opts)
# Initialize network
self.net = pSp(self.opts).to(self.device)
# Estimate latent_avg via dense sampling if latent_avg is not available
# 如果没有潜在平均值,则通过密集抽样估计潜在平均值
if self.net.latent_avg is None:
self.net.latent_avg = self.net.decoder.mean_latent(int(1e5))[0].detach()
# Initialize loss
if self.opts.id_lambda > 0 and self.opts.moco_lambda > 0:
raise ValueError('Both ID and MoCo loss have lambdas > 0! Please select only one to have non-zero lambda!')
# moco是什么, moco_based 特征相似性损失
self.mse_loss = nn.MSELoss().to(self.device).eval()
if self.opts.lpips_lambda > 0: # 可以通过考虑对应的lambda是否大于0来控制是否要采用这个算损失。
self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
if self.opts.id_lambda > 0:
self.id_loss = id_loss.IDLoss().to(self.device).eval()
if self.opts.w_norm_lambda > 0:
self.w_norm_loss = w_norm.WNormLoss(start_from_latent_avg=self.opts.start_from_latent_avg)
if self.opts.moco_lambda > 0:
self.moco_loss = moco_loss.MocoLoss().to(self.device).eval()
# Initialize optimizer
self.optimizer = self.configure_optimizers()
# Initialize dataset
self.train_dataset, self.test_dataset = self.configure_datasets()
self.train_dataloader = DataLoader(self.train_dataset,
batch_size=self.opts.batch_size,
shuffle=True,
num_workers=int(self.opts.workers),
drop_last=True)
self.test_dataloader = DataLoader(self.test_dataset,
batch_size=self.opts.test_batch_size,
shuffle=False,
num_workers=int(self.opts.test_workers),
drop_last=True)
# Initialize logger
log_dir = os.path.join(opts.exp_dir, 'logs')
os.makedirs(log_dir, exist_ok=True)
self.logger = SummaryWriter(log_dir=log_dir) #
# Initialize checkpoint dir
self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
os.makedirs(self.checkpoint_dir, exist_ok=True)
self.best_val_loss = None
if self.opts.save_interval is None: # 保存val_loss的时间间隔。
self.opts.save_interval = self.opts.max_steps
coach.train().
self.net.train()
设置为tranining模式。
进入循环while self.global_step < self.opts.max_steps:
进入batch for batch_idx, batch in enumerate(self.train_dataloader):
self.optimizer.zero_grad()
x, y = batch # image和标签。 所谓的标签也是image
x, y = x.to(self.device).float(), y.to(self.device).float()
y_hat, latent = self.net.forward(x, return_latents=True) # 得到psp的结果。
loss, loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent) #得到损失和相关信息
loss.backward()
self.optimizer.step()
# Logging related
if self.global_step % self.opts.image_interval == 0 or (self.global_step < 1000 and self.global_step % 25 == 0):
self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces')
if self.global_step % self.opts.board_interval == 0:
self.print_metrics(loss_dict, prefix='train')
self.log_metrics(loss_dict, prefix='train')
val_loss_dict = None
if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
val_loss_dict = self.validate() # 执行验证
if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
self.best_val_loss = val_loss_dict['loss'] # 保存下最好的val_loss
self.checkpoint_me(val_loss_dict, is_best=True) # 保存对应的checkpoint
if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
if val_loss_dict is not None: # 有验证的情况下,存在验证的。
self.checkpoint_me(val_loss_dict, is_best=False)
else: # 没有验证的情况下,存训练的。
self.checkpoint_me(loss_dict, is_best=False)
具体看pSp class。
init
函数。 给定了encoder 和Generator def __init__(self, opts):
super(pSp, self).__init__()
self.set_opts(opts)
# compute number of style inputs based on the output resolution
# 根据输出分辨率计算样式输入的数量
self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
# Define architecture
self.encoder = self.set_encoder()
self.decoder = Generator(self.opts.output_size, 512, 8)
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
# 使得池化后的每个通道上的大小是256, 256
# Load weights if needed
self.load_weights()
set_encoder
会根据 opt.encoder_type
来进行编码器的选择。 def set_encoder(self): # 选择编码器类型。
if self.opts.encoder_type == 'GradualStyleEncoder':
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) # 高斯风格编码器
elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW': # w
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus': # w+
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
else:
raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
return encoder
checkpoint_path
, 则需要对应的load_weights
. 但没有指定checkpoint的时候,默认从ir_se50
中加载encoder的权重,从opts.stylegan_weights
中加载decoder的权重。 两种情况下latent_avg
都需要单独加载。 def load_weights(self):
if self.opts.checkpoint_path is not None:
print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
# 权重包含了 encoder 和decoder的 。
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
self.__load_latent_avg(ckpt)
else:
# 如果了没有给定的checkpoint,则默认是使用预训练的irse50
print('Loading encoders weights from irse50!')
encoder_ckpt = torch.load(model_paths['ir_se50']) # 这个文件的路径在paths_config.py中指定
# if input to encoder is not an RGB image, do not load the input layer weights
# 如果编码器的输入不是RGB图像,不要加载输入层权重 (这个不是太明白)
if self.opts.label_nc != 0:
encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
self.encoder.load_state_dict(encoder_ckpt, strict=False)
print('Loading decoder weights from pretrained!')
ckpt = torch.load(self.opts.stylegan_weights) #decoder 从styleGan中加载权重,stylegan_weights需要给出路径。
self.decoder.load_state_dict(ckpt['g_ema'], strict=False) # 这个g_ema是什么意思?
if self.opts.learn_in_w:
self.__load_latent_avg(ckpt, repeat=1)
else: #使用w+,
self.__load_latent_avg(ckpt, repeat=self.opts.n_styles) # n_styles 不应是是18吗? 难道18是计算得到的?
forward
函数的内容。
if input_code: # 支持给定 codes的情况。
codes = x
else:
codes = self.encoder(x) #
# normalize with respect to the center of an average face
if self.opts.start_from_latent_avg: #相对于平均面的中心进行归一化
if self.opts.learn_in_w: # w
codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
else: # w+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
print(codes.shape)
if latent_mask is not None:
for i in latent_mask:
if inject_latent is not None:
if alpha is not None:
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
else:
codes[:, i] = inject_latent[:, i]
else:
codes[:, i] = 0
input_is_latent = not input_code
?decoder
的结果。images, result_latent = self.decoder([codes], #result_latent 是通过 decoder 计算
input_is_latent=input_is_latent,
randomize_noise=randomize_noise,
return_latents=return_latents) # 解码得到返回的结果。
if return_latents:
return images, result_latent
else:
return images
具体看 coach 中关于损失的计算 loss, loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
输入的是 【输入图像、标签图像、预测图像、潜码】、
返回的是【loss, loss_dict, id_logs 】。
看了一下代码,具体内容就不展示了。loss是所以loss的加权求和结果。loss_dict中记录了每个loss的值。 id_logs 是id loss 或者moco 损失产生的。
具体看pSp的encoder
。 encoder
的使用codes = self.encoder(x)
, 输入图像x
,返回codes。==但不是最后用于计算损失的latent。
具体实现方式,代码给了三种。 通过opts.encoder_type
进行指定。
pSp的decoder
。 以codes
为输入,返回image 和result_latent。
先了解到这里——————————————————————————
人脸对齐
from scripts.align_all_parallel import align_face
def run_alignment(self, image_path): # 这个函数可以直接将图像进行人脸对齐。
aligned_image = align_face(filepath=image_path, predictor=self.predictor)
print("Aligned image has shape: {}".format(aligned_image.size))
return aligned_image