argparse模块介绍
步骤
- 导包
- 创建解析器
- 添加参数和选项
- 解析参数
- 使用
导包
import argparse
创建解析器
parser = argparse.ArgumentParser()
添加参数
ArgumentParser.add_argument(name or flags...[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])
参数说明:https://blog.csdn.net/weixin_45800242/article/details/125658094
举例:
https://blog.csdn.net/topsogn/article/details/121562512
# parser.add_argument('--dataset', default='ModelNet40', help='ModelNet10|ModelNet40|ShapeNet')
parser.add_argument('--dataroot', default='dataset/train', help='path to dataset')
parser.add_argument('--workers', type=int, default=0, help='number of data loading workers')
parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
parser.add_argument('--pnum', type=int, default=2048, help='the point number of a sample')
parser.add_argument('--crop_point_num', type=int, default=1024,
help='0 means do not use else use with this weight')
parser.add_argument('--nc', type=int, default=3)
parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for')
parser.add_argument('--weight_decay', type=float, default=0.001)
parser.add_argument('--learning_rate', default=0.0002, type=float, help='learning rate in training')
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9')
parser.add_argument('--cuda', type=bool, default=True, help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='Trained_Model_peach/point_netG10.pth',
help="path to netG (to continue training)")
# parser.add_argument('--netD', default='Trained_Model/point_netD50.pth', help="path to netD (to continue training)")
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--drop', type=float, default=0.2)
parser.add_argument('--num_scales', type=int, default=3, help='number of scales')
parser.add_argument('--point_scales_list', type=list, default=[2048, 1024, 512], help='number of points in each scales')
parser.add_argument('--each_scales_size', type=int, default=1, help='each scales size')
parser.add_argument('--wtl2', type=float, default=0.9, help='0 means do not use else use with this weight')
parser.add_argument('--cropmethod', default='random_center', help='random|center|random_center')
–dataset:数据集名称,默认为’ModelNet40’
–dataroot:数据集路径,默认为’dataset/train’
–workers:数据加载时的工作进程数,默认为0
–batchSize:输入批量大小,默认为1
–pnum:样本中点的数量,默认为2048
–crop_point_num:裁剪后的点的数量,默认为1024,如果为0则不使用,否则使用该权重
–nc:输入图像的通道数,默认为3
–niter:训练时的epoch数,默认为300
–weight_decay:权重衰减值,默认为0.001
–learning_rate:训练时的学习率,默认为0.0002
–beta1:Adam优化器的beta1值,默认为0.9
–cuda:是否启用CUDA,默认为True
–ngpu:使用的GPU数目,默认为1
–netG:Generator网络的路径,默认为’Trained_Model_peach/point_netG10.pth’
–manualSeed:手动设置的随机种子
–drop:Dropout层的概率,默认为0.2
–num_scales:金字塔尺度的数量,默认为3
–point_scales_list:每个尺度中的点数列表,默认为[2048, 1024, 512]
–each_scales_size:每个尺度的大小,默认为1
–wtl2:L2正则化的权重,默认为0.9,如果为0则不使用,否则使用该权重
–cropmethod:裁剪方法,可选值为’random’、‘center’和’random_center’,默认为’random_center’
解析参数
opt = parser.parse_args()
使用
例如给生成器添加参数 直接 opt.参数名称
# 加载生成器
point_netG = _netG(opt.num_scales, opt.each_scales_size, opt.point_scales_list, 1024)
point_netG = torch.nn.DataParallel(point_netG)
point_netG.to(device)
point_netG.load_state_dict(torch.load(opt.netG, map_location=lambda storage, location: storage)['state_dict'])
point_netG.eval()