Caffe—根据log日志绘制loss曲线和accuracy

本文在此只讲述Ubuntu16.04
caffe训练日志绘制loss曲线以及accuracy
如果是windows平台直接跳转文末

caffe中其实已经自带了这样的小工具

  • caffe-master/tools/extra/parse_log.sh
  • caffe-master/tools/extra/extract_seconds.py
  • caffe-master/tools/extra/plot_training_log.py.example

拷贝以上文件到当前log目录下(训练日志保存目录下)
并将plot_training_log.py.example改名为plot_training_log.py

第一步保存日志文件,用重定向即可:

  • 如果您已有log文件忽略,没有请将训练日志保存(可参考如下代码)
    $TOOLS/caffe train --solver=$SOLVERFILE 2>&1 |tee out.log

第二步直接绘制:

  • 注意,采用的是Python2.7不是Python3
python plot_training_log.py 2 testloss.png out.log
  • Notes: 上述代码中的参数2-是选择画哪种类型的图片,具体数字是代表哪个类型可以查看帮助信息看到:
    1. Supporting multiple logs.
    2. Log file name must end with the lower-cased ".log". 
  • Supported chart types:
    0: Test accuracy  vs. Iters
    1: Test accuracy  vs. Seconds
    2: Test loss  vs. Iters
    3: Test loss  vs. Seconds
    4: Train learning rate  vs. Iters
    5: Train learning rate  vs. Seconds
    6: Train loss  vs. Iters
    7: Train loss  vs. Seconds

读者根据自身需求进行修改参数
Caffe—根据log日志绘制loss曲线和accuracy_第1张图片

当遇到如下情况

onds.py: 权限不够
paste: aux4.txt: 没有那个文件或目录
rm: 无法删除'aux4.txt': 没有那个文件或目录
/home/th/data/TH/CARlog/parse_log.sh: 
行 47: /home/th/data/TH/CARlog/extract_seconds.py: 权限不够
paste: aux3.txt: 没有那个文件或目录
rm: 无法删除'aux3.txt': 没有那个文件或目录

因为aux4.txt是由aux3.txt来的,没有aux3.txt就无法生成aux4.txt
也就报错说不能pasterm aux4.txt
在extract_seconds.py中也是通过寻找sovling来确定开始时间的。
如果单独用parse_log.py生成日志文件,不会报aux4.txt的错误,但会报extract_seconds.py的错

parse_log.shline28-31 修改两行代码:

grep '] Solving ' $1 > aux3.txt
# grep 'Testing net' $1 >> aux3.txt
grep 'Train net' $1 >> aux3.txt

同时需要修改plot_training_log.py文件中的load_data()函数
之所以修改这个函数,因为原函数是从 .log.test.log.train 的第一行读取数据
但是第一行是单词如法转换成浮点数,必须从第二行开始读取数据。

load_data()修改代码如下:

