pytorch的可视化

7.1可视化网络结构

7.1.1 使用print函数打印模型基础信息

import matplotlib.pyplot as plt
# ^^^ pyforest auto-imports - don't write above this line
import torchvision.models as models
model = models.resnet18()
print(model)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

7.1.2 使用torchinfo可视化网络结构

# torchinfo的安装
# 安装方法一
pip install torchinfo
# 安装方法二
conda install -c conda-forge torchinfo
!pip install torchinfo
Collecting torchinfo
  Downloading torchinfo-1.7.0-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.0
import torchvision.models as models
from torchinfo import summary
resnet18 = models.resnet18() # 实例化模型
print(summary(resnet18, (1, 3, 224, 224)))# 1:batch_size 3:图片的通道数 224: 图片的高宽
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-9                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           128
│    │    └─ReLU: 3-12                   [1, 64, 56, 56]           --
├─Sequential: 1-6                        [1, 128, 28, 28]          --
│    └─BasicBlock: 2-3                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-13                 [1, 128, 28, 28]          73,728
│    │    └─BatchNorm2d: 3-14            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-15                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-16                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-17            [1, 128, 28, 28]          256
│    │    └─Sequential: 3-18             [1, 128, 28, 28]          8,448
│    │    └─ReLU: 3-19                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-20                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-21            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-22                   [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-23                 [1, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-24            [1, 128, 28, 28]          256
│    │    └─ReLU: 3-25                   [1, 128, 28, 28]          --
├─Sequential: 1-7                        [1, 256, 14, 14]          --
│    └─BasicBlock: 2-5                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-26                 [1, 256, 14, 14]          294,912
│    │    └─BatchNorm2d: 3-27            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-28                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-29                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-30            [1, 256, 14, 14]          512
│    │    └─Sequential: 3-31             [1, 256, 14, 14]          33,280
│    │    └─ReLU: 3-32                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-6                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-33                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-34            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-35                   [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-36                 [1, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-37            [1, 256, 14, 14]          512
│    │    └─ReLU: 3-38                   [1, 256, 14, 14]          --
├─Sequential: 1-8                        [1, 512, 7, 7]            --
│    └─BasicBlock: 2-7                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-39                 [1, 512, 7, 7]            1,179,648
│    │    └─BatchNorm2d: 3-40            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-41                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-42                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-43            [1, 512, 7, 7]            1,024
│    │    └─Sequential: 3-44             [1, 512, 7, 7]            132,096
│    │    └─ReLU: 3-45                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-8                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-46                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-47            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-48                   [1, 512, 7, 7]            --
│    │    └─Conv2d: 3-49                 [1, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-50            [1, 512, 7, 7]            1,024
│    │    └─ReLU: 3-51                   [1, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 1000]                 513,000
==========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 1.81
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 39.75
Params size (MB): 46.76
Estimated Total Size (MB): 87.11
==========================================================================================
我们可以看到torchinfo提供了更加详细的信息,
包括模块信息(每一层的类型、输出shape和参数量)、
模型整体的参数量、模型大小、
一次前向或者反向传播需要的内存大小等


7.2CNN的可视化




# 7.2.1 CNN卷积核可视化
import torch
from torchvision.models import vgg11

model = vgg11(pretrained=True)
print(dict(model.features.named_children()))
Downloading: "https://download.pytorch.org/models/vgg11-bbd30ac9.pth" to C:\Users\b/.cache\torch\hub\checkpoints\vgg11-bbd30ac9.pth



  0%|          | 0.00/507M [00:00
conv1 = dict(model.features.named_children())['3']
kernel_set = conv1.weight.detach()
num = len(conv1.weight.detach())
print(kernel_set.shape)
for i in range(0,num):
    i_kernel = kernel_set[i]
    plt.figure(figsize=(20, 17))
    if (len(i_kernel)) > 1:
        for idx, filer in enumerate(i_kernel):
            plt.subplot(9, 9, idx+1) 
            plt.axis('off')
            plt.imshow(filer[ :, :].detach(),cmap='bwr')
torch.Size([128, 64, 3, 3])








7.2.2 CNN特征图可视化方法

class Hook(object):
    def __init__(self):
        self.module_name = []
        self.features_in_hook = []
        self.features_out_hook = []

    def __call__(self,module, fea_in, fea_out):
        print("hooker working", self)
        self.module_name.append(module.__class__)
        self.features_in_hook.append(fea_in)
        self.features_out_hook.append(fea_out)
        return None
    

def plot_feature(model, idx, inputs):
    hh = Hook()
    model.features[idx].register_forward_hook(hh)
    
    # forward_model(model,False)
    model.eval()
    _ = model(inputs)
    print(hh.module_name)
    print((hh.features_in_hook[0][0].shape))
    print((hh.features_out_hook[0].shape))
    
    out1 = hh.features_out_hook[0]

    total_ft  = out1.shape[1]
    first_item = out1[0].cpu().clone()    

    plt.figure(figsize=(20, 17))
    

    for ftidx in range(total_ft):
        if ftidx > 99:
            break
        ft = first_item[ftidx]
        plt.subplot(10, 10, ftidx+1) 
        
        plt.axis('off')
        #plt.imshow(ft[ :, :].detach(),cmap='gray')
        plt.imshow(ft[ :, :].detach())

7.2.3 CNN class activation map可视化方法

!pip install grad-cam
Collecting grad-cam
  Downloading grad-cam-1.4.5.tar.gz (7.8 MB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
    Preparing wheel metadata: started
    Preparing wheel metadata: finished with status 'done'
Requirement already satisfied: scikit-learn in c:\users\b\anaconda3\lib\site-packages (from grad-cam) (0.23.2)
Requirement already satisfied: Pillow in c:\users\b\anaconda3\lib\site-packages (from grad-cam) (8.1.0)
Requirement already satisfied: torchvision>=0.8.2 in c:\users\b\anaconda3\lib\site-packages (from grad-cam) (0.8.2+cpu)
Requirement already satisfied: tqdm in c:\users\b\anaconda3\lib\site-packages (from grad-cam) (4.55.1)
Requirement already satisfied: opencv-python in c:\users\b\anaconda3\lib\site-packages (from grad-cam) (4.5.1.48)
Requirement already satisfied: numpy in c:\users\b\anaconda3\lib\site-packages (from grad-cam) (1.19.5)
Requirement already satisfied: matplotlib in c:\users\b\anaconda3\lib\site-packages (from grad-cam) (3.3.2)
Requirement already satisfied: torch>=1.7.1 in c:\users\b\anaconda3\lib\site-packages (from grad-cam) (1.7.1)
Requirement already satisfied: typing-extensions in c:\users\b\anaconda3\lib\site-packages (from torch>=1.7.1->grad-cam) (3.7.4.3)
Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\b\anaconda3\lib\site-packages (from matplotlib->grad-cam) (1.3.0)
Requirement already satisfied: cycler>=0.10 in c:\users\b\anaconda3\lib\site-packages (from matplotlib->grad-cam) (0.10.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in c:\users\b\anaconda3\lib\site-packages (from matplotlib->grad-cam) (2.4.7)
Requirement already satisfied: python-dateutil>=2.1 in c:\users\b\anaconda3\lib\site-packages (from matplotlib->grad-cam) (2.8.1)
Requirement already satisfied: certifi>=2020.06.20 in c:\users\b\anaconda3\lib\site-packages (from matplotlib->grad-cam) (2022.5.18.1)
Requirement already satisfied: six in c:\users\b\anaconda3\lib\site-packages (from cycler>=0.10->matplotlib->grad-cam) (1.15.0)
Requirement already satisfied: scipy>=0.19.1 in c:\users\b\anaconda3\lib\site-packages (from scikit-learn->grad-cam) (1.9.0)
Requirement already satisfied: joblib>=0.11 in c:\users\b\anaconda3\lib\site-packages (from scikit-learn->grad-cam) (1.0.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in c:\users\b\anaconda3\lib\site-packages (from scikit-learn->grad-cam) (2.1.0)
Collecting ttach
  Downloading ttach-0.0.3-py3-none-any.whl (9.8 kB)
Building wheels for collected packages: grad-cam
  Building wheel for grad-cam (PEP 517): started
  Building wheel for grad-cam (PEP 517): finished with status 'done'
  Created wheel for grad-cam: filename=grad_cam-1.4.5-py3-none-any.whl size=37008 sha256=34bcbc7b96cb6d796dd4553f4f0afe487453bf7c52b886b01b56838488fa7366
  Stored in directory: c:\users\b\appdata\local\pip\cache\wheels\64\e5\e9\7c4f8b034a7d7009a3b3baa534084980eec60f39155814278b
Successfully built grad-cam
Installing collected packages: ttach, grad-cam
Successfully installed grad-cam-1.4.5 ttach-0.0.3
import torch
from torchvision.models import vgg11,resnet18,resnet101,resnext101_32x8d
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

model = vgg11(pretrained=True)
img_path = './kobe.jpg'
# resize操作是为了和传入神经网络训练图片大小一致
img = Image.open(img_path).resize((224,224))
# 需要将原始图片转为np.float32格式并且在0-1之间 
rgb_img = np.float32(img)/255
plt.imshow(img)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-E3QDF3Hn-1663945640094)(output_17_1.png)]

from pytorch_grad_cam import GradCAM,ScoreCAM,GradCAMPlusPlus,AblationCAM,XGradCAM,EigenCAM,FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

target_layers = [model.features[-1]]
# 选取合适的类激活图,但是ScoreCAM和AblationCAM需要batch_size
cam = GradCAM(model=model,target_layers=target_layers)
targets = [ClassifierOutputTarget(preds)]   
# 上方preds需要设定,比如ImageNet有1000类,这里可以设为200
grayscale_cam = cam(input_tensor=img_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
cam_img = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
print(type(cam_img))
Image.fromarray(cam_img)
---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

 in 
      6 # 选取合适的类激活图,但是ScoreCAM和AblationCAM需要batch_size
      7 cam = GradCAM(model=model,target_layers=target_layers)
----> 8 targets = [ClassifierOutputTarget(preds)]
      9 # 上方preds需要设定,比如ImageNet有1000类,这里可以设为200
     10 grayscale_cam = cam(input_tensor=img_tensor, targets=targets)


NameError: name 'preds' is not defined

7.2.4 使用FlashTorch快速实现CNN可视化

!pip install flashtorch
Collecting flashtorch
  Downloading flashtorch-0.1.3.tar.gz (28 kB)
Requirement already satisfied: matplotlib in c:\users\b\anaconda3\lib\site-packages (from flashtorch) (3.3.2)
Requirement already satisfied: numpy in c:\users\b\anaconda3\lib\site-packages (from flashtorch) (1.19.5)
Requirement already satisfied: Pillow in c:\users\b\anaconda3\lib\site-packages (from flashtorch) (8.1.0)
Requirement already satisfied: torch in c:\users\b\anaconda3\lib\site-packages (from flashtorch) (1.7.1)
Requirement already satisfied: torchvision in c:\users\b\anaconda3\lib\site-packages (from flashtorch) (0.8.2+cpu)
Requirement already satisfied: importlib_resources in c:\users\b\anaconda3\lib\site-packages (from flashtorch) (5.4.0)
Requirement already satisfied: zipp>=3.1.0 in c:\users\b\anaconda3\lib\site-packages (from importlib_resources->flashtorch) (3.4.0)
Requirement already satisfied: cycler>=0.10 in c:\users\b\anaconda3\lib\site-packages (from matplotlib->flashtorch) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\b\anaconda3\lib\site-packages (from matplotlib->flashtorch) (1.3.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in c:\users\b\anaconda3\lib\site-packages (from matplotlib->flashtorch) (2.4.7)
Requirement already satisfied: python-dateutil>=2.1 in c:\users\b\anaconda3\lib\site-packages (from matplotlib->flashtorch) (2.8.1)
Requirement already satisfied: certifi>=2020.06.20 in c:\users\b\anaconda3\lib\site-packages (from matplotlib->flashtorch) (2022.5.18.1)
Requirement already satisfied: six in c:\users\b\anaconda3\lib\site-packages (from cycler>=0.10->matplotlib->flashtorch) (1.15.0)
Requirement already satisfied: typing-extensions in c:\users\b\anaconda3\lib\site-packages (from torch->flashtorch) (3.7.4.3)
Building wheels for collected packages: flashtorch
  Building wheel for flashtorch (setup.py): started
  Building wheel for flashtorch (setup.py): finished with status 'done'
  Created wheel for flashtorch: filename=flashtorch-0.1.3-py3-none-any.whl size=26247 sha256=985fa2c01945aeffbc5156f74e87ea7a9c99eb0f5b3a90eb689d422d49688103
  Stored in directory: c:\users\b\appdata\local\pip\cache\wheels\62\7a\fd\e186c4584835bf57e3b56f8470c018af80c0ac1f5723b4262a
Successfully built flashtorch
Installing collected packages: flashtorch
Successfully installed flashtorch-0.1.3
可视化梯度
import matplotlib.pyplot as plt
import torchvision.models as models
from flashtorch.utils import apply_transforms, load_image
from flashtorch.saliency import Backprop

model = models.alexnet(pretrained=True)
backprop = Backprop(model)

image = load_image('./kobe.jpg')
owl = apply_transforms(image)

target_class = 24
backprop.visualize(owl, target_class, guided=True, use_gpu=True)
---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

 in 
     11 
     12 target_class = 24
---> 13 backprop.visualize(owl, target_class, guided=True, use_gpu=True)


~\Anaconda3\lib\site-packages\flashtorch\saliency\backprop.py in visualize(self, input_, target_class, guided, use_gpu, figsize, cmap, alpha, return_output)
    180             # (title, [(image1, cmap, alpha), (image2, cmap, alpha)])
    181             ('Input image',
--> 182              [(format_for_plotting(denormalize(input_)), None, None)]),
    183             ('Gradients across RGB channels',
    184              [(format_for_plotting(standardize_and_clip(gradients)),


~\Anaconda3\lib\site-packages\flashtorch\utils\__init__.py in denormalize(tensor)
    117 
    118     for channel, mean, std in zip(denormalized[0], means, stds):
--> 119         channel.mul_(std).add_(mean)
    120 
    121     return denormalized


RuntimeError: Output 0 of UnbindBackward is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
# 可视化卷积核
import torchvision.models as models
from flashtorch.activmax import GradientAscent

model = models.vgg16(pretrained=True)
g_ascent = GradientAscent(model.features)

# specify layer and filter info
conv5_1 = model.features[24]
conv5_1_filters = [45, 271, 363, 489]

g_ascent.visualize(conv5_1, conv5_1_filters, title="VGG16: conv5_1")
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\b/.cache\torch\hub\checkpoints\vgg16-397923af.pth



  0%|          | 0.00/528M [00:00

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ks7CT8od-1663945640095)(output_24_2.png)]

7.3 使用TensorBoard可视化训练过程

7.3.1 TensorBoard安装

!pip install tensorboardX
Requirement already satisfied: tensorboardX in c:\users\b\anaconda3\lib\site-packages (2.4)
Requirement already satisfied: protobuf>=3.8.0 in c:\users\b\anaconda3\lib\site-packages (from tensorboardX) (3.17.3)
Requirement already satisfied: numpy in c:\users\b\anaconda3\lib\site-packages (from tensorboardX) (1.19.5)
Requirement already satisfied: six>=1.9 in c:\users\b\anaconda3\lib\site-packages (from protobuf>=3.8.0->tensorboardX) (1.15.0)

7.3.2 TensorBoard可视化的基本逻辑

我们可以将TensorBoard看做一个记录员,它可以记录我们指定的数据,包括模型每一层的feature map,权重,以及训练loss等等。
TensorBoard将记录下来的内容保存在一个用户指定的文件夹里,程序不断运行中TensorBoard会不断记录。
记录下的内容可以通过网页的形式加以可视化。

7.3.3 TensorBoard的配置与启动

from tensorboardX import SummaryWriter

writer = SummaryWriter('./runs')
from torch.utils.tensorboard import SummaryWriter
# 命令行中启动
tensorboard --logdir=/path/to/logs/ --port=xxxx

7.3.4 TensorBoard模型结构可视化

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)
        self.pool = nn.MaxPool2d(kernel_size = 2,stride = 2)
        self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)
        self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64,32)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(32,1)
        self.sigmoid = nn.Sigmoid()

    def forward(self,x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.adaptive_pool(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        y = self.sigmoid(x)
        return y

model = Net()
print(model)
Net(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (adaptive_pool): AdaptiveMaxPool2d(output_size=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=64, out_features=32, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=32, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)
writer.add_graph(model, input_to_model = torch.rand(1, 3, 224, 224))
writer.close()
## 7.3.5 TensorBoard图像可视化
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform_train = transforms.Compose(
    [transforms.ToTensor()])
transform_test = transforms.Compose(
    [transforms.ToTensor()])

train_data = datasets.CIFAR10(".", train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10(".", train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64)

images, labels = next(iter(train_loader))
 
# 仅查看一张图片
writer = SummaryWriter('./pytorch_tb')
writer.add_image('images[0]', images[0])
writer.close()
 
# 将多张图片拼接成一张图片,中间用黑色网格分割
# create grid of images
writer = SummaryWriter('./pytorch_tb')
img_grid = torchvision.utils.make_grid(images)
writer.add_image('image_grid', img_grid)
writer.close()
 
# 将多张图片直接写入
writer = SummaryWriter('./pytorch_tb')
writer.add_images("images",images,global_step = 0)
writer.close()
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to .\cifar-10-python.tar.gz



|          | 0/? [00:00 in 
      8     [transforms.ToTensor()])
      9 
---> 10 train_data = datasets.CIFAR10(".", train=True, download=True, transform=transform_train)
     11 test_data = datasets.CIFAR10(".", train=False, download=True, transform=transform_test)
     12 train_loader = DataLoader(train_data, batch_size=64, shuffle=True)


~\Anaconda3\lib\site-packages\torchvision\datasets\cifar.py in __init__(self, root, train, transform, target_transform, download)
     63 
     64         if download:
---> 65             self.download()
     66 
     67         if not self._check_integrity():


~\Anaconda3\lib\site-packages\torchvision\datasets\cifar.py in download(self)
    141             print('Files already downloaded and verified')
    142             return
--> 143         download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
    144 
    145     def extra_repr(self) -> str:


~\Anaconda3\lib\site-packages\torchvision\datasets\utils.py in download_and_extract_archive(url, download_root, extract_root, filename, md5, remove_finished)
    254         filename = os.path.basename(url)
    255 
--> 256     download_url(url, download_root, filename, md5)
    257 
    258     archive = os.path.join(download_root, filename)


~\Anaconda3\lib\site-packages\torchvision\datasets\utils.py in download_url(url, root, filename, md5)
     68         try:
     69             print('Downloading ' + url + ' to ' + fpath)
---> 70             urllib.request.urlretrieve(
     71                 url, fpath,
     72                 reporthook=gen_bar_updater()


~\Anaconda3\lib\urllib\request.py in urlretrieve(url, filename, reporthook, data)
    274 
    275             while True:
--> 276                 block = fp.read(bs)
    277                 if not block:
    278                     break


~\Anaconda3\lib\http\client.py in read(self, amt)
    456             # Amount is given, implement using readinto
    457             b = bytearray(amt)
--> 458             n = self.readinto(b)
    459             return memoryview(b)[:n].tobytes()
    460         else:


~\Anaconda3\lib\http\client.py in readinto(self, b)
    500         # connection, and the user is reading more bytes than will be provided
    501         # (for example, reading in 1k chunks)
--> 502         n = self.fp.readinto(b)
    503         if not n and b:
    504             # Ideally, we would raise IncompleteRead if the content-length


~\Anaconda3\lib\socket.py in readinto(self, b)
    667         while True:
    668             try:
--> 669                 return self._sock.recv_into(b)
    670             except timeout:
    671                 self._timeout_occurred = True


~\Anaconda3\lib\ssl.py in recv_into(self, buffer, nbytes, flags)
   1239                   "non-zero flags not allowed in calls to recv_into() on %s" %
   1240                   self.__class__)
-> 1241             return self.read(nbytes, buffer)
   1242         else:
   1243             return super().recv_into(buffer, nbytes, flags)


~\Anaconda3\lib\ssl.py in read(self, len, buffer)
   1097         try:
   1098             if buffer is not None:
-> 1099                 return self._sslobj.read(len, buffer)
   1100             else:
   1101                 return self._sslobj.read(len)


KeyboardInterrupt: 

7.3.6 TensorBoard连续变量可视化

writer = SummaryWriter('./pytorch_tb')
for i in range(500):
    x = i
    y = x**2
    writer.add_scalar("x", x, i) #日志中记录x在第step i 的值
    writer.add_scalar("y", y, i) #日志中记录y在第step i 的值
writer.close()
C:\Users\b\Anaconda3\lib\site-packages\h5py\__init__.py:36: UserWarning: h5py is running against HDF5 1.10.5 when it was built against 1.10.6, this may cause problems
  _warn(("h5py is running against HDF5 {0} when it was built against {1}, "

7.3.7 TensorBoard参数分布可视化

import torch
import numpy as np

# 创建正态分布的张量模拟参数矩阵
def norm(mean, std):
    t = std * torch.randn((100, 20)) + mean
    return t
 
writer = SummaryWriter('./pytorch_tb/')
for step, mean in enumerate(range(-10, 10, 1)):
    w = norm(mean, 1)
    writer.add_histogram("w", w, step)
    writer.flush()
writer.close()
---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

 in 
      7     return t
      8 
----> 9 writer = SummaryWriter('./pytorch_tb/')
     10 for step, mean in enumerate(range(-10, 10, 1)):
     11     w = norm(mean, 1)


NameError: name 'SummaryWriter' is not defined
7.3.8 服务器端使用TensorBoard
该方法是将服务器的6006端口重定向到自己机器上来,我们可以在本地的终端里输入以下代码:其中16006代表映射到本地的端口,
6006代表的是服务器上的端口。
ssh -L 16006:127.0.0.1:6006 username@remote_server_ip
# 在服务上使用默认的6006端口正常启动tensorboard
tensorboard --logdir=xxx --port=6006
# 在本地的浏览器输入地址
localhost:16006

7.3.9 总结

对于TensorBoard来说,它的功能是很强大的,可以记录的东西不只限于本节所介绍的范围。

主要的实现方案是构建一个SummaryWriter,然后通过add_XXX()函数来实现。

其实TensorBoard的逻辑还是很简单的,它的基本逻辑就是文件的读写逻辑,写入想要可视化的数据,然后TensorBoard自己会读出来。

你可能感兴趣的:(pytorch的可视化)