pixel2style2pixel 源码解析【2】

文章目录

  • 项目分析
  • 一些重要功能的实现函数

项目分析

文章中提到用psp可以实现多种应用。 参考【论文解析】Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation.
项目中,包scriptes 包含了这些应用对应的运行文件。
pixel2style2pixel 源码解析【2】_第1张图片
上一次,我们在【源码解析】Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation 中已经走通了inference.py的整个流程,为新的测试图像进行了风格转换。

这次主要了解train.py

  1. 运行的所有参数通过TrainOptions 类进行传递。这个在train_options.py中实现。
    具体内容如下: 整体来说,还是比较好理解。

    self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
    self.parser.add_argument('--dataset_type', default='ffhq_encode', type=str, help='Type of dataset/experiment to run')
    self.parser.add_argument('--encoder_type', default='GradualStyleEncoder', type=str, help='Which encoder to use') # 编码类型
    self.parser.add_argument('--input_nc', default=3, type=int, help='Number of input image channels to the psp encoder') # 编码器对应的输入图像通道。
    self.parser.add_argument('--label_nc', default=0, type=int, help='Number of input label channels to the psp encoder')  # 输入标签通道。 标签是什么?
    self.parser.add_argument('--output_size', default=1024, type=int, help='Output size of generator') # 生成器的输出规模
    
    self.parser.add_argument('--batch_size', default=4, type=int, help='Batch size for training')
    self.parser.add_argument('--test_batch_size', default=2, type=int, help='Batch size for testing and inference')
    self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
    self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')
    
    self.parser.add_argument('--learning_rate', default=0.0001, type=float, help='Optimizer learning rate')
    self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')  # 优化起选择
    self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model') # 是否训练解码器
    self.parser.add_argument('--start_from_latent_avg', action='store_true', help='Whether to add average latent vector to generate codes from encoder.')
    # 是否加入平均潜在向量从编码器生成编码。
    self.parser.add_argument('--learn_in_w', action='store_true', help='Whether to learn in w space instead of w+') # 选择学习空间
    
    # 关于损失的设定
    self.parser.add_argument('--lpips_lambda', default=0.8, type=float, help='LPIPS loss multiplier factor')
    self.parser.add_argument('--id_lambda', default=0, type=float, help='ID loss multiplier factor')
    self.parser.add_argument('--l2_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
    self.parser.add_argument('--w_norm_lambda', default=0, type=float, help='W-norm loss multiplier factor')
    self.parser.add_argument('--lpips_lambda_crop', default=0, type=float, help='LPIPS loss multiplier factor for inner image region')
    self.parser.add_argument('--l2_lambda_crop', default=0, type=float, help='L2 loss multiplier factor for inner image region')
    self.parser.add_argument('--moco_lambda', default=0, type=float, help='Moco-based feature similarity loss multiplier factor')
    
    self.parser.add_argument('--stylegan_weights', default=model_paths['stylegan_ffhq'], type=str, help='Path to StyleGAN model weights') # stylegan网络的权重路劲
    self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') # 模型路径
    
    self.parser.add_argument('--max_steps', default=500000, type=int, help='Maximum number of training steps')  # 最多训练多少步
    self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training') # 记录图像的间隔
    self.parser.add_argument('--board_interval', default=50, type=int, help='Interval for logging metrics to tensorboard')
    self.parser.add_argument('--val_interval', default=1000, type=int, help='Validation interval') # 验证间隔
    self.parser.add_argument('--save_interval', default=None, type=int, help='Model checkpoint interval') # 保存checkpoint的间隔。
    
    # arguments for weights & biases support
    self.parser.add_argument('--use_wandb', action="store_true", help='Whether to use Weights & Biases to track experiment.')
    # 是否使用权重和偏差来跟踪实验
    # arguments for super-resolution
    self.parser.add_argument('--resize_factors', type=str, default=None, help='For super-res, comma-separated resize factors to use for inference.')
    
  2. train.py 的主要内容

def main():
	opts = TrainOptions().parse()
	if os.path.exists(opts.exp_dir):
		raise Exception('Oops... {} already exists'.format(opts.exp_dir))
	os.makedirs(opts.exp_dir)  # 创建实验目录

	opts_dict = vars(opts)
	pprint.pprint(opts_dict)
	with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
		json.dump(opts_dict, f, indent=4, sort_keys=True)  #把参数信息保存下来

	coach = Coach(opts)
	coach.train()
  1. coach init函数
	def __init__(self, opts):
		self.opts = opts

		self.global_step = 0

		self.device = 'cuda:0'  # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES
		self.opts.device = self.device

		if self.opts.use_wandb:
			from utils.wandb_utils import WBLogger
			self.wb_logger = WBLogger(self.opts)

		# Initialize network
		self.net = pSp(self.opts).to(self.device)

		# Estimate latent_avg via dense sampling if latent_avg is not available
		# 如果没有潜在平均值,则通过密集抽样估计潜在平均值
		if self.net.latent_avg is None:
			self.net.latent_avg = self.net.decoder.mean_latent(int(1e5))[0].detach()

		# Initialize loss
		if self.opts.id_lambda > 0 and self.opts.moco_lambda > 0:
			raise ValueError('Both ID and MoCo loss have lambdas > 0! Please select only one to have non-zero lambda!')
			# moco是什么, moco_based 特征相似性损失

		self.mse_loss = nn.MSELoss().to(self.device).eval()
		if self.opts.lpips_lambda > 0:  # 可以通过考虑对应的lambda是否大于0来控制是否要采用这个算损失。
			self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
		if self.opts.id_lambda > 0:
			self.id_loss = id_loss.IDLoss().to(self.device).eval()
		if self.opts.w_norm_lambda > 0:
			self.w_norm_loss = w_norm.WNormLoss(start_from_latent_avg=self.opts.start_from_latent_avg)
		if self.opts.moco_lambda > 0:
			self.moco_loss = moco_loss.MocoLoss().to(self.device).eval()

		# Initialize optimizer
		self.optimizer = self.configure_optimizers()

		# Initialize dataset
		self.train_dataset, self.test_dataset = self.configure_datasets()
		self.train_dataloader = DataLoader(self.train_dataset,
										   batch_size=self.opts.batch_size,
										   shuffle=True,
										   num_workers=int(self.opts.workers),
										   drop_last=True)
		self.test_dataloader = DataLoader(self.test_dataset,
										  batch_size=self.opts.test_batch_size,
										  shuffle=False,
										  num_workers=int(self.opts.test_workers),
										  drop_last=True)

		# Initialize logger
		log_dir = os.path.join(opts.exp_dir, 'logs')
		os.makedirs(log_dir, exist_ok=True)
		self.logger = SummaryWriter(log_dir=log_dir)  #

		# Initialize checkpoint dir
		self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
		os.makedirs(self.checkpoint_dir, exist_ok=True)
		self.best_val_loss = None
		if self.opts.save_interval is None: # 保存val_loss的时间间隔。
			self.opts.save_interval = self.opts.max_steps
  1. coach.train().

    1. self.net.train() 设置为tranining模式。

    2. 进入循环while self.global_step < self.opts.max_steps: 进入batch for batch_idx, batch in enumerate(self.train_dataloader):

      1. 核心计算过程
      self.optimizer.zero_grad()
      				x, y = batch  # image和标签。 所谓的标签也是image
      				x, y = x.to(self.device).float(), y.to(self.device).float()
      				y_hat, latent = self.net.forward(x, return_latents=True)  # 得到psp的结果。 
      				loss, loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)  #得到损失和相关信息
      				loss.backward()
      				self.optimizer.step()
      
      1. logging related。 打印或者存储相关信息。
      # Logging related
      				if self.global_step % self.opts.image_interval == 0 or (self.global_step < 1000 and self.global_step % 25 == 0):
      					self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces')
      				if self.global_step % self.opts.board_interval == 0:
      					self.print_metrics(loss_dict, prefix='train')
      					self.log_metrics(loss_dict, prefix='train')
      
      1. Validation related.。 保存对应的checkpoint
      val_loss_dict = None
      if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
      	val_loss_dict = self.validate()  # 执行验证
      	if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
      		self.best_val_loss = val_loss_dict['loss']  # 保存下最好的val_loss
      		self.checkpoint_me(val_loss_dict, is_best=True)  # 保存对应的checkpoint
      
      if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
      	if val_loss_dict is not None: # 有验证的情况下,存在验证的。
      		self.checkpoint_me(val_loss_dict, is_best=False)
      	else: # 没有验证的情况下,存训练的。
      		self.checkpoint_me(loss_dict, is_best=False)
      
  2. 具体看pSp class。

    1. init函数。 给定了encoder 和Generator
    	def __init__(self, opts):
    		super(pSp, self).__init__()
    		self.set_opts(opts)
    		# compute number of style inputs based on the output resolution
    		# 根据输出分辨率计算样式输入的数量
    		self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
    		# Define architecture
    		self.encoder = self.set_encoder()
    		self.decoder = Generator(self.opts.output_size, 512, 8)
    		self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
    		# 使得池化后的每个通道上的大小是256, 256
    		# Load weights if needed
    		self.load_weights()
    
    1. set_encoder 会根据 opt.encoder_type 来进行编码器的选择。
    	def set_encoder(self):  # 选择编码器类型。
    		if self.opts.encoder_type == 'GradualStyleEncoder':
    			encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) # 高斯风格编码器
    		elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW': # w
    			encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
    		elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus': # w+
    			encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
    		else:
    			raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
    		return encoder
    
    1. 如果opt中给出checkpoint_path, 则需要对应的load_weights. 但没有指定checkpoint的时候,默认从ir_se50 中加载encoder的权重,从opts.stylegan_weights 中加载decoder的权重。 两种情况下latent_avg都需要单独加载。
    	def load_weights(self):
    		if self.opts.checkpoint_path is not None:
    			print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
    			ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
    			# 权重包含了 encoder 和decoder的 。
    			self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
    			self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
    			self.__load_latent_avg(ckpt)
    		else:
    			# 如果了没有给定的checkpoint,则默认是使用预训练的irse50
    			print('Loading encoders weights from irse50!')
    			encoder_ckpt = torch.load(model_paths['ir_se50']) # 这个文件的路径在paths_config.py中指定
    			# if input to encoder is not an RGB image, do not load the input layer weights
    			# 如果编码器的输入不是RGB图像,不要加载输入层权重  (这个不是太明白)
    			if self.opts.label_nc != 0:
    				encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
    			self.encoder.load_state_dict(encoder_ckpt, strict=False)
    			print('Loading decoder weights from pretrained!')
    			ckpt = torch.load(self.opts.stylegan_weights) #decoder 从styleGan中加载权重,stylegan_weights需要给出路径。 
    			self.decoder.load_state_dict(ckpt['g_ema'], strict=False)  # 这个g_ema是什么意思? 
    			if self.opts.learn_in_w:
    				self.__load_latent_avg(ckpt, repeat=1)
    			else: #使用w+,
    				self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)  # n_styles 不应是是18吗? 难道18是计算得到的?
    
    1. forward函数的内容。
      1. 计算codes
      if input_code:  # 支持给定 codes的情况。
      	codes = x
      else:
      	codes = self.encoder(x) #
      	# normalize with respect to the center of an average face
      	if self.opts.start_from_latent_avg:  #相对于平均面的中心进行归一化
      		if self.opts.learn_in_w: # w
      			codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
      		else: # w+
      			codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
      print(codes.shape)
      
      if latent_mask is not None:
      	for i in latent_mask:
      		if inject_latent is not None:
      			if alpha is not None:
      				codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
      			else:
      				codes[:, i] = inject_latent[:, i]
      		else:
      			codes[:, i] = 0
      
      1. input_is_latent = not input_code
      2. 计算decoder的结果。
      images, result_latent = self.decoder([codes],  #result_latent 是通过 decoder 计算
      		                                     input_is_latent=input_is_latent,
      		                                     randomize_noise=randomize_noise,
      		                                     return_latents=return_latents)  # 解码得到返回的结果。
      
      1. 返回结果
      		if return_latents:
      			return images, result_latent
      		else:
      			return images
      
  3. 具体看 coach 中关于损失的计算 loss, loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
    输入的是 【输入图像、标签图像、预测图像、潜码】、
    返回的是【loss, loss_dict, id_logs 】。
    看了一下代码,具体内容就不展示了。loss是所以loss的加权求和结果。loss_dict中记录了每个loss的值。 id_logs 是id loss 或者moco 损失产生的。

  4. 具体看pSp的encoderencoder的使用codes = self.encoder(x) , 输入图像x,返回codes。==但不是最后用于计算损失的latent。
    具体实现方式,代码给了三种。 通过opts.encoder_type 进行指定。

  5. pSp的decoder。 以codes 为输入,返回image 和result_latent。

先了解到这里——————————————————————————

一些重要功能的实现函数

人脸对齐

from scripts.align_all_parallel import align_face
def run_alignment(self, image_path):  # 这个函数可以直接将图像进行人脸对齐。 
     aligned_image = align_face(filepath=image_path, predictor=self.predictor)
     print("Aligned image has shape: {}".format(aligned_image.size))
     return aligned_image

你可能感兴趣的:(GAN,人工智能,pSp,Style,StyleGAN,CV)