参考pytorch论坛:How to extract features of an image from a trained model
定义一个特征提取的类:
#中间特征提取
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor,self).__init__()
self.submodule = submodule
self.extracted_layers= extracted_layers
def forward(self, x):
outputs = []
for name, module in self.submodule._modules.items():
if name is "fc": x = x.view(x.size(0), -1)
x = module(x)
print(name)
if name in self.extracted_layers:
outputs.append(x)
return outputs
#输入数据
test_loader=DataLoader(test_dataset,batch_size=1)
img,label=iter(test_loader).next()
img, label = Variable(img, volatile=True), Variable(label, volatile=True)
#特征输出
myresnet=resnet18(pretrained=False)
myresnet.load_state_dict(torch.load('cafir_resnet18_1.pkl'))
exact_list=["conv1","layer1","avgpool"]
myexactor=FeatureExtractor(myresnet,exact_list)
x=myexactor(img)
#特征输出可视化
import matplotlib.pyplot as plt
for i in range(64):
ax = plt.subplot(8, 8, i + 1)
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
plt.show()
整体代码(还未测试)
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
'''trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)'''
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor,self).__init__()
self.submodule = submodule
self.extracted_layers= extracted_layers
def forward(self, x):
outputs = []
for name, module in self.submodule._modules.items():
if name is "fc": x = x.view(x.size(0), -1)
x = module(x)
print(name)
if name in self.extracted_layers:
outputs.append(x)
net=models.resnet18(pretrained=True).eval()
inputs, labels = iter(testloader).next()
exact_list=["conv1","layer1","avgpool"]
myexactor=FeatureExtractor(net,exact_list)
x=myexactor(inputs)
for i in range(64):
ax = plt.subplot(8, 8, i + 1)
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
plt.show()