Python 从 train_log.txt 中画 loss 曲线,多epoch多batch

需要解决的问题:

1. 逐行读取 txt 文件中的训练记录

2. 提取 每行中的 loss、epoch 数据信息

3. 针对每个 epoch 的多个 batch 计算一个 mean_loss

train_log 中的数据信息和格式:

Python 从 train_log.txt 中画 loss 曲线,多epoch多batch_第1张图片

Python 代码

import re
import matplotlib.pyplot as plt
import os.path as osp

fullpath = osp.abspath('./train_log.txt')
filedir, filename = osp.split(fullpath)
count, x = 0, 0
Loss, epoch = [], {0}

with open(fullpath, 'r') as f:
    while True:
        line = f.readline()
        if line == '':
            break
        if not line.startswith('2021-08'):
            continue

        _, start_epoch = re.search('epoch: ', line, flags=0).span()
        end_epoch, _ = re.search(', batch:', line, flags=0).span()
        current_epoch = float(line[start_epoch:end_epoch])

        _, start_loss = re.search('train_loss: ', line, flags=0).span()
        end_loss, _ = re.search(', time:', line, flags=0).span()
        current_loss = float(line[start_loss:end_loss])

        if current_epoch in epoch:
            x += current_loss
            count += 1
        else:
            epoch.add(current_epoch)
            Loss.append(x/count)
            x = current_loss
            count = 1
    Loss.append(x / count)

plt.plot(list(epoch), Loss)
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
pngName = filename.split('.')[0]
plt.savefig(osp.join(filedir, pngName))
plt.show()

 

你可能感兴趣的:(Python应用,神经网络学习)