Faster-RCNN_TF的loss曲线可视化(tensorflow版本,Faster RCNN的loss曲线可视化)

我用的Faster-RCNN是tensorflow版本,github地址:Faster-RCNN_TF
查到网上的loss曲线都是直接从log的txt里读取,我的代码不生成txt,所以得自己改。我的思路是每cfg.TRAIN.SNAPSHOT_ITERS步生成一个csv文件,最后读取这个文件再画图。

具体操作如下:

1. 修改train.py

在/lib/fast rcnn/train.py中,train_model(self, sess, max_iters)函数里:
在for iter in range(max_iters):这句的上下对应修改,改后代码如下:

        ####add
        all_loss = []
        rpn_cls_loss = []
        rpn_box_loss = []
        cls_loss = []
        box_loss = []
        for iter in range(max_iters):
            # get one batch
            blobs = data_layer.forward()

            # Make one SGD update
            feed_dict={self.net.data: blobs['data'], self.net.im_info: blobs['im_info'], self.net.keep_prob: 0.5, \
                           self.net.gt_boxes: blobs['gt_boxes']}

            run_options = None
            run_metadata = None
            if cfg.TRAIN.DEBUG_TIMELINE:
                run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

            timer.tic()

            rpn_loss_cls_value, rpn_loss_box_value,loss_cls_value, loss_box_value, _ = sess.run([rpn_cross_entropy, rpn_loss_box, cross_entropy, loss_box, train_op],
                                                                                                feed_dict=feed_dict,
                                                                                                options=run_options,
                                                                                                run_metadata=run_metadata)

            timer.toc()
            ####add loss log
            all_loss.append(rpn_loss_cls_value + rpn_loss_box_value + loss_cls_value + loss_box_value)
            rpn_cls_loss.append(rpn_loss_cls_value)
            rpn_box_loss.append(rpn_loss_box_value)
            cls_loss.append(loss_cls_value)
            box_loss.append(loss_box_value)


            if cfg.TRAIN.DEBUG_TIMELINE:
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                trace_file = open(str(long(time.time() * 1000)) + '-train-timeline.ctf.json', 'w')
                trace_file.write(trace.generate_chrome_trace_format(show_memory=False))
                trace_file.close()

            if (iter+1) % (cfg.TRAIN.DISPLAY) == 0:
                print 'iter: %d / %d, total loss: %.4f, rpn_loss_cls: %.4f, rpn_loss_box: %.4f, loss_cls: %.4f, loss_box: %.4f, lr: %f'%\
                        (iter+1, max_iters, rpn_loss_cls_value + rpn_loss_box_value + loss_cls_value + loss_box_value ,rpn_loss_cls_value, rpn_loss_box_value,loss_cls_value, loss_box_value, lr.eval())
                print 'speed: {:.3f}s / iter'.format(timer.average_time)

            if (iter+1) % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                self.snapshot(sess, iter)
                ###add
                dataframe = pd.DataFrame({'all_loss': all_loss,'rpn_cls_loss':rpn_cls_loss, 'rpn_box_loss':rpn_box_loss,'cls_loss':cls_loss,'box_loss':box_loss})
                dataframe.to_csv("xxxxxxxx/loss%d.csv" %(iter+1))

2. 画图

代码如下:

#!/usr/bin/env python  
import os  
import sys  
import numpy as np  
import matplotlib.pyplot as plt  
import math  
import re  
import pylab  
from pylab import figure, show, legend  
from mpl_toolkits.axes_grid1 import host_subplot  
import pandas as pd 

train_iterations = []  
train_loss = []
df = pd.read_csv('xxxxxxxxx/loss2500.csv')
train_iterations = df['index']
train_loss = df['all_loss']

host = host_subplot(111)  
plt.subplots_adjust(right=0.8) 

# set labels  
host.set_xlabel("iterations")  
host.set_ylabel("RPN loss")    

# plot curves  
p1, = host.plot(train_iterations, train_loss, label="train RPN loss")     
host.legend(loc=1)  

# set label color  
host.axis["left"].label.set_color(p1.get_color())  
host.set_xlim([-150,2500])  
host.set_ylim([0., 4])  

plt.draw()  
plt.show()

最后画出的结果如下:
Faster-RCNN_TF的loss曲线可视化(tensorflow版本,Faster RCNN的loss曲线可视化)_第1张图片
额 是不是lr有点大了。。

我只画了all_loss的曲线,其他loss可以相应画出。
有什么问题欢迎交流,谢谢~

你可能感兴趣的:(faster-r-cnn)