关于热力图的绘制代码

import cv2
import time
import os
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np

def draw_features(width, height, x, savename):
    tic = time.time()
    fig = plt.figure(figsize=(80, 80))
    fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05)
    for i in range(width * height):
        plt.subplot(height, width, i + 1)
        plt.axis('off')
        img = x[0, i, :, :]
        pmin = np.min(img)
        pmax = np.max(img)
        img = ((img - pmin) / (pmax - pmin + 0.000001)) * 255  # float在[0,1]之间,转换成0-255
        img = img.astype(np.uint8)  # 转成unit8
        img = cv2.applyColorMap(img, cv2.COLORMAP_JET)  # 生成heat map
        img = img[:, :, ::-1]  # 注意cv2(BGR)和matplotlib(RGB)通道是相反的
        plt.imshow(img)
        print("{}/{}".format(i, width * height))
    fig.savefig(savename, dpi=100)
    fig.clf()
    plt.close()
    print("time:{}".format(time.time() - tic))

使用的时候直接在输出特征x之后加上下面这一句即可,需要自己设置保存路径。

draw_features(2, 2, x.cpu().numpy()[:,2:10,:,:], "save_in")
#其中的2,2是图中几行几列的热力图,他们的乘积只要小于后面2:10中包含的通道数即可

你可能感兴趣的:(pytorch,人工智能,图像处理)