【python】热力图绘制: intensity_heatmap,density_heatmap

在使用深度学习时,尤其生成模型,我们想知道生成误差或者重构误差出现在什么地方,这时我们就需要用热力图来可视化。在可视化中,除了要显示损失的heatmap以外,我们还要与原始图像进行一个加权,使我们清楚地看到重构误差损失出现在什么位置。

本文提供两个函数解决这个问题:intensity_heatmap和density_heatmap

  • intensity_heatmap:损失强度热力图。像素值的亮暗显示损失的强与弱。

  • density_heatmap:损失密度热力图。只统计损失强度超过一定阈值的点,并且按阈值二值化。然后对这些强损失点添加半径,可视化强损失点的密度。

代码如下:

首先,需要安装三个函数包 cv2,PIL,pyheatmap,其中pyheatmap的安装如下:

pip install pyheatmap
# conda install pyheatmap  会报错,conda中不存在这个程序包

热力图的定义,使用代码如下:

import cv2
from PIL import Image
from pyheatmap.heatmap import HeatMap

def intensity_heatmap(background_img, intensity_map, blue_mask_weight=0.3, heat_map_weight=0.5):
	"""
	
	:param background_img: 背景图
	:param intensity_map: 损失强度图
	:param blue_mask_weight: 
	:param heat_map_weight: 
	:return: 
	"""
	assert background_img.ndim == 3
	assert intensity_map.ndim == 2 or intensity_map.ndim == 3
	assert background_img.shape[0:2] == intensity_map.shape[0:2]
	background_img_norm = ((background_img - background_img.min()) / (background_img.max() - background_img.min())
	                       * 255).astype(np.uint8)
	intensity_map_norm = ((intensity_map - intensity_map.min()) / (intensity_map.max() - intensity_map.min())
	                      * 255).astype(np.uint8)
	
	# 背景图需要加一个蓝色掩码,以更好的显示热力图
	blue_mask = cv2.rectangle(background_img.copy(), (0, 0), background_img.shape[0:2], (0, 0, 255), -1).astype(np.uint8)
	heatmap_base = cv2.addWeighted(blue_mask, blue_mask_weight, background_img_norm, 1 - blue_mask_weight, 0).astype(np.uint8)
	
	# 获取热力图
	hotmap_blue = np.array(cv2.applyColorMap(intensity_map_norm, cv2.COLORMAP_HOT))  # 是蓝色的3通道heatmap
	hotmap_red = cv2.cvtColor(hotmap_blue, cv2.COLOR_RGB2BGR)  # 由蓝色heatmap转为红色heatmap
	
	intensity_hotmap_img = cv2.addWeighted(hotmap_red, heat_map_weight, heatmap_base, 1 - heat_map_weight, 0)
	
	return intensity_hotmap_img

def density_heatmap(background_img, intensity_map, intensity_thr=50, density_radius=5, heat_map_weight=0.5):
	"""
	
	:param background_img: 背景图
	:param intensity_map: 损失map
	:param intensity_thr: 强损失阈值
	:param density_radius: 密度半径
	:param heat_map_weight: 
	:return: 
	"""
	assert background_img.ndim == 3
	assert intensity_map.ndim == 2 or intensity_map.ndim == 3
	assert background_img.shape[0:2] == intensity_map.shape[0:2]
	background_img_norm = ((background_img - background_img.min()) / (background_img.max() - background_img.min())
	                       * 255).astype(np.uint8)
	intensity_map_norm = ((intensity_map - intensity_map.min()) / (intensity_map.max() - intensity_map.min())
	                      * 255).astype(np.uint8)
	hotmap_points = []
	for i in range(intensity_map.shape[0]):  ##获取异常点
		for j in range(intensity_map.shape[1]):
			if intensity_map_norm[i, j] > intensity_thr:
				hotmap_points.append([j, i])
	
	background = Image.new("RGB", (intensity_map.shape[1], intensity_map.shape[0]), color=0)  ## heatmap基底
	hm = HeatMap(hotmap_points)
	density_heatmap_img = np.array(hm.heatmap(base=background, r=density_radius))
	density_heatmap_img = cv2.addWeighted(density_heatmap_img, heat_map_weight, background_img_norm,
	                                      1 - heat_map_weight, 0)
	
	return density_heatmap_img

if __name__ == "__main__":
    img_path = "gt_img.jpg" 
    loss_map_path = "loss_map.npy" 
    
    gt_img = cv2.imread(img_path) # shape = [h,w,c]
    loss_map = np.load(loss_map_path) # shape = [h,w,1] 或 shape = [h,w]
    
    density_heatmap_img = density_heatmap(gt_img,loss_map,intensity_thr=50,density_radius=10,heat_map_weight=0.5)
    intensity_hotmap_img = intensity_heatmap(gt_img,loss_map,blue_mask_weight = 0.1,heat_map_weight = 0.6)
    
    plt.subplot(121)
    plt.imshow(density_heatmap_img)
    plt.subplot(122)
    plt.imshow(intensity_hotmap_img)
    plt.show()

以下使用示例引用自:https://blog.csdn.net/weixin_43289135/article/details/104651047 

# -*- coding: utf-8 -*-
from pyheatmap.heatmap import HeatMap
import numpy as np
x = [10,20,30,80,50,40,70,90,60,30,40,50]
y = [50,40,20,30,60,90,60,20,60,10,70,50]
data = []
for i in range(0,11):
    temperature = [int(x[i]), int(y[i]),1]# 设置每个像素点处的温度
    data.append(temperature)
heat = HeatMap(data)
heat.clickmap(save_as="1.png") #点击图
heat.heatmap(save_as="2.png") #热图

你可能感兴趣的:(python,热力图,heatmap,hotmap)