最近一直在搞课题,因为看代码不直观,所以将网络结构进行可视化处理。使用了两种方法,各有优缺点,下面记录一下使用方法供人参考
方法一:torchsummary可视化
torchsummary可视化是pytorch可视化的一种方法,需要安装库,关于库的安装可以搜一下帖子,然后就是关于使用方法。首先导入这个库,在model里更改需要可视化的结构,这里我可视化的是我的判别器,然后传入网络设定的256x1024大小的图片,就可以打印出每层的输入输出和参数
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Discriminator()是我需要可视化的判别器
model = Discriminator().to(device)
#传入model和一个3通道256x1024大小的图片
summary(model, (3, 256, 1024))
model = Discriminator()
print(model)
这里打印出的就是输入图片经过每一层的输入输出,可以看到每一层的变化以及最后的参数量
E:\software\anaconda\envs\pytorch\python.exe F:/GASDA-master/models/networks.py
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 128, 512] 3,136
LeakyReLU-2 [-1, 64, 128, 512] 0
Conv2d-3 [-1, 128, 64, 256] 131,072
BatchNorm2d-4 [-1, 128, 64, 256] 256
LeakyReLU-5 [-1, 128, 64, 256] 0
Conv2d-6 [-1, 256, 32, 128] 524,288
BatchNorm2d-7 [-1, 256, 32, 128] 512
LeakyReLU-8 [-1, 256, 32, 128] 0
Conv2d-9 [-1, 512, 31, 127] 2,097,152
BatchNorm2d-10 [-1, 512, 31, 127] 1,024
LeakyReLU-11 [-1, 512, 31, 127] 0
Conv2d-12 [-1, 1, 30, 126] 8,193
================================================================
Total params: 2,765,633
Trainable params: 2,765,633
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 182.17
Params size (MB): 10.55
Estimated Total Size (MB): 195.72
torchsummary可视化的优点就在于可以直观的看到图片经过了哪些层,以及经过了每一层的输入输出和这个网络模块的参数量,这样在对网络进行轻量化的时候可以对比参数量比较直观。缺点就是对于复杂的网络结构看起来很费劲,因为在可视化u-net等类似于跳层连接的网络时很难看清到底从哪层跳到了哪层,所以需要方法二。
方法二:netron可视化
pytorch的可视化方法之一,安装方法可以去搜一下,可以安装库来可视化也可以在网页上可视化。这里有一个坑,就是如果网络跑完只保存了权重而没有保存网络模型的时候那么这时候可视化是看不出来任何东西的,类似于以下这种,画出来也看不到任何网络结构,需要额外处理
在可视化网络时可以用以下代码,这里randn中的第一个参数是batch,设成1就行,输出网络结果后会在终端给出一个链接,点进去就能看到可视化结果了
myNet = Discriminator() # 实例化
x = torch.randn(1, 3, 256, 1024) # 随机生成一个输入
modelData = "./demo.pth" # 定义模型数据保存的路径
# modelData = "./demo.onnx" # 有人说应该是 onnx 文件,但我尝试 pth 是可以的
torch.onnx.export(myNet, x, modelData) # 将 pytorch 模型以 onnx 格式导出并保存
netron.start(modelData) # 输出网络结构
这里是我的判别器的一个网络结构,因为这个结构比较简单,所以整个图片不长,在可视化u-net的时候图片就特别长,而且因为画的很细,所以看起来也比较复杂,优点是可以看清跳层连接和拼接操作,到底是从哪到哪的,缺点就是没法看到总体参数,如果网络一旦复杂就看起来很痛苦
关于代码加到哪里的问题:
最后就是关于在代码的哪里使用这些可视化的代码,还不太弄得懂代码的时候很容易不知道代码要放在哪里,用我的代码举例,首先要找到model这个存储模型的文件夹,然后要找到需要可视化的网络模块,一般的网络模块都会继承于pytorch的nn.Module,在这里封装好模块,在网络模型里进行调用,下图就是我的判别器模块,在可视化的时候把这个模块名加入到上面的代码里就可以进行以上两种方法的可视化了