最近调整网络模型结构,添加模块的时候遇到输入输出维度的问题,每一次都需要debug来看网络输出的维度,很麻烦,因此去找了一些能将模型每一层输入输出每一层可视化的代码。
需要注意的是,每一个模块之后会有一行是显示该模块的输入和输出,并不单指卷积、激活等操作,如下图,1-3行是每一次计算操作的输入输出及参数数量,而1-3行属于模块BasicConv,所以第4显示的是BasicConv模块整体输入输出和参数数量:
# coding:utf8
import torch
from torch.autograd import Variable
from collections import OrderedDict
from torch import nn
import pandas as pd
import numpy as np
from nets.yolo import YoloBody
def get_names_dict(model):
names = {}
def _get_names(module, parent_name=''):
for key, module in module.named_children():
name = parent_name + '.' + key if parent_name else key
names[name] = module
if isinstance(module, torch.nn.Module):
_get_names(module, parent_name=name)
_get_names(model)
return names
def torch_summarize_df(input_size, model, weights=False, input_shape=True, nb_trainable=False):
def register_hook(module):
def hook(module, input, output):
name = ''
for key, item in names.items():
if item == module:
name = key
#
class_name = str(module.__class__).split('.')[-1].split("'")[0]
module_idx = len(summary)
m_key = module_idx + 1
summary[m_key] = OrderedDict()
# summary[m_key]['name'] = name
#这个name可能会特别长,影响可视化,不介意的可以也打出来
summary[m_key]['class_name'] = class_name
if input_shape:
summary[m_key]['input_shape'] = (-1,) + tuple(input[0].size())[1:]
summary[m_key]['output_shape'] = (-1,) + tuple(output[0].size())[:]
if weights:
summary[m_key]['weights'] = list(
[tuple(p.size()) for p in module.parameters()])
# summary[m_key]['trainable'] = any([p.requires_grad for p in module.parameters()])
if nb_trainable:
params_trainable = sum(
[torch.LongTensor(list(p.size())).prod() for p in module.parameters() if p.requires_grad])
summary[m_key]['nb_trainable'] = params_trainable
params = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters()])
summary[m_key]['nb_params'] = params
if not isinstance(module, nn.Sequential) and \
not isinstance(module, nn.ModuleList) and \
not (module == model):
hooks.append(module.register_forward_hook(hook))
# 名称存储在parent中,path+name是唯一的,而不是名称
names = get_names_dict(model)
# 检查网络是否有多个输入
if isinstance(input_size[0], (list, tuple)):
x = [Variable(torch.rand(1, *in_size)) for in_size in input_size]
else:
x = Variable(torch.rand(1, *input_size))
if next(model.parameters()).is_cuda:
x = x.cuda()
# 创建属性
summary = OrderedDict()
hooks = []
# 统计参数信息/ 注册hook
model.apply(register_hook)
# 向前传递
model(x)
# 移除这些hook
for h in hooks:
h.remove()
# 制作结构
df_summary = pd.DataFrame.from_dict(summary, orient='index')
return df_summary
注意 :第36行的位置,我将summary[m_key]['name'] = name注释调了,因为name显示的是每一层在代码中的名称,有时候特别长会影响我查找,如果大家有需要可以将其还原,以下是存在name的情况下:
# 导入项目中模型,在最上面导入就好
from nets.yolo import YoloBody
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model= YoloBody([[6, 7, 8], [3, 4, 5], [0, 1, 2]], 20, phi=0).to(device)
# input_size可以不用考虑batch_size这里只看网络结构
# 显示出来的网络结构会在batch_size的位置显示-1
df = torch_summarize_df(input_size=(3, 416, 416), model=model)
# 以下代码是解决print的时候会存在内容过多而省略无法显示的情况
# 如果以下代码还不能完全打印,请根据注释提示将其修改
np.set_printoptions(threshold=np.inf)
pd.set_option('display.width', 500)# 设置字符显示宽度
pd.set_option('display.max_rows', None)# 设置显示最大行
pd.set_option('display.max_columns', None)# 设置显示最大列,None为显示所有列
# 自动创建一个名为input_output.txt
# 'w'是覆盖式写入文件,以防调整需要的信息时省去删除文件内容的操作
# 不想覆盖的可以换成'a'
f = open("input_output.txt", 'w')
print(df,file=f) # 写入文件input_output.txt
上述文件运行后会在项目文件下直接生成一个
以上内容参考pytorch实用工具总结 - 知乎进行修改,但是该文章内容我在一开始运行的时候有一定的错误,本文章的内容都是在调整修改后的代码,直接使用即可。