pytorch特征图可视化

本文基于https://blog.csdn.net/GrayOnDream/article/details/99090247的博客进行了进一步的修改

因为上述博客的网络层顺序是从network文件顺序读取class的,不适用于我的网络(我的网络是定义了很多基础模块然后拼接起来的)。因为大多数人定义网络的顺序和真实运行的顺序不太一样,所以我在此基础上做了修改

完整代码如下,网络是一个类似u-net的网络

import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
from PIL import Image
import cv2


class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers

    def forward(self, x):
        outputs = {}
        # for name, module in self.submodule._modules.items():
        #     if "fc" in name:
        #         x = x.view(x.size(0), -1)
        #
        #     x = module(x)
        #     print(name)
        #     if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
        #         outputs[name] = x

################修改成自己的网络,直接在network.py中return你想输出的层

        x1,x2,x3,x4,x5,x6,up7,merge7,conv7,up8,merge8,conv8,up9,merge9,conv9,up10,merge10,conv10,up11,merge11,conv11,conv12,mask,x2_0 = self.submodule(x)
        outputs["x1"] = x1
        outputs["x2"] = x2
        outputs["x3"] = x3

        outputs["x4"] = x4
        outputs["x5"] = x5
        outputs["x6"] = x6

        outputs["up7"] = up7
        outputs["merge7"] = merge7
        outputs["conv7"] = conv7

        outputs["up8"] = up8
        outputs["merge8"] = merge8
        outputs["conv8"] = conv8

        outputs["up9"] = up9
        outputs["merge9"] = merge9
        outputs["conv9"] = conv9

        outputs["up10"] = up10
        outputs["merge10"] = merge10
        outputs["conv10"] = conv10

        outputs["up11"] = up11
        outputs["merge11"] = merge11
        outputs["conv11"] = conv11

        outputs["conv12"] = conv12
        outputs["mask"] = mask
        outputs["x2_0"] = x2_0



        # return outputs
        return outputs


def get_picture(pic_name, transform):
    img = skimage.io.imread(pic_name)
    img = skimage.transform.resize(img, (224, 224))
    img = np.asarray(img, dtype=np.float32)
    return transform(img)


def make_dirs(path):
    if os.path.exists(path) is False:
        os.makedirs(path)


def get_feature():
    pic_dir = './input_images/1.jpg' #往网络里输入一张图片
    transform = transforms.ToTensor()
    img = get_picture(pic_dir, transform)
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    # 插入维度
    img = img.unsqueeze(0)

    img = img.to(device)

    net = torch.load('./models/1_70/19.pth')
    net.to(device)
    # exact_list = None
    exact_list = ['conv1_block',""]
    dst = './features' #保存的路径
    therd_size = 256 #有些图太小,会放大到这个尺寸

    myexactor = FeatureExtractor(net, exact_list)
    outs = myexactor(img)
    for k, v in outs.items():
        features = v[0]
        iter_range = features.shape[0]
        for i in range(iter_range):
            # plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
            if 'fc' in k:
                continue

            feature = features.data.cpu().numpy()
            feature_img = feature[i, :, :]
            feature_img = np.asarray(feature_img * 255, dtype=np.uint8)

            dst_path = os.path.join(dst, k)

            make_dirs(dst_path)
            feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
            if feature_img.shape[0] < therd_size:
                tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
                tmp_img = feature_img.copy()
                tmp_img = cv2.resize(tmp_img, (therd_size, therd_size), interpolation=cv2.INTER_NEAREST)
                cv2.imwrite(tmp_file, tmp_img)

            dst_file = os.path.join(dst_path, str(i) + '.png')
            cv2.imwrite(dst_file, feature_img)


if __name__ == '__main__':
    get_feature()

最后的文件夹内容是这样的:

pytorch特征图可视化_第1张图片

可视化效果截图

pytorch特征图可视化_第2张图片

pytorch特征图可视化_第3张图片

 

你可能感兴趣的:(pytorch,python,日常笔记)