强化学习画曲线图(不带阴影)

只需更改文件路径(xx/xx/xx.log)和train_pattern,文件类型为日志(xx.log)。

#!/usr/bin/env python
# encoding: utf-8
"""
@author: xxx
@license: (C) Copyright 2013-2017, Node Supply Chain Manager Corporation Limited.
@software: pycharm
@desc:
"""
import argparse
import re

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib import rcParams

rcParams['font.family'] = 'Times New Roman'
sns.set()
parser = argparse.ArgumentParser()
parser.add_argument('--window_size', type=int, default=500)
args = parser.parse_args()


def running_mean(x, n):
    cumsum = np.cumsum(np.insert(x, 0, 0))
    return (cumsum[n:] - cumsum[:-n]) / float(n)


def main():
    # define the names of the models you want to plot and the longest episodes you want to show
    max_episodes = 1000000
    log_file = 'xx/xx/xx.log'
    with open(log_file, 'r') as file:
        log = file.read()
    # print(log)

    train_pattern = r"episode:(.*), reward:(.*), memory size:(.*), time:(.*), info:(.*)"
    train_reward = []
    infolist = []
    success_rate = 0
    timeout_rate = 0
    collision_rate = 0
    for r in re.findall(train_pattern, log):
        train_reward.append(float(r[1]))
        # infolist.append((r[4]))

    # for info in infolist[-1000:]:
    #     if info == 'Reaching goal':
    #         success_rate += 1
    #     if info == 'Collision':
    #         collision_rate += 1
    #     if info == 'Timeout':
    #         timeout_rate += 1

    train_reward = train_reward[:max_episodes]
    # print train_reward
    train_reward_smooth = running_mean(train_reward, args.window_size)
    # print('The success rate:{}, collision rate:{}, timeout rate:{}'.format(success_rate / 1000,
    #                                                                        collision_rate / 1000,
    #                                                                        timeout_rate / 1000))
    _, ax4 = plt.subplots()

    ax4_legends = []
    ax4.plot(range(len(train_reward_smooth)), train_reward_smooth)
    # ax4_legends.append(models[i])
    # ax4_legends.append('ppo')
    ax4.legend(ax4_legends, shadow=True, loc='best')
    # ax4.grid(True)
    ax4 = plt.gca()
    # ax4.patch.set_facecolor('xkcd:mint green')

    ax4.spines['top'].set_visible(False)  # 去掉上边框
    ax4.spines['right'].set_visible(False)  # 去掉右边框
    # ax4.patch.set_facecolor("green")
    ax4.patch.set_alpha(0.5)
    ax4.set_xlabel('Updates', fontproperties='Times New Roman')
    ax4.set_ylabel('Eprewmean', fontproperties='Times New Roman')
    ax4.set_title(' ')
    labels = ax4.get_xticklabels() + ax4.get_yticklabels()
    [label.set_fontname('Times New Roman') for label in labels]

    plt.show()


if __name__ == '__main__':
    main()

强化学习画曲线图(不带阴影)_第1张图片

你可能感兴趣的:(python,开发语言)