【学习笔记】pytorch 深度学习训练如何显示进度条

2021/8/17
通过使用tadm,实时显示训练进度,并显示当前训练集正确率以及损失
效果图如下:【学习笔记】pytorch 深度学习训练如何显示进度条_第1张图片
实现代码:

def train(model, criterion, optimizer, trainloader, Epoch, EPOCHS, BATCH_SIZE):
    model.train()
    loop = tqdm(enumerate(trainloader), total =len(trainloader))
    running_loss = 0.0
    right = 0
    for step, (batch_x, batch_y) in loop:
        batch_x, batch_y = batch_x.cuda(), batch_y.cuda()

        output = model(batch_x)

        optimizer.zero_grad()
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        # 累加识别正确的样本数
        right += (predicted == batch_y).sum()
        #更新信息
        loop.set_description(f'Epoch [{Epoch}/{EPOCHS}]')
        loop.set_postfix(loss=running_loss/(step+1), acc=float(right)/float(BATCH_SIZE*step+len(batch_x)))

参考:https://zhuanlan.zhihu.com/p/378474516

你可能感兴趣的:(深度学习,pytorch)