Grad-CAM (CNN可视化) Python示例

论文:ICCV 2017《Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization》

代码:https://github.com/yizt/Grad-CAM.pytorch/blob/master/main.py
           https://github.com/jacobgil/pytorch-grad-cam/blob/master/grad-cam.py

1、首先定义并训练好CNN网络,原网络结构不用调整。假设网路训练好,得到一个best_net。

Grad-CAM (CNN可视化) Python示例_第1张图片

class GradCAM(object):
    """
    1: gradients update when input
    2: backpropatation by the high scores of class
    """

    def __init__(self, net, layer_name):
        self.net = net
        self.layer_name = layer_name
        self.feature = None
        self.gradient = None
        self.net.eval()
        self.handlers = []
        self._register_hook()

    def _get_features_hook(self, module, input, output):
        self.feature = output
        #print("feature shape:{}".format(output.size()))

    def _get_grads_hook(self, module, input_grad, output_grad):
        """
        :param input_grad: tuple, input_grad[0]: None
                                   input_grad[1]: weight
                                   input_grad[2]: bias
        :param output_grad:tuple,length = 1
        :return:
        """
        self.gradient = output_grad[0]

    def _register_hook(self):
        for (name, module) in self.net.named_modules():
            if name == self.layer_name:
                self.handlers.append(module.register_forward_hook(self._get_features_hook))
                self.handlers.append(module.register_backward_hook(self._get_grads_hook))

    def remove_handlers(self):
        for handle in self.handlers:
            handle.remove()

    def __call__(self, inputs, index=None):
        """
        :param inputs: [1,3,H,W]
        :param index: class id
        :return:
        """
        self.net.zero_grad()
        output = self.net(inputs)  # [1,num_classes]
        if index is None:
            index = np.argmax(output.cpu().data.numpy())
        target = output[0][index]
        target.backward()

        gradient = self.gradient[0].cpu().data.numpy()  # [C,H,W]
        weight = np.mean(gradient, axis=(1, 2))  # [C]

        feature = self.feature[0].cpu().data.numpy()  # [C,H,W]

        cam = feature * weight[:, np.newaxis, np.newaxis]  # [C,H,W]
        cam = np.sum(cam, axis=0)  # [H,W]
        cam = np.maximum(cam, 0)  # ReLU

        # nomalization
        cam -= np.min(cam)
        cam /= np.max(cam)
        # resize to 256*256
        cam = cv2.resize(cam, (256, 256))
        return cam


class GradCamPlusPlus(GradCAM):
    def __init__(self, net, layer_name):
        super(GradCamPlusPlus, self).__init__(net, layer_name)

    def __call__(self, inputs, index=None):
        """
        :param inputs: [1,3,H,W]
        :param index: class id
        :return:
        """
        self.net.zero_grad()
        output = self.net(inputs)  # [1,num_classes]
        if index is None:
            index = np.argmax(output.cpu().data.numpy())
        target = output[0][index]
        target.backward()

        gradient = self.gradient[0].cpu().data.numpy()  # [C,H,W]
        gradient = np.maximum(gradient, 0.)  # ReLU
        indicate = np.where(gradient > 0, 1., 0.)  # 示性函数
        norm_factor = np.sum(gradient, axis=(1, 2))  # [C]归一化
        for i in range(len(norm_factor)):
            norm_factor[i] = 1. / norm_factor[i] if norm_factor[i] > 0. else 0.  # 避免除零
        alpha = indicate * norm_factor[:, np.newaxis, np.newaxis]  # [C,H,W]

        weight = np.sum(gradient * alpha, axis=(1, 2))  # [C]  alpha*ReLU(gradient)

        feature = self.feature[0].cpu().data.numpy()  # [C,H,W]

        cam = feature * weight[:, np.newaxis, np.newaxis]  # [C,H,W]
        cam = np.sum(cam, axis=0)  # [H,W]
        # cam = np.maximum(cam, 0)  # ReLU

        # nomalization
        cam -= np.min(cam)
        cam /= np.max(cam)
        # resize 
        cam = cv2.resize(cam, (256, 256))
        return cam
    
class GuidedBackPropagation(object):

    def __init__(self, net):
        self.net = net
        for (name, module) in self.net.named_modules():
            if isinstance(module, nn.ReLU):
                module.register_backward_hook(self.backward_hook)
                
        self.net.eval()

    @classmethod
    def backward_hook(cls, module, grad_in, grad_out):
        """
        :param module:
        :param grad_in: tuple,length=1
        :param grad_out: tuple,length=1
        :return: tuple(new_grad_in,)
        """
        return torch.clamp(grad_in[0], min=0.0),

    def __call__(self, inputs, index=None):
        """
        :param inputs: [1,3,H,W]
        :param index: class_id
        :return:
        """
        self.net.zero_grad()
        output = self.net(inputs)  # [1,num_classes]
        if index is None:
            index = np.argmax(output.cpu().data.numpy())
        target = output[0][index]

        target.backward()

        return inputs.grad[0]  # [3,H,W]
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    cv2.imwrite("cam.jpg", np.uint8(255 * cam))

root='/data/fjsdata/qtsys/img/sz.002509-20200325.png'
img_list = []
img_list.append( cv2.resize(cv2.imread(root).astype(np.float32), (256, 256)))#(256, 256) is the model input size
inputs = torch.from_numpy(np.array(img_list)).type(torch.FloatTensor).cuda()
# Grad-CAM    
#grad_cam = GradCAM(net=best_net, layer_name='conv3')
#mask = grad_cam(inputs.permute(0, 3, 1, 2))  # cam mask
#show_cam_on_image(img_list[0], mask)
#grad_cam.remove_handlers()

# Grad-CAM++
#grad_cam_plus_plus = GradCamPlusPlus(net=best_net, layer_name='conv3')
#mask_plus_plus = grad_cam_plus_plus(inputs.permute(0, 3, 1, 2))  # cam mask
#show_cam_on_image(img_list[0], mask)
#grad_cam_plus_plus.remove_handlers()

# GuidedBackPropagation
gbp = GuidedBackPropagation(best_net)
inputs = inputs.requires_grad_(True)
inputs.grad.zero_() 
grad = gbp(inputs.permute(0, 3, 1, 2))
print(grad)

 最后GuidedBackPropagation没完全调通,详细阅读论文后再处理。前面Grad-CAM 和Grad-CAM++可以。

 

你可能感兴趣的:(机器学习专栏,python专栏)