pytorch 之 __call__, __init__,forward

在学习pytorch之前,你会看到这样一段代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        #1个输入图像通道,6个输出通道,3x3平方卷积核
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension 
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
 
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
 
net = Net()
print(net)
>>>Net(
 
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
 
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
 
  (fc1): Linear(in_features=576, out_features=120, bias=True)
 
  (fc2): Linear(in_features=120, out_features=84, bias=True)
 
  (fc3): Linear(in_features=84, out_features=10, bias=True)
 )

那关于forward是怎么被调用的,模型训练时,不需要使用forward,只要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数

接下来我们看几个例子,了解一下__call__, __init__,forward
class A():
    def __call__(self):
        print('i can be called like a function')
 
a = A()
a()
>>>i can be called like a function

class B():
    def __init__(self):
        print('i can be called like a function')
 
a = B()
a
>>>i can be called like a function
>>><__main__.B at 0x5978fd0>

class C():
    def __init__(self):
        print('a')
    def __call__(self):
        print('b')    
 
a = C()
a()
>>>a
>>>b

class D():
    def __init__(self, init_age):
        print('我年龄是:',init_age)
        self.age = init_age
 
    def __call__(self, added_age):
        res = self.forward(added_age)
        return res
 
    def forward(self, input_):
        print('forward 函数被调用了')
        
        return input_ + self.age
print('对象初始化。。。。')
a = D(10)
 
input_param = a(2)
print("我现在的年龄是:", input_param)

>>>对象初始化。。。。
我年龄是: 10
forward 函数被调用了
我现在的年龄是: 12

pytorch主要也是就是按照__call____init__,forward三个函数实现网络层之间的架构的,了解这个还得了解super继承,所以从以上例子可看出,定义__call__方法的类可以当作函数调用,当把定义的网络模型model当作函数调用的时候就自动调用定义的网络模型的forward方法。nn.Module 的__call__方法部分源码如下所示:

def __call__(self, *input, **kwargs):
   result = self.forward(*input, **kwargs)
   for hook in self._forward_hooks.values():
       #将注册的hook拿出来用
       hook_result = hook(self, input, result)
   ...
   return result

可以看到,当执行model(x)的时候,底层自动调用forward方法计算结果。

接下来可以看一个例子:

class Animal():
    def __init__(self,name): 
        self.name = name
    def greet(self):
        print('animal is %s' % self.name)
class Dog(Animal):
    def greet(self):
        super(Dog, self).greet()
        print('wangwang')
a=Dog('dog')
a.greet()
>>>animal is dog
wangwang

有关super的知识点可以看一下python关于类的继承等等。

那么调用forward方法的具体流程是什么样的呢?具体流程是这样的:

以一个Module为例:
1. 调用module的call方法
2. module的call里面调用module的forward方法
3. forward里面如果碰到Module的子类,回到第1步,如果碰到的是Function的子类,继续往下
4. 调用Function的call方法
5. Function的call方法调用了Function的forward方法。
6. Function的forward返回值
7. module的forward返回值
8. 在module的call进行forward_hook操作,然后返回值。

你可能感兴趣的:(pytorch 之 __call__, __init__,forward)