pytorch学习笔记(七)---网络模型的使用修改、保存以及加载

        本篇自学笔记来自于b站《PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】》,Up主讲的非常通俗易懂,文章下方有视频连接,如有需要可移步up主讲解视频,如有侵权,实非故意,深表歉意,请与我联系,删除相关内容!

        本节将介绍的内容有:1.使用torchvison中定义好的模型 ,2.如何修改定义好的模型,3.保存模型的方式和对应的加载模型的方式。

        1.使用定义好的模型(以VGG16为例)

         给出官方解释:如图可以看到有两个字段,分别为pretrained和progress,pretrained表示是否使用预训练好的模型,该模型是在ImageNet上训练好的。progress则为是否显示下载进度条。pytorch学习笔记(七)---网络模型的使用修改、保存以及加载_第1张图片

        代码为:分别写了pretrained为true和false的两种情况。

vgg16_false= torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

        2.修改上述模型 

        首先输出VGG16的网络模型如下:

pytorch学习笔记(七)---网络模型的使用修改、保存以及加载_第2张图片

 pytorch学习笔记(七)---网络模型的使用修改、保存以及加载_第3张图片

        首先演示给模型的classifier中添加一层Linear层:

vgg16_true.classifier.add_module(name="add_linear",module=nn.Linear(in_features=1000,out_features=10))

         其中add_module的官方定义如下,需要给出添加的名字,以及具体添加的某一层的定义。    pytorch学习笔记(七)---网络模型的使用修改、保存以及加载_第4张图片

        添加完之后结果如下:pytorch学习笔记(七)---网络模型的使用修改、保存以及加载_第5张图片         其次也可以在模型中直接修改,接下来展示的是直接修改classifier中(6)Linear层。

vgg16_false.classifier[6] = nn.Linear(in_features=4096,out_features=10)

        结果如下:pytorch学习笔记(七)---网络模型的使用修改、保存以及加载_第6张图片 

        3.网络模型的保存和加载 

        保存方式主要有两种,对应的加载方式也是两种,下面展示第一种保存方式,这种保存方式保存了模型的结构和参数。第一个参数为要保存的模型名,第二个参数为要保存的路径。 

torch.save(vgg16,"vgg16_method1.pth")

         对应的读取方式为:

vgg_method1 = torch.load("vgg16_method1.pth")

        读取的结果如下图,可以看到确实保存了模型的结构和参数pytorch学习笔记(七)---网络模型的使用修改、保存以及加载_第7张图片

         第二种保存方式,这种保存方式是以字典的形式保存了模型参数,而不保存模型结构,所以对应的读取方式也会不同。

#保存方式2
#只保存模型参数
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
#对应的读取
vgg16_method2 = vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16_method2)

附视频地址:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili

你可能感兴趣的:(pytorch,学习笔记,python,pytorch,深度学习)