为了解决实现快速debug,精准描述深度网络的输入结构,输出结构及参数等信息,人们使用torchinfo工具包。
import torch.nn as nn
from torchinfo import summary
transformer_model = nn.Transformer(nhead=16)
summary(transformer_model, [(10, 32, 10), (10, 32, 10)])
====================================================================================================
Layer (type:depth-idx) Output Shape Param #
====================================================================================================
Transformer -- --
├─TransformerEncoder: 1-1 -- (recursive)
│ └─ModuleList: 2-1 -- --
├─TransformerDecoder: 1 -- --
│ └─ModuleList: 2-2 -- --
├─TransformerEncoder: 1-2 [10, 32, 512] --
│ └─ModuleList: 2-1 -- --
│ │ └─TransformerEncoderLayer: 3-1 [10, 32, 512] 3,152,384
│ │ └─TransformerEncoderLayer: 3-2 [10, 32, 512] 3,152,384
│ │ └─TransformerEncoderLayer: 3-3 [10, 32, 512] 3,152,384
│ │ └─TransformerEncoderLayer: 3-4 [10, 32, 512] 3,152,384
│ │ └─TransformerEncoderLayer: 3-5 [10, 32, 512] 3,152,384
│ │ └─TransformerEncoderLayer: 3-6 [10, 32, 512] 3,152,384
├─TransformerDecoder: 1 -- --
│ └─ModuleList: 2-3 -- (recursive)
├─TransformerEncoder: 1-1 -- (recursive)
│ └─LayerNorm: 2-4 [10, 32, 512] 1,024
├─TransformerDecoder: 1-3 [20, 32, 512] --
│ └─ModuleList: 2-2 -- --
│ │ └─TransformerDecoderLayer: 3-7 [20, 32, 512] 4,204,032
│ │ └─TransformerDecoderLayer: 3-8 [20, 32, 512] 4,204,032
│ │ └─TransformerDecoderLayer: 3-9 [20, 32, 512] 4,204,032
│ │ └─TransformerDecoderLayer: 3-10 [20, 32, 512] 4,204,032
│ │ └─TransformerDecoderLayer: 3-11 [20, 32, 512] 4,204,032
│ │ └─TransformerDecoderLayer: 3-12 [20, 32, 512] 4,204,032
│ └─LayerNorm: 2-5 [20, 32, 512] 1,024
====================================================================================================
Total params: 25,229,312
Trainable params: 25,229,312
Non-trainable params: 0
Total mult-adds (M): 378.47
====================================================================================================
Input size (MB): 1.97
Forward/backward pass size (MB): 184.81
Params size (MB): 100.92
Estimated Total Size (MB): 287.69
====================================================================================================
命令窗口输入
pip install tensorboardX
from tensorboardX import SummaryWriter
writer = SummaryWriter('./runs')
命令窗口输入
tensorboard --logdir=/path/to/logs/ --port=xxxx
transformer_model = nn.Transformer(nhead=16)
writer.add_graph(transformer_model, input_to_model=[torch.rand(10, 32, 512),
torch.rand(20, 32, 512)])
writer.close()