pytorch TensorboardX 画Loss曲线 和 attention图

def test_2():
    writer1 = SummaryWriter('./runs/data_loss')
    writer2 = SummaryWriter('./runs/png')
    attns_en = []
    step = 199

    for i in range(3):
        attns_en.append(torch.randn([512,100,100]))
    print(len(attns_en))
    print(attns_en[0].shape)
    for i in range(50):
        m_loss = i
        l_loss = i+10
        writer1.add_scalars('data/mel_loss', {
            'mel_loss': m_loss,
        }, step)

        writer1.add_scalars('data/linear_loss', {
            'linear_loss': l_loss,
        }, step)
    writer1.close()

    for i, prob in enumerate(attns_en):  # 第i层
        for j in range(1, 4 + 1):  # 1,2,3,4  第j个
            print(f"j * 128 - 1:{j * 128 - 1}")
            x = vutils.make_grid(prob[j * 128 - 1] * 255)  # eg:如果是512,94,94  则取127,255,383,511
            writer2.add_image(f'Png/Attention_enc_layer{i}_{j}', x, step)
    writer2.close()

结果:

pytorch TensorboardX 画Loss曲线 和 attention图_第1张图片

pytorch TensorboardX 画Loss曲线 和 attention图_第2张图片

你可能感兴趣的:(pytorch,Python数据分析理论,python)