24模型微调(finetune)

一、Transfer Learning & Model Finetune

1.1 Transfer Learning

Transfer Learning:机器学习分支,研究源域(source domain)的知识如何应用到目标域(targetdomain)
24模型微调(finetune)_第1张图片

传统的机器学习:
对不同的任务分别训练学习得到不同的learning system,即模型,如上图有三个不同任务,就得到三个不同的模型

迁移学习:
先对源任务进行学习,得到知识,然后在目标任务中,会使用再源任务上学习得到的知识来学习训练模型,也就是说该模型不仅用到了target tasks,也用到了source tasks

1.2 Model Finetune

1.2.1 Model Finetune概念

Model Finetune:模型的迁移学习24模型微调(finetune)_第2张图片
模型微调:
模型微调就是一个迁移学习的过程,模型中训练学习得到的权值,就是迁移学习中所谓的知识,而这些知识是可以进行迁移的,把这些知识迁移到新任务中,这就完成了迁移学习

微调的原因:
在新任务中,数据量太小,不足以去训练一个较大的模型,从而选择Model Finetune去辅助训练一个较好的模型,使得训练更快

卷积神经网络的迁移:
24模型微调(finetune)_第3张图片
将卷积神经网络分成两部分:features extractor + classifier

  • features extractor:模型的共性部分,通常对其进行保留
  • classifier:根据不同任务要求对输出层进行finetune

1.2.2 Model Finetune步骤

24模型微调(finetune)_第4张图片
Model Finetune:
先进行模型微调,加载模型参数,并根据任务要求修改模型,此过程称预训练,然后进行正式训练,此时要注意预训练的参数的保持,具体步骤和方法如下

模型微调步骤:

  1. 获取预训练模型参数
  2. 加载模型( load_state_dict)
  3. 修改输出层

模型微调训练方法:

  • 固定预训练的参数,两种方法:
    • requires_grad =False
    • lr=0
  • Features Extractor部分设置较小学习率( params_group)

说明:
优化器中可以管理不同的参数组,这样就可以为不同的参数组设置不同的超参数,对Features Extractor部分设置较小学习率

二、Pytorch中的Finetune

2.1 Model Finetune实例

24模型微调(finetune)_第5张图片
数据: https://download.pytorch.org/tutorial/hymenoptera_data.zip
模型: https://download.pytorch.org/models/resnet18-5c106cde.pth

2.1.1 目录结构

24模型微调(finetune)_第6张图片
模型和数据的存放位置如上图所示

2.1.1 代码详解

my_dataset.py

# -*- coding: utf-8 -*-
import os
import random
from PIL import Image
from torch.utils.data import Dataset

random.seed(1)
rmb_label = {
   "1": 0, "100": 1}


class AntsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.label_name = {
   "ants": 0, "bees": 1}
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img,label

    def __len__(self):
        return len(self.data_info)

    def get_img_info(self, data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = self.label_name[sub_dir]
                    data_info.append((path_img, 

你可能感兴趣的:(#,Pytorch)