PyTorch中nn.Module理解

PyTorch中nn.Module理解_第1张图片

  • nn.Module是Pytorch封装的一个类,是搭建神经网络时需要继承的父类:
import torch
import torch.nn as nn

# 括号中加入nn.Module(父类)。Test2变成子类,继承父类(nn.Module)的所有特性。
class Test2(nn.Module):  
    def __init__(self):  # Test2类定义初始化方法
       super(Test2, self).__init__()  # 父类初始化
       self.M = nn.Parameter(torch.ones(10))
        
    def weightInit(self):
        print('Testing')

    def forward(self, n):
        # print(2 * n)
        print(self.M * n)
        self.weightInit()

# 调用方法
network = Test2()
network(2)  # 2赋值给forward(self, n)中的n。
……省略一部分代码……
# 因为Test2是nn.Module的子类,所以也可以执行父类中的方法。如:
model_dict = network.state_dict()  # 调用父类中的方法state_dict(),将Test2中训练参数赋值model_dict。
for k, v in model_dict.items():  # 查看自己网络参数各层名称、数值
	print(k)  # 输出网络参数名字
    # print(v)  # 输出网络参数数值

继承nn.Module的子类程序是从forward()方法开始执行的,如果要想执行其他方法,必须把它放在forward()方法中。这一点与python中继承有稍许的不同。

PyTorch中nn.Module理解_第2张图片

你可能感兴趣的:(机器学习,nn.Module,pytorch)