torchvision.models之实现一个通用接口调用各种classifier

使用pytorch的童鞋们应该对torchvision很熟悉了,其中的torchvision.models
支持的大部分的分类器,主要是alexnet、resnet、vgg 、inception、densenet、googlenet、mobilenet、shufflenetv2,但是在使用时,需要根据自己的类别数,修改最后一层的输出个数,由于这些网络的最后一层的实现有些不同,因此不能直接使用model.fc,model.classifer等来修改,因此我写了一个简单的通用接口,不管是哪种网络,只要传入名字和类别数,就可以直接初始化。

1、分类网络中三种不同形式的输出

第一种:

# 例如:alexnet、vgg、mobilenet、mnasnet
self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

第二种:

#例如:resnet,inception,googlenet、shfflenetv2
self.fc = nn.Linear(512 * block.expansion, num_classes)

第三种:

#例如:densenet
self.classifier = nn.Linear(num_features, num_classes)

1、分类器通用调用接口

注意:该代码不支持squeezenet系列

def get_models_last(model):
    last_name = list(model._modules.keys())[-1]
    last_module = model._modules[last_name]
    return last_module, last_name

class CustomClassifier(nn.Module):
    def __init__(self, arch: str, num_classes: int, pretrained: bool = True):
        super().__init__()
        if pretrained:
           self.model = torchvision.models.__dict__[arch](pretrained = pretrained)
        else:
           self.model = torchvision.models.__dict__[arch]()
        last_module, last_name = get_models_last(self.model)
        if isinstance(last_module, nn.Linear):
            n_features = last_module.in_features
            self.model._modules[last_name] = nn.Linear(n_features, num_classes)
        elif isinstance(last_module, nn.Sequential):
            seq_last_module, seq_last_name = get_models_last(last_module)
            n_features = seq_last_module.in_features
            last_module._modules[seq_last_name] = nn.Linear(n_features, num_classes)

        #just for test
        self.last = list(self.model.named_modules())[-1][1]


    def forward(self, input_neurons):
        # TODO: add dropout layers, or the likes.
        output_predictions = self.model(input_neurons)
        return output_predictions


supported_arch = [  'alexnet', 'AlexNet', 'resnet18',
                     'resnet34', 'resnet50', 'resnet101', 'resnet152',
                     'resnext50_32x4d', 'resnext101_32x8d',
                     'wide_resnet50_2', 'wide_resnet101_2',
                     'vgg11', 'vgg11_bn','vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
                     'vgg19_bn', 'vgg19', 'Inception3', 'inception_v3',
                     'DenseNet', 'densenet121', 'densenet169',
                     'densenet201', 'densenet161', 'googlenet', 'GoogLeNet',
                     'MobileNetV2', 'mobilenet_v2','mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0',
                     'mnasnet1_3','shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
                     'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0']

for arch in supported_arch:
   print(arch)
   net = CustomClassifier(arch, 2, False)
   print(net.last)

你可能感兴趣的:(pytorch,pytorch,神经网络)