论文:ICCV 2017《Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization》
代码:https://github.com/yizt/Grad-CAM.pytorch/blob/master/main.py
https://github.com/jacobgil/pytorch-grad-cam/blob/master/grad-cam.py
1、首先定义并训练好CNN网络,原网络结构不用调整。假设网路训练好,得到一个best_net。
class GradCAM(object):
"""
1: gradients update when input
2: backpropatation by the high scores of class
"""
def __init__(self, net, layer_name):
self.net = net
self.layer_name = layer_name
self.feature = None
self.gradient = None
self.net.eval()
self.handlers = []
self._register_hook()
def _get_features_hook(self, module, input, output):
self.feature = output
#print("feature shape:{}".format(output.size()))
def _get_grads_hook(self, module, input_grad, output_grad):
"""
:param input_grad: tuple, input_grad[0]: None
input_grad[1]: weight
input_grad[2]: bias
:param output_grad:tuple,length = 1
:return:
"""
self.gradient = output_grad[0]
def _register_hook(self):
for (name, module) in self.net.named_modules():
if name == self.layer_name:
self.handlers.append(module.register_forward_hook(self._get_features_hook))
self.handlers.append(module.register_backward_hook(self._get_grads_hook))
def remove_handlers(self):
for handle in self.handlers:
handle.remove()
def __call__(self, inputs, index=None):
"""
:param inputs: [1,3,H,W]
:param index: class id
:return:
"""
self.net.zero_grad()
output = self.net(inputs) # [1,num_classes]
if index is None:
index = np.argmax(output.cpu().data.numpy())
target = output[0][index]
target.backward()
gradient = self.gradient[0].cpu().data.numpy() # [C,H,W]
weight = np.mean(gradient, axis=(1, 2)) # [C]
feature = self.feature[0].cpu().data.numpy() # [C,H,W]
cam = feature * weight[:, np.newaxis, np.newaxis] # [C,H,W]
cam = np.sum(cam, axis=0) # [H,W]
cam = np.maximum(cam, 0) # ReLU
# nomalization
cam -= np.min(cam)
cam /= np.max(cam)
# resize to 256*256
cam = cv2.resize(cam, (256, 256))
return cam
class GradCamPlusPlus(GradCAM):
def __init__(self, net, layer_name):
super(GradCamPlusPlus, self).__init__(net, layer_name)
def __call__(self, inputs, index=None):
"""
:param inputs: [1,3,H,W]
:param index: class id
:return:
"""
self.net.zero_grad()
output = self.net(inputs) # [1,num_classes]
if index is None:
index = np.argmax(output.cpu().data.numpy())
target = output[0][index]
target.backward()
gradient = self.gradient[0].cpu().data.numpy() # [C,H,W]
gradient = np.maximum(gradient, 0.) # ReLU
indicate = np.where(gradient > 0, 1., 0.) # 示性函数
norm_factor = np.sum(gradient, axis=(1, 2)) # [C]归一化
for i in range(len(norm_factor)):
norm_factor[i] = 1. / norm_factor[i] if norm_factor[i] > 0. else 0. # 避免除零
alpha = indicate * norm_factor[:, np.newaxis, np.newaxis] # [C,H,W]
weight = np.sum(gradient * alpha, axis=(1, 2)) # [C] alpha*ReLU(gradient)
feature = self.feature[0].cpu().data.numpy() # [C,H,W]
cam = feature * weight[:, np.newaxis, np.newaxis] # [C,H,W]
cam = np.sum(cam, axis=0) # [H,W]
# cam = np.maximum(cam, 0) # ReLU
# nomalization
cam -= np.min(cam)
cam /= np.max(cam)
# resize
cam = cv2.resize(cam, (256, 256))
return cam
class GuidedBackPropagation(object):
def __init__(self, net):
self.net = net
for (name, module) in self.net.named_modules():
if isinstance(module, nn.ReLU):
module.register_backward_hook(self.backward_hook)
self.net.eval()
@classmethod
def backward_hook(cls, module, grad_in, grad_out):
"""
:param module:
:param grad_in: tuple,length=1
:param grad_out: tuple,length=1
:return: tuple(new_grad_in,)
"""
return torch.clamp(grad_in[0], min=0.0),
def __call__(self, inputs, index=None):
"""
:param inputs: [1,3,H,W]
:param index: class_id
:return:
"""
self.net.zero_grad()
output = self.net(inputs) # [1,num_classes]
if index is None:
index = np.argmax(output.cpu().data.numpy())
target = output[0][index]
target.backward()
return inputs.grad[0] # [3,H,W]
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
cv2.imwrite("cam.jpg", np.uint8(255 * cam))
root='/data/fjsdata/qtsys/img/sz.002509-20200325.png'
img_list = []
img_list.append( cv2.resize(cv2.imread(root).astype(np.float32), (256, 256)))#(256, 256) is the model input size
inputs = torch.from_numpy(np.array(img_list)).type(torch.FloatTensor).cuda()
# Grad-CAM
#grad_cam = GradCAM(net=best_net, layer_name='conv3')
#mask = grad_cam(inputs.permute(0, 3, 1, 2)) # cam mask
#show_cam_on_image(img_list[0], mask)
#grad_cam.remove_handlers()
# Grad-CAM++
#grad_cam_plus_plus = GradCamPlusPlus(net=best_net, layer_name='conv3')
#mask_plus_plus = grad_cam_plus_plus(inputs.permute(0, 3, 1, 2)) # cam mask
#show_cam_on_image(img_list[0], mask)
#grad_cam_plus_plus.remove_handlers()
# GuidedBackPropagation
gbp = GuidedBackPropagation(best_net)
inputs = inputs.requires_grad_(True)
inputs.grad.zero_()
grad = gbp(inputs.permute(0, 3, 1, 2))
print(grad)
最后GuidedBackPropagation没完全调通,详细阅读论文后再处理。前面Grad-CAM 和Grad-CAM++可以。