读取Pytorch .pt格式的模型

核心代码

class Net()...  //需要先声明类Net
torch.load('xxx.pt')
//打印实际的层数和参数
for name,param in m_state_dict.named_parameters():
    print(name,param)

截图结果如下:

读取Pytorch .pt格式的模型_第1张图片

你可能感兴趣的:(Python,Pytorch,AI)