PyTorch提取中间层的特征(Resnet)

     特征提取在深度学习的训练中是经常要做的事情,之前的一篇blog有写到使用pytorch提取Vgg、Resnet、Densenet三种模型下的特征,这里所述的是提取全连接层(FC层)的特征,详情可见:https://blog.csdn.net/qq_34611579/article/details/84330968。

     在本文中,主要是介绍提取中间层的特征,对于特征的提取,可以先把模型的结构输出,不同的模型结构是不一样的;下面拿resnet作为示例;由于pytorch模型很多用到nn.sequential,所以对各层的特征提取要自己去修改forward函数。

# 中间层特征提取
class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        super(FeatureExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers

    # 自己修改forward函数
    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
            if name is "fc": x = x.view(x.size(0), -1)
            x = module(x)
            if name in self.extracted_layers:
                outputs.append(x)
        return outputs

    这里可以看到,我们自定义了forward函数,由此可以选择在哪一层提取特征。

extract_list = ["conv1", "maxpool", "layer1", "avgpool", "fc"]
img_path = "./1_00001.jpg"
saved_path = "./1_00001.txt"
resnet = models.resnet50(pretrained=True)
# print(resnet) 可以打印看模型结构

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()]
)

img = Image.open(img_path)
img = transform(img)

x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)

if use_gpu:
    x = x.cuda()
    resnet = resnet.cuda()

extract_result = FeatureExtractor(resnet, extract_list)
print(extract_result(x)[4])  # [0]:conv1  [1]:maxpool  [2]:layer1  [3]:avgpool  [4]:fc

  对于模型的结构我们可以打印出来看看,这里是采用的Resnet,此时forward函数也是针对该模型进行了修改;若比如其他的VGG,DenseNet也可以用类似的方法进行修改。

参考这里所说:https://blog.csdn.net/qq_24306353/article/details/82995320.

还有一些比如可视化,可参考:https://blog.csdn.net/xz1308579340/article/details/85622579.

源代码:https://github.com/Messi-Q/Pytorch-extract-feature/blob/master/feature_extract.py.

你可能感兴趣的:(python,深度学习,神经网络)