本文基于https://blog.csdn.net/GrayOnDream/article/details/99090247的博客进行了进一步的修改
因为上述博客的网络层顺序是从network文件顺序读取class的,不适用于我的网络(我的网络是定义了很多基础模块然后拼接起来的)。因为大多数人定义网络的顺序和真实运行的顺序不太一样,所以我在此基础上做了修改
完整代码如下,网络是一个类似u-net的网络
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
from PIL import Image
import cv2
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = {}
# for name, module in self.submodule._modules.items():
# if "fc" in name:
# x = x.view(x.size(0), -1)
#
# x = module(x)
# print(name)
# if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
# outputs[name] = x
################修改成自己的网络,直接在network.py中return你想输出的层
x1,x2,x3,x4,x5,x6,up7,merge7,conv7,up8,merge8,conv8,up9,merge9,conv9,up10,merge10,conv10,up11,merge11,conv11,conv12,mask,x2_0 = self.submodule(x)
outputs["x1"] = x1
outputs["x2"] = x2
outputs["x3"] = x3
outputs["x4"] = x4
outputs["x5"] = x5
outputs["x6"] = x6
outputs["up7"] = up7
outputs["merge7"] = merge7
outputs["conv7"] = conv7
outputs["up8"] = up8
outputs["merge8"] = merge8
outputs["conv8"] = conv8
outputs["up9"] = up9
outputs["merge9"] = merge9
outputs["conv9"] = conv9
outputs["up10"] = up10
outputs["merge10"] = merge10
outputs["conv10"] = conv10
outputs["up11"] = up11
outputs["merge11"] = merge11
outputs["conv11"] = conv11
outputs["conv12"] = conv12
outputs["mask"] = mask
outputs["x2_0"] = x2_0
# return outputs
return outputs
def get_picture(pic_name, transform):
img = skimage.io.imread(pic_name)
img = skimage.transform.resize(img, (224, 224))
img = np.asarray(img, dtype=np.float32)
return transform(img)
def make_dirs(path):
if os.path.exists(path) is False:
os.makedirs(path)
def get_feature():
pic_dir = './input_images/1.jpg' #往网络里输入一张图片
transform = transforms.ToTensor()
img = get_picture(pic_dir, transform)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# 插入维度
img = img.unsqueeze(0)
img = img.to(device)
net = torch.load('./models/1_70/19.pth')
net.to(device)
# exact_list = None
exact_list = ['conv1_block',""]
dst = './features' #保存的路径
therd_size = 256 #有些图太小,会放大到这个尺寸
myexactor = FeatureExtractor(net, exact_list)
outs = myexactor(img)
for k, v in outs.items():
features = v[0]
iter_range = features.shape[0]
for i in range(iter_range):
# plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
if 'fc' in k:
continue
feature = features.data.cpu().numpy()
feature_img = feature[i, :, :]
feature_img = np.asarray(feature_img * 255, dtype=np.uint8)
dst_path = os.path.join(dst, k)
make_dirs(dst_path)
feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
if feature_img.shape[0] < therd_size:
tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
tmp_img = feature_img.copy()
tmp_img = cv2.resize(tmp_img, (therd_size, therd_size), interpolation=cv2.INTER_NEAREST)
cv2.imwrite(tmp_file, tmp_img)
dst_file = os.path.join(dst_path, str(i) + '.png')
cv2.imwrite(dst_file, feature_img)
if __name__ == '__main__':
get_feature()
最后的文件夹内容是这样的:
可视化效果截图