基于Pytorch特征提取

在Pytorch开源的网络以及权重的基础上进行特征提取

就用VGG16网络举个例子 官方开源的vgg网络

我们想提取全链接层的特征时,只需要将官方的代码注释掉一部分

 def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            # nn.ReLU(True),
            # nn.Dropout(),
            # nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

然后在读取网络权重的时候

def vgg16(pretrained=False, **kwargs):
    """VGG 16-layer model (configuration "D")

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['D']), **kwargs)
    if pretrained:			# 只读取需要的权重
        pretrained_dict = model_zoo.load_url(model_urls['vgg16'])  # 预训练模型参数保存地址
        model_dict = model.state_dict()  # 自己的模型参数变量

        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}  # 去除一些不需要的参数
        model_dict.update(pretrained_dict)  # 参数更新
        model.load_state_dict(model_dict)  # 加载

    return model

完整的代码 ——>github
README中有更详细的介绍.

其中根据.TXT文件读取图像和标签的方式 转载自
https://blog.csdn.net/MiaoB226/article/details/88262484

你可能感兴趣的:(基于Pytorch特征提取)