命令行解析之parse_args()和从checkpoint文件中导入模型参数(checkpoint['state_dict'],checkpoint['optimizer'])

import argparse
def parse_args():
        parser = argparse.ArgumentParser(description='PyTorch Implementation of DeepCluster')
        parser.add_argument('data', metavar='DIR', help='path to dataset')
        return parser.parse_args()

还有一个:

if args.resume:
        if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

你可能感兴趣的:(pytorch)