Windows+Python3下绘制Caffe训练日志中的Loss和Accuracy曲线图

在深度学习中,可以通过学习曲线评估当前训练状态:

  1. train loss 不断下降,test loss 不断下降,说明网络仍然在认真学习中。
  2. train loss 不断下降,test loss 趋于不变,说明网络过拟合。
  3. train loss 趋于不变,test loss 趋于不变,说明学习遇到瓶颈,需减小学习速率或者批量数据尺寸。
  4. train loss 趋于不变,test loss 不断下降,说明数据集 100% 有问题。
  5. train loss 不断上升,test loss不断上升(最终为NaN),可能网络结构设计不当、训练超参数设置不当、程序bug等某个问题引起,需要进一步定位。

Linux下的MATLAB代码:
// 提取log文件中的loss值shell命令:cat train_log_file | grep ”Train net output ” | awk ‘{print $11}’

clear;
clc;
close all;
train_log_file = 'train.log';
train_interval = 100;
test_interval = 200;
[~, train_string_output] = dos(['cat ', train_log_file, ' | grep ''Train net output #0'' | awk ''{print $11}''']);
train_loss = str2num(train_string_output);
n = 1 : length(train_loss);
idx_train = (n - 1) * train_interval;
[~, test_string_output] = dos(['cat ', train_log_file, ' | grep ''Test net output #1'' | awk ''{print $11}''']);
test_loss = str2num(test_string_output);
m = 1 : length(test_loss);
idx_test = (m - 1) * test_interval;
figure;
plot(idx_train, train_loss);
hold on;
plot(idx_test, test_loss);

grid on;
legend('Train Loss', 'Test Loss');
xlabel('iterations');
ylabel('loss');
title(' Train & Test Loss Curve');

Window下的Python3(Anaconda3+Pycharm)代码:

"./bin/caffe.exe" train --solver=./examples/mnist/lenet_solver.prototxt >./examples/mnist/log/mnist_Lenet_train_test.log 2>&1
pause

命令>./examples/mnist/log/mnist_Lenet_train_test.log 2>&1表示训练日志的输出。
parse_log.py和extract_seconds.py文件用于解析训练日志:
parse_log.py源码:

import re
from examples.mnist.log.extract_seconds import *
import csv
from collections import OrderedDict

def parse_log(log_file_name):
    """
    Parse log file
    :param log_file_name: the name of log file
    :return: (train_dict_list, test_dict_list)
    """
    regex_iteration = re.compile('Iteration (\d+)')
    regex_train_output = re.compile('Train net output #(\d+): (\S+) = ([.\deE+-]+)')
    regex_test_output = re.compile('Test net output #(\d+): (\S+) = ([.\deE+-]+)')
    regex_learning_rate = re.compile('lr = ([-+]?[0-9]*\.?[0-9]+([eE]?[-+]?[0-9]+)?)')

    # Pick out lines of interest
    iteration = -1
    learning_rate = float('NaN')
    train_dict_list = []
    test_dict_list = []
    train_row = None
    test_row = None

    logfile_year = get_log_created_year(log_file_name)
    with open(log_file_name) as f:
        start_time = get_start_time(f, logfile_year)
        last_time = start_time

        for line in f:
            iteration_match = regex_iteration.search(line)
            if iteration_match:
                iteration = float(iteration_match.group(1))
            if iteration == -1:
                # Only start parsing for other stuff if we've found the first iteration
                continue

            try:
                time = extract_datetime_from_line(line, logfile_year)
            except ValueError:
                # Skip lines with bad formatting, for example when resuming solver
                continue

            # if it's another year
            if time.month < last_time.month:
                logfile_year += 1
                time = extract_datetime_from_line(line, logfile_year)
            last_time = time

            seconds = (last_time - start_time).total_seconds()

            learning_rate_match = regex_learning_rate.search(line)
            if learning_rate_match:
                learning_rate = float(learning_rate_match.group(1))

            train_dict_list, train_row = parse_line_for_net_output(
                regex_train_output, train_row, train_dict_list, line, iteration, seconds, learning_rate)
            test_dict_list, test_row = parse_line_for_net_output(
                regex_test_output, test_row, test_dict_list, line, iteration, seconds, learning_rate)

        fix_initial_nan_learning_rate(train_dict_list)
        fix_initial_nan_learning_rate(test_dict_list)

        return train_dict_list, test_dict_list

