利用pytorch预训练resnet模型并保存模型及参数

利用pytorch预训练resnet模型并保存模型及参数

  • Resnet模型简介
  • 预训练模型及保存
  • 输出结果

Resnet模型简介

若将输入设为X,将某一有参网络层设为H,那么以X为输入的此层的输出将为H(X)。一般的CNN网络如Alexnet/VGG等会直接通过训练学习出参数函数H的表达,从而直接学习X -> H(X)。

而残差学习则是致力于使用多个有参网络层来学习输入、输出之间的参差即H(X) - X即学习X -> (H(X) - X) + X。其中X这一部分为直接的identity mapping,而H(X) - X则为有参网络层要学习的输入输出间残差。

利用pytorch预训练resnet模型并保存模型及参数_第1张图片

预训练模型及保存

#dowmload and load the pretrained ResNet-152
import torchvision

resnet=torchvision.models.resnet152(pretrained=True)

#fintune the top layer
for param in resnet.parameters():
    param.requires_grad=False
    resnet.fc=nn.Linear(resnet.fc.in_features,100)
#forwad pass
images=torch.randn(128,3,224,224)
outputs=resnet(images)
print(outputs.size())

#save and load thr entire model
torch.save(resnet,'model.ckpt')
model=torch.load('model.ckpt')

#save and load only the model parameters(recommended)
torch.save(resnet.state_dict(),'params.ckpt')
resnet.load_state_dict(torch.load('params.ckpt'))


输出结果

利用pytorch预训练resnet模型并保存模型及参数_第2张图片
在这里插入图片描述

你可能感兴趣的:(pytorch深度学习,深度学习,pytorch)