Pytorch 训练自己的图像分类网络

文章目录

  • 前言
  • 一、如何配置
    • 1.运行环境:
    • 2.准备数据集
    • 3.生成训练数据索引
    • 4. 训练
    • 5.评估
    • 6.日志查看
    • 7. 预测
    • 8.训练技巧和练丹
  • 总结


前言

本篇博客主要是讲述如何,训练自己的分类网络。使用的是Pytorch 框架。所有代码是使用自己学习到的方法拼凑出来的。包含了我之前博客中提到的ResNet,MLP-Mixer,ConvNeXt,ConvMixer。

网络 博客链接
ResNet 链接
MLP-Mixer 链接
ConvNeXt 链接
ConvMixer 链接

项目地址:https://github.com/jiantenggei/torch-classification
猫狗数据集链接:https://pan.baidu.com/s/1y7Vjy3-RhlEFvtH6RY8JMQ 提取码:90f0

一、如何配置

1.运行环境:

仓库中提供了yaml文件,conda虚拟环境导入,torch是GPU版本:

conda env create -f torch.yaml

配置好后检查GPU是否可用:

import torch
#返回当前设备索引
print(torch.cuda.current_device())
#返回GPU的数量
print(torch.cuda.device_count())
#cuda是否可用 True 表示可用
print(torch.cuda.is_available())

2.准备数据集

深度学习图像分类数据集,一般是以文件夹作为类别命,下面以猫狗分类为例。数据集拜访如下所示:

─dataset
    ├─train
    │	└─cats
    │		└─xxjj.jpg
    │	└─dogs
    │	 	└─xxx.jpg
    ├─test
    │	└─cats
    │	└─dogs

cats 和dogs 文件夹下直接是图片即可。
Pytorch 训练自己的图像分类网络_第1张图片

3.生成训练数据索引

且还需要在classes.txt文件 按顺序写下该顺序如下图:
Pytorch 训练自己的图像分类网络_第2张图片
classes = [“cats”, “dogs”] 数据集中文件夹的命名顺序保持一致。 运行txt_annotation.py 生成上图中的
cls_train.txt 和cls_text.txt。

import os
from os import getcwd
#---------------------------------------------------#
#   训练自己的数据集的时候一定要注意修改classes
#   修改成自己数据集所区分的种类
#   种类顺序需要和训练时用到的classes.txt一样
#   生成cls_train.txt, cls_text.txt
#---------------------------------------------------#
classes = ["cats", "dogs"]
sets    = ["train", "test"]
if __name__ == "__main__":
    wd = getcwd()
    for se in sets:
        list_file = open('cls_' + se + '.txt', 'w')

        datasets_path = "datasets/" + se
        types_name = os.listdir(datasets_path)
        for type_name in types_name:
            if type_name not in classes:
                continue
            cls_id = classes.index(type_name)
            
            photos_path = os.path.join(datasets_path, type_name)
            photos_name = os.listdir(photos_path)
            for photo_name in photos_name:
                _, postfix = os.path.splitext(photo_name)
                if postfix not in ['.jpg', '.png', '.jpeg']:
                    continue
                list_file.write(str(cls_id) + ";" + '%s/%s'%(wd, os.path.join(photos_path, photo_name)))
                list_file.write('\n')
        list_file.close()

4. 训练

训练时,需要在config.py 中修改训练配置参数,当前所提供的参数如下:

Cuda             = True  #是否使用GPU 没有为Flase

input_shape      = [112,112]  # 输入图片大小

batch_size      = 4 # 自己可以更改
lr              = 1e-3         

classes_path    = 'classes.txt'


num_workers     = 0  # 是否开启多进程


annotation_path     = 'cls_train.txt'  



val_split       = 0.1  #验证集比率


resume          =''  # 加载训练权重路径

log_dir         = 'logs' # 日志路径 tensorboard 保存

#------------------------------------------#
#   FocalLoss :处理样本不均衡
#   alpha   
#   gamma >0 当 gamma=0 时就是交叉熵损失函数
#   论文中gamma = [0,0.5,1,2,5]
#   一般而言当γ增加的时候,a需要减小一点
#   reduction : 就平均:'mean' 求和 'sum'
#------------------------------------------#
# 还未配置成功
#Focal_loss      = True  # True Focal loss 处理原本不均衡 False  使用 CrossEntropyLoss() 

#label_smoothing 防止过拟合
label_smoothing =  False #

smoothing_value = 0.1  #[0,1] 之间

#学习率变化策略
scheduler   = 'cos' #[None,reduce,cos] None保持不变 reduce  按epoch 来减少 cos 余弦下降算法

配置好上述参数后,运行train.py 文件开始训练。训练效果控制台输出:
在这里插入图片描述

5.评估

目前支持eval_top1 和eval_top5, 在eval.py 中设置好 模型和权重(模型和权重一定要对应)。代码如下:

if __name__ == "__main__":

    # 读取测试集路劲和标签
    with open("./cls_test.txt","r") as f: 
        lines = f.readlines()
    #---------------------------------------------------#
    #   权重和模型
    #   注意:训练时设置的模型需要和权重匹配,
    #   也就是训练的啥模型使用啥权重
    #---------------------------------------------------#
    model_path = '' #训练好的权重路径
    model = ConvMixer_768_32(n_classes=2) # 自己训练好的模型

    mode = load_dict(model_path,model) # 加载权重
    eval = eval_top(anno_lines=lines,model=model)
    #---------------------------------------------------#
    #   top1 预测概率最好高的值与真实标签一致 √
    #   top5 预测概率前五个值由一个与真实标签一致 √
    #---------------------------------------------------#
    print('start eval.....')
    top1 = eval.eval_top1()
    
    top5 = eval.eval_top5()
    print('top1:%.3f,top5:%3.f'%(top1,top5))
    print('Eval Finished')

6.日志查看

由于每次启动训练时,会在logs 文件下按照时间创建一个日志文件。使用tensorboard 记录,查看方法:

tensorboard --logdir=logs\loss_2022_03_06_12_11_30

在这里插入图片描述
浏览器打开链接查看。

7. 预测

在predict.py, 设置好模型和对应的权重。运行

from PIL import Image
from eval import eval_top
from nets.ConvMixer import ConvMixer_768_32
from utils.utils import load_dict
#加载模型
model_path = 'logs\ep050-loss0.414-val_loss0.376.pth'
model = ConvMixer_768_32(n_classes=2)
model = load_dict(model_path,model)
eval = eval_top(anno_lines=None,model=model)

while True:
    img = input('Input image filename:')
    try:
        image = Image.open(img)
    except:
        print('Open Error! Try again!')
        continue
    else:
        class_name = eval.detect_img(image,mode='predict')
        print(class_name)

在控制台出入图片路劲即可。

8.训练技巧和练丹

  • Focl_loss(样本不均衡策略)
  • label_smoothing (训练样本偏少时,防止过拟合策略)
  • 学习率衰减(使模型收敛更充分)

存在bug及其他问题私信:[email protected]

总结

后续会不断学习和维护该仓库。加入更多的训练方法技巧和网络。希望大家在使用过程中能及时反馈,或者留下一些代码修改意见。我们一起让它变得更好。如果觉得有用 请给我点star 。

你可能感兴趣的:(深度学习入门,分类,pytorch,深度学习)