基于multivariate_normal,在指定点生成heatmap

热力图展示:


myplot.png

多元正太分布生成2d的heatmap, 多个中心点生成的heatmap 进行累积,最后输出多个中心点的热力图。

# coding=utf-8
import numpy as np
import matplotlib.pyplot as plt
import random
from scipy.stats import multivariate_normal

covs =[40,50,60] # 可选cov list


def CenterLabelHeatMap(img_width, img_height, c_x, c_y, sigma):
    X1 = np.linspace(1, img_width, img_width)
    Y1 = np.linspace(1, img_height, img_height)
    [X, Y] = np.meshgrid(X1, Y1)
    X = X - c_x
    Y = Y - c_y
    D2 = X * X + Y * Y
    E2 = 2.0 * sigma * sigma
    Exponent = D2 / E2
    heatmap = np.exp(-Exponent)
    return heatmap



def heatmap_n_point(points, size, scaler=30):
    """
    根据中心点生成热力值(基于multivariate_normal)
    :param points:  2d list [[point1], point2]
    :param size: tuple (h,w)
    :param scaler: 热力值缩放因子,默认为1(调整方差大小)
    :return: np.array
    """
    assert isinstance(size, tuple), "size 输入错误"
    assert isinstance(points, list) and isinstance(points[0], list), "points input error"
    xx, yy = np.meshgrid(range(size[1]), range(size[0]))

    # evaluate kernels at grid points
    xxyy = np.c_[xx.ravel(), yy.ravel()]

    kernel = 0.0
    for point in points:
        kernel += multivariate_normal(point, scaler*random.choice([90,80,100])).pdf(xxyy)
        # kernel += CenterLabelHeatMap(size[0], size[1], point[0], point[1],40)
    return kernel.reshape(size)

if __name__ == '__main__':
    import time
    s = time.time()
    img = heatmap_n_point([[40,30], [400,70], [200, 800], [80,300]], (800, 1200))
    print(time.time()-s)
    plt.imshow(img)
    plt.show()

你可能感兴趣的:(基于multivariate_normal,在指定点生成heatmap)