模型可视化工具 torchinfo 计算每层输出的Output Shape

官方文档,pypi

之前的方法是print(model),或者在debug的时候去取值,但是在pip install torchinfo ,可以很方便的查看网络中的信息:

# 官方的例子
from torchinfo import summary
import torchvision

model = torchvision.models.resnet152()
summary(model, (1, 3, 224, 224), depth=3)
# [看看之前的vivit的效果](https://blog.csdn.net/ResumeProject/article/details/123470594?)
# 效果还行
	img = torch.ones([1, 16, 3, 64, 64]).cuda()
    # b t c (h p1) (w p2) -> b t (h w) (p1 p2 c)   p1=p2=16  # torch.Size([1, 16, 3, 64, 64]) -> torch.Size([1, 16, 16, 192])
    model = ViViT(224, 16, 100, 16).cuda()
    from torchinfo import *
    summary(
        model,                 # PyTorch model
        (1, 16, 3, 64, 64),    #  Shape of input data as a List/Tuple/torch.Size
        # dtypes=[torch.long],
        # verbose=2,
        # col_width=16,
        # col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
        # row_settings=["var_names"],
    )
====================================================================================================
Layer (type:depth-idx)                             Output Shape              Param #
====================================================================================================
ViViT                                              --                        --
├─Transformer: 1                                   --                        --
│    └─ModuleList: 2-1                             --                        --
│    │    └─ModuleList: 3-1                        --                        444,288
│    │    └─ModuleList: 3-2                        --                        444,288
│    │    └─ModuleList: 3-3                        --                        444,288
│    │    └─ModuleList: 3-4                        --                        444,288
├─Transformer: 1                                   --                        --
│    └─ModuleList: 2-2                             --                        --
│    │    └─ModuleList: 3-5                        --                        444,288
│    │    └─ModuleList: 3-6                        --                        444,288
│    │    └─ModuleList: 3-7                        --                        444,288
│    │    └─ModuleList: 3-8                        --                        444,288
├─Sequential: 1-1                                  [1, 16, 16, 192]          --
│    └─Rearrange: 2-3                              [1, 16, 16, 768]          --
│    └─Linear: 2-4                                 [1, 16, 16, 192]          147,648
├─Dropout: 1-2                                     [1, 16, 17, 192]          --
├─Transformer: 1-3                                 [16, 17, 192]             --
│    └─LayerNorm: 2-5                              [16, 17, 192]             384
├─Transformer: 1-4                                 [1, 17, 192]              --
│    └─LayerNorm: 2-6                              [1, 17, 192]              384
├─Sequential: 1-5                                  [1, 100]                  --
│    └─LayerNorm: 2-7                              [1, 192]                  384
│    └─Linear: 2-8                                 [1, 100]                  19,300
====================================================================================================
Total params: 3,722,404
Trainable params: 3,722,404
Non-trainable params: 0
Total mult-adds (M): 30.39
====================================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 20.37
Params size (MB): 14.89
Estimated Total Size (MB): 36.05
====================================================================================================

Process finished with exit code 0

你可能感兴趣的:(深度学习,transformer,深度学习,pytorch)