torch.nn.Conv2d源代码解析

1,isinstance(padding, str)什么意思?

torch.nn.Conv2d源代码解析_第1张图片

2,super(Conv2d, self).__init__什么意思?

第一步:

super函数用于多层继承(multilevel inheritance)的情况,简单来说,就是之继承最近的那个父类。

class A:
    def __init__(self):
        print('Initializing: class A')

    def sub_method(self, b):
        print('Printing from class A:', b)


class B(A):
    def __init__(self):
        print('Initializing: class B')
        super().__init__()

    def sub_method(self, b):
        print('Printing from class B:', b)
        super().sub_method(b + 1)


class C(B):
    def __init__(self):
        print('Initializing: class C')
        super().__init__()

    def sub_method(self, b):
        print('Printing from class C:', b)
        super().sub_method(b + 1)


if __name__ == '__main__':
    c = C()
    c.sub_method(1)

# Initializing: class C
# Initializing: class B
# Initializing: class A
# Printing from class C: 1
# Printing from class B: 2
# Printing from class A: 3

c = C() 创建了一个class C的实例,然后可以看到初始化是从C->B->A的。

c.sub_method(1) 首先调用了C类里的sub_method(),输出了1,然后通过super().sub_method(b + 1)调用了B类里的sub_method()。可以看到C类里的super()就是代替了class C(B)里的B类

第二步:

super(Net, self).__init__()

Python中的super(Net, self).__init__()是指首先找到Net的父类(比如是类NNet),然后把类Net的对象self转换为类NNet的对象,然后“被转换”的类NNet对象调用自己的init函数,其实简单理解就是子类把父类的__init__()放到自己的__init__()当中,这样子类就有了父类的__init__()的那些东西。 

回过头来看看我们的我们最上面的代码,Net类继承nn.Module,super(Net, self).__init__()就是对继承自父类nn.Module的属性进行初始化。而且是用nn.Module的初始化方法来初始化继承的属性。

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 输入图像channel:1;输出channel:6;5x5卷积核
        self.conv1 = nn.Conv2d(1, 6, 5)

 也就是说,子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化。

class Person:
    def __init__(self,name,gender):
        self.name = name
        self.gender = gender
    def printinfo(self):
        print(self.name,self.gender)

class Stu(Person):
    def __init__(self,name,gender,school):
        super(Stu, self).__init__(name,gender) # 使用父类的初始化方法来初始化子类
        self.school = school
    def printinfo(self): # 对父类的printinfo方法进行重写
        print(self.name,self.gender,self.school)

if __name__ == '__main__':
    stu = Stu('djk','man','nwnu')
    stu.printinfo()

 

参考链接:

https://bramblexu.com/posts/3adca41/
https://blog.csdn.net/dongjinkun/article/details/114575998

3,def _conv_forward()和def forward什么意思?

我们在使用Pytorch的时候,模型训练时,不需要调用forward这个函数,只需要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数。这个称之为前向传播。

 class Module(nn.Module):
    def __init__(self):
        super().__init__()
        # ......

    def forward(self, x):
        # ......
        return x


data = ......  # 输入数据

# 实例化一个对象
model = Module()

# 前向传播
model(data)

# 而不是使用下面的
# model.forward(data) 

但是实际上model(data)是等价于model.forward(data)

4,

torch.nn.Conv2d源代码解析_第2张图片

 torch.nn.Conv2d源代码解析_第3张图片

 torch.nn.Conv2d源代码解析_第4张图片

 torch.nn.Conv2d源代码解析_第5张图片

 torch.nn.Conv2d源代码解析_第6张图片

 

你可能感兴趣的:(pytorch,pytorch)