def parse_line_for_net_output(regex_obj, row, row_dict_list, line, iteration, seconds, learning_rate):
    """Parse a single line for training or test output

    Returns a a tuple with (row_dict_list, row)
    row: may be either a new row or an augmented version of the current row
    row_dict_list: may be either the current row_dict_list or an augmented
    version of the current row_dict_list
    """
    output_match = regex_obj.search(line)
    if output_match:
        if not row or row['NumIters'] != iteration:
            # Push the last row and start a new one
            if row:
                # If we're on a new iteration, push the last row
                # This will probably only happen for the first row; otherwise
                # the full row checking logic below will push and clear full
                # rows
                row_dict_list.append(row)

            row = OrderedDict(
                [
                 ('NumIters', iteration),
                 ('Seconds', seconds),
                 ('LearningRate', learning_rate)
                ]
            )

        # output_num is not used; may be used in the future
        output_name = output_match.group(2)
        output_val = output_match.group(3)
        row[output_name] = float(output_val)

    if row and len(row_dict_list) >= 1 and len(row) == len(row_dict_list[0]):
        # The row is full, based on the fact that it has the same number of columns as the first row;
        # append it to the list
        row_dict_list.append(row)
        row = None

    return row_dict_list, row

def fix_initial_nan_learning_rate(dict_list):
    """Correct initial value of learning rate
    Learning rate is normally not printed until after the initial test and
    training step, which means the initial testing and training rows have
    LearningRate = NaN. Fix this by copying over the LearningRate from the
    second row, if it exists.
    """
    if len(dict_list) > 1:
        dict_list[0]['LearningRate'] = dict_list[1]['LearningRate']

def save_csv_files(logfile, output_dir, train_dict_list, test_dict_list, delimiter=',', verbose=False):
    """Save CSV files to output_dir
    If the input log file is, e.g., caffe.INFO, the names will be
    caffe.INFO.train and caffe.INFO.test
    """
    log_basename = os.path.basename(logfile)
    train_filename = os.path.join(output_dir, log_basename + '.train')
    write_csv(train_filename, train_dict_list, delimiter, verbose)

    test_filename = os.path.join(output_dir, log_basename + '.test')
    write_csv(test_filename, test_dict_list, delimiter, verbose)

def write_csv(output_filename, dict_list, delimiter, verbose=False):
    """Write a CSV file
    """
    if not dict_list:
        if verbose:
            print('Not writing %s; no lines to write' % output_filename)
        return

    dialect = csv.excel
    dialect.delimiter = delimiter

    with open(output_filename, 'w') as f:
        dict_writer = csv.DictWriter(f, fieldnames=dict_list[0].keys(),dialect=dialect)
        dict_writer.writeheader()
        dict_writer.writerows(dict_list)
    if verbose:
        print('Wrote %s' % output_filename)

def main():
    log_file_name = 'mnist_Lenet_train_test.log'
    output_dir = 'C:\\Programming Code\\Caffe\\examples\\mnist\\log\\'//解析后的文件保存地址
    train_dict_list, test_dict_list = parse_log(log_file_name)
    save_csv_files(log_file_name, output_dir, train_dict_list, test_dict_list, delimiter=',')

if __name__ == '__main__':
    main()

extract_seconds.py源码:

import datetime
import os

def extract_datetime_from_line(line, year):
    """
    extract datetime from line
    :param line: the lines
    :param year: the year
    :return: datetime
    """
    # Expected format: I0210 13:39:22.381027 25210 solver.cpp:204] Iteration 100, lr = 0.00992565
    line = line.strip().split()
    month = int(line[0][1:3])
    day = int(line[0][3:])
    timestamp = line[1]
    pos = timestamp.rfind('.')
    ts = [int(x) for x in timestamp[:pos].split(':')]
    hour = ts[0]
    minute = ts[1]
    second = ts[2]
    microsecond = int(timestamp[pos + 1:])
    dt = datetime.datetime(year, month, day, hour, minute, second, microsecond)
    return dt

def get_log_created_year(input_file):
    """
    get the year from log file system timestamp
    :param input_file: the input 
    :return: the created year of the log file
    """
    log_created_time = os.path.getctime(input_file)
    log_created_year = datetime.datetime.fromtimestamp(log_created_time).year
    return log_created_year

def get_start_time(line_iterable, year):
    """
    find start time from group of lines
    :param line_iterable: the lines of log file
    :param year: the created year of log file
    :return: the start datetime
    """
    start_datetime = None
    for line in line_iterable:
        line = line.strip()
        if line.find('Solving') != -1:
            start_datetime = extract_datetime_from_line(line, year)
            break
    return start_datetime

