python 可视化解释模型

1. 自定义DataSet

MakeDataset.py
首先准备好一个数据集文件,这里以mydata文件夹存放图片数据, 实现自定义DataSet

class MyDataset(Dataset):
	def __init__(self,resize):
		super(MyDataset,self).__init__()
		self.resize = resize
	
	def __len__(self):
		return len(images)
	
	def __getitem__(self,idx):
		img = images[idx]
		tf = transforms.Compose([
			lambda x:Image.open(x).convert('RGB'),
			transforms.Resize((self.resize,self.resize)),
			transforms.ToTensor(),
			transforms.Normalize(mean = [0.485,0.456,0.406],
										std = [0.229,0.224,0.225])		
		])
		img_tensor = tf(image)
		# `mydata\\ICH\\1470718-1.JPG`
		label_tensor = torch.tensor(class_name_index[image.split(os.sep)[-2]])
		return img_tensor,label_tensor

2 模型定义及训练

2.1 模型

这里以一个玩具模型作为演示,模型的定义如下:
MyModle.py

class MyResNet(nn.Module):
	def __init__(self):
		super(MyResNet,self).__init__()
		general_features = 32
		
		# Initial convolution block
		self.conv0 = nn.Conv2d(3,general_features,3,1,padding=1)
		self.conv1 = nn.Conv2d(general_features,general_features,3,1,padding =1)
		self.relu1 = nn.ReLU()
		self.conv2 = nn.Conv2d(general_features,general_features,3,1,padding=1)
		self.relu2 = nn.ReLU()
		
		# Down sample 1/2
		self.downsample0 = nn.Maxpool2d(2,2)
		self.downsample1 = nn.Maxpool2d(2,2)
		self.downsample2 = nn.Maxpool2d(2,2)
		self.downsample3 = nn.Maxpool2d(2,2)
		
		self.fc0 = nn.Linear(32*8*8, 2)
	
	def forward(self,x):
		x = self.conv0(x)			#[1,32,128,128]
		x = self.downsample0(x)		#[1,32,64,64]
		x = self.downsample1(x)		#[1,32,32,32]
		
		x = self.relu1(self.conv1(x)) #[1,32,32,32]
		x = self.downsample2(x)		  # [1,32,16,16]
		
		x = self.relu2(self.conv2(x)) #[1,32,16,16]
		x = self.downsample3(x)		  # [1,32,8,8]
		
		x = x.view(x.shape[0],-1)     # Flatten
		
		x = x.softmax(self.fc0(x),dim=1)
		return x

# x = torch.randn(1,3,128,128)
# m = myResNet()
# summary(m,(3,128,128))
# print(m(x).shape)	

2.2 训练

训练train.py获得权重文件

import torch
from torch import optim,nn
from torch.utils.data import Dataloader
from MakeDataSet import MyDataset

from MyModel import MyResNet

train_db = MyDataset(resize = 128)
train_loader = DataLoader(train_db,batch_size=4,shuffle=True)
print('num_train:',len(train_loader.dataset))

model = MyResNet()

optimizer = optim.Adam(model.parameters(),lr =0.001)
criteon = nn.CrossEntropyLoss()

epochs = 5

for epoch in range(epochs):
	for step,(x,y) in enumerate(train_loader):
		model.train()
		logits = model(x)
		loss = criteon(logits,y)
		
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
	print('Epochs:',epoch,'Loss:',loss)

torch.save(model.state_dict(),'weights_MyResNet.mdl')
print('Save Done')

3 利用SmoothGradCAMpp对特征图可视化

Visualize_featrue_map, 这里介绍smooth gradcampp用法

import torch
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from torchcam.methods import SmoothGradCAMpp,CAM,GradCAM,GradCAMpp,XGradCAM,ScoreCAM
from torchcam.utils import overlay_mask
from MyModel import MyResNet
from PIL import image
import matplotlib.pyplot import plt

tf = transforms.Compose([
		lambda x:Image.open(x).convert('RGB')
		transforms.Resize(128,128)
		transforms.ToTensor(),
		transforms.Normalize(mean = [0.485,0.456,0.406],
							std = [0.229,0.224,0.225]
							)
])

img_ICH_test = tf('ICH_test.jpg').unsqueeze(dim=0)
#print(img_ICH_test.shape)

img_Normal_test = tf('Normal_test.jpg').unsqueeze(dim=0)

model = MyResNet()
model.load_state_dict(torch.load('weights_MyResNet.mdl'))
print('loaded from ckpt')
model.eval()

cam_extractor = SmoothGradCAMpp(model,input_shape=(3,128,128))
# cam_extractor = GradCAMpp(model,input_shape=(3,128,128))
# cam_extractor = XGradCAM(model,input_shape=(3,128,128))
# cam_extractor = ScoreCAM(model,input_shape=(3,128,128))
# cam_extractor = SSCAM(model,input_shape=(3,128,128))
# cam_extractor =ISCAM(model,input_shape=(3,128,128))
# cam_extractor = LayerCAM(model,input_shape=(3,128,128))
  • 载入测试图片Normal_test.jpg
    python 可视化解释模型_第1张图片
  • 加载预训练权重,实例化模型
output = model(img_Normal_test)
print(output)

activation_map = cam_extractor(output.sequeeze(0).argmax().item(),output)
print(activation_map[0],activation_map[0].min(),activation_map[0].max(),activation_map[0].shape)

#fused_map = cam_extractor.fuse_cams(activation_map)
#print(fused_map[0],fused_map[0].min(),fused_map[0].max(),fused_map[0].shape)


result = overlay_mask(to_pil_image(img_Normal_test[0]),
					to_pil_image(activation_map[0],mode='F'),alpha=0.3)
plt.imshow(result)
plt.show()
  • 将模型的输出预测类别索引送入到构建的cam_extractor对象中,由于activation_map输出的是一个tuple,通过索引0取值
  • 接着用overlay_mask进行可视化效果展示,传入原图和激活map,并利用alpha参数设置一定的透明度
  • 由于输出的result是PIL格式,所以可以直接用imshow显示
    python 可视化解释模型_第2张图片
    最热的区域就是模型主要依据这部分来判断类别,这里没有指定可视化feature map的哪一层的话,就默认是全连接测上一层feature map

这个包的主页在: https://pypi.org/project/torchcam/,感兴趣的可以看看

你可能感兴趣的:(图像分类,python,深度学习,pytorch)