Pytorch搭建模型需要注意一个问题

Pytorch搭建模型需要注意一个问题

  • 【问题描述】
  • 【解决方式】
  • 【代码示例】

【问题描述】

我在搭建一个包括多分支的网络,由于分支数量不确定需要用到容器;最初直接用的list,出现了子模块的self.training属性无法和模型自身同步的问题。

【解决方式】

需要将list转为torch.nn.ModuleList类别,才能同步self.training属性。搭建网络时,子模块都需要使用torch定义的container存放。如果是串接就用torch.nn.Sequential (这个很熟悉,相信都不会犯错);如果不是串接,自己定义forward方式的话,要用torch.nn.ModuleList而不能用Python中原生的容器。

【代码示例】

以下示意代码仅供说明这种情况:

import torch
import torch.nn as nn

class SubModule(nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
    
    def forward(self, x):
        if self.training:
            print('running sub-module in training mode.')
            return x
        else:
            print('running sub-module in evaluation mode.')
            return x+1


class Model(nn.Module):
    def __init__(self, is_container=False, num_of_subs=3):
        super(Model, self).__init__()
        
        # self.sub_modules = nn.ModuleList()
        # for i in range(num_of_subs):
        #     self.sub_modules.append(SubModule())

        self.sub_modules = []
        for i in range(num_of_subs):
            self.sub_modules.append(SubModule())
        if is_container:       
            self.sub_modules = nn.ModuleList(self.sub_modules)
    
    def forward(self, x):
        outputs = []
        for md in self.sub_modules:
            outputs.append(md(x))
        return sum(outputs)


if __name__ == '__main__':
    x = 0
    print('[WRONG]-------------------------')
    model = Model(is_container=False)
    model.eval()        
    print('model is training mode.' if model.training else 'model is evaluation mode.')
    with torch.no_grad():
        y = model(x)
        print(y)

    print('[RIGHT]-------------------------')
    model = Model(is_container=True)
    model.eval()
    print('model is training mode.' if model.training else 'model is evaluation mode.')
    with torch.no_grad():
        y = model(x)
        print(y)

输出如下:

[WRONG]-------------------------
model is evaluation mode.
running sub-module in training mode.
running sub-module in training mode.
running sub-module in training mode.
0
[RIGHT]-------------------------
model is evaluation mode.
running sub-module in evaluation mode.
running sub-module in evaluation mode.
running sub-module in evaluation mode.
3

通过结果可以发现,就算指定了model.eval()并且加上了torch.no_grad()限制,若不使用torch定义的容器存放子模块,它是无法同步模型的training和evaluation状态的。此种情况下,当需要针对子模块在训练和推理阶段做不同操作时就会产生错误。

你可能感兴趣的:(深度学习,pytorch,深度学习,python)