本篇博客主要是讲述如何,训练自己的分类网络。使用的是Pytorch 框架。所有代码是使用自己学习到的方法拼凑出来的。包含了我之前博客中提到的ResNet,MLP-Mixer,ConvNeXt,ConvMixer。
网络 | 博客链接 |
---|---|
ResNet | 链接 |
MLP-Mixer | 链接 |
ConvNeXt | 链接 |
ConvMixer | 链接 |
项目地址:https://github.com/jiantenggei/torch-classification
猫狗数据集链接:https://pan.baidu.com/s/1y7Vjy3-RhlEFvtH6RY8JMQ 提取码:90f0
仓库中提供了yaml文件,conda虚拟环境导入,torch是GPU版本:
conda env create -f torch.yaml
配置好后检查GPU是否可用:
import torch
#返回当前设备索引
print(torch.cuda.current_device())
#返回GPU的数量
print(torch.cuda.device_count())
#cuda是否可用 True 表示可用
print(torch.cuda.is_available())
深度学习图像分类数据集,一般是以文件夹作为类别命,下面以猫狗分类为例。数据集拜访如下所示:
─dataset
├─train
│ └─cats
│ └─xxjj.jpg
│ └─dogs
│ └─xxx.jpg
├─test
│ └─cats
│ └─dogs
且还需要在classes.txt文件 按顺序写下该顺序如下图:
classes = [“cats”, “dogs”] 数据集中文件夹的命名顺序保持一致。 运行txt_annotation.py 生成上图中的
cls_train.txt 和cls_text.txt。
import os
from os import getcwd
#---------------------------------------------------#
# 训练自己的数据集的时候一定要注意修改classes
# 修改成自己数据集所区分的种类
# 种类顺序需要和训练时用到的classes.txt一样
# 生成cls_train.txt, cls_text.txt
#---------------------------------------------------#
classes = ["cats", "dogs"]
sets = ["train", "test"]
if __name__ == "__main__":
wd = getcwd()
for se in sets:
list_file = open('cls_' + se + '.txt', 'w')
datasets_path = "datasets/" + se
types_name = os.listdir(datasets_path)
for type_name in types_name:
if type_name not in classes:
continue
cls_id = classes.index(type_name)
photos_path = os.path.join(datasets_path, type_name)
photos_name = os.listdir(photos_path)
for photo_name in photos_name:
_, postfix = os.path.splitext(photo_name)
if postfix not in ['.jpg', '.png', '.jpeg']:
continue
list_file.write(str(cls_id) + ";" + '%s/%s'%(wd, os.path.join(photos_path, photo_name)))
list_file.write('\n')
list_file.close()
训练时,需要在config.py 中修改训练配置参数,当前所提供的参数如下:
Cuda = True #是否使用GPU 没有为Flase
input_shape = [112,112] # 输入图片大小
batch_size = 4 # 自己可以更改
lr = 1e-3
classes_path = 'classes.txt'
num_workers = 0 # 是否开启多进程
annotation_path = 'cls_train.txt'
val_split = 0.1 #验证集比率
resume ='' # 加载训练权重路径
log_dir = 'logs' # 日志路径 tensorboard 保存
#------------------------------------------#
# FocalLoss :处理样本不均衡
# alpha
# gamma >0 当 gamma=0 时就是交叉熵损失函数
# 论文中gamma = [0,0.5,1,2,5]
# 一般而言当γ增加的时候,a需要减小一点
# reduction : 就平均:'mean' 求和 'sum'
#------------------------------------------#
# 还未配置成功
#Focal_loss = True # True Focal loss 处理原本不均衡 False 使用 CrossEntropyLoss()
#label_smoothing 防止过拟合
label_smoothing = False #
smoothing_value = 0.1 #[0,1] 之间
#学习率变化策略
scheduler = 'cos' #[None,reduce,cos] None保持不变 reduce 按epoch 来减少 cos 余弦下降算法
配置好上述参数后,运行train.py 文件开始训练。训练效果控制台输出:
目前支持eval_top1 和eval_top5, 在eval.py 中设置好 模型和权重(模型和权重一定要对应)。代码如下:
if __name__ == "__main__":
# 读取测试集路劲和标签
with open("./cls_test.txt","r") as f:
lines = f.readlines()
#---------------------------------------------------#
# 权重和模型
# 注意:训练时设置的模型需要和权重匹配,
# 也就是训练的啥模型使用啥权重
#---------------------------------------------------#
model_path = '' #训练好的权重路径
model = ConvMixer_768_32(n_classes=2) # 自己训练好的模型
mode = load_dict(model_path,model) # 加载权重
eval = eval_top(anno_lines=lines,model=model)
#---------------------------------------------------#
# top1 预测概率最好高的值与真实标签一致 √
# top5 预测概率前五个值由一个与真实标签一致 √
#---------------------------------------------------#
print('start eval.....')
top1 = eval.eval_top1()
top5 = eval.eval_top5()
print('top1:%.3f,top5:%3.f'%(top1,top5))
print('Eval Finished')
由于每次启动训练时,会在logs 文件下按照时间创建一个日志文件。使用tensorboard 记录,查看方法:
tensorboard --logdir=logs\loss_2022_03_06_12_11_30
在predict.py, 设置好模型和对应的权重。运行
from PIL import Image
from eval import eval_top
from nets.ConvMixer import ConvMixer_768_32
from utils.utils import load_dict
#加载模型
model_path = 'logs\ep050-loss0.414-val_loss0.376.pth'
model = ConvMixer_768_32(n_classes=2)
model = load_dict(model_path,model)
eval = eval_top(anno_lines=None,model=model)
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
class_name = eval.detect_img(image,mode='predict')
print(class_name)
在控制台出入图片路劲即可。
存在bug及其他问题私信:[email protected]
后续会不断学习和维护该仓库。加入更多的训练方法技巧和网络。希望大家在使用过程中能及时反馈,或者留下一些代码修改意见。我们一起让它变得更好。如果觉得有用 请给我点star 。