【深度学习可视化系列】—— 特征图可视化(支持Vit系列模型的特征图可视化,包含使用Tensorboard对可视化结果进行保存)

【深度学习可视化系列】—— 特征图可视化(支持Vit系列模型的特征图可视化,包含使用Tensorboard对可视化结果进行保存)

import sys
import os
import torch
import cv2
import timm
import numpy as np 
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model.MitUnet import  MitUnet
from collections import OrderedDict
from typing import Dict, Iterable, Callable
from torch import nn, Tensor
from PIL import Image
from pprint import pprint


# --------------------------------------------------------------------------------------------------------------------------
# 构建模型特征图提取模型,输入参数为模型、以及需提取特征图层的key名称,该名称可通过model.named_modules()或model.named_children()获取
# --------------------------------------------------------------------------------------------------------------------------
class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        # assert layers is not None
        self.model = model
        self.layers = layers
        self._features = OrderedDict({layer: torch.empty(0) for layer in layers})
        self.hook = []

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            self.hook = layer.register_forward_hook(self.hook_func(layer_id))
            # self.hook.append(self.layer_id)

    def hook_func(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            # print("_____{}".format(output.dim()))   
            if output.dim() == 3:
                output = self.reshape_transform(in_tensor=output) 
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        self.remove()
        return self._features
    
    def remove(self):
        # for hook in self.hook:
        self.hook.remove()

    def reshape_transform(self, in_tensor):
        result = in_tensor.reshape(in_tensor.size(0),
            int(np.sqrt(in_tensor.size(1))), int(np.sqrt(in_tensor.size(1))), in_tensor.size(2))

        result = result.transpose(2, 3).transpose(1, 2)
        return result
    
    
# --------------------------------------------------------------------------------------------------------------------------
# 构建模型,并进行特征提取
# --------------------------------------------------------------------------------------------------------------------------
img_mask_size = 256
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model = UNet(....)
# map_location={'cuda:0': 'cpu'}
state_dict = torch.load('./state_dict/model.pth')
model.load_state_dict(state_dict['model'])
print('网络设置完毕 :成功载入了训练完毕的权重。')
model.to(device=device)
transformer = A.Compose([
    A.Resize(img_mask_size, img_mask_size),
    A.Normalize(
        mean=(0.5835, 0.5820, 0.5841),
        std=(0.1149, 0.1111, 0.1064),
        max_pixel_value=255.0
    ),
    ToTensorV2()
])
return_layers = ["encoder.norm1"]
e_model = FeatureExtractor(model=model, layers=return_layers)
image_file = ".\images"
image_file_path = os.path.join(image_file, str("15") + (".jpg"))
img = Image.open(image_file_path)
img_width, img_height = img.size
image_np = np.array(img)
augmented = transformer(image=image_np)
augmented_img = augmented['image'].to(device)  
# 由于模型中存在BN层,其不允许推理的batchsize小于2,所以生成一个和原始影像相同大小尺度的虚拟图像使得batchsize=2。
virual_image = torch.randn(size=(3, img_mask_size, img_mask_size), dtype=torch.float32).to(device=device)
augmented_img = torch.stack([augmented_img, virual_image], dim=0)
print(augmented_img.shape)
output = e_model(augmented_img)
for keys, values in output.items():
    output[keys] = values[0].unsqueeze(0) 
pprint({keys : torch.sigmoid(values[0]).detach().shape for keys, values in output.items()})


# --------------------------------------------------------------------------------------------------------------------------
# 使用tensorboard保存特征图可视化结果
# --------------------------------------------------------------------------------------------------------------------------
from torchvision.utils import make_grid
from torch.utils.tensorboard.writer import SummaryWriter

writer = SummaryWriter("runs/test")
for keys, values in output.items():
    values = torch.sigmoid(values[0]).cpu().detach().numpy()
    imgs_ = np.empty(shape=(values.shape[0], 3, values.shape[1], values.shape[2])) 
    for index, batch_img in enumerate(values):
        imgs_[index] =  cv2.applyColorMap(np.uint8(batch_img * 255), cv2.COLORMAP_JET).transpose(2, 0, 1)
    imgs_grid = make_grid(torch.from_numpy(imgs_), nrow=5, padding=2, pad_value=0)
    cv2.namedWindow("imgs_grid", cv2.WINDOW_FULLSCREEN)
    cv2.imshow("imgs_grid", imgs_grid.permute(1, 2, 0).numpy())
    cv2.waitKey()
	cv2.destroyAllWindows()
    
    writer.add_images(keys + "_TEST", imgs_, 0, dataformats="NCHW")
writer.close()

可视化结果如下(以地表裂缝图像为例):
【深度学习可视化系列】—— 特征图可视化(支持Vit系列模型的特征图可视化,包含使用Tensorboard对可视化结果进行保存)_第1张图片
​ 地裂缝图像以及分割结果
【深度学习可视化系列】—— 特征图可视化(支持Vit系列模型的特征图可视化,包含使用Tensorboard对可视化结果进行保存)_第2张图片

​ 裂缝提取模型部分特征图可视化结果

你可能感兴趣的:(深度学习,python,开发语言,人工智能,计算机视觉)