参考博文:
class SummaryWriter(builtins.object)
| SummaryWriter(log_dir=None, comment='', purge_step=None, max_queue=10, flush_secs=120, filename_suffix='')
# log dir
if os.path.exists('logs'):
shutil.rmtree('logs')# 如果文件存在,则递归的删除文件内容
print('Remove log dir')
采用的源码安装
git clone https://github.com/lanpa/tensorboardX && cd tensorboardX && python setup.py install
似乎不能下面这样搞?!!!
导入库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from torch.utils.tensorboard import SummaryWriter
共用的模型实例化
device = torch.device('cpu')
net = nn.Linear(1,1).to(device) # 最简单全链接模型
loss_fn =nn.MSELoss().to(device)# 损失函数
optimizer = optim.SGD(net.parameters(),lr=0.001)
if os.path.exists('log1'):
shutil.rmtree('log1')# 如果文件存在,则递归的删除文件内容
print('Remove log dir')
# 1. 输入数据
x_train = np.linspace(-1, 1, 100).reshape(100, 1)
y_train = 3 * np.power(x_train, 2) + 2 + 0.2 * np.random.randn(100, 1)
# 2. 实例化sw
sw1 = SummaryWriter(log_dir='log1')
# 3. 模型训练
for epoch in trange(40):
inputs = torch.from_numpy(x_train).type(torch.float32).to(device) # 所有输入
targets = torch.from_numpy(y_train).type(torch.float32).to(device) # 标签
output = net(inputs) # 模型的输入输出
loss = loss_fn(output, targets) # 预测则和真实值的loss
loss.backward() # 模型向后传递
optimizer.step() # 优化器优化
sw1.add_scalar("train_loss", loss, epoch)
定义Dataset数据
class Mydataset(Dataset):
def __init__(self,):
x_train =np.linspace(-1,1,100).reshape(100,1)
y_train= 3*np.power(x_train,2)+2+0.2*np.random.randn(100,1)
self.x= torch.from_numpy(x_train)
self.y=torch.from_numpy(y_train)
def __len__(self,):
return (100,1)
def __getitem__(self,index):
return self.x[index],self.y[index]
使用DataLoader设置batch_size,批量训练
dataset = Mydataset()
train_loader = DataLoader(dataset,batch_size=5,shuffle=True)
# 训练模型比记录lossz值
for epoch range(30):
running_loss =0.0
for x_t,y_t in enumerate(train_loader):
optimizer.zero_grad()
output=net(x_t)
loss=loss_fn(output,y_t)
loss.backward()
optimizer.step()
running_loss+=loss.item()
sw.add_scalar("train_loss",epoch,running_loss)
cmd
打开终端Terminal
tensorboardX
所在的虚拟环境tensorboard
下激活logdir