微调(fine-tune)

迁移学习不是一种算法而是一种机器学习思想,应用到深度学习就是微调(Fine-tune)。

通过修改预训练网络模型结构(如修改样本类别输出个数),选择性载入预训练网络模型权重(通常是载入除最后的全连接层的之前所有层 ,也叫瓶颈层)再用自己的数据集重新训练模型就是微调的基本步骤。 微调能够快速训练好一个模型,用相对较小的数据量,还能达到不错的结果。
迁移学习有几种方式

1)Transfer Learning:冻结预训练模型的全部卷积层,只训练自己定制的全连接层。
2)Extract Feature Vector:先计算出预训练模型的卷积层对所有训练和测试数据的特征向量,然后抛开预训练模型,只训练自己定制的简配版全连接网络。
3)Fine-tune:冻结预训练模型的部分卷积层(通常是靠近输入的多数卷积层),训练剩下的卷积层(通常是靠近输出的部分卷积层)和全连接层。
* 注:Transfer Learning关心的问题是:什么是“知识”以及如何更好地运用之前得到的“知识”,这可以有很多方法和手段,eg:SVM,贝叶斯,CNN等。而fine-tune只是其中的一种手段,更常用于形容迁移学习的后期微调中。


模型微调:
模型微调就是一个迁移学习的过程,模型中训练学习得到的权值,就是迁移学习中所谓的知识,而这些知识是可以进行迁移的,把这些知识迁移到新任务中,这就完成了迁移学习

微调的原因:
在新任务中,数据量太小,不足以去训练一个较大的模型,从而选择Model Finetune去辅助训练一个较好的模型,使得训练更快

模型微调步骤:

获取预训练模型参数
加载模型( load_state_dict)
修改输出层


模型微调训练方法

固定预训练的参数,两种方法:
requires_grad =False
lr=0
Features Extractor部分设置较小学习率( params_group)

说明:
优化器中可以管理不同的参数组,这样就可以为不同的参数组设置不同的超参数,对Features Extractor部分设置较小学习率

其实就是把参数分为两部分(base、fc)。

base是预训练模型的不改的部分,fc是模型要修改的最后几层。

在参数更新优化的时候,对base的学习率设为0或者requires_grad设为False。将base部分freaze住,而对于fc部分正常的学习率且True,正常优化。
 

# 1/3 构建模型
resnet18_ft = models.resnet18()

# 2/3 加载参数
# flag = 0
flag = 1
if flag:
    path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")
    state_dict_load = torch.load(path_pretrained_model)
    resnet18_ft.load_state_dict(state_dict_load)

# 法1 : 冻结卷积层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:
    for param in resnet18_ft.parameters():
        param.requires_grad = False
    print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))


# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features            # 从原始的resnet18从获取输入的结点数
resnet18_ft.fc = nn.Linear(num_ftrs, classes)


resnet18_ft.to(device)        # 将模型迁移到设置的设备上
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
# 法2 : conv 小学习率
flag = 0
# flag = 1
if flag:
    # 划分模型参数为两个部分:resnet18_ft.fc.parameters()和base_params
    fc_params_id = list(map(id, resnet18_ft.fc.parameters()))     # 返回的是parameters的 内存地址
    base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())

    optimizer = optim.SGD([
        {'params': base_params, 'lr': LR*0.1},   # 0
        {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)

else:
    optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)               # 选择优化器

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)     # 设置学习率下降策略

你可能感兴趣的:(深度学习,pytorch)