github项目中的argparse超参数如何在代码中输入使用

argparse 如何在jupyter中使用

通常github很多项目都是使用argparse输入超参数,而我们想要测试该项目的话会比较麻烦,因为需要各种调试。

可以在jupyter中将项目中定义的argparse函数复制到cell中,以ImageMol的项目文件为例,作者提示我们想要使用pretrain.py中的代码,可以使用以下命令:

python pretrain.py --ckpt_dir ./ckpts/pretraining/ \
                   --checkpoints 1 \
                   --Jigsaw_lambda 1 \
                   --cluster_lambda 1 \
                   --constractive_lambda 1 \
                   --matcher_lambda 1 \
                   --is_recover_training 1 \
                   --batch 256 \
                   --dataroot ./datasets/pretraining/ \
                   --dataset data \
                   --gpu 0,1,2,3 \
                   --ngpu 4

我们可以查看该文件中写的代码

import argparse
import os
.
.
.


def parse_args():
    parser = argparse.ArgumentParser(description='parameters of pretraining ImageMol')

    parser.add_argument('--lr', default=0.01, type=float, help='learning rate (default: 0.01)')
    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow (default: -5)')
    parser.add_argument('--workers', default=2, type=int, help='number of data loading workers (default: 2)')
    parser.add_argument('--val_workers', default=16, type=int, help='number of data loading workers (default: 16)')
    parser.add_argument('--epochs', type=int, default=151, help='number of total epochs to run (default: 151)')
    parser.add_argument('--start_epoch', default=0, type=int,
                        help='manual epoch number (useful on restarts) (default: 0)')
    parser.add_argument('--batch', default=256, type=int, help='mini-batch size (default: 256)')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
    parser.add_argument('--checkpoints', type=int, default=1,
                        help='how many iterations between two checkpoints (default: 1)')
    parser.add_argument('--seed', type=int, default=31, help='random seed (default: 31)')
    parser.add_argument('--dataroot', type=str, default="./datasets/pretraining/", help='data root')
    parser.add_argument('--dataset', type=str, default="toy", help='dataset name, e.g. data, toy')
    parser.add_argument('--ckpt_dir', default='./ckpts/pretrain_model', help='path to checkpoint')
    parser.add_argument('--modelname', type=str, default="ResNet18", choices=["ResNet18"], help='supported model')
    parser.add_argument('--verbose', action='store_true', help='')
    parser.add_argument('--ngpu', type=int, default=8, help='number of GPUs to use')
    parser.add_argument('--gpu', type=str, default="0", help='GPUs of CUDA_VISIBLE_DEVICES')
    parser.add_argument('--nc', type=int, default=3)
    parser.add_argument('--ndf', type=int, default=64)
    parser.add_argument('--imageSize', type=int, default=224, help='the height / width of the input image to network')
    parser.add_argument('--Jigsaw_lambda', type=float, default=1,
                        help='start JPP task, 1 means start, 0 means not start')
    parser.add_argument('--cluster_lambda', type=float, default=1, help='start M3GC task')
    parser.add_argument('--constractive_lambda', type=float, default=0, help='start MCL task')
    parser.add_argument('--matcher_lambda', type=float, default=0, help='start MRD task')
    parser.add_argument('--is_recover_training', type=int, default=1, help='start MIR task')
    parser.add_argument('--cl_mask_type', type=str, default="rectangle_mask", help='',
                        choices=["random_mask", "rectangle_mask", "mix_mask"])
    parser.add_argument('--cl_mask_shape_h', type=int, default=16, help='mask_utils->create_rectangle_mask()')
    parser.add_argument('--cl_mask_shape_w', type=int, default=16, help='mask_utils->create_rectangle_mask()')
    parser.add_argument('--cl_mask_ratio', type=float, default=0.001, help='mask_utils->create_random_mask()')

    return parser.parse_args()

.
.
.
if __name__ == '__main__':
    args = parse_args()
    main(args)

可以发现有一个函数控制了参数的生成,这一部分就是我们需要复制到cell的部分,而main函数是我们需要的运行的代码。

import argparse
import pretrain
def parse_args(args):
    parser = argparse.ArgumentParser(description='parameters of pretraining ImageMol')

    parser.add_argument('--lr', default=0.01, type=float, help='learning rate (default: 0.01)')
    parser.add_argument('--wd', default=-5, type=float, help='weight decay pow (default: -5)')
    parser.add_argument('--workers', default=2, type=int, help='number of data loading workers (default: 2)')
    parser.add_argument('--val_workers', default=16, type=int, help='number of data loading workers (default: 16)')
    parser.add_argument('--epochs', type=int, default=151, help='number of total epochs to run (default: 151)')
    parser.add_argument('--start_epoch', default=0, type=int,
                        help='manual epoch number (useful on restarts) (default: 0)')
    parser.add_argument('--batch', default=256, type=int, help='mini-batch size (default: 256)')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
    parser.add_argument('--checkpoints', type=int, default=1,
                        help='how many iterations between two checkpoints (default: 1)')
    parser.add_argument('--seed', type=int, default=31, help='random seed (default: 31)')
    parser.add_argument('--dataroot', type=str, default="./datasets/pretraining/", help='data root')
    parser.add_argument('--dataset', type=str, default="toy", help='dataset name, e.g. data, toy')
    parser.add_argument('--ckpt_dir', default='./ckpts/pretrain_model', help='path to checkpoint')
    parser.add_argument('--modelname', type=str, default="ResNet18", choices=["ResNet18"], help='supported model')
    parser.add_argument('--verbose', action='store_true', help='')
    parser.add_argument('--ngpu', type=int, default=8, help='number of GPUs to use')
    parser.add_argument('--gpu', type=str, default="0", help='GPUs of CUDA_VISIBLE_DEVICES')
    parser.add_argument('--nc', type=int, default=3)
    parser.add_argument('--ndf', type=int, default=64)
    parser.add_argument('--imageSize', type=int, default=224, help='the height / width of the input image to network')
    parser.add_argument('--Jigsaw_lambda', type=float, default=1,
                        help='start JPP task, 1 means start, 0 means not start')
    parser.add_argument('--cluster_lambda', type=float, default=1, help='start M3GC task')
    parser.add_argument('--constractive_lambda', type=float, default=0, help='start MCL task')
    parser.add_argument('--matcher_lambda', type=float, default=0, help='start MRD task')
    parser.add_argument('--is_recover_training', type=int, default=1, help='start MIR task')
    parser.add_argument('--cl_mask_type', type=str, default="rectangle_mask", help='',
                        choices=["random_mask", "rectangle_mask", "mix_mask"])
    parser.add_argument('--cl_mask_shape_h', type=int, default=16, help='mask_utils->create_rectangle_mask()')
    parser.add_argument('--cl_mask_shape_w', type=int, default=16, help='mask_utils->create_rectangle_mask()')
    parser.add_argument('--cl_mask_ratio', type=float, default=0.001, help='mask_utils->create_random_mask()')

    return parser.parse_args(args)
args = ['--ckpt_dir','./ckpts/pretraining-toy/','--dataroot','./datasets/toy/pretraining/','--dataset','data','--epochs','5','--ngpu','1','--batch','1']
args = parse_args(args)
pretrain.main(args)

​ 我们可以将想要输入的参数通过列表的形式输入,输入规格为[’–参数名称‘,‘参数输入’],如此重复,在该函数中添加参数args,然后在返回的函数中也输入该参数,通过调用该函数的时候将我们想要的参数列表传入其中,即可获得代码所需要的args,将其传入pretrain中的main函数,即可完成任务。

你可能感兴趣的:(python,人工智能,学习,github)