github上的参数代码一般是如下格式:
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--test_dir", type=str, default="test", help='Directory containing the test data (must have subdirectory noisy/)')
parser.add_argument("--enhanced_dir", type=str, default="test", help='Directory containing the enhanced data')
parser.add_argument("--ckpt", type=str, default="logs/epoch=326-step=408750.ckpt", help='Path to model checkpoint.')
parser.add_argument("--corrector", type=str, choices=("ald", "langevin", "none"), default="ald", help="Corrector class for the PC sampler.")
parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps")
parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics.")
parser.add_argument("--N", type=int, default=30, help="Number of reverse steps")
args = parser.parse_args()
为了让jupyter notebook运行代码,需要将ArgumentParser定义的参数替换,但是又由于项目中用了很多args的属性,因此定义一个需要定义 一个 args.参数的函数,如下替换就可以:
class Point_Attribute(dict):
def __init__(self, **kwargs):
super().__init__(kwargs)
def __setattr__(self, key, value):
self[key] = value
def __getattr__(self, key):
self[key]
return self[key]
args = Point_Attribute(
test_dir="test",
enhanced_dir="/kaggle/working/",
ckpt="logs/epoch=326-step=408750.ckpt",
corrector="ald",
corrector_steps=1,
snr=0.5,
N=4
)