caffe训练预读取数据及结果可视化

  1. caffe训练预读取数据,加快训练速度
1) 把图片转为lmdb,可以加快数据读取速度,提高IO效率;
2) 在1的基础上,可以使用prefetch参数对数据进行预取
data_param {
source: “./data/ilsvrc12/ilsvrc12_train_lmdb”
batch_size: 32
backend: LMDB
prefetch: 20
} 

1.caffe训练时添加输出日志

#!/bin/sh
mkdir -p models 
mkdir -p log
LOG=log/train-`date +%Y-%m-%d-%H-%M-%S`.log
/home/yang/MobileNet-YOLO1/build/tools/caffe train -solver="yuface_v1_solver.prototxt" -gpu 0  \ 2>&1|tee $LOG

输出的日志保存在log/下
2.解析训练日志文件
增加了一个 倍数time 的变量,因为有时候输出波动太大,按一定倍数取平均会让曲线平滑一点。
第一个参数是log文件路径。
需要修改代码中display和test_iterval的数值个solver.prototxt中一致。
time是倍数,想看原始数据曲线的话就设置为1。

# -*- coding: utf-8 -*-

"""
      python log.py -p log/train-2019-12-25-18-10-58.log
"""
 
import matplotlib.pyplot as plt
import numpy as np
import commands
import argparse
 
parser = argparse.ArgumentParser()
parser.add_argument(
    '-p','--log_path',
    type = str,
    default = '',
    help = """\
    path to log file\
    """
)
 
FLAGS = parser.parse_args()
 
train_log_file = FLAGS.log_path
 
display = 100 #solver
test_interval = 1000 #solver
 
time = 10
 
train_output = commands.getoutput("cat " + train_log_file + " | grep 'Train net output #0' | awk '{print $11}'")  #train mbox_loss
accu_output = commands.getoutput("cat " + train_log_file + " | grep 'Test net output #0' | awk '{print $11}'") #test detection_eval
 
train_loss = train_output.split("\n")
test_accu = accu_output.split("\n")
  
def reduce_data(data):
  iteration = len(data)/time*time
  _data = data[0:iteration]
  if time > 1:
    data_ = []
    for i in np.arange(len(data)/time):
      sum_data = 0
      for j in np.arange(time):
        index = i*time + j
        sum_data += float(_data[index])
      data_.append(sum_data/float(time))
  else:
    data_ = data
  return data_
 
train_loss_ = reduce_data(train_loss)
test_accu_ = reduce_data(test_accu)
 
_,ax1 = plt.subplots()
ax2 = ax1.twinx()
 
ax1.plot(time*display*np.arange(len(train_loss_)), train_loss_)
ax2.plot(time*test_interval*np.arange(len(test_accu_)), test_accu_, 'r')
 
ax1.set_xlabel('Iteration')
ax1.set_ylabel('Train Loss')
ax2.set_ylabel('Test Accuracy')
plt.show()

结果:


参考文档:https://blog.csdn.net/renhanchi/article/details/78411095

你可能感兴趣的:(caffe训练预读取数据及结果可视化)