PyTorch中使用Tensorboard

内容出自A Complete Guide to Using TensorBoard with PyTorch。建议使用jupyter notebook或者jupyter lab进行代码运行,可以执行每一个代码段,并看到结果。

1. conda安装tensorboard

conda install -c conda-forge tensorboard

也可以使用pip

pip install tensorboard

2. 导入一些包

其中直接和使用tensorboard相关的是from torch.utils.tensorboard import SummaryWriter,其他的都是为了示例创建小的网络模型用。

import torch
import torch.nn as nn
import torch.optim as opt
torch.set_printoptions(linewidth=120)
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

3.1 定义一个函数统计最后分类正确率

def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

3.2 定义一个卷积网络模型

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size = 2, stride = 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size = 2, stride = 2)
        x = torch.flatten(x,start_dim = 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x)

        return x

3.3 导入数据并创建训练加载器

train_set = torchvision.datasets.FashionMNIST(root="./data",
train = True,
 download=True,
transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_set,batch_size = 100, shuffle = True)

4. 在tensorboard中展示图像和图

tb = SummaryWriter()
model = CNN()
images, labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)
tb.add_image("images", grid)
tb.add_graph(model, images)
tb.close()

这段代码执行后,在项目目录下会有个名为run的文件夹。打开命令行,进入项目目录(确保处于torch环境下)。然后输入

tensorboard --logdir runs

命令窗口会出现。复制http网址,在浏览器中打开。
在这里插入图片描述

5. tensorboard中可视化训练循环过程中各参数变化

device = ("cuda" if torch.cuda.is_available() else cpu)
model = CNN().to(device)
train_loader = torch.utils.data.DataLoader(train_set,batch_size = 100, shuffle = True)
optimizer = opt.Adam(model.parameters(), lr= 0.01)
criterion = torch.nn.CrossEntropyLoss()

tb = SummaryWriter()

for epoch in range(10):

    total_loss = 0
    total_correct = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        preds = model(images)

        loss = criterion(preds, labels)
        total_loss+= loss.item()
        total_correct+= get_num_correct(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    tb.add_scalar("Loss", total_loss, epoch)
    tb.add_scalar("Correct", total_correct, epoch)
    tb.add_scalar("Accuracy", total_correct/ len(train_set), epoch)

    tb.add_histogram("conv1.bias", model.conv1.bias, epoch)
    tb.add_histogram("conv1.weight", model.conv1.weight, epoch)
    tb.add_histogram("conv2.bias", model.conv2.bias, epoch)
    tb.add_histogram("conv2.weight", model.conv2.weight, epoch)

    print("epoch:", epoch, "total_correct:", total_correct, "loss:",total_loss)

tb.close()

运行完代码,重复第4步操作,或者如果第4步打开的网址没有关闭,可以刷新下网页。

6. 超参数调整

先定义一个字典存放超参数的各种取值

from itertools import product
parameters = dict(
    lr = [0.01, 0.001],
    batch_size = [32,64,128],
    shuffle = [True, False]
)

param_values = [v for v in parameters.values()]
print(param_values)

for lr,batch_size, shuffle in product(*param_values):
    print(lr, batch_size, shuffle)

分别对每一种超参数取值组合进行训练

for run_id, (lr,batch_size, shuffle) in enumerate(product(*param_values)):
    print("run id:", run_id + 1)
    model = CNN().to(device)
    train_loader = torch.utils.data.DataLoader(train_set,batch_size = batch_size, shuffle = shuffle)
    optimizer = opt.Adam(model.parameters(), lr= lr)
    criterion = torch.nn.CrossEntropyLoss()
    comment = f' batch_size = {batch_size} lr = {lr} shuffle = {shuffle}'
    tb = SummaryWriter(comment=comment)
    for epoch in range(5):
        total_loss = 0
        total_correct = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            preds = model(images)

            loss = criterion(preds, labels)
            total_loss+= loss.item()
            total_correct+= get_num_correct(preds, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tb.add_scalar("Loss", total_loss, epoch)
        tb.add_scalar("Correct", total_correct, epoch)
        tb.add_scalar("Accuracy", total_correct/ len(train_set), epoch)

        print("batch_size:",batch_size, "lr:",lr,"shuffle:",shuffle)
        print("epoch:", epoch, "total_correct:", total_correct, "loss:",total_loss)
    print("__________________________________________________________")

    tb.add_hparams(
            {"lr": lr, "bsize": batch_size, "shuffle":shuffle},
            {
                "accuracy": total_correct/ len(train_set),
                "loss": total_loss,
            },
        )

tb.close()

运行完代码,重复第4步操作,或者如果第4步打开的网址没有关闭,可以刷新下网页。

你可能感兴趣的:(Python,pytorch)