当我们使用pytorch进行模型训练或测试时,有时候希望能知道模型每一层分别是什么,具有怎样的参数。此时我们可以将模型打印出来,输出每一层的名字、类型、参数等。
常用的命令行打印模型结构的方法有两种:
一是直接print
二是使用torchsummary库的summary
但是二者在输出上有着一些区别。首先说结论:
1. print输出结果是每一层的名字、类别、以及构造时的参数,例如对于卷积层,还包括用户定义的stride、bias等;而torch summary则会打印类别、深度、输出Tensor的形状、参数数量等。
2. 这也是很重要的一点,print打印的每一层顺序,是模型init函数中定义的顺序,而torchsummary则是模型执行起来输入张量真正计算的顺序。
以下举例说明。
我们首先定义一个网络:
class testModel(nn.Module):
def __init__(self):
super(testModel,self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, bias=False)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, bias=False)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, bias=False)
def forward(self, x):
out = self.conv1(x)
out = self.relu1(out)
out = self.conv2(out)
out = self.relu2(out)
out = self.conv3(out)
return out
包含3个卷积层与2个激活函数,平平无奇,而且,定义的模块顺序与执行顺序一致。(输入变量依次经过模型的conv1,relu1,conv2,relu2,conv3)
对于直接print的方法,我们通过:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
t = testModel().to(device)
print(t)
实现,结果如下:
按照conv1,relu1,conv2,relu2,conv3的顺序打印了每一层的结构与参数。
对于torch summary的方式,首先pip install torch-summary
安装这个包,在代码中:
from torchsummary import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
t = testModel().to(device)
summary(t, (3, 224, 224))
得到结果:
可见,同样打印出了每一层的信息,不同之处在于,此时看不出一些设计细节,比如卷积层的stride,是否有bias等(可以算出,不够直接)。但是多了一些信息,比如输入张量通过每一层后的输出是什么形状,以及每一层和总的参数量。
这里注意,我们定义每一层的顺序,和模型跑起来的实际执行顺序一致,因此两种打印方式看起来每一层的顺序也都是一样的。
我们修改testModel的init函数:
class testModel(nn.Module):
def __init__(self):
super(testModel, self).__init__()
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, bias=False)
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, bias=False)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, bias=False)
self.relu2 = nn.ReLU(inplace=True)
def forward(self, x):
out = self.conv1(x)
out = self.relu1(out)
out = self.conv2(out)
out = self.relu2(out)
out = self.conv3(out)
return out
打乱了定义时的顺序。这一点其实很常见,对于复杂的网络或多次修改了代码的网络,初始化这里看起来比较乱是很正常的(虽然不建议)。如上述代码所示,初始化虽然还是这5层,但是顺序跟执行顺序是不一样的,但执行顺序仍未改变。我们再用两种方式打印一下:
通过print:
通过torch summary:
可见print的方式打印出来仍然是定义时的结构,哪个成员变量先声明,就先打印出哪个。虽然对于这个简单的网络,我们从直觉和参数情况能够推断出一定的顺序,但是对于复杂网络就很有可能被迷惑。
而torch summary的方式则按照正常执行顺序打印了每一层的信息。
综上所述,两种方式在打印的信息熵存在差异,而更重要的是,print打印模型的方式在判断模型层次结构方面是不完全可信的,因为它是按照初始化定义等顺序打印的。而torch summary的方式虽然没有一些网络设计细节,但是在层次顺序上是可靠的,同时给出了参数量等信息。
究其原因,个人认为可能是如下原因造成的:
首先,我们知道python是能够print复杂数据结构的,如:
比C++的输入流高到哪里去(不是)。个人猜测python的print可能是递归每一层数据结构并打印的。
而一个torch.nn.Module的每层信息如何存储呢,可以参看这篇博客的部分内容,每一层都是存在一个字典里的。打印的时候自然是遍历整个模型的child,也就是字典中记录的每一层并打印,这个顺序就由初始化也就是加入字典的顺序而规定。
而torch summary则是传入模型与输入张量的尺寸,可以认为传入了dummy input,来探索输入张量真实经过每一层的顺序以及通过每一层后的尺寸,其结果自然保存了真实执行顺序。
总结来说
torch summary作为一个完善的工具库,确实具有一些由于print模型的优点。如果只是简单查看模型各层,完全可以使用print,速度更快。但如果想知道模型每一层的真正的顺序,以及参数量或每一层输出大小这些细节,还是建议使用torch summary。尤其是对于复杂网络,各层的定义顺序与真实执行顺序可能存在差别。