【pytorch】ResNet中间层可视化

这里用到的是我之前训练好的ResNet18,或者可以
将代码的
model = torch.load(‘model_params(RES_best).pkl’)
这一排改为
model_ft = models.resnet18(pretrained=True)
就可以直接下载ResNet18了。

import torch
from torch import nn
from torchvision import models, transforms
from torchvision.utils import make_grid
import numpy as np
from PIL import Image
import json

model = torch.load('model_params(RES_best).pkl')#可以改为model_ft = models.resnet18(pretrained=True),直接下载ResNet18.
image = Image.open('E:/python_myprojects/zhouyi-projects/test_maize/test/1/b7102d1b-cd66-4303-9544-0ddeefc7ed12.JPG')

transform = transforms.Compose([transforms.Resize((224, 224)),
                         transforms.ToTensor(),
                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])])
img = transform(image)
img = img.unsqueeze(0)

def save_img(tensor, name):
    tensor = tensor.permute((1, 0, 2, 3))
    im = make_grid(tensor, normalize=True, scale_each=True, nrow=8, padding=2).permute((1, 2, 0))
    im = (im.data.numpy() * 255.).astype(np.uint8)
    Image.fromarray(im).save(name + '.jpg')

new_model = nn.Sequential(*list(model.children())[:5])
f3 = new_model(img)
save_img(f3, 'layer1')

new_model = nn.Sequential(*list(model.children())[:6])
f4 = new_model(img)  # [1, 128, 28, 28]
save_img(f4, 'layer2')

new_model = nn.Sequential(*list(model.children())[:7])
print(new_model)
f5 = new_model(img)  # [1, 256, 14, 14]
print(f5.shape)
save_img(f5, 'layer3')

new_model = nn.Sequential(*list(model.children())[:8])
print(new_model)
f6 = new_model(img)  # [1, 256, 14, 14]
print(f6.shape)
save_img(f6, 'layer4')

这里我提取出了Layer1、Layer2、Layer3、Layer4这几层处理后的结果。
比如,用new_model = nn.Sequential(*list(model.children())[:4])这一句,可以提取出模型的前4层,第四层正好为layer1,
对图片运行这个提取后的模型,即可得到layer1处理完后的结果,图像会保存在此程序的python文件中。
同理,要得到layer2处理后的输出结果,则将4改为5即可。new_model = nn.Sequential(*list(model.children())[:5])

原图如下:
【pytorch】ResNet中间层可视化_第1张图片

这是Layer1处理之后的结果:
【pytorch】ResNet中间层可视化_第2张图片

layer2:
【pytorch】ResNet中间层可视化_第3张图片
layer3:
【pytorch】ResNet中间层可视化_第4张图片
layer4:
【pytorch】ResNet中间层可视化_第5张图片

参考文献

你可能感兴趣的:(pytorch,笔记,python,深度学习,可视化)