项目地址:https://github.com/argusswift/YOLOv4-pytorch
这份代码实现的逻辑非常清楚,主要一些数据集处理的代码需要相应的改动:
这里的数据集label格式:
train_annotation:
image_name1 x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id
image_name2 x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id
改动内容:
(1)路径
DATA_PATH = ""
PROJECT_PATH = ""
DETECTION_PATH=""
MODEL_TYPE = {
"TYPE": "YOLOv4"
}
(2)训练参数
TRAIN = {
"DATA_TYPE": "Customer", # DATA_TYPE: VOC ,COCO or Customer
"TRAIN_IMG_SIZE": 608,
"AUGMENT": True,
"BATCH_SIZE": 16,
"MULTI_SCALE_TRAIN": False,
"IOU_THRESHOLD_LOSS": 0.5,
"YOLO_EPOCHS": 50,
"Mobilenet_YOLO_EPOCHS": 120,
"NUMBER_WORKERS": 0,
"MOMENTUM": 0.9,
"WEIGHT_DECAY": 0.0005,
"LR_INIT": 1e-4,
"LR_END": 1e-6,
"WARMUP_EPOCHS": 2, # or None
}
(3)VAL 参数
VAL = {
"TEST_IMG_SIZE": 608, #同train
"BATCH_SIZE": 1,
"NUMBER_WORKERS": 0,
"CONF_THRESH": 0.005,
"NMS_THRESH": 0.45,
"MULTI_SCALE_VAL": False, #-----
"FLIP_VAL": False, #------ 因为数据集里有信号灯,所以关闭翻转
"Visual": True,
}
(4)Customer_DATA 目标列表
Customer_DATA = {
"NUM": **, # your dataset number
"CLASSES": [** ], # your dataset class
}
(1)设 img_path
(2)如果训练中报resize的错,注意检查训练数据集,可能是由于:
如果是公开数据集可能不会出这种错,如果是自己做的数据集,标注过程可能会出现这种。
(3)index 越界
这个问题是因为index越界,是datasets.py中
这个位置出现xind,或者yind的越界,比如网络输入是608大小,到第一级anchor层的stride是8,这一层的特征图大小就是608/8=76。
所以xind或者yind的取值范围应该在[0-75],报错是因为这里xind或yind取到了76,做一些越界判断处理即可,例如设定xind/yind上限为75。
由于自己的label格式和文件夹并不是按照voc的格式,所以这里的evaluator.py和voc_eval.py都需要进行相应的修改。
(1)self.val_data_path
(2)img_inds_file
(3)img_path
(4)annopath 和 imagesetfile (存放的是图像名字列表)配置val数据集的label路径和图像路径
(1)parse_gt 函数,因为我没有用voc的xml格式,所以在解析label的时候自己重写了这个函数
相应的,下面调用的时候:
(2)下面都是针对数据集label不是voc的xml格式带来的不同需要修改的地方
这样修改完,基本就可以训练跑起来了:
训练命令:
CUDA_VISIBLE_DEVICES=0 nohup python -u train.py --weight_path weight/yolov4.weights --gpu_id 0
CUDA_VISIBLE_DEVICES=0 python3 video_test.py --weight_path ./weight/best.pt --gpu_id 0 --video_path video.mp4 --output_dir .