pytorch里面nn.Module讲解

nn.Module是在pytorch使用非常广泛的类,搭建网络基本都需要用到这个。

当我们搭建自己的网络时,可以继承官方写好的nn.Module模块,为什么要用这个呢?好处如下:

nn.Module作用

    • 1.可以提供一些现成的基本模块比如:
    • 2. 容器
    • 3.参数管理
    • 4. 所有modules的节点 孩子节点都是直系的
    • 5.to(device)
    • 6.保存和加载模型
    • 7.训练/测试
    • 8.实现自己的类
      • 8.1举一个自己写的线性层的例子

1.可以提供一些现成的基本模块比如:

Linear、ReLU、Sigmoid、Conv2d、Dropout

不用自己一个一个的写这些函数了,这也是为什么我们用框架的原因之一吧。

2. 容器

比如我们经常用到的 nn.Sequential(),顾名思义,将网络模块封装在一个容器中,可以方面网络搭建
如下面一个例子:

class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
                                 nn.MaxPool2d(2, 2),
                                 Flatten(),
                                 nn.Linear(1*14*14, 10))
    def forward(self, x):
        return self.net(x)

3.参数管理

参数名字可以自动生成(想想如果自己去命名,百万参数的网络没法搭建),然后这些参数都可以传到优化器里面去优化

4. 所有modules的节点 孩子节点都是直系的

class BasicNet(nn.Module):
    def __init__(self):
        super(BasicNet, self).__init__()
        self.net = nn.Linear(4, 3)
        
    def forward(self, x):
        return self.net(x)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(BasicNet(),
                                 nn.ReLU(),
                                 nn.Linear(3, 2))
    def forward(self, x):
        return self.net(x)

比如上面的代码,我们可以看出Net网络中有5个孩子节点:nn.Sequential,BasicNet, nn.ReLU,nn.Linear,BasicNet里面的nn.Linear

5.to(device)

nn.Module还有一个功能是将某个网络所有成员、函数、操作都搬移到GPU上面。
采用代码如下:

    device = torch.device('cuda')
    net = Net()
    net.to(device)

上面device代表当前的设备是GPU还是CPU,需要注意的是为什么我们不写

net = net.to(device)

其实效果是一样的,采用nn.Module模块,net加上.to(device),还是net。如果是变量则不是一样的,即如果对于tensor bias,那么biasbias.to(device)不是一样的,则需要重新命名。

6.保存和加载模型

可以方面我们保存和加载模型

加载模型:

net.load_state_dict(torch.load('ckpt.mdl'))

保存模型:

torch.save(net.state_dict(), 'ckpt.mdl')

7.训练/测试

方便训练和测试进行切换,为什么?因为网络中Dropout和BN在训练和测试是不一样的,需要切换
如果不切换效果就会很差,这个是容易犯的一个错误

    net.train()
    net.eval()

8.实现自己的类

官方给的模块还是基础操作的,如果自己要搭建复杂的操作也容易实现,一个典型的例子就是可以自己设计一个新的损失函数。
下面给出将tensor压平的例子(nn.Module没有这个操作):

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, input):
        return input.view(input.size(0), -1)
        
class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
                                 nn.MaxPool2d(2, 2),
                                 Flatten(),  #自己定义的
                                 nn.Linear(1*14*14, 10))

    def forward(self, x):
        return self.net(x)

Flatten压平的操作则是我们自己构建的类,可以方便后续BasicNet类使用,注意nn.Sequential里面必须是类。
且在上面例子中Flatten不需要接任何参数。

8.1举一个自己写的线性层的例子

class MyLinear(nn.Module):
    def __init__(self, inp, outp):
        super(MyLinear, self).__init__()
        # requires_grad = True
        self.w = nn.Parameter(torch.randn(outp, inp))
        self.b = nn.Parameter(torch.randn(outp))

    def forward(self, x):
        x = x @ self.w.t() + self.b
        return x

在上面自己写的线性层 y = w x + b y=wx+b y=wx+b,可以看出 w w w b b b必须要使用nn.Parameter这个模块。原因是只用加上了nn.Parameter后, w w w b b b才可以用优化器SGD等进行优化。

如果不写nn.Parameter那么则需要写requires_grad = True,还要自己写优化器,就很麻烦。用了Parameter可以方便我们优化网络:

model = MyLinear.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()

你可能感兴趣的:(pytorch,机器学习)