Pytorch 笔记 -- model模型

1. 保存载入模型

import torch
torch.save(model,'model_name')  #将网络模型、模型参数全部保存
torch.save(model.state_dict(),'model_name.pkl')  #只保存模型参数

两种载入方式也不相同

torch.load('model_name')

from torchvision import models
model = models.resnet18()   #载入网络
model.load_state_dict(torch.load('model_name.pkl')) #载入参数

2. 查看模型参数

for param_name, param in model.named_parameters():
	print (param_name,param.shape)

3.修改模型参数

param.copy_(input_param)   #本程序待验证,摘自torch.nn.modules.module.py :656行 param 为原来的参数,input_param 为新的参数,两者维度相同

4. CLASS torch.nn.Module

下面记录此类包含函数的使用方法

  • add_module(name, module)
    在当前的模块中增加一个子模块
model.add_module(name, module)
...
  • children()
    返回子模块的迭代器。

查看model下的子模块

from module in model.children()
	pass
  • cpu()
    将所有模型参数和缓冲区移动到CPU。(Moves all model parameters and buffers to the CPU.)

使用cpu载入模型

model.cpu()
  • cuda(device=None)
    将所有模型参数和缓冲区移动到GPU。(Moves all model parameters and buffers to the GPU.)
    在构造优化器之前使用它

使用gpu载入模型

model.cuda()
  • double()
    将所有浮点参数和缓冲区转换为double数据类型。
  • float()
    将所有浮点参数和缓冲区转换为float数据类型。
  • eval()
    设置模型为评估模式,在测试模型之前使用
model.eval()
  • half()
    Casts all floating point parameters and buffers to half datatype.
  • forward(*input)
    定义每次调用时执行的计算。应该被所有子类覆盖。(Defines the computation performed at every call.
    Should be overridden by all subclasses.)
    每次调用网络模型时会执行此函数
    待续。。。

你可能感兴趣的:(Pytorch,Python,pytorch,模型)