Caffe-Windows下画loss与accuracy曲线

本篇博客主要讲述怎样在Windows下利用Caffe提供的脚本程序和Caffe训练日志画loss曲线与accuracy曲线。如果你是在Linux下使用Caffe可以参考这篇博客:http://blog.csdn.net/fx409494616/article/details/53197209?ref=myread。

如果你还没有Caffe训练日志,请参考上一篇博客http://blog.csdn.net/sunshine_in_moon/article/details/53529028,生成自己的训练日志。

好了废话少说,直接上干货!!!

1、修改上一篇博客中的一行代码,目的是使生成的日志文件的后缀名为".log"。因为Caffe提供的脚本处理的文件默认后缀是".log",当然,我们也可以不用修改,生成日志文件后直接认为修改后缀名即可了。一劳永逸我们还是改一下吧。

void initGlog() {
	FLAGS_log_dir = "E:\\caffe\\caffe-windows\\log\\";//存放日志文件的文件夹路径,我们可以自己指定  
	_mkdir(FLAGS_log_dir.c_str());
	std::string LOG_INFO_FILE;
	std::string LOG_WARNING_FILE;
	std::string LOG_ERROR_FILE;
	std::string LOG_FATAL_FILE;
	std::string now_time = boost::posix_time::to_iso_extended_string(boost::posix_time::second_clock::local_time());
	now_time[13] = '-';
	now_time[16] = '-';
	//LOG_INFO_FILE = FLAGS_log_dir + "INFO" + now_time + ".txt";
        /************将txt改成log*********/
        LOG_INFO_FILE = FLAGS_log_dir + "INFO" + now_time + ".log";
        /*****************************/
	google::SetLogDestination(google::GLOG_INFO, LOG_INFO_FILE.c_str());
	LOG_WARNING_FILE = FLAGS_log_dir + "WARNING" + now_time + ".txt";
	google::SetLogDestination(google::GLOG_WARNING, LOG_WARNING_FILE.c_str());
	LOG_ERROR_FILE = FLAGS_log_dir + "ERROR" + now_time + ".txt";
	google::SetLogDestination(google::GLOG_ERROR, LOG_ERROR_FILE.c_str());
	LOG_FATAL_FILE = FLAGS_log_dir + "FATAL" + now_time + ".txt";
	google::SetLogDestination(google::GLOG_FATAL, LOG_FATAL_FILE.c_str());
}
OK!第一步完成,重新编译就好了。

2、修改tools/extra/plot_training_log.py,这里面需要修改的东西太多了,我们分步讲解,可能代码优点乱,大家不要介意。

2.1、生成*****log.test,*****log.train两个文件

方法一:利用tools/extra/parse_log.py文件

python parse_log.py ****.log save_path
第一个参数:我们的训练日志,后缀名必须是".log",其实这也不是必须的,我们可以修改plot_training_log.py中子函数

def get_log_file_suffix():
    return '.log'#可以返回其他后缀名
第二个参数:保存路径,执行上述命令后会生成两个文件****.log.test,****.log.train。

方法二:将生成这两个文件集成到plot_training_log.py中。我们首先看一下两个plot_training_log.py文件中的子函数

def get_log_parsing_script():
    dirname = os.path.dirname(os.path.abspath(inspect.getfile(
        inspect.currentframe())))
    return dirname + '/parse_log.sh'
返回的是parse_log.sh脚本的路径,看来要调用这个脚本,但是我们知道在Windows下是无法使用shell脚本的。所以我们需要修改调用这个shell脚本的地方。就在下面这个子函数

