官方文档,pypi
之前的方法是print(model),或者在debug的时候去取值,但是在pip install torchinfo ,可以很方便的查看网络中的信息:
# 官方的例子
from torchinfo import summary
import torchvision
model = torchvision.models.resnet152()
summary(model, (1, 3, 224, 224), depth=3)
# [看看之前的vivit的效果](https://blog.csdn.net/ResumeProject/article/details/123470594?)
# 效果还行
img = torch.ones([1, 16, 3, 64, 64]).cuda()
# b t c (h p1) (w p2) -> b t (h w) (p1 p2 c) p1=p2=16 # torch.Size([1, 16, 3, 64, 64]) -> torch.Size([1, 16, 16, 192])
model = ViViT(224, 16, 100, 16).cuda()
from torchinfo import *
summary(
model, # PyTorch model
(1, 16, 3, 64, 64), # Shape of input data as a List/Tuple/torch.Size
# dtypes=[torch.long],
# verbose=2,
# col_width=16,
# col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
# row_settings=["var_names"],
)
====================================================================================================
Layer (type:depth-idx) Output Shape Param #
====================================================================================================
ViViT -- --
├─Transformer: 1 -- --
│ └─ModuleList: 2-1 -- --
│ │ └─ModuleList: 3-1 -- 444,288
│ │ └─ModuleList: 3-2 -- 444,288
│ │ └─ModuleList: 3-3 -- 444,288
│ │ └─ModuleList: 3-4 -- 444,288
├─Transformer: 1 -- --
│ └─ModuleList: 2-2 -- --
│ │ └─ModuleList: 3-5 -- 444,288
│ │ └─ModuleList: 3-6 -- 444,288
│ │ └─ModuleList: 3-7 -- 444,288
│ │ └─ModuleList: 3-8 -- 444,288
├─Sequential: 1-1 [1, 16, 16, 192] --
│ └─Rearrange: 2-3 [1, 16, 16, 768] --
│ └─Linear: 2-4 [1, 16, 16, 192] 147,648
├─Dropout: 1-2 [1, 16, 17, 192] --
├─Transformer: 1-3 [16, 17, 192] --
│ └─LayerNorm: 2-5 [16, 17, 192] 384
├─Transformer: 1-4 [1, 17, 192] --
│ └─LayerNorm: 2-6 [1, 17, 192] 384
├─Sequential: 1-5 [1, 100] --
│ └─LayerNorm: 2-7 [1, 192] 384
│ └─Linear: 2-8 [1, 100] 19,300
====================================================================================================
Total params: 3,722,404
Trainable params: 3,722,404
Non-trainable params: 0
Total mult-adds (M): 30.39
====================================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 20.37
Params size (MB): 14.89
Estimated Total Size (MB): 36.05
====================================================================================================
Process finished with exit code 0