通常github很多项目都是使用argparse输入超参数,而我们想要测试该项目的话会比较麻烦,因为需要各种调试。
可以在jupyter中将项目中定义的argparse函数复制到cell中,以ImageMol的项目文件为例,作者提示我们想要使用pretrain.py中的代码,可以使用以下命令:
python pretrain.py --ckpt_dir ./ckpts/pretraining/ \
--checkpoints 1 \
--Jigsaw_lambda 1 \
--cluster_lambda 1 \
--constractive_lambda 1 \
--matcher_lambda 1 \
--is_recover_training 1 \
--batch 256 \
--dataroot ./datasets/pretraining/ \
--dataset data \
--gpu 0,1,2,3 \
--ngpu 4
我们可以查看该文件中写的代码
import argparse
import os
.
.
.
def parse_args():
parser = argparse.ArgumentParser(description='parameters of pretraining ImageMol')
parser.add_argument('--lr', default=0.01, type=float, help='learning rate (default: 0.01)')
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow (default: -5)')
parser.add_argument('--workers', default=2, type=int, help='number of data loading workers (default: 2)')
parser.add_argument('--val_workers', default=16, type=int, help='number of data loading workers (default: 16)')
parser.add_argument('--epochs', type=int, default=151, help='number of total epochs to run (default: 151)')
parser.add_argument('--start_epoch', default=0, type=int,
help='manual epoch number (useful on restarts) (default: 0)')
parser.add_argument('--batch', default=256, type=int, help='mini-batch size (default: 256)')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
parser.add_argument('--checkpoints', type=int, default=1,
help='how many iterations between two checkpoints (default: 1)')
parser.add_argument('--seed', type=int, default=31, help='random seed (default: 31)')
parser.add_argument('--dataroot', type=str, default="./datasets/pretraining/", help='data root')
parser.add_argument('--dataset', type=str, default="toy", help='dataset name, e.g. data, toy')
parser.add_argument('--ckpt_dir', default='./ckpts/pretrain_model', help='path to checkpoint')
parser.add_argument('--modelname', type=str, default="ResNet18", choices=["ResNet18"], help='supported model')
parser.add_argument('--verbose', action='store_true', help='')
parser.add_argument('--ngpu', type=int, default=8, help='number of GPUs to use')
parser.add_argument('--gpu', type=str, default="0", help='GPUs of CUDA_VISIBLE_DEVICES')
parser.add_argument('--nc', type=int, default=3)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--imageSize', type=int, default=224, help='the height / width of the input image to network')
parser.add_argument('--Jigsaw_lambda', type=float, default=1,
help='start JPP task, 1 means start, 0 means not start')
parser.add_argument('--cluster_lambda', type=float, default=1, help='start M3GC task')
parser.add_argument('--constractive_lambda', type=float, default=0, help='start MCL task')
parser.add_argument('--matcher_lambda', type=float, default=0, help='start MRD task')
parser.add_argument('--is_recover_training', type=int, default=1, help='start MIR task')
parser.add_argument('--cl_mask_type', type=str, default="rectangle_mask", help='',
choices=["random_mask", "rectangle_mask", "mix_mask"])
parser.add_argument('--cl_mask_shape_h', type=int, default=16, help='mask_utils->create_rectangle_mask()')
parser.add_argument('--cl_mask_shape_w', type=int, default=16, help='mask_utils->create_rectangle_mask()')
parser.add_argument('--cl_mask_ratio', type=float, default=0.001, help='mask_utils->create_random_mask()')
return parser.parse_args()
.
.
.
if __name__ == '__main__':
args = parse_args()
main(args)
可以发现有一个函数控制了参数的生成,这一部分就是我们需要复制到cell的部分,而main函数是我们需要的运行的代码。
import argparse
import pretrain
def parse_args(args):
parser = argparse.ArgumentParser(description='parameters of pretraining ImageMol')
parser.add_argument('--lr', default=0.01, type=float, help='learning rate (default: 0.01)')
parser.add_argument('--wd', default=-5, type=float, help='weight decay pow (default: -5)')
parser.add_argument('--workers', default=2, type=int, help='number of data loading workers (default: 2)')
parser.add_argument('--val_workers', default=16, type=int, help='number of data loading workers (default: 16)')
parser.add_argument('--epochs', type=int, default=151, help='number of total epochs to run (default: 151)')
parser.add_argument('--start_epoch', default=0, type=int,
help='manual epoch number (useful on restarts) (default: 0)')
parser.add_argument('--batch', default=256, type=int, help='mini-batch size (default: 256)')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
parser.add_argument('--checkpoints', type=int, default=1,
help='how many iterations between two checkpoints (default: 1)')
parser.add_argument('--seed', type=int, default=31, help='random seed (default: 31)')
parser.add_argument('--dataroot', type=str, default="./datasets/pretraining/", help='data root')
parser.add_argument('--dataset', type=str, default="toy", help='dataset name, e.g. data, toy')
parser.add_argument('--ckpt_dir', default='./ckpts/pretrain_model', help='path to checkpoint')
parser.add_argument('--modelname', type=str, default="ResNet18", choices=["ResNet18"], help='supported model')
parser.add_argument('--verbose', action='store_true', help='')
parser.add_argument('--ngpu', type=int, default=8, help='number of GPUs to use')
parser.add_argument('--gpu', type=str, default="0", help='GPUs of CUDA_VISIBLE_DEVICES')
parser.add_argument('--nc', type=int, default=3)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--imageSize', type=int, default=224, help='the height / width of the input image to network')
parser.add_argument('--Jigsaw_lambda', type=float, default=1,
help='start JPP task, 1 means start, 0 means not start')
parser.add_argument('--cluster_lambda', type=float, default=1, help='start M3GC task')
parser.add_argument('--constractive_lambda', type=float, default=0, help='start MCL task')
parser.add_argument('--matcher_lambda', type=float, default=0, help='start MRD task')
parser.add_argument('--is_recover_training', type=int, default=1, help='start MIR task')
parser.add_argument('--cl_mask_type', type=str, default="rectangle_mask", help='',
choices=["random_mask", "rectangle_mask", "mix_mask"])
parser.add_argument('--cl_mask_shape_h', type=int, default=16, help='mask_utils->create_rectangle_mask()')
parser.add_argument('--cl_mask_shape_w', type=int, default=16, help='mask_utils->create_rectangle_mask()')
parser.add_argument('--cl_mask_ratio', type=float, default=0.001, help='mask_utils->create_random_mask()')
return parser.parse_args(args)
args = ['--ckpt_dir','./ckpts/pretraining-toy/','--dataroot','./datasets/toy/pretraining/','--dataset','data','--epochs','5','--ngpu','1','--batch','1']
args = parse_args(args)
pretrain.main(args)
我们可以将想要输入的参数通过列表的形式输入,输入规格为[’–参数名称‘,‘参数输入’],如此重复,在该函数中添加参数args,然后在返回的函数中也输入该参数,通过调用该函数的时候将我们想要的参数列表传入其中,即可获得代码所需要的args,将其传入pretrain中的main函数,即可完成任务。