mmdetection将训练的结果保存在****.log.json文件中,绘制其中的mAP、loss指标可使用mmdetection自带的py文件,官方可视化api如下:https://mmdetection.readthedocs.io/en/latest/useful_tools.htmlhttps://mmdetection.readthedocs.io/en/latest/useful_tools.html但我用了总是报错,干脆自己写一个~
import json
import matplotlib.pyplot as plt
import argparse
'''
解析参数
'''
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, default='val')
parser.add_argument("--select", type=str, default='bbox_mAP_50')
parser.add_argument("--json_paths", type=str, nargs='+')
parser.add_argument("--line_names", type=str, nargs='+')
parser.add_argument("--out_dir", type=str, default='./')
parser.add_argument("--epoch_num", type=int, default=20)
parser.add_argument("--pic_name", type=str,default="result")
args = parser.parse_args()
select=args.select
pic_name=args.pic_name
mode = args.mode # 选择log文件中的模式
json_paths = args.json_paths
line_names = args.line_names
out_dir = args.out_dir
epoch_num = args.epoch_num
plt.figure(figsize=(12, 8), dpi=300)
for i, json_path in enumerate(json_paths):
epoch_now = 0
x = [] # 存放epoch
y = [] # 存放指标
y_min = 1000000 # 存放指标最大值 ap不会超过1 绘制loss可自由更改
y_max = -1 # 存放指标最小值 ap不会小于-1 绘制loss可自由更改
x_min = 0 # 出现最小值的epoch
x_max = 0 # 出现最大值的epoch
isFirst = True
with open(json_path, 'r') as f:
for jsonstr in f.readlines():
if epoch_now == epoch_num:
break
if isFirst: # mmdetection生成的log json文件第一行是配置信息 跳过
isFirst = False
continue
row_data = json.loads(jsonstr)
if row_data['mode'] == mode: # 选择train或者val模式中的指标数据
epoch_now = epoch_now + 1
item_select = float(row_data[select])
x_select = int(row_data['epoch'])
x.append(x_select)
y.append(item_select)
if item_select >= y_max: # 选择最大值 为什么不用numpy.argmin呢? 因为epoch可能不从1开始 xmin和ymin可能匹配错误 比较麻烦
y_max = item_select
x_max = x_select
if item_select <= y_min: # 选择最大值
y_min = item_select
x_min = x_select
plt.grid(True, linestyle='--', alpha=0.5)
plt.plot(x, y, label=line_names[i])
plt.plot(x_min, y_min, 'g-p', x_max, y_max, 'r-p')
show_min = '[' + str(x_min) + ' , ' + str(y_min) + ']'
show_max = '[' + str(x_max) + ' , ' + str(y_max) + ']'
plt.annotate(show_min, xy=(x_min, y_min), xytext=(x_min, y_min))
plt.annotate(show_max, xy=(x_max, y_max), xytext=(x_max, y_max))
plt.xlabel('epoch')
plt.legend()
plt.ylabel(select)
# plt.ylim(0.8, 1.0) # 设置y轴坐标范围
plt.savefig(args.out_dir + '/' + pic_name + '.jpg', dpi=300)
python [visual_log.py文件位置] [--mode train或val模式] [--select 绘制的指标名称] [--json_paths log.json文件位置] [--line_names 绘制曲线的名称] [--out_dir 图片保存的父文件夹] [--epoch_num 绘制的世代数] [--pic_name 图片名称]