Pytorch自带Resnet50特征图heat map热力图可视化

代码如下: 

import cv2
import time
import os
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np

#savepath = r'features_whitegirl'
savepath =r"D:/hexiaojuan/coding/lianxi/hatmap_RGB"
#savepath='D:\hexiaojuan\coding\Unsupervised-Person-Re-identification-Clustering-and-Fine-tuning-master\dataset\DukeMTMC-reID\DukeMTMC-reID\bounding_box_test\0030_c1_f0056923'
if not os.path.exists(savepath):
    os.mkdir(savepath)


def draw_features(width, height, x, savename):
    tic = time.time()
    fig = plt.figure(figsize=(16, 16))
    fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, wspace=0.05, hspace=0.05)
    for i in range(width * height):
        plt.subplot(height, width, i + 1)
        plt.axis('off')
        img = x[0, i, :, :]
        pmin = np.min(img)
        pmax = np.max(img)
        img = ((img - pmin) / (pmax - pmin + 0.000001)

你可能感兴趣的:(Pytorch自带Resnet50特征图heat map热力图可视化)