【深度学习可视化系列]】—— CAM可视化(以语义分割网络为例,支持Vit系列主干网络的分割模型,支持GradCAM, GradCAMPlusPlus, LayerCAM等cam可视化方法)

cam相关基础知识可参考链接:https://zhuanlan.zhihu.com/p/269702192

import warnings
import torch
import requests
import torchvision
import torch.functional as F
import numpy as np
from PIL import Image
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image


image_file = ".\images"
image_file_path = os.path.join(image_file, str("15") + (".jpg"))
image = Image.open(image_file_path)
rgb_img = np.float32(image) / 255
input_tensor = preprocess_image(rgb_img,
                                mean=[0.5835, 0.5820, 0.5841],
                                std=[0.1149, 0.1111, 0.1064])

model = deeplabv3_resnet50(weights=DeepLabV3_ResNet50_Weights.DEFAULT)
model = model.eval()

device = torch.device('cuda:0') if torch.cuda.is_available()
model.to(device=device)
input_tensor = input_tensor.to(device)  

class SegmentationModelOutputWrapper(torch.nn.Module):
    def __init__(self, model): 
        super(SegmentationModelOutputWrapper, self).__init__()
        self.model = model
        
    def forward(self, x):
        return self.model(x)["out"]
    
model = SegmentationModelOutputWrapper(model)
output = model(input_tensor)

normalized_masks = torch.nn.functional.softmax(output, dim=1).cpu()
sem_classes = [
    'crack',
    'background'
]
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}

crack_category = sem_class_to_idx["crack"]
crack_mask = normalized_masks[0, :, :, :].argmax(axis=0).detach().cpu().numpy()
crack_mask_uint8 = 255 * np.uint8(crack_mask == crack_category)
crack_mask_float = np.float32(crack_mask == car_category)

both_images = np.hstack((image, np.repeat(crack_mask_uint8[:, :, None], 3, axis=-1)))
img = Image.fromarray(both_images)
# img.show()

from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, LayerCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image

class SemanticSegmentationTarget:
    def __init__(self, category, mask):
        self.category = category
        self.mask = torch.from_numpy(mask)
        if torch.cuda.is_available():
            self.mask = self.mask.cuda()
        
    def __call__(self, model_output):
        return (model_output[self.category - 1, :, : ] * self.mask).sum()


target_layers = [model.encoder.norm4]

def reshape_transform(in_tensor):
    result = in_tensor.reshape(in_tensor.size(0),
        int(np.sqrt(in_tensor.size(1))), int(np.sqrt(in_tensor.size(1))), in_tensor.size(2))

    result = result.transpose(2, 3).transpose(1, 2)
    return result

targets = [SemanticSegmentationTarget(crack_category, crack_mask_float)]
with GradCAM(model=model,
             target_layers=target_layers,
             use_cuda=torch.cuda.is_available(),
             # reshape_transform=reshape_transform # 该部分是针对Vit系列模型的相关变换参数,cnn模型可不设置.
             ) as cam:
    grayscale_cam = cam(input_tensor=augmented_img[0].unsqueeze(0), targets=targets)[0, :]
    cam_image = show_cam_on_image(np.float32(img) /255, grayscale_cam, use_rgb=True)

cam_img = Image.fromarray(cam_image)
cam_img.show()

GradCAM方法的可视化结果如下:
【深度学习可视化系列]】—— CAM可视化(以语义分割网络为例,支持Vit系列主干网络的分割模型,支持GradCAM, GradCAMPlusPlus, LayerCAM等cam可视化方法)_第1张图片

你可能感兴趣的:(深度学习,#,语义分割,#,PyTorch,深度学习,人工智能,计算机视觉,python)