AttributeError: Can't get attribute 'Net' on module '__main__'

在使用pytorch加载模型时报错:

torch.save(old_model, PATH)
new_model = torch.load(PATH)

AttributeError: Can't get attribute 'Net' on

 

解决办法:

1、将类的定义添加到加载模型的这个py文件中,这个方法有点。。。

2、使用官方推荐的方法:https://pytorch.org/docs/master/notes/serialization.html

只保存,加载模型的权重参数:

torch.save(the_model.state_dict(), PATH)
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

 

你可能感兴趣的:(错误警告,pytorch)