PyTorch 1.x Visulalization (可视化)

PyTorch 1.x Visulalization -可视化

  • 1. 简介
  • 2. TorchSummary
  • 3. TensorWatch
    • 3.1 保存模型图
    • 3.2 查看层参数
    • 3.3 错误及解决方案
      • 3.3.1 AttributeError: module 'torch.onnx' has no attribute 'set_training'
      • 3.3.2 AttributeError: ‘Dot’ object has no attribute ‘repr_svg’
      • 3.3.3 FileNotFoundError: [Error 2] "dot" not found in path
  • 4. NetRon
  • 参考

1. 简介

2. TorchSummary

  • 安装
pip install torchsummary
  • 定义
from torchsummary import summary
summary(your_model, input_size=(channels, H, W))
  • 参数说明

    • your_model:需要查看的model
    • input_size:model输入tensor的尺寸
  • 源码

    • https://github.com/sksq96/pytorch-summary
  • 示例代码

import torch
import torchvision.models as models
from torchsummary import summary
import tensorwatch as tw
print(torch.__version__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
vgg = models.vgg19().to(device)
summary(vgg, (3, 224, 224))
  • 输出
1.6.0+cu101
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]         590,080
             ReLU-16          [-1, 256, 56, 56]               0
           Conv2d-17          [-1, 256, 56, 56]         590,080
             ReLU-18          [-1, 256, 56, 56]               0
        MaxPool2d-19          [-1, 256, 28, 28]               0
           Conv2d-20          [-1, 512, 28, 28]       1,180,160
             ReLU-21          [-1, 512, 28, 28]               0
           Conv2d-22          [-1, 512, 28, 28]       2,359,808
             ReLU-23          [-1, 512, 28, 28]               0
           Conv2d-24          [-1, 512, 28, 28]       2,359,808
             ReLU-25          [-1, 512, 28, 28]               0
           Conv2d-26          [-1, 512, 28, 28]       2,359,808
             ReLU-27          [-1, 512, 28, 28]               0
        MaxPool2d-28          [-1, 512, 14, 14]               0
           Conv2d-29          [-1, 512, 14, 14]       2,359,808
             ReLU-30          [-1, 512, 14, 14]               0
           Conv2d-31          [-1, 512, 14, 14]       2,359,808
             ReLU-32          [-1, 512, 14, 14]               0
           Conv2d-33          [-1, 512, 14, 14]       2,359,808
             ReLU-34          [-1, 512, 14, 14]               0
           Conv2d-35          [-1, 512, 14, 14]       2,359,808
             ReLU-36          [-1, 512, 14, 14]               0
        MaxPool2d-37            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-38            [-1, 512, 7, 7]               0
           Linear-39                 [-1, 4096]     102,764,544
             ReLU-40                 [-1, 4096]               0
          Dropout-41                 [-1, 4096]               0
           Linear-42                 [-1, 4096]      16,781,312
             ReLU-43                 [-1, 4096]               0
          Dropout-44                 [-1, 4096]               0
           Linear-45                 [-1, 1000]       4,097,000
================================================================
Total params: 143,667,240
Trainable params: 143,667,240
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 238.69
Params size (MB): 548.05
Estimated Total Size (MB): 787.31
----------------------------------------------------------------

3. TensorWatch

  • 说明
    • TensorWatch是微软为data science, deep learning and reinforcement learning设计的一款调试和可视化工具
  • 源码
    • https://github.com/microsoft/tensorwatch
  • 定义
def draw_model(model, input_shape=None, orientation='TB', png_filename=None): #orientation = 'LR' for landscpe
  • 安装组件
pip install graphviz
pip install torchvision
pip install scikit-learn
pip install tensorwatch

3.1 保存模型图

  • 示例代码
alexnet_model = torchvision.models.alexnet()
tw.draw_model(alexnet_model, [1, 3, 224, 224])
img = tw.draw_model(alexnet_model, [1, 3, 224, 224])
img.save(r'D:/alexnet.jpg')

3.2 查看层参数

  • 示例代码
alexnet_model = torchvision.models.alexnet()
tw.model_stats(alexnet_model, [1, 3, 224, 224])
  • 输出

PyTorch 1.x Visulalization (可视化)_第1张图片

3.3 错误及解决方案

3.3.1 AttributeError: module ‘torch.onnx’ has no attribute ‘set_training’

  • 原因:pytorch版本太高,我是1.6,而1.6以下的版本torch.onnxhas 才有属性 ‘set_training’
  • 办法:把pytorch的版本降低,可以直接使用
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch==1.2

3.3.2 AttributeError: ‘Dot’ object has no attribute ‘repr_svg’

  • 原因:也是版本问题
  • 办法:Anaconda\Lib\site-packages\tensorwatch\model_graph\hiddenlayer\pytorch_draw_model.py的第13行改为 return self.dot.create_svg().decode()

3.3.3 FileNotFoundError: [Error 2] “dot” not found in path

  • 原因:Graphviz和pydot的问题
  • 办法: 点击链接
  • 装好Graphviz就没有问题了。

4. NetRon

  • Netron是一款支持离线查看“各种”神经网络框架的模型可视化神器。
  • 代码
  • 可支持
    • Netron supports ONNX (.onnx, .pb, .pbtxt)
    • Keras (.h5, .keras)
    • Core ML (.mlmodel)
    • Caffe (.caffemodel, .prototxt)
    • Caffe2 (predict_net.pb)
    • Darknet (.cfg)
    • MXNet (.model, -symbol.json)
    • Barracuda (.nn)
    • ncnn (.param)
    • Tengine (.tmfile)
    • TNN (.tnnproto)
    • UFF (.uff)
    • TensorFlow Lite (.tflite).
  • 实验性地支持
    • TorchScript (.pt, .pth),
    • PyTorch (.pt, .pth)
    • Torch (.t7),
    • Arm NN (.armnn)
    • BigDL (.bigdl, .model)
    • Chainer (.npz, .h5)
    • CNTK (.model, .cntk)
    • Deeplearning4j (.zip)
    • MediaPipe (.pbtxt)
    • ML.NET (.zip)
    • MNN (.mnn)
    • PaddlePaddle (.zip, model)
    • OpenVINO (.xml)
    • scikit-learn (.pkl)
    • TensorFlow.js (model.json, .pb)
    • TensorFlow (.pb, .meta, .pbtxt, .ckpt, .index)

参考

  • Pytorch神经网络结构可视化模块

你可能感兴趣的:(Pytorch,debug,可视化)