(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

我的任务是在行人头肩数据上训练并测试centernet网络,先证明一下我是真的训练了哈,这是用centernet检测的一张结果(训练了10个epochs的结果,大家放心使用,网络功能还是很强大的):

我参考的这篇博客,对我自己的实验帮助很大:https://blog.csdn.net/weixin_42634342/article/details/97756458

论文作者代码:https://github.com/xingyizhou/CenterNet

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第1张图片

这个博客是整个训练的过程,可能会有点长。

1. 准备数据集

    0.(我用的数据是VOC格式的,需要将其转化为COCO格式)

    详细过程在我的另一个博客里,本来想写在这里,发现太长了,就移到另一个里了。

    链接:https://blog.csdn.net/weixin_41765699/article/details/100124689

    1. 当我们生成三个json文件之后,来到CenterNet这个工程里,在data文件夹下新建一个文件夹,名字就是你数据集的名字,如下图:

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第2张图片

     再在这个文件夹里面建两个文件夹(annotations里面存放的是我们之前生成的那三个json文件;images存放的是所有的图片,包括训练测试验证三个,所有的):

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第3张图片

    2. 在src/lib/datasets/dataset里面新建一个“ped. py”,文件内容照着文件夹下coco.py改成自己的

       0. 将COCO类改成自己的名字

       

       1. 第14行num_classes=80改成自己的类别数
       2. 第15行default_resolution(这个参数有两种(300,300)或者(512,512),很明显512的参数计算量大,300计算量小,我用的是512,之后打算训练一个300的对比一下)
       3. 接下来的mean和std改成自己图片数据集的均值和方差,脚本链接:                            https://blog.csdn.net/weixin_41765699/article/details/100118660

       4. 修改数据和图片路径,data_dir 输入的是咱们之前建立的数据集文件夹的名字,img_dir 输入的是 images 图片文件夹

       5. 修改json文件路径如下:

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第4张图片

       6.  类别名字和类别id改成自己

       我就改了以上六点内容。

    3. 将数据集加入src/lib/datasets/dataset_factory里面

       1. 在dataset_facto字典里加入自己的数据集名字 (格式为   '你之前创建的Python文件的名字':你自己数据集类的名字,因为要从你创建的py文件里找到你的数据类,名字必须对应上)

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第5张图片

    4. 修改/src/lib/opts.py

         1.第一步,将自己的数据集设为默认数据集,加入到help里面

self.parser.add_argument('--dataset', default='ped',  
                                 help='coco | kitti | coco_hp | pascal | ped)

         2.修改ctdet任务使用的默认数据集为新添加的数据集,如下(修改分辨率,类别数,均值,方差,数据集名字):

   

         3. 修改src/lib/utils/debugger.py文件(变成自己数据的类别和名字,前后数据集名字一定保持一致)

       (绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第6张图片

            再加上自己数据的类别,不包括背景__background__

            (绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第7张图片

      到这里,准备数据集的工作就告一段落了!

2. 搭建pytorch0.4.1+cuda90+cudnn7.6.1(版本不能改,还有就是numpy的版本必须在1.13以上,建议装最新的)

     我搭建这个环境也费老大劲了,pytorch1.2貌似直接pip安装就自动装上了cuda和cudnn,0.4.1版本的我没看见有自动安装的,所以就苦哈哈自己动手装了,关于这个,我也记录了一下,大家也可以自己上网查查别的方法

cuda和cudnn安装链接:https://blog.csdn.net/weixin_41765699/article/details/99966617

torch0.4.1安装链接:https://blog.csdn.net/weixin_41765699/article/details/99756697

3. 克隆工程并运行demo

    关于工程里面这个作者写的很详细了,我是按照一步步来的,没有出错。https://github.com/xingyizhou/CenterNet/blob/master/readme/INSTALL.md

    程序里面在运行demo.py之前,会下载一个预训练权重(比如dla34,resnet18,resnet101之类的),这个不用管,等他下载完,因为我们训练的时候也要用的。(下载的时候可能会很慢,如果是在龟速的话,将他下载的网址用QQ浏览器打开自己下载,下载完放到这个它自动创建的文件夹里就可以了,QQ浏览器下载确实比其他的稍快一些)

    (绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第8张图片

   改完之后在MODEL_ZOO.md里面下载参数,ctdet_coco_dla_2x,下载完毕后放在models文件夹里面。
到这里,环境基本搭建成功,接下来可以跑一下代码了

(模型文件下载貌似要Google drive,这是我下载的:

链接:https://pan.baidu.com/s/1QOmIwy8lXJBuLv5hH5j3ag 
提取码:vwk4 )

    运行demo.py

python demo.py ctdet --demo /home/CenterNet/images/ --load_model /home/CenterNet/models/ctdet_coco_dla_2x.pth

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第9张图片嘿嘿嘿

     要注意的是,当弹出第一站图片的时候,按esc除外的任意键可以继续检测下一张图,想要保存检测结果的话,只需要在src/lib/detectors/cdet.py文件中:

    def show_results(self, debugger, image, results):  # demo文件會調用這個函數,本函`python main.py ctdet --exp_id coco_dla --batch_size 32 --master_batch 1 --lr 1.25e-4  --gpus 0,1`數是demo時顯示圖片並保存圖片
        debugger.add_img(image, img_id='ctdet')
        for j in range(1, self.num_classes + 1):
            for bbox in results[j]:
                if bbox[4] > self.opt.vis_thresh:
                    debugger.add_coco_bbox(bbox[:4], j - 1, bbox[4], img_id='ctdet')
        debugger.show_all_imgs(pause=self.pause)
        debugger.save_all_imgs(path='/home/czb/CenterNet-master/output/', genID=True)

      加上一行代码,就是最后一行debugger.save_all_imgs(path='/home/CenterNet/output/', genID=True) ,path是输出路径,需要在CenterNet文件夹下新建一个文件夹output,然后再运行一遍发现检测后的图片就会保存在这个文件夹里面了。当然,去掉倒数第二行show_all_imgs,那么运行的时候就不会弹出照片了。

4. 训练阶段

     1. 定位一下发现前面自己建立的ped.py文件(修改下面的代码):

   if split == 'val':
            self.annot_path = os.path.join(
                self.data_dir, 'annotations',
                'val.json').format(split) # 修改test的json文件位置
        else:
            if opt.task == 'exdet':
                self.annot_path = os.path.join(
                    self.data_dir, 'annotations',
                    'train.json').format(split)
            else:
                self.annot_path = os.path.join(
                    self.data_dir, 'annotations',
                    'train.json').format(split) # 这才是train文件

     要把第一行if split 改为 ==‘val’,这样validate时就会调用val.json文件。把最后一行要调用的文件改为‘train.json’,这样训练的时候才会调用train.json文件。改完之后数据集导入就正常了。
     2. 运行main.py

python main.py ctdet --exp_id coco_dla --batch_size 32 --master_batch 1 --lr 1.25e-4  --gpus 0,1
(如果显示显存不够之类的那种错误,需要在opts.py文件中将--num_workers改成0,batch_size改成16或者更小

   

     这时候会下载一个预训练模型,可能会很慢,我是。。下载的,这是百度盘链接,需要的可以用:

    链接:https://pan.baidu.com/s/1I1oW_l2Xe2-LV1gIjViPTg 
    提取码:2pt0 

     下载完之后放在/root/.torch/models里面

(我的是在这个里,你也可以看看他自动下载的那个在哪个文件夹里,然后把权重放在那个文件夹下)

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第10张图片

    没有意外的话,经过上面的步骤,就开始训练了::::

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第11张图片

5. 测试部分

     当训练完之后(我训练了两天,泰坦X,140个epochs,有点憨批,其实最好的模型是在第55个epoch出现的),在./exp/ctdet/coco_dla/文件夹下会出现如下文件

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第12张图片

     其中,model_last是最后一次epoch的模型;model_best是val最好的模型,我选的是model_best模型;

然后开始测试。。。。。。

   1. 在我们之前建立的ped.py中修改如下部分,加入if split == ‘test’:…,作用是当test时指定标签文件为之前建立的测试文件       test.json

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第13张图片

   2. 运行test.py

       python test.py --exp_id coco_dla --not_prefetch_test ctdet --load_model /root/CenterNet/exp/ctdet/coco_dla/model_best.pth

   不出意外的话会出现下面的画面(出现一系列AP值),其中,一般使用的是第二行,也就是IOU=0.5,全区域的AP值,其他的分别是不同IOU以及不同目标尺寸区域的结果。

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第14张图片

  完事了。。。

2019.9.6

附加1:

我想换个骨干网络试试,作者的源代码支持resnet和hourglass,我尝试替换成resnet18,记录一下替换方法:

在原来的训练命令行命令里添加上两个参数:(顺便把exp_id 改一下,保证每个模型不乱)

python main.py ctdet --exp_id coco_res_18 --batch_size 32 --master_batch 1 --lr 1.25e-4  --gpus 0,1 --arch res_18 --head_conv 64

开始训练时也会下载相应的预训练模型,如果下载速度慢,也参照上面说的方法下载。

训练之后,在测试和运行demo的命令行代码里也要加上两个参数:--arch res_18 --head_conv 64

附加2:

训练完成的时候,我们需要绘制出loss值得曲线,以下代码可以实现该功能:

训练生成的日志文件一般在exp/ctdet/../../logs.txt,找到这个文件,打开之后会出现如下:

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)_第15张图片

我们需要读取这些loss值并可视化(一般情况下,该代码只需要改变日志文件的路径即可):

import matplotlib.pyplot as plt
import numpy as np


def plot_loss_curve(log_file):

    loss_data = open(log_file)
    all_lines = loss_data.readlines()
    print(all_lines[4].split(' '))
    # losses
    total_loss = []          # 4
    hm_loss = []             # 7
    wh_loss = []             # 10
    off_loss = []            # 13
    val_loss = []            # 19
    spend_time = []          # 16
    num_lines = len(all_lines)
    for line in range(num_lines):
        total_loss1 = all_lines[line].split(' ')[4]
        hm_loss1 = all_lines[line].split(' ')[7]
        wh_loss1 = all_lines[line].split(' ')[10]
        off_loss1 = all_lines[line].split(' ')[13]
        spend_time1 = all_lines[line].split(' ')[16]

        total_loss.append(float(total_loss1))
        hm_loss.append(float(hm_loss1))
        wh_loss.append(float(wh_loss1))
        off_loss.append(float(off_loss1))
        spend_time.append(float(spend_time1))

    index_val = np.linspace(0, 140, 29) - 1
    index_val = np.delete(index_val, 0, 0)

    for id in index_val:

        val_loss1 = all_lines[int(id)].split(' ')[19]
        val_loss.append(float(val_loss1))
    return val_loss, total_loss


if __name__ == '__main__':
    # 标准图形绘制
    # sns.set()
    vloss_res18, loss_res18 = plot_loss_curve('logres18.txt')              # 读取训练时生成的日志文件
    # vloss_resdcn18, loss_resdcn18 = plot_loss_curve('logresdcn18.txt')
    # vloss_dla, loss_dla = plot_loss_curve('logdla34.txt')
    # vloss_res101, loss_res101 = plot_loss_curve('logres101.txt')
    # vloss_dla34p, loss_dla34p = plot_loss_curve('logdla34p.txt')
    # vloss_hg, loss_hg = plot_loss_curve('loghourglass.txt')
    
    fig = plt.figure(figsize=(10, 4))
    ax = fig.add_subplot(111)
    ax.plot(range(len(loss_res18)), loss_res18, 'c', label='res_18_train_loss', linewidth=1)         # 这个label是图线自己的标签;
    # ax.plot(range(len(loss_resdcn18)), loss_resdcn18, 'y', label='resdcn_18_train_loss', linewidth=1)
    # ax.plot(range(len(loss_dla)), loss_dla, 'b', label='dla_34_train_loss', linewidth=1)
    # ax.plot(range(len(loss_res101)), loss_res101, 'g', label='res_101_train_loss', linewidth=1)
    # ax.plot(range(len(loss_dla34p)), loss_dla34p, 'r', label='dla_34_train_loss', linewidth=1)
    # ax.plot(range(len(loss_hg)), loss_hg, 'r', label='hourglass_train_loss', linewidth=1)

    # ax.plot(index_val+1, val_loss, label='val_loss')
    # ax.plot(np.random.randn(1000).cumsum(), label='line2')
    # ax.set_xlim([0, 800])                                      # 设置刻度;
    # ax.set_xticks(range(0, 500, 100))                          # 设置显示的刻度;
    # ax.set_yticklabels(['jan', 'feb', 'mar'])                  # 设置刻度标签;
    ax.set_xlabel('epochs')                                    # 设置坐标轴标签;
    ax.set_ylabel('loss_value')
    # ax.text(8750, 20, "海拔", color='red')                     # 加入文本
    ax.set_title('loss_of_CenterNet')
    ax.legend(loc='best')                                      # 将图例摆放在不遮挡图线的位置即可
    ax.grid()                                                  # 添加网格
    plt.savefig('loss_of_CenterNet.png')                                    # 保存文件到指定文件夹
    plt.show()

    fig1 = plt.figure(figsize=(12, 8))
    ax1 = fig1.add_subplot(111)
    ax1.plot(range(len(vloss_res18)), vloss_res18, 'c', label='res_18_val_loss', linewidth=2)         # 这个label是图线自己的标签;
    # ax1.plot(range(len(vloss_resdcn18)), vloss_resdcn18, 'y', label='resdcn_18_val_loss', linewidth=2)
    # ax1.plot(range(len(vloss_dla)), vloss_dla, 'b', label='dla_34_val_loss', linewidth=2)
    # ax1.plot(range(len(vloss_res101)), vloss_res101, 'g', label='res_101_val_loss', linewidth=2)
    # ax1.plot(range(len(vloss_dla34p)), vloss_dla34p, 'r', label='dla_34_val_loss_p', linewidth=2)
    # ax.plot(index_val+1, val_loss, label='val_loss')
    # ax.plot(np.random.randn(1000).cumsum(), label='line2')
    # ax.set_xlim([0, 800])                                      # 设置刻度;
    # ax.set_xticks(range(0, 500, 100))                          # 设置显示的刻度;
    # ax.set_yticklabels(['jan', 'feb', 'mar'])                  # 设置刻度标签;
    ax1.set_xlabel('epochs')                                    # 设置坐标轴标签;
    ax1.set_ylabel('loss_value')
    # ax.text(8750, 20, "海拔", color='red')                     # 加入文本
    ax1.set_title('val_loss_of_CenterNet')
    ax1.legend(loc='best')                                      # 将图例摆放在不遮挡图线的位置即可
    ax1.grid()                                                  # 添加网格
    plt.savefig('val_loss_of_CenterNet.png')                                    # 保存文件到指定文件夹
    plt.show()

 

你可能感兴趣的:(目标检测)