很多人第一次看到 DenseNet 时最迷糊的一句:“model.classifier.in_features
是啥鬼?为啥不能直接用数字?”
我来给你拆开讲透彻,说明它的本质是什么。
num_ftrs = model.classifier.in_features
从 DenseNet121
这个预训练模型中 取出最后一层分类器(classifier)输入的维度数。
这个数字我们要用来创建我们自己的新分类器:
model.classifier = nn.Linear(num_ftrs, num_classes)
你可以理解成:
我要把模型最后的“方向盘”换掉,但我不能乱来,要知道“方向盘底座多大”(输入维度),不然接口对不上。
我们来验证一下:
import torchvision.models as models
model = models.densenet121(weights=None)
print(model.classifier)
输出结果是:
Linear(in_features=1024, out_features=1000, bias=True)
说明 DenseNet121
的分类层是一个全连接层:
nn.Linear(1024, 1000)
那 model.classifier.in_features
就是:
1024
1024
?当然你可以硬编码:
model.classifier = nn.Linear(1024, 2)
但是:
resnet50
,它最后一层的输入维度就是 2048
,你就得改代码。model.classifier.in_features
可以让代码具有更强的泛化能力,不怕模型结构变化。你可以把模型比作一条流水线工厂:
你现在想重新定义出口数量(比如只输出“新冠 or 正常”):
nn.Linear(num_ftrs, 2)
但是你不能瞎接线,要看前面传来的有多少条“输送带”过来,也就是 in_features
。
表达式 | 解释 |
---|---|
model.classifier |
是最后的分类器层(一个 nn.Linear ) |
model.classifier.in_features |
是该层的输入维度数(前面层输出的特征数量) |
num_ftrs = model.classifier.in_features |
获取这个值,为替换分类层做准备 |
完整示例:
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 2) # 我只要两个类别
.in_features
该怎么取,比如 resnet50
、mobilenet_v2
、efficientnet_b0
?常用的 图像分类模型 .in_features
写法表,你拿来替换分类器时直接复制粘贴就行,不用猜维度,也不用硬编码数字。
.in_features
获取方式对照表:模型名称 | 获取 in_features 写法 |
默认分类数 | 说明 |
---|---|---|---|
ResNet18/34 | model.fc.in_features |
1000 | 最后一层是 model.fc |
ResNet50/101/152 | model.fc.in_features |
1000 | 和 ResNet18 一样 |
DenseNet121/169/201 | model.classifier.in_features |
1000 | 最后一层是 model.classifier |
VGG16/19 | model.classifier[6].in_features |
1000 | 最后一层是 classifier[6] |
AlexNet | model.classifier[6].in_features |
1000 | 和 VGG 一样的结构 |
MobileNetV2 | model.classifier[1].in_features |
1000 | 最后一层是 classifier[1] |
EfficientNet_B0~B7 | model.classifier[1].in_features |
1000 | 同样是 classifier[1] |
Swin Transformer | model.head.in_features |
1000 | 最后一层叫 head |
num_ftrs = model.<最后一层位置>.in_features
model.<最后一层位置> = nn.Linear(num_ftrs, num_classes)
import torchvision.models as models
import torch.nn as nn
model = models.mobilenet_v2(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 2) # 替换为 2 类分类
很多时候我们加载预训练模型,只训练最后一层:
for param in model.parameters():
param.requires_grad = False # 冻结所有层
# 只让分类层训练
model.classifier[1].requires_grad = True