《动手学深度学习》(PyTorch版)代码注释 - 16 【Model_construction】

目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上4.5节
此节功能为:读取和存储
由于此节相对容易理解,代码注释量较少

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.8_mlp
# 4.5 读取和存储
#注释:黄文俊
#邮箱:[email protected]

import torch
from torch import nn

x = torch.ones(3)
torch.save(x, 'Read&Write/x.pt')

x2 = torch.load('Read&Write/x.pt')
print(x2)

y = torch.zeros(4)
torch.save([x, y], 'Read&Write/x.pt')

xy_list = torch.load('Read&Write/x.pt')
print(xy_list)


# 4.5.2 读写模型
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP()
print(net)
print(net.state_dict())
# state_dict是一个从参数名称隐射到参数Tesnor的字典对象,即可以以字典形式返回net中不同层的参数值

optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print(optimizer.state_dict())
# 优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。


# 4.5.2.2 保存和加载模型
'''
1. 保存和加载state_dict(推荐方式)
保存:
torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth

加载:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

2. 保存和加载整个模型
保存:torch.save(model, PATH)
加载:model = torch.load(PATH)

'''

X = torch.randn(2, 3)
Y = net(X)

PATH = "./Read&Write/net.pt"
torch.save(net.state_dict(), PATH)

net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
print(Y2 == Y)







print("*"*30)

你可能感兴趣的:(python,深度学习,人工智能,pycharm,pytorch)