命令行参数解析器:argparser.ArgumentParser()

一、什么是ArgParser?

argparse是一个Python模块:命令行选项、参数和子命令解析器;argparse模块可以让人轻松编写用户友好的命令行接口。程序定义它需要的参数。然后argparser将弄清如何从sys.argv解析出那些参数。argparse模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。

二、用法

在使用时最好将其分化成三个部分:trainer参数(如gpus)、特定于模型的参数(如层数维度)、程序参数(如路径)

1.基础用法:

#导入模块
from argparse import ArgumentParser
#创建解析器对象
parser = ArgumentParser()
#添加参数
#给一个ArgumentParser添加程序阐述信息是通过调用add_arguement()方法完成的。
parser.add_argument("--layer_1_dim", type=int, default=128)
#解析参数
args = parser.parse_args()


#可以输入一下命令调用程序
python trainer.py --layer_1_dim 64

(1)ArgumentParser对象

prog - 程序的名称(默认: sys.argv[0],prog猜测是programma的缩写)

usage - 描述程序用途的字符串(默认值:从添加到解析器的参数生成)

description - 在参数帮助文档之后显示的文本 (默认值:无)

(2)add_argument()方法

name or flags - 一个命名或者一个选项字符串的列表

action - 表示该选项要执行的操作

default - 当参数未在命令行中出现时使用的值

dest - 用来指定参数的位置

type - 为参数类型,例如int

choices - 用来选择输入参数的范围。例如choice = [1, 5, 10], 表示输入参数只能为1,5 或10

help - 用来描述这个选项的作用

2.在主Trainer文件中,添加Trainer参数、程序参数和模型参数

# ----------------
# trainer_main.py
# ----------------
from argparse import ArgumentParser

parser = ArgumentParser()

# 添加程序级参数
parser.add_argument("--conda_env", type=str, default="some_name")
parser.add_argument("--notification_email", type=str, default="[email protected]")

# 添加特定于模型的参数
parser = LitModel.add_model_specific_args(parser)

# 将所有可用的trainer选项添加到argparse
parser = Trainer.add_argparse_args(parser)

args = parser.parse_args()

在调用时,可以如下书写调用:

python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12

3、我们通常有多个模块,每个模块都有不同的参数。Module允许你为每个文件定义参数,而不用污染main.py文件。

class LitMNIST(LightningModule):
    def __init__(self, layer_1_dim, **kwargs):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, layer_1_dim)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("LitMNIST")
        parser.add_argument("--layer_1_dim", type=int, default=128)
        return parent_parser


class GoodGAN(LightningModule):
    def __init__(self, encoder_layers, **kwargs):
        super().__init__()
        self.encoder = Encoder(layers=encoder_layers)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("GoodGAN")
        parser.add_argument("--encoder_layers", type=int, default=12)
        return parent_parser

 现在,我们可以允许每个模型在main.py中插入所需的参数:

def main(args):
    dict_args = vars(args)

    # 挑选模型
    if args.model_name == "gan":
        model = GoodGAN(**dict_args)
    elif args.model_name == "mnist":
        model = LitMNIST(**dict_args)

    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)

    # 找出要使用的模型
    parser.add_argument("--model_name", type=str, default="gan", help="gan or mnist")

    # 这一行是提取模型名称的关键
    temp_args, _ = parser.parse_known_args()

    # 让模型添加它想要的东西
    if temp_args.model_name == "gan":
        parser = GoodGAN.add_model_specific_args(parser)
    elif temp_args.model_name == "mnist":
        parser = LitMNIST.add_model_specific_args(parser)

    args = parser.parse_args()

    # 训练
    main(args)

现在我们可以使用命令行界面训练MNIST或GAN!

 python main.py --model_name gan --encoder_layers 24
 python main.py --model_name mnist --layer_1_dim 128

相关参考

1.PyTorch Lightning初步教程(下) - 知乎

2.argparse.ArgumentParser()用法解析 - 重大的小鸿 - 博客园

你可能感兴趣的:(pytorch)