def load_data(data_file, field_idx0, field_idx1):
    data = [[], []]
    with open(data_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line[0] != '#':
                line=','.join(filter(lambda x: x, line.split(' ')))
                print line
                fields = line.split(',')
                print fields
                data[0].append(float(fields[field_idx0].strip()))
                data[1].append(float(fields[field_idx1].strip()))
    return data

如果为Windows平台,没有shell可以采用parse_log.py工具

源代码如下:

#!/usr/bin/env python

"""
Parse training log

Evolved from parse_log.sh
"""

import os
import re
import extract_seconds
import argparse
import csv
from collections import OrderedDict


def parse_log(path_to_log):
    """Parse log file
    Returns (train_dict_list, test_dict_list)

    train_dict_list and test_dict_list are lists of dicts that define the table
    rows
    """

    regex_iteration = re.compile('Iteration (\d+)')
    regex_train_output = re.compile('Train net output #(\d+): (\S+) = ([\.\deE+-]+)')
    regex_test_output = re.compile('Test net output #(\d+): (\S+) = ([\.\deE+-]+)')
    regex_learning_rate = re.compile('lr = ([-+]?[0-9]*\.?[0-9]+([eE]?[-+]?[0-9]+)?)')

    # Pick out lines of interest
    iteration = -1
    learning_rate = float('NaN')
    train_dict_list = []
    test_dict_list = []
    train_row = None
    test_row = None

    logfile_year = extract_seconds.get_log_created_year(path_to_log)
    with open(path_to_log) as f:
        start_time = extract_seconds.get_start_time(f, logfile_year)

        for line in f:
            iteration_match = regex_iteration.search(line)
            if iteration_match:
                iteration = float(iteration_match.group(1))
            if iteration == -1:
                # Only start parsing for other stuff if we've found the first
                # iteration
                continue

            try:
                time = extract_seconds.extract_datetime_from_line(line,
                                                                  logfile_year)
            except ValueError:
                # Skip lines with bad formatting, for example when resuming solver
                continue

            seconds = (time - start_time).total_seconds()

            learning_rate_match = regex_learning_rate.search(line)
            if learning_rate_match:
                learning_rate = float(learning_rate_match.group(1))

            train_dict_list, train_row = parse_line_for_net_output(
                regex_train_output, train_row, train_dict_list,
                line, iteration, seconds, learning_rate
            )
            test_dict_list, test_row = parse_line_for_net_output(
                regex_test_output, test_row, test_dict_list,
                line, iteration, seconds, learning_rate
            )

    fix_initial_nan_learning_rate(train_dict_list)
    fix_initial_nan_learning_rate(test_dict_list)

    return train_dict_list, test_dict_list


def parse_line_for_net_output(regex_obj, row, row_dict_list,
                              line, iteration, seconds, learning_rate):
    """Parse a single line for training or test output

    Returns a a tuple with (row_dict_list, row)
    row: may be either a new row or an augmented version of the current row
    row_dict_list: may be either the current row_dict_list or an augmented
    version of the current row_dict_list
    """

    output_match = regex_obj.search(line)
    if output_match:
        if not row or row['NumIters'] != iteration:
            # Push the last row and start a new one
            if row:
                # If we're on a new iteration, push the last row
                # This will probably only happen for the first row; otherwise
                # the full row checking logic below will push and clear full
                # rows
                row_dict_list.append(row)

            row = OrderedDict([
                ('NumIters', iteration),
                ('Seconds', seconds),
                ('LearningRate', learning_rate)
            ])

        # output_num is not used; may be used in the future
        # output_num = output_match.group(1)
        output_name = output_match.group(2)
        output_val = output_match.group(3)
        row[output_name] = float(output_val)

    if row and len(row_dict_list) >= 1 and len(row) == len(row_dict_list[0]):
        # The row is full, based on the fact that it has the same number of
        # columns as the first row; append it to the list
        row_dict_list.append(row)
        row = None

    return row_dict_list, row


def fix_initial_nan_learning_rate(dict_list):
    """Correct initial value of learning rate

    Learning rate is normally not printed until after the initial test and
    training step, which means the initial testing and training rows have
    LearningRate = NaN. Fix this by copying over the LearningRate from the
    second row, if it exists.
    """

    if len(dict_list) > 1:
        dict_list[0]['LearningRate'] = dict_list[1]['LearningRate']


def save_csv_files(logfile_path, output_dir, train_dict_list, test_dict_list,
                   delimiter=',', verbose=False):
    """Save CSV files to output_dir

    If the input log file is, e.g., caffe.INFO, the names will be
    caffe.INFO.train and caffe.INFO.test
    """

    log_basename = os.path.basename(logfile_path)
    train_filename = os.path.join(output_dir, log_basename + '.train')
    write_csv(train_filename, train_dict_list, delimiter, verbose)

    test_filename = os.path.join(output_dir, log_basename + '.test')
    write_csv(test_filename, test_dict_list, delimiter, verbose)


def write_csv(output_filename, dict_list, delimiter, verbose=False):
    """Write a CSV file
    """

    if not dict_list:
        if verbose:
            print('Not writing %s; no lines to write' % output_filename)
        return

    dialect = csv.excel
    dialect.delimiter = delimiter

    with open(output_filename, 'w') as f:
        dict_writer = csv.DictWriter(f, fieldnames=dict_list[0].keys(),
                                     dialect=dialect)
        dict_writer.writeheader()
        dict_writer.writerows(dict_list)
    if verbose:
        print 'Wrote %s' % output_filename


def parse_args():
    description = ('Parse a Caffe training log into two CSV files '
                   'containing training and testing information')
    parser = argparse.ArgumentParser(description=description)

    parser.add_argument('logfile_path',
                        help='Path to log file')

    parser.add_argument('output_dir',
                        help='Directory in which to place output CSV files')

    parser.add_argument('--verbose',
                        action='store_true',
                        help='Print some extra info (e.g., output filenames)')

    parser.add_argument('--delimiter',
                        default=',',
                        help=('Column delimiter in output files '
                              '(default: \'%(default)s\')'))

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    train_dict_list, test_dict_list = parse_log(args.logfile_path)
    save_csv_files(args.logfile_path, args.output_dir, train_dict_list,
                   test_dict_list, delimiter=args.delimiter)


if __name__ == '__main__':
    main()
来生成log.test,log.train两个文件

Windows后续操作可参考该博客:
https://blog.csdn.net/sunshine_in_moon/article/details/53541573

本博客修改后的完整代码请到此处下载:
https://download.csdn.net/download/tanghong1996/10601483

你可能感兴趣的:(Caffe—根据log日志绘制loss曲线和accuracy)