【PyTorch】载入模型的两种方法

import torch
import argparse

parser = argparse.ArgumentParser("-")
parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="model")
args = parser.parse_args()

# 载入模型和载入参数
if torch.cuda.is_available():
    if opt.pre_trained_model_type == "model":
        model = torch.load(opt.pre_trained_model_path)
    if opt.pre_trained_model_type == "params":
        model = m()
        model.load_state_dict(torch.load(opt.pre_trained_model_path))
else:
    if opt.pre_trained_model_type == "model":
        model = torch.load(opt.pre_trained_model_path, map_location=lambda storage, loc: storage)
    if opt.pre_trained_model_type == "params":
        model = m()
        model.load_state_dict(torch.load(opt.pre_trained_model_path, map_location=lambda storage, loc: storage))

你可能感兴趣的:(PyTorch)