三、生成heatmap(二)基于patch画热力图

太长不看版

1、将PIL.Image转换成批训练的DataLoader

  • 为什么一批一批进去处理

2、载入网络( torch.load('Resnet.pkl') ),并将数据放入网络,通过 outputs = model(images) 得到预测值,放在对应的对象中

3、按对象的数字顺序排序,生成热力图保存

 

耐心看完版

事先准备:

test_path = r'C:\Users\BME419\Desktop\resnet\slide\patch'
background_path = r'E:\WSI\CAMELYON16\Processed\patch-based-classification\raw-data\test\tumor091heatmaps  patches\none'

pre_savename = r'C:\Users\BME419\Desktop\resnet\slide\heatmap'
savename = os.path.join(pre_savename, 'heatmap')   #保存名字为'heatmap'

batch_size = 64
classes = ['negative','positive']
global positive_prob
positive_prob = []                      #positive_prob类型为list
def reload_net(model_name):    #可选择四种网络
    if  model_name == "VGG":
        trainednet = torch.load('VGGnet.pkl')
    elif model_name == "Google":
        trainednet = torch.load('Google.pkl')
    elif model_name == "Res":
        trainednet = torch.load('Resnet.pkl')
    elif model_name == "Alex":
        trainednet = torch.load('Alexnet.pkl')
    return trainednet

 主函数:

test_data("Res", 224)     #内涵调用heatmap_gen()

具体函数 ↓ ↓ 

def test_data(model_name, input_size):       # input_size = 224
    # 先转换成 torch 能识别的 Dataset
    testset = torchvision.datasets.ImageFolder(test_path,
                            transform = transforms.Compose([
                            transforms.Resize((input_size, input_size)),
                            # 将图片缩放到指定大小(h,w)或者保持长宽比并缩放最短的边到int大小
                            transforms.ToTensor(),
                            ]))      
    # 把 dataset 放入 DataLoader                                
    testloader = torch.utils.data.DataLoader(testset, batch_size = batch_size,            
                                             shuffle = False, num_workers = 0)    # shuffle = False(不打乱),按顺序取patch,否则 = True,随机取
    model = reload_net(model_name)   # load模型,函数具体定义见“事先准备”   
    model.eval()         # 把BN和Dropout固定住,不会取平均,而是用训练好的值
三、生成heatmap(二)基于patch画热力图_第1张图片 产生的Dataset和DataLoader
    #将testset.imgs从tuple变为list(用于append网络产生的概率(outputs))
    for j in range(len(testset.imgs)):      
        testset.imgs[j] = list(testset.imgs[j])
    #利用训练好的网络预测patch概率
    for i, data in enumerate(testloader, 0):
        images, labels = data
        print(labels)
        images = Variable(images, requires_grad=True)   # 转换数据格式用Variable
        if torch.cuda.is_available():
            images = images.cuda()       # 转换数据格式用Variable
            model.cuda()
        with torch.no_grad():
            outputs = model(images)
            outputs = outputs.cpu().numpy()        # 将outputs由GPU转化为numpy
            positive_prob.extend(outputs[:, 1])    # positive_prob(list):各个patch肿瘤positive的概率
            if i <= (len(testset.imgs) / 64-1):    # 将testset.imgs与各自概率一一对应
                for j in range(64):
                    testset.imgs[j + 64 * i].append(outputs[j, 1])
            else:
                for j in range(len(testset.imgs) - 64 * i):
                    testset.imgs[j + 64 * i].append(outputs[j, 1])

    probmin = np.min(positive_prob)  # 用于背景显示
    heatmap_gen(positive_prob, testset,input_size,probmin) #调用heatmap_gen(),具体见↓
三、生成heatmap(二)基于patch画热力图_第2张图片 0表示'negative', 1表示'positive'

def heatmap_gen(positive_prob,testset,input_size,probmin):
    fig = plt.figure(figsize=(172, 153))   #figsize以英寸为单位 width=172,height=153
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)

    # 读背景patch
    testset2 = torchvision.datasets.ImageFolder(background_path,
                            transform=transforms.Compose([
                            transforms.Resize((input_size, input_size)),
                          # 将图片缩放到指定大小(h,w)或者保持长宽比并缩放最短的边到int大小
                            transforms.ToTensor(),   # 把一个取值范围是[0,255]的PIL.Image 转换成 Tensor
                            ]))
    for n in range(len(testset2.imgs)):   # 将背景patch概率减10,并append到testset
        testset2.imgs[n] = list(testset2.imgs[n])
        testset2.imgs[n].append(probmin-10)     
        testset.imgs.append(testset2.imgs[n])   
 
    #使路径减成数字,排序
    for n in range(len(testset.imgs)):
        testset.imgs[n][0] = os.path.basename(testset.imgs[n][0])
    testset.imgs.sort(key=lambda x: int(x[0][:-6]))

    positive_prob = [None] * len(testset.imgs)
    positive_prob = np.array(positive_prob)

    # 按照片名从小到大生成positive_prob
    for n in range(len(testset.imgs)):
        positive_prob[n] = testset.imgs[n][2]














    probmin = np.min(positive_prob)
    probmax = np.max(positive_prob)
    heatmap = ((positive_prob - probmin) / (probmax - probmin + 0.000001)) * 255  # float在[0,1]之间,转换成0-255
    heatmap = heatmap.astype(np.uint8)  # 转成unit8
    heatmap = heatmap.reshape(153, 172)      # y为行,x为列
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 生成heat map
    heatmap = heatmap[:, :, ::-1]  # 注意cv2(BGR)和matplotlib(RGB)通道是相反的
    plt.imshow(heatmap)
    fig.savefig(savename, dpi=10)  #dpi指每英寸有多少个像素,save路径见最上面

 

你可能感兴趣的:(PBL,Pytorch,heatmap)