目录
一.数据准备
1.制作COCO数据集
2.计算数据集的均值方差
二.代码修改
1.新建类别
2.加入dataset
3.修改/src/lib/opts.py
4.修改src/lib/utils/debugger.py文件
二 训练与测试:
1训练:
2测试:
3绘制loss曲线
参照博客:
https://blog.csdn.net/weixin_42634342/article/details/97756458#commentsedit
https://blog.csdn.net/weixin_42634342/article/details/97756458#commentsedit
这里我用的是VOC数据集转COCO
参照博客:
https://blog.csdn.net/weixin_41765699/article/details/100124689
主要trian,val,test三个文件夹下txt转化为json
import cv2, os, argparse
import numpy as np
from tqdm import tqdm
def main():
dirs = '/home/zbb/CenterNet/data/plane/images' # 修改你自己的图片路径
img_file_names = os.listdir(dirs)
m_list, s_list = [], []
for img_filename in tqdm(img_file_names):
img = cv2.imread(dirs + '/' + img_filename)
img = img / 255.0
m, s = cv2.meanStdDev(img)
m_list.append(m.reshape((3,)))
s_list.append(s.reshape((3,)))
m_array = np.array(m_list)
s_array = np.array(s_list)
m = m_array.mean(axis=0, keepdims=True)
s = s_array.mean(axis=0, keepdims=True)
print("mean = ", m[0][::-1])
print("std = ", s[0][::-1])
if __name__ == '__main__':
main()
src/lib/datasets/dataset
里面新建一个“plane. py”,文件内容照着文件夹下coco.py改成自己的
1).把COCO关键字改为Plane
2)路径格式
使用相对路径报错,改成了绝对路径
3)训练修改
修改为val,train,测试再修改回来
类别名字和类别id改成自己
将数据集加入src/lib/datasets/dataset_factory
里面
一定要记得import,否则会报你的类别未定义
将自己的数据集设为默认数据集,加入到help里面
修改ctdet任务使用的默认数据集为新添加的数据集,如下(修改分辨率,类别数,均值,方差,数据集名字):
变成自己数据的类别和名字,前后数据集名字一定保持一致
再加上自己数据的类别,不包括背景__background__
输入命令:
python main.py ctdet --exp_id coco_dla --batch_size 4 --master_batch 1 --lr 1.25e-4 --gpus 0,1
如果显示显存不够之类的那种错误,需要在opts.py文件中将--num_workers改成0,batch_size小
建立的plane.py中修改如下部分,加入if split == ‘test’:…,作用是当test时指定标签文件为之前建立的测试文件
运行test.py
python test.py --exp_id coco_dla --not_prefetch_test ctdet --load_model /home/zbb/CenterNet/exp/ctdet/coco_dla/model_best.pth
结果:
其中,一般使用的是第二行,也就是IOU=0.5,全区域的AP值,其他的分别是不同IOU以及不同目标尺寸区域的结果。
训练生成的日志文件一般在exp/ctdet/../../logs.txt
参照博主但是,val—loss绘制不好,先绘制total—loss
import matplotlib.pyplot as plt
import numpy as np
def plot_loss_curve(log_file):
loss_data = open(log_file)
all_lines = loss_data.readlines()
print(all_lines[4].split(' '))
# losses
total_loss = [] # 4
hm_loss = [] # 7
wh_loss = [] # 10
off_loss = [] # 13
val_loss = [] # 19
spend_time = [] # 16
num_lines = len(all_lines)
for line in range(num_lines):
total_loss1 = all_lines[line].split(' ')[4]
hm_loss1 = all_lines[line].split(' ')[7]
wh_loss1 = all_lines[line].split(' ')[10]
off_loss1 = all_lines[line].split(' ')[13]
#val_loss1 = all_lines[line].split(' ')[19]
spend_time1 = all_lines[line].split(' ')[16]
print(total_loss1)
print(spend_time1)
total_loss.append(float(total_loss1))
#val_loss.append(float(val_loss1))
hm_loss.append(float(hm_loss1))
wh_loss.append(float(wh_loss1))
off_loss.append(float(off_loss1))
spend_time.append(float(spend_time1))
return total_loss
if __name__ == '__main__':
# 标准图形绘制
# sns.set()
loss_res18 = plot_loss_curve(
'/home/zbb/CenterNet/exp/ctdet/coco_dla/logs_2019-10-17-15-41/log.txt') # 读取训练时生成的日志文件
fig = plt.figure(figsize=(10, 4))
ax = fig.add_subplot(111)
ax.plot(range(len(loss_res18)), loss_res18, 'c', label='building', linewidth=1) # 这个label是图线自己的标签;
# ax.set_xlim([0, 800]) # 设置刻度;
# ax.set_xticks(range(0, 500, 100)) # 设置显示的刻度;
# ax.set_yticklabels(['jan', 'feb', 'mar']) # 设置刻度标签;
ax.set_xlabel('epochs') # 设置坐标轴标签;
ax.set_ylabel('loss_value')
ax.text(8750, 20, "plane", color='red') # 加入文本
ax.set_title('loss_of_CenterNet')
ax.legend(loc='best') # 将图例摆放在不遮挡图线的位置即可
ax.grid() # 添加网格
plt.savefig('/home/zbb/CenterNet/loss_of_CenterNet.png') # 保存文件到指定文件夹
plt.show()
total——loss结果图: