Pytorch参数注册

register_parameter()和parameter()

pytorch模型注册参数的常用方法

相同点:将一个不可训练的类型Tensor转换成可以训练的类型parameter,并将这个parameter绑定到这个module里面,相当于变成了模型的一部分,成为了模型中可以根据训练进行变化的参数。

不同点:获取参数时,使用的名称不同

class Example(nn.Module):
    def __init__(self):
        super(Example, self).__init__()
        self.W1_params = nn.Parameter(torch.rand(2,3))
        self.register_parameter('W2_params', nn.Parameter(torch.rand(2,3)))
        
    def forward(self, x):
        return x

nn.ModuleList和nn.ModuleDict

这是另外两种注册可学习参数的途径

nn.ModuleList

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
	def __init__(self):
		super(Net,self).__init__()
		self.linears = nn.ModuleList([nn.Linear(4,4),nn.Linear(4,4),nn.Linear(4,2)])
	
	def forward(self,x):
		for linear in self.linears:
			x = linear(x)
			x = F.relu(x)
		return x

if __name__ == '__main__':
	net = Net()
	for parameter in net.parameters():
		print(parameter)		

nn.ModuleDict

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

Parameter和Buffer

分析他们的异同,作为参数注册的总结吧

  1. 模型中需要进行更新的参数注册为Parameter,不需要进行更新的参数注册为buffer
  2. 模型保存的参数是 model.state_dict() 返回的 OrderDict
  3. 模型进行设备移动时,模型中注册的参数(Parameter和buffer)会同时进行移动

参考资料

pytorch中的register_parameter()和parameter()

Pytorch参数注册问题和nn.ModuleList nn.ModuleDict

Pytorch模型中的parameter与buffer

你可能感兴趣的:(科学炼丹,pytorch,深度学习,python)