目录
预备知识
main.py
解析命令行参数
解析配置文件
由于代码中除了一些必要的对模型、数据进行操作的PyTorch函数外,还有一些辅助显示训练等过程有关信息的,或辅助对文件目录进行操作的库。因此,建议读者先对这些库进行了解,试着写一写示例代码,理解库中函数的使用方法后再阅读下面的讲解,这样可以更顺畅。
import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os
首先对输出的选项进行设定,让输出的内容不按科学计数法模式。
torch.set_printoptions(sci_mode=False) # 设置为不按照科学计数法表示输出
然后程序进入main()函数中,在main函数中完成了以下任务:
后面我们逐一进行代码分析。
def main():
args, config = parse_args_and_config() # 解析命令行参数和配置文件
logging.info("Writing log file to {}".format(args.log_path)) # 显示日志存储路径信息
logging.info("Exp instance id = {}".format(os.getpid())) # 显示进程id信息
logging.info("Exp comment = {}".format(args.comment)) # 显示实验注释信息
try:
runner = Diffusion(args, config) # 构建扩散运行实例对象
if args.sample: # 如果是采样操作,就执行采样函数
runner.sample()
elif args.test: # 如果是测试模型,就执行测试函数
runner.test()
else: # 否则就执行训练函数
runner.train()
except Exception: # 如果报错就输出错误信息日志
logging.error(traceback.format_exc())
return 0
对命令行参数的解析在parse_args_and_config函数中完成,每一个参数的含义以注释的形式标明,如果有异议欢迎在评论中指出。
def parse_args_and_config():
parser = argparse.ArgumentParser(description=globals()["__doc__"])
parser.add_argument( # config文件路径
"--config", type=str, required=True, help="Path to the config file"
)
parser.add_argument("--seed", type=int, default=1234, help="Random seed") # 随机种子
parser.add_argument( # 用于保存运行相关数据的路径
"--exp", type=str, default="exp", help="Path for saving running related data."
)
parser.add_argument( # log日志文件夹名称
"--doc",
type=str,
required=True,
help="A string for documentation purpose. "
"Will be the name of the log folder.",
)
parser.add_argument( # 实验注释
"--comment", type=str, default="", help="A string for experiment comment"
)
parser.add_argument( # logging日志的级别: info, debug, warning, critical
"--verbose",
type=str,
default="info",
help="Verbose level: info | debug | warning | critical",
)
parser.add_argument("--test", action="store_true", help="Whether to test the model") # 是否测试模型
parser.add_argument( # 是否从模型产生采样
"--sample",
action="store_true",
help="Whether to produce samples from the model",
)
parser.add_argument("--fid", action="store_true") # FID指标
parser.add_argument("--interpolation", action="store_true") # 插值
parser.add_argument( # 是否为继续训练
"--resume_training", action="store_true", help="Whether to resume training"
)
parser.add_argument( # 采样的文件夹名称
"-i",
"--image_folder",
type=str,
default="images",
help="The folder name of samples",
)
parser.add_argument( # 无交互
"--ni",
action="store_true",
help="No interaction. Suitable for Slurm Job launcher",
)
parser.add_argument("--use_pretrained", action="store_true") # 使用预训练
parser.add_argument( # 采样类型
"--sample_type",
type=str,
default="generalized",
help="sampling approach (generalized or ddpm_noisy)",
)
parser.add_argument( # 跳跃类型
"--skip_type",
type=str,
default="uniform",
help="skip according to (uniform or quadratic)",
)
parser.add_argument( # 步数
"--timesteps", type=int, default=1000, help="number of steps involved"
)
parser.add_argument( # \eta超参数用于控制方差
"--eta",
type=float,
default=0.0,
help="eta used to control the variances of sigma",
)
parser.add_argument("--sequence", action="store_true") # 是否为序列
args = parser.parse_args() # 解析参数
args.log_path = os.path.join(args.exp, "logs", args.doc) # log日志路径: exp/logs/$doc$
...
解析配置文件的过程也是在parse_args_and_config函数中,args.config应该是bedroom,celeba,church,cifar10中的一个。这样我们可以直接打开文件夹configs中对应数据集的yaml配置文件,此时config为字典类型。经过dict2namespace函数,将字典类型转换为argparse中命名空间的形式。
def parse_args_and_config():
...
# parse config file
with open(os.path.join("configs", args.config), "r") as f:
config = yaml.safe_load(f)
new_config = dict2namespace(config)
...
转换函数如下:
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
之后还有一步设定tensorboard日志的路径,可以在训练时用tensorboard查看训练进度信息:
def parse_args_and_config():
...
tb_path = os.path.join(args.exp, "tensorboard", args.doc) # tensorboard日志路径: exp/tensorboard/$doc$
...
之后会执行训练 / 采样 / 测试不同的代码部分:
首先看一下对于训练会执行的代码:
def parse_args_and_config():
...
if not args.test and not args.sample:
if not args.resume_training:
if os.path.exists(args.log_path): # 如果log输出路径存在的话
overwrite = False # 选择不覆盖
if args.ni: # 如果ni为True
overwrite = True # 选择覆盖
else:
response = input("Folder already exists. Overwrite? (Y/N)") # 询问是否覆盖
if response.upper() == "Y": # 如果Y, 则选择覆盖原有log
overwrite = True
if overwrite: # 如果选择覆盖
shutil.rmtree(args.log_path) # 删除原有log文件路径
shutil.rmtree(tb_path) # 删除原有tensorboard文件路径
os.makedirs(args.log_path) # 创建新的log文件路径
if os.path.exists(tb_path): # 如果tensorboard文件路径存在, 就删除它
shutil.rmtree(tb_path)
else: # 如果选择不覆盖, 则提示文件夹存在, 程序停止
print("Folder exists. Program halted.")
sys.exit(0)
else: # 如果log输出路径不存在就创建路径
os.makedirs(args.log_path)
with open(os.path.join(args.log_path, "config.yml"), "w") as f:
yaml.dump(new_config, f, default_flow_style=False)
new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)
# setup logger
level = getattr(logging, args.verbose.upper(), None) # 20 (logging.INFO) 或者其它的级别
if not isinstance(level, int): # 如果为None的话就会报错
raise ValueError("level {} not supported".format(args.verbose))
handler1 = logging.StreamHandler() # 将log在CLI输出的handler
handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt")) # 将log在文件输出的handler
formatter = logging.Formatter( # 控制log输出格式的formatter
"%(levelname)s - %(filename)s - %(asctime)s - %(message)s" # INFO - __main__ - ... - ....
)
handler1.setFormatter(formatter) # 设置CLI输出handler的格式
handler2.setFormatter(formatter) # 设置文件输出handler的格式
logger = logging.getLogger() # root logger
logger.addHandler(handler1) # 添加CLI输出handler
logger.addHandler(handler2) # 添加文件输出handler
logger.setLevel(level) # 设定root logger的级别
...
然后是采样 / 测试会执行的代码:
def parse_args_and_config():
...
else:
level = getattr(logging, args.verbose.upper(), None)
if not isinstance(level, int):
raise ValueError("level {} not supported".format(args.verbose))
handler1 = logging.StreamHandler()
formatter = logging.Formatter(
"%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
)
handler1.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler1)
logger.setLevel(level)
if args.sample: # 如果是采样
os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True) # 创建目录: exp/image_samples
args.image_folder = os.path.join( # 添加图像文件夹参数: exp/image_samples/$image_folder$
args.exp, "image_samples", args.image_folder
)
if not os.path.exists(args.image_folder): # 如果图像文件夹不存在就创建一个
os.makedirs(args.image_folder)
else: # 如果图像文件夹存在
if not (args.fid or args.interpolation):
overwrite = False
if args.ni:
overwrite = True
else:
response = input(
f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"
)
if response.upper() == "Y":
overwrite = True
if overwrite: # 如果覆盖, 删除并新建文件夹
shutil.rmtree(args.image_folder)
os.makedirs(args.image_folder)
else:
print("Output image folder exists. Program halted.")
sys.exit(0)
...
最后是对PyTorch进行设置:
def parse_args_and_config():
...
# add device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logging.info("Using device: {}".format(device))
new_config.device = device
# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
return args, new_config
至此,就基本结束main.py的学习了,后面讲进入Diffusion类中查看具体初始化、训练、采样、测试这些函数是如何实现的了。