使用GradCam和LIME可视化EfficientNet

1.GradCam实现
参考https://github.com/sidml/EfficientNet-GradCam-Visualization

git clone https://github.com/FrancescoSaverioZuppichini/A-journey-into-Convolutional-Neural-Network-visualization-.git
cd A-journey-into-Convolutional-Neural-Network-visualization-/
from torchvision.models import *
from visualisation.core.utils import device
from efficientnet_pytorch import EfficientNet
import glob
import matplotlib.pyplot as plt
import numpy as np
import torch 
from utils import *
import PIL.Image
import cv2

from visualisation.core.utils import device 
from visualisation.core.utils import image_net_postprocessing

from torchvision.transforms import ToTensor, Resize, Compose, ToPILImage
from visualisation.core import *
from visualisation.core.utils import image_net_preprocessing

# for animation

from IPython.display import Image
from matplotlib.animation import FuncAnimation
from collections import OrderedDict


def efficientnet(model_name='efficientnet-b0',**kwargs):
    return EfficientNet.from_pretrained(model_name).to(device)

max_img = 5
path = '/home/zls/Documents/pytorch_transfer_learning/data+extractor/fundus_data/test'
interesting_categories = ['Abnormal','normal']

images = [] 
for category_name in interesting_categories:
    image_paths = glob.glob(f'{path}/{category_name}/*')
    category_images = list(map(lambda x: PIL.Image.open(x), image_paths[:max_img]))
    images.extend(category_images)

inputs  = [Compose([Resize((224,224)), ToTensor(), image_net_preprocessing])(x).unsqueeze(0) for x in images]  # add 1 dim for batch
inputs = [i.to(device) for i in inputs]

model_outs = OrderedDict()
model_instances = [alexnet,densenet121, 
                  lambda pretrained:efficientnet(model_name='efficientnet-b0'),
                  lambda pretrained:efficientnet(model_name='efficientnet-b4')]

model_names = [m.__name__ for m in model_instances]
model_names[-2],model_names[-1] = 'EB0','EB4'
print(model_names)
print(model_instances)
images = list(map(lambda x: cv2.resize(np.array(x),(224,224)),images)) # resize i/p img

for name,model in zip(model_names,model_instances):
    #print("s12")
    print(name)
    module = model(pretrained=True).to(device)
    module.eval()

    vis = GradCam(module, device)
    print(vis)
    model_outs[name] = list(map(lambda x: tensor2img(vis(x, None,postprocessing=image_net_postprocessing)[0]), inputs))
    del module
    torch.cuda.empty_cache()



# create a figure with two subplots
fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1,5,figsize=(20,20))
axes = [ax2, ax3, ax4, ax5]
    
def update(frame):
    all_ax = []
    ax1.set_yticklabels([])
    ax1.set_xticklabels([])
    ax1.text(1, 1, 'Orig. Im', color="white", ha="left", va="top",fontsize=30)
    all_ax.append(ax1.imshow(images[frame]))
    for i,(ax,name) in enumerate(zip(axes,model_outs.keys())):
        ax.set_yticklabels([])
        ax.set_xticklabels([])        
        ax.text(1, 1, name, color="white", ha="left", va="top",fontsize=20)
        all_ax.append(ax.imshow(model_outs[name][frame], animated=True))

    return all_ax

ani = FuncAnimation(fig, update, frames=range(len(images)), interval=1000, blit=True)
model_names = [m.__name__ for m in model_instances]
model_names = ', '.join(model_names)
fig.tight_layout()
ani.save('../compare_arch.gif', writer='imagemagick') 

2.使用LIME实现:


```python
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import os, json

import torch
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn.functional as F

def get_image(path):
    with open(os.path.abspath(path), 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB') 
        
img = get_image('./images/test/0/60.JPG')
plt.imshow(img)

# resize and take the center part of image to what our model expects
def get_input_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])       
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])    

    return transf

def get_input_tensors(img):
    transf = get_input_transform()
    # unsqeeze converts single image to batch of 1
    return transf(img).unsqueeze(0)
model = torch.load('/home/zls/Documents/pytorch_transfer_learning/pretrained_efficient_pytorch/efficient7_augment_pretrained.pth',map_location='cpu')


idx2label, cls2label, cls2idx = [], {}, {}
with open(os.path.abspath('./imagenet_class_index.json'), 'r') as read_file:
    class_idx = json.load(read_file)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
    cls2label = {class_idx[str(k)][0]: class_idx[str(k)][1] for k in range(len(class_idx))}
    cls2idx = {class_idx[str(k)][0]: k for k in range(len(class_idx))}  

img_t = get_input_tensors(img)
model.eval()
logits = model(img_t)

probs = F.softmax(logits, dim=1)
probs5 = probs.topk(1)
tuple((p,c, idx2label[c]) for p, c in zip(probs5[0][0].detach().numpy(), probs5[1][0].detach().numpy()))

def get_pil_transform(): 
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224)
    ])    

    return transf

def get_preprocess_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])     
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])    

    return transf    

pill_transf = get_pil_transform()
preprocess_transform = get_preprocess_transform()

def batch_predict(images):
    model.eval()
    batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    batch = batch.to(device)
    
    logits = model(batch)
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()
test_pred = batch_predict([pill_transf(img)])
test_pred.squeeze().argmax()


from lime import lime_image
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(np.array(pill_transf(img)), 
                                         batch_predict, # classification function
                                         top_labels=1, 
                                         hide_color=0, 
                                         num_samples=2) # number of images that will be sent to classification function
from skimage.segmentation import mark_boundaries  
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=26, hide_rest=False)
img_boundry1 = mark_boundaries(temp/255.0, mask)
plt.imshow(img_boundry1)
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=16, hide_rest=False)
img_boundry2 = mark_boundaries(temp/255.0, mask,color=(1,1,0))
plt.imshow(img_boundry2)                                       


你可能感兴趣的:(code)