绘图源码:

import matplotlib.pyplot as plt
import random
import itertools
def load_data(data_file, phase):
    """
    load the data
    :param data_file: the data file 
    :param phase: the data of train phase or test phase
    :return: data
    """
    if phase == 'Train':
        data = [[], [], []]
        with open(data_file, 'r') as f:
            for line in itertools.islice(f, 2, None, 2):
                line = line.strip()
                fields = line.split(",")
                data[0].append(float(fields[0].strip()))
                data[1].append(float(fields[2].strip()))
                data[2].append(float(fields[3].strip()))
    else:
        data = [[], [], [], []]
        with open(data_file, 'r') as f:
            for line in itertools.islice(f, 2, None, 2):
                line = line.strip()
                fields = line.split(",")
                data[0].append(float(fields[0].strip()))
                data[1].append(float(fields[2].strip()))
                data[2].append(float(fields[3].strip()))
                data[3].append(float(fields[4].strip()))
    return data

def plot_chart(path_to_png, data, phase):
    """
    plot the chart according the log file
    :param path_to_png: the save path of the png chart
    :param data: the data of chart
    :param phase: plot the chart of train phase or test phase
    :return: None
    """

    line_width = 1.0 # the line width

    if phase == 'Train':
        train_num_iteration = data[0]
        train_learning_rate = data[1]
        train_loss = data[2]
        # plot the Iteration VS Loss
        train_color = [random.random(), random.random(), random.random()]  # the color of line
        figure_1 = plt.figure('Train Iterations VS Loss')
        plt.plot(train_num_iteration, train_loss, color=train_color, linewidth=line_width)
        plt.title('Train Iterations VS Loss')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.savefig(path_to_png + 'Train Iterations VS Loss.png')

        # plot the Iteration VS learning rate
        train_color = [random.random(), random.random(), random.random()]  # the color of line
        figure_2 = plt.figure('Train Iterations VS LearningRate')
        plt.plot(train_num_iteration, train_learning_rate, color=train_color, linewidth=line_width)
        plt.title('Train Iterations VS LearningRate')
        plt.xlabel('Iterations')
        plt.ylabel('LearningRate')
        plt.savefig(path_to_png + 'Train Iterations VS LearningRate.png')

    else:
        test_num_iteration = data[0]
        test_learning_rate = data[1]
        test_accuracy = data[2]
        test_loss = data[3]

        # plot the Iteration VS Loss
        test_color = [random.random(), random.random(), random.random()]  # the color of line
        figure_1 = plt.figure('Test Iterations VS Loss')
        plt.plot(test_num_iteration, test_loss, color=test_color, linewidth=line_width)
        plt.title('Test Iterations VS Loss')
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.savefig(path_to_png + 'Test Iterations VS Loss.png')

        # plot the Iteration VS LearningRate
        test_color = [random.random(), random.random(), random.random()]  # the color of line
        figure_2 = plt.figure('Test Iterations VS LearningRate')
        plt.plot(test_num_iteration, test_learning_rate, color=test_color, linewidth=line_width)
        plt.title('Test Iterations VS LearningRate')
        plt.xlabel('Iterations')
        plt.ylabel('LearningRate')
        plt.savefig(path_to_png + 'Test Iterations VS LearningRate.png')

        # plot the Iteration VS Accuracy
        test_color = [random.random(), random.random(), random.random()]  # the color of line
        figure_3 = plt.figure('Test Iterations VS Accuracy')
        plt.plot(test_num_iteration, test_accuracy, color=test_color, linewidth=line_width)
        plt.title('Test Iterations VS Accuracy')
        plt.xlabel('Iterations')
        plt.ylabel('Accuracy')
        plt.savefig(path_to_png + 'Test Iterations VS Accuracy.png')

def main():
    train_log = 'mnist_Lenet_train_test.log.train'
    test_log = 'mnist_Lenet_train_test.log.test'
    path_to_png = 'C:\\Programming Code\\Caffe\\examples\\mnist\\log\\'

    # load the train data
    train_data = load_data(train_log, phase='Train')
    # plot the train chart
    plot_chart(path_to_png, train_data, phase='Train')
    # load the test data
    test_data = load_data(test_log, phase='Test')
    # plot the test chart
    plot_chart(path_to_png, test_data, phase='Test')

if __name__ == '__main__':
    main()

你可能感兴趣的:(深度学习,python3)