Python之相关性分析热力图

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr


def randomPlot():
    '''
    构造随机数矩阵来绘制热力图
    '''
    data = np.random.rand(8, 8)
    print(data)
    fig, ax = plt.subplots(figsize=(10, 10))
    key_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
    sns.heatmap(pd.DataFrame(np.round(data, 4), columns=key_list, index=key_list), annot=True, vmax=1, vmin=0,
                xticklabels=True,
                yticklabels=True, square=True, cmap="YlGnBu")
    ax.set_title(' Heat Map ', fontsize=18)
    ax.set_ylabel('Y', fontsize=18)
    ax.set_xlabel('X', fontsize=18)
    plt.savefig('Random.png')


def dataPlot():
    '''
    基于相关性系数计算结果来绘制
    '''
    data1 = [[1,3,1,1/5,1/2,3],[1/3,1,2,1/3,1,2],[1,1/2,1,1/5,1/2,5],[5,3,5,1,3,7],[2,1,2,1/3,1,3],[1/3,1/2,1/5,1/7,1/3,1]]
    data2 = [[0.9999999999999999, 0.8345554576233073, 0.9243854856819417, 0.9330027648545496, 0.9482209505391218,
              -0.6482501837911674, -0.8818546914603946, -0.6482501837911674],
             [0.8345554576233073, 0.9999999999999999, 0.9003585423988755, 0.7649830893396171, 0.7979521001324441,
              -0.4691190784521079, -0.6480284658216039, -0.4691190784521079],
             [0.9243854856819417, 0.9003585423988756, 1.0, 0.8506386730833658, 0.8830994621736679, -0.7006298478257242,
              -0.7676395373967932, -0.7006298478257242],
             [0.9330027648545496, 0.7649830893396171, 0.8506386730833659, 0.9999999999999999, 0.9951726440650351,
              -0.5278605167290854, -0.9273089460158745, -0.5278605167290854],
             [0.9482209505391219, 0.7979521001324442, 0.8830994621736679, 0.9951726440650353, 1.0, -0.5724194583833783,
              -0.9112198965249181, -0.5724194583833783],
             [-0.6482501837911674, -0.46911907845210793, -0.7006298478257242, -0.5278605167290855, -0.5724194583833783,
              1.0, 0.49515922475047763, 1.0],
             [-0.8818546914603947, -0.6480284658216039, -0.7676395373967932, -0.9273089460158744, -0.9112198965249181,
              0.49515922475047763, 1.0, 0.49515922475047763],
             [-0.6482501837911674, -0.46911907845210793, -0.7006298478257242, -0.5278605167290855, -0.5724194583833783,
              1.0, 0.49515922475047763, 1.0]]
    data3 = [[1.0, 0.6447733742285494, 0.7648449086941359, 0.7781871747188993, 0.8077705036447606, -0.45314582645069557,
              -0.7063398069618333, -0.45314582645069557],
             [0.6447733742285494, 1.0, 0.723107089412363, 0.5724058545087037, 0.6071562414937144, -0.3151362329600627,
              -0.47203186715609546, -0.3151362329600627],
             [0.7648449086941359, 0.7231070894123629, 1.0, 0.6636168802979515, 0.7036503531839524, -0.5094145396347782,
              -0.5801992046091554, -0.5094145396347782],
             [0.7781871747188993, 0.5724058545087037, 0.6636168802979514, 1.0, 0.9417388451103401, -0.35110754201106087,
              -0.818036667286979, -0.35110754201106087],
             [0.8077705036447606, 0.6071562414937144, 0.7036503531839524, 0.9417388451103401, 1.0, -0.3871446641295609,
              -0.7851151107221745, -0.3871446641295609],
             [-0.45314582645069557, -0.3151362329600627, -0.5094145396347782, -0.35110754201106087, -0.3871446641295609,
              1.0, 0.331222719795913, 1.0],
             [-0.7063398069618333, -0.4720318671560954, -0.5801992046091554, -0.818036667286979, -0.7851151107221744,
              0.331222719795913, 1.0, 0.331222719795913],
             [-0.45314582645069557, -0.3151362329600627, -0.5094145396347782, -0.35110754201106087, -0.3871446641295609,
              1.0, 0.331222719795913, 1.0]]
    data1, data2, data3 = np.array(data1), np.array(data2), np.array(data3)
    fig, ax = plt.subplots(figsize=(10, 10))
    key_list = ['A', 'B', 'C', 'D', 'E', 'F']
    sns.heatmap(pd.DataFrame(np.round(data1, 4), columns=key_list, index=key_list), annot=True, vmax=4, vmin=0,
                xticklabels=True,
                yticklabels=True, square=True, cmap="YlGnBu")
    ax.set_title(' Heat Map ', fontsize=18)
    ax.set_ylabel('Y', fontsize=18)
    ax.set_xlabel('X', fontsize=18)
    #plt.savefig('data1.png')
    #plt.savefig('data2.png')
    plt.savefig('data1.png')


if __name__ == '__main__':
    randomPlot()
    dataPlot()

运行效果如下:
Python之相关性分析热力图_第1张图片Python之相关性分析热力图_第2张图片

你可能感兴趣的:(Python之相关性分析热力图)