centernet下训练自己的数据

目录

一.数据准备

1.制作COCO数据集

2.计算数据集的均值方差

二.代码修改

1.新建类别

 2.加入dataset

 3.修改/src/lib/opts.py

4.修改src/lib/utils/debugger.py文件

二 训练与测试:

1训练:

2测试:

3绘制loss曲线


参照博客:

https://blog.csdn.net/weixin_42634342/article/details/97756458#commentsedit

https://blog.csdn.net/weixin_42634342/article/details/97756458#commentsedit

一.数据准备

1.制作COCO数据集

这里我用的是VOC数据集转COCO

参照博客:

https://blog.csdn.net/weixin_41765699/article/details/100124689

主要trian,val,test三个文件夹下txt转化为json

2.计算数据集的均值方差

import cv2, os, argparse
import numpy as np
from tqdm import tqdm


def main():
    dirs = '/home/zbb/CenterNet/data/plane/images'   # 修改你自己的图片路径
    img_file_names = os.listdir(dirs)
    m_list, s_list = [], []
    for img_filename in tqdm(img_file_names):
        img = cv2.imread(dirs + '/' + img_filename)
        img = img / 255.0
        m, s = cv2.meanStdDev(img)
        m_list.append(m.reshape((3,)))
        s_list.append(s.reshape((3,)))
    m_array = np.array(m_list)
    s_array = np.array(s_list)
    m = m_array.mean(axis=0, keepdims=True)
    s = s_array.mean(axis=0, keepdims=True)
    print("mean = ", m[0][::-1])
    print("std = ", s[0][::-1])

if __name__ == '__main__':
    main()

二.代码修改

1.新建类别

src/lib/datasets/dataset里面新建一个“plane. py”,文件内容照着文件夹下coco.py改成自己的

1).把COCO关键字改为Plane

centernet下训练自己的数据_第1张图片

2)路径格式

centernet下训练自己的数据_第2张图片

使用相对路径报错,改成了绝对路径

3)训练修改

修改为val,train,测试再修改回来

类别名字和类别id改成自己

centernet下训练自己的数据_第3张图片

 2.加入dataset

将数据集加入src/lib/datasets/dataset_factory里面

一定要记得import,否则会报你的类别未定义

centernet下训练自己的数据_第4张图片

 3.修改/src/lib/opts.py

将自己的数据集设为默认数据集,加入到help里面

 修改ctdet任务使用的默认数据集为新添加的数据集,如下(修改分辨率,类别数,均值,方差,数据集名字):

4.修改src/lib/utils/debugger.py文件

变成自己数据的类别和名字,前后数据集名字一定保持一致

centernet下训练自己的数据_第5张图片

再加上自己数据的类别,不包括背景__background__ centernet下训练自己的数据_第6张图片

二 训练与测试:

1训练:

 输入命令:

python main.py ctdet --exp_id coco_dla --batch_size 4 --master_batch 1 --lr 1.25e-4  --gpus 0,1

如果显示显存不够之类的那种错误,需要在opts.py文件中将--num_workers改成0,batch_size小

2测试:

  建立的plane.py中修改如下部分,加入if split == ‘test’:…,作用是当test时指定标签文件为之前建立的测试文件     

centernet下训练自己的数据_第7张图片

   运行test.py

       python test.py --exp_id coco_dla --not_prefetch_test ctdet --load_model /home/zbb/CenterNet/exp/ctdet/coco_dla/model_best.pth

结果:

centernet下训练自己的数据_第8张图片

其中,一般使用的是第二行,也就是IOU=0.5,全区域的AP值,其他的分别是不同IOU以及不同目标尺寸区域的结果。 

3绘制loss曲线

训练生成的日志文件一般在exp/ctdet/../../logs.txt

参照博主但是,val—loss绘制不好,先绘制total—loss

import matplotlib.pyplot as plt
import numpy as np


def plot_loss_curve(log_file):
    loss_data = open(log_file)
    all_lines = loss_data.readlines()
    print(all_lines[4].split(' '))
    # losses
    total_loss = []  # 4
    hm_loss = []  # 7
    wh_loss = []  # 10
    off_loss = []  # 13
    val_loss = []  # 19
    spend_time = []  # 16
    num_lines = len(all_lines)
    for line in range(num_lines):
        total_loss1 = all_lines[line].split(' ')[4]
        hm_loss1 = all_lines[line].split(' ')[7]
        wh_loss1 = all_lines[line].split(' ')[10]
        off_loss1 = all_lines[line].split(' ')[13]
        #val_loss1 = all_lines[line].split(' ')[19]
        spend_time1 = all_lines[line].split(' ')[16]
        print(total_loss1)
        print(spend_time1)

        total_loss.append(float(total_loss1))
        #val_loss.append(float(val_loss1))
        hm_loss.append(float(hm_loss1))
        wh_loss.append(float(wh_loss1))
        off_loss.append(float(off_loss1))
        spend_time.append(float(spend_time1))
    return total_loss

if __name__ == '__main__':
    # 标准图形绘制
    # sns.set()
    loss_res18 = plot_loss_curve(
        '/home/zbb/CenterNet/exp/ctdet/coco_dla/logs_2019-10-17-15-41/log.txt')  # 读取训练时生成的日志文件
    fig = plt.figure(figsize=(10, 4))
    ax = fig.add_subplot(111)
    ax.plot(range(len(loss_res18)), loss_res18, 'c', label='building', linewidth=1)  # 这个label是图线自己的标签;

    # ax.set_xlim([0, 800])                                      # 设置刻度;
    # ax.set_xticks(range(0, 500, 100))                          # 设置显示的刻度;
    # ax.set_yticklabels(['jan', 'feb', 'mar'])                  # 设置刻度标签;
    ax.set_xlabel('epochs')  # 设置坐标轴标签;
    ax.set_ylabel('loss_value')
    ax.text(8750, 20, "plane", color='red')  # 加入文本
    ax.set_title('loss_of_CenterNet')
    ax.legend(loc='best')  # 将图例摆放在不遮挡图线的位置即可
    ax.grid()  # 添加网格
    plt.savefig('/home/zbb/CenterNet/loss_of_CenterNet.png')  # 保存文件到指定文件夹
    plt.show()

total——loss结果图:

centernet下训练自己的数据_第9张图片

你可能感兴趣的:(深度学习)