pytorch tensorboardX获取writer实例

自定义函数,获取tensorboardX里的writer实例并使用其记录数据

import os
from datetime import datetime
import socket
from tensorboardX import SummaryWriter

def get_summary_writer(log_dir, prefix=None):
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    if prefix is None:
        log_dir = os.path.join(log_dir, datetime.now().strftime('%b%d_%H-%M-%S')+'_'+socket.gethostname())
    else:
        log_dir = os.path.join(log_dir, prefix+'_'+datetime.now().strftime('%b%d_%H-%M-%S')+'_'+socket.gethostname())
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    writer = SummaryWriter(log_dir)
    return writer

writer = SummaryWriter(log_dir)
writer.add_scalar(os.path.join(prefix, 'train_loss'), loss_data, iteration)

你可能感兴趣的:(pytorch,人工智能,python)