def plot_chart(chart_type, path_to_png, path_to_log_list):
    for path_to_log in path_to_log_list:
        #os.system('%s %s' % (get_log_parsing_script(), path_to_log))
        ######################自己修改#############################
        train_dict_list, test_dict_list = parse_log.parse_log(path_to_log)
        parse_log.save_csv_files(path_to_log, './', train_dict_list,test_dict_list)
        #####################记得要在前面导入parse_log模块########
        data_file = get_data_file(chart_type, path_to_log)
        x_axis_field, y_axis_field = get_field_descriptions(chart_type)
        x, y = get_field_indices(x_axis_field, y_axis_field)
        data = load_data(data_file, x, y)
        ## TODO: more systematic color cycle for lines
        color = [random.random(), random.random(), random.random()]
        label = get_data_label(path_to_log)
        linewidth = 0.75
        ## If there too many datapoints, do not use marker.
##        use_marker = False
        use_marker = True
        if not use_marker:
            plt.plot(data[0], data[1], label = label, color = color,
                     linewidth = linewidth)
        else:
            ok = False
            ## Some markers throw ValueError: Unrecognized marker style
            while not ok:
                try:
                    marker = random_marker()
                    plt.plot(data[0], data[1], label = label, color = color,
                             marker = marker, linewidth = linewidth)
                    ok = True
                except:
                    pass
    legend_loc = get_legend_loc(chart_type)
    plt.legend(loc = legend_loc, ncol = 1) # ajust ncol to fit the space
    plt.title(get_chart_type_description(chart_type))
    plt.xlabel(x_axis_field)
    plt.ylabel(y_axis_field)
    plt.savefig(path_to_png)
    plt.show()
看到了第一句就是调用shell脚本,我们将其注释掉,然后利用parse_log.py文件中的子函数来实现相同的功能。
2.2、Caffe提供的工具可以生成8种不同的曲线

Caffe-Windows下画loss与accuracy曲线_第1张图片

2.3、修改子函数creat_field_index()

def create_field_index():
    train_key = 'Train'
    test_key = 'Test'
    field_index = {train_key:{'Iters':0, 'Seconds':1, train_key + ' learning rate':2,
                              train_key + ' loss':3},#根据自己的**.log.train文件修改了2和3的顺序
                   test_key:{'Iters':0, 'Seconds':1, 'learning rate':2,test_key + ' accuracy':3,
                             test_key + ' loss':4}}#自己增加test_key 中learning rate
    fields = set()
    for data_file_type in field_index.keys():
        fields = fields.union(set(field_index[data_file_type].keys()))
    fields = list(fields)
    fields.sort()
    return field_index, fields
主要修改的地方就是field_index,这要根据你前面生成的****.log.test和****.log.train两个文件中第一行的单词的顺序修改字典对应顺序。我此处的修改是根据我的文件,切记一定要和你的文件核对,否则生成的曲线是不对的。我已经测试过8种曲线都能正确画出。

2.4、修改load_data()

def load_data(data_file, field_idx0, field_idx1):
    data = [[], []]
    fr = open(data_file,'r')
    lines = fr.readlines()
    for i in range(1,len(lines)):
        line = lines[i].strip()
        if line[0] != '#':
            fields = line.split(',')
            data[0].append(float(fields[field_idx0].strip()))
            data[1].append(float(fields[field_idx1].strip()))
    fr.close()
    return data
之所以修改这个函数,因为原函数是从****.log.test和****.log.train的第一行读取数据,但是第一行是单词如法转换成浮点数,必须从第二行开始读取数据。

OK,到此为止,需要修改的地方基本上已经没有了。

需要注意两点:

1、保存的图片默认后缀名.png,如果你想保存成其他后缀名,可修改下面的代码

path_to_png = sys.argv[2]
if not path_to_png.endswith('.png'):#此处检查后缀名,可以改成你想要的后缀
      print 'Path must ends with png' % path_to_png
2、Windows命令格式

python plot_training_log.py 7 train.png INFO2016-12-09T12-54-26.log
结果如下:

Caffe-Windows下画loss与accuracy曲线_第2张图片

是不是很酷!

修改后的完整代码请到此处下载:http://download.csdn.net/detail/sunshine_in_moon/9706954

下载积分为5分,毕竟辛辛苦苦改了很长时间,请多多支持。如果你的积分确实有限,可以给我留言并附上邮箱。







你可能感兴趣的:(caffe)