pytorch 提取中间层的特征

一、背景

需要提取网络中间层的特征,用于特征工程或者可视化

二、解决方案

先说好,有很多解决的方法呢,这里给出一种我认为是简单的,官方提供的功能

https://pytorch.org/vision/main/generated/torchvision.models.feature_extraction.create_feature_extractor.html#torchvision.models.feature_extraction.create_feature_extractor

核心代码如下

from torchvision.models.feature_extraction import create_feature_extractor

 # Feature extraction with resnet
model = torchvision.models.resnet18()
# extract layer1 and layer3, giving as names `feat1` and feat2`
model = create_feature_extractor(
	model, {'layer1': 'feat1', 'layer3': 'feat2'})
out = model(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
#     [('feat1', torch.Size([1, 64, 56, 56])),
#     ('feat2', torch.Size([1, 256, 14, 14]))]

你可能感兴趣的:(deeplearning,pytorch,pytorch,人工智能,python)