使用pytorch的童鞋们应该对torchvision很熟悉了,其中的torchvision.models
支持的大部分的分类器,主要是alexnet、resnet、vgg 、inception、densenet、googlenet、mobilenet、shufflenetv2,但是在使用时,需要根据自己的类别数,修改最后一层的输出个数,由于这些网络的最后一层的实现有些不同,因此不能直接使用model.fc,model.classifer等来修改,因此我写了一个简单的通用接口,不管是哪种网络,只要传入名字和类别数,就可以直接初始化。
第一种:
# 例如:alexnet、vgg、mobilenet、mnasnet
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
第二种:
#例如:resnet,inception,googlenet、shfflenetv2
self.fc = nn.Linear(512 * block.expansion, num_classes)
第三种:
#例如:densenet
self.classifier = nn.Linear(num_features, num_classes)
注意:该代码不支持squeezenet系列
def get_models_last(model):
last_name = list(model._modules.keys())[-1]
last_module = model._modules[last_name]
return last_module, last_name
class CustomClassifier(nn.Module):
def __init__(self, arch: str, num_classes: int, pretrained: bool = True):
super().__init__()
if pretrained:
self.model = torchvision.models.__dict__[arch](pretrained = pretrained)
else:
self.model = torchvision.models.__dict__[arch]()
last_module, last_name = get_models_last(self.model)
if isinstance(last_module, nn.Linear):
n_features = last_module.in_features
self.model._modules[last_name] = nn.Linear(n_features, num_classes)
elif isinstance(last_module, nn.Sequential):
seq_last_module, seq_last_name = get_models_last(last_module)
n_features = seq_last_module.in_features
last_module._modules[seq_last_name] = nn.Linear(n_features, num_classes)
#just for test
self.last = list(self.model.named_modules())[-1][1]
def forward(self, input_neurons):
# TODO: add dropout layers, or the likes.
output_predictions = self.model(input_neurons)
return output_predictions
supported_arch = [ 'alexnet', 'AlexNet', 'resnet18',
'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2',
'vgg11', 'vgg11_bn','vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19_bn', 'vgg19', 'Inception3', 'inception_v3',
'DenseNet', 'densenet121', 'densenet169',
'densenet201', 'densenet161', 'googlenet', 'GoogLeNet',
'MobileNetV2', 'mobilenet_v2','mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0',
'mnasnet1_3','shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0']
for arch in supported_arch:
print(arch)
net = CustomClassifier(arch, 2, False)
print(net.last)