一、什么是ArgParser?
argparse是一个Python模块:命令行选项、参数和子命令解析器;argparse模块可以让人轻松编写用户友好的命令行接口。程序定义它需要的参数。然后argparser将弄清如何从sys.argv解析出那些参数。argparse模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。
二、用法
在使用时最好将其分化成三个部分:trainer参数(如gpus)、特定于模型的参数(如层数维度)、程序参数(如路径)
1.基础用法:
#导入模块
from argparse import ArgumentParser
#创建解析器对象
parser = ArgumentParser()
#添加参数
#给一个ArgumentParser添加程序阐述信息是通过调用add_arguement()方法完成的。
parser.add_argument("--layer_1_dim", type=int, default=128)
#解析参数
args = parser.parse_args()
#可以输入一下命令调用程序
python trainer.py --layer_1_dim 64
(1)ArgumentParser对象
prog - 程序的名称(默认: sys.argv[0],prog猜测是programma的缩写)
usage - 描述程序用途的字符串(默认值:从添加到解析器的参数生成)
description - 在参数帮助文档之后显示的文本 (默认值:无)
(2)add_argument()方法
name or flags - 一个命名或者一个选项字符串的列表
action - 表示该选项要执行的操作
default - 当参数未在命令行中出现时使用的值
dest - 用来指定参数的位置
type - 为参数类型,例如int
choices - 用来选择输入参数的范围。例如choice = [1, 5, 10], 表示输入参数只能为1,5 或10
help - 用来描述这个选项的作用
2.在主Trainer文件中,添加Trainer参数、程序参数和模型参数
# ----------------
# trainer_main.py
# ----------------
from argparse import ArgumentParser
parser = ArgumentParser()
# 添加程序级参数
parser.add_argument("--conda_env", type=str, default="some_name")
parser.add_argument("--notification_email", type=str, default="[email protected]")
# 添加特定于模型的参数
parser = LitModel.add_model_specific_args(parser)
# 将所有可用的trainer选项添加到argparse
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
在调用时,可以如下书写调用:
python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12
3、我们通常有多个模块,每个模块都有不同的参数。Module允许你为每个文件定义参数,而不用污染main.py文件。
class LitMNIST(LightningModule):
def __init__(self, layer_1_dim, **kwargs):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, layer_1_dim)
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("LitMNIST")
parser.add_argument("--layer_1_dim", type=int, default=128)
return parent_parser
class GoodGAN(LightningModule):
def __init__(self, encoder_layers, **kwargs):
super().__init__()
self.encoder = Encoder(layers=encoder_layers)
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("GoodGAN")
parser.add_argument("--encoder_layers", type=int, default=12)
return parent_parser
现在,我们可以允许每个模型在main.py中插入所需的参数:
def main(args):
dict_args = vars(args)
# 挑选模型
if args.model_name == "gan":
model = GoodGAN(**dict_args)
elif args.model_name == "mnist":
model = LitMNIST(**dict_args)
trainer = Trainer.from_argparse_args(args)
trainer.fit(model)
if __name__ == "__main__":
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
# 找出要使用的模型
parser.add_argument("--model_name", type=str, default="gan", help="gan or mnist")
# 这一行是提取模型名称的关键
temp_args, _ = parser.parse_known_args()
# 让模型添加它想要的东西
if temp_args.model_name == "gan":
parser = GoodGAN.add_model_specific_args(parser)
elif temp_args.model_name == "mnist":
parser = LitMNIST.add_model_specific_args(parser)
args = parser.parse_args()
# 训练
main(args)
现在我们可以使用命令行界面训练MNIST或GAN!
python main.py --model_name gan --encoder_layers 24
python main.py --model_name mnist --layer_1_dim 128
相关参考
1.PyTorch Lightning初步教程(下) - 知乎
2.argparse.ArgumentParser()用法解析 - 重大的小鸿 - 博客园