这里首先介绍TensorBoard中各个功能如何使用。
add_scalar(tag, scalar_value, global_step=None, walltime=None)
参数
tag (string): 数据名称,不同名称的数据使用不同曲线展示
scalar_value (float): 数字常量值
global_step (int, optional): 训练的 step
walltime (float, optional): 记录发生的时间,默认为 time.time()
基础使用参考这篇文章
https://blog.csdn.net/bigbennyguo/article/details/87956434
本地浏览器使用tensorboard查看远程服务器训练情况
参考这篇博文
https://blog.csdn.net/u010626937/article/details/107747070
参考pytorch官方文档
https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html
这表明6006端口号被占用,使用lsof -i:6006找出占用号,之后使用kill将其杀死
如图
这是路径问题,将显示tensorBoard的指令中logdir后的路径改为绝对路径
tensorboard --logdir=/home/sgyj/code/FrequecyTransformer/runs/FTmodel --port=6006
import torch
import torch.nn as nn
import matplotlib as mpl
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
mpl.use('Agg')
import matplotlib.pyplot as plt
import Copy_data as dataload
from FTmodelEasy import FTModel
from vit_seg_modeling import VisionTransformer
from vit_seg_modeling import CONFIGS
from tensorboardX import SummaryWriter
import time
import numpy as np
import os
BATCH = 16
LR = 1e-3
EPOCHES = 1
# 计算模型准确率,召回率和f1分数
# output->[batch, 1, 256, 256]
# img_gt->[batch, 1, 256, 256]
def calprecise(output, img_gt):
output = torch.sigmoid(output)
mask = output > 0.3
acc_mask = torch.mul(mask.float(), img_gt)
acc_mask = acc_mask.sum()
acc_fenmu = mask.sum()
recall_fenmu = img_gt.sum()
acc = acc_mask / (acc_fenmu + 0.0001)
recall = acc_mask / (recall_fenmu + 0.0001)
f1 = 2 * acc * recall / (acc + recall + 0.0001)
return acc, recall, f1
#TensorBoardX设置
writer=SummaryWriter('runs/FTmodel')
def train():
copy_train = dataload.Copy_DATA("train")
'''
np.random.seed(200)
np.random.shuffle(casia_train.image_name)
np.random.seed(200)
np.random.shuffle(casia_train.gt_name)
'''
train_loader = torch.utils.data.DataLoader(copy_train, batch_size=BATCH, shuffle=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#net = get_danet().to(device)
net = FTModel().to(device)
net.train()
# 初始化网络参数
'''
for m in net.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, mode='fan_in')
'''
lossfunction = nn.BCELoss()
learning_rate = LR
optimizer = torch.optim.SGD([{'params' : net.parameters(),'initial_lr': 1e-3}], lr=learning_rate, momentum=0.9, weight_decay=1e-4)
'''
max_iterations = EPOCHES * len(train_loader)
iter_num = 0
'''
#scheduler = MultiStepLR(optimizer, milestones=[240,280,340,380],gamma=0.1,last_epoch=200)
scheduler = MultiStepLR(optimizer, milestones=[40,80],gamma=0.1)
losses = []
precises = []
recalles = []
f1es = []
# 如果从断点开始 resume为true
RESUME = False
if RESUME:
path_checkpoint = "/home/sgyj/code/FrequecyTransformer/checkpoint/ckpt_best_55.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
net.load_state_dict(checkpoint['net']) # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch
#scheduler.load_state_dict(checkpoint['scheduler'])
# 如果需要从断点开始训练 则下面循环 in range(start_epoch + 1 ,EPOCH)
for epoch in range(EPOCHES):
total_loss = 0
precise = 0
recall_score = 0
f1_score = 0
st = time.time()
for step, data in enumerate(train_loader):
img, img_gt = data
img = img.to(device)
img_gt = img_gt.to(device)
pred_mask = net(img)
#pred_mask=pred_mask[0]
pred_mask_sigmoid = torch.sigmoid(pred_mask)
pred_mask_flat = pred_mask_sigmoid.view(-1)
true_masks_flat = img_gt.view(-1)
loss = lossfunction(pred_mask_flat, true_masks_flat)
optimizer.zero_grad()
loss.backward()
optimizer.step()
'''
lr_ = LR * (1.0 - iter_num / max_iterations) ** 0.9
#动态调整学习率
for param_group in optimizer.param_groups:
param_group['lr'] = lr_
iter_num = iter_num + 1
'''
acc, recall, f1 = calprecise(pred_mask, img_gt)
lrr=optimizer.state_dict()['param_groups'][0]['lr']
print("(train)epoch%d->step%d loss:%.6f acc:%.6f recall:%.6f f1:%.6f lr:%.6f cost time:%ds" % (
epoch, step, loss, acc, recall, f1, lrr,time.time() - st))
total_loss = loss.item() + total_loss
precise = precise + acc
recall_score = recall_score + recall
f1_score = f1_score + f1
scheduler.step()
# 计算每个epoch的平均指标
losses.append(total_loss / len(train_loader))
precises.append(precise / len(train_loader))
recalles.append(recall_score / len(train_loader))
f1es.append(f1_score / len(train_loader))
cost = time.time() - st
print("(train)epoch%d-> loss:%.6f acc:%.6f recall:%.6f f1:%.6f cost time:%ds" %
(epoch, total_loss / len(train_loader), precise / len(train_loader), recall_score / len(train_loader),
f1_score / len(train_loader), cost))
writer.add_scalar('training loss',
total_loss / len(train_loader),
epoch)
writer.add_scalar('precises',
precise / len(train_loader),
epoch)
writer.add_scalar('recalles',
recall_score / len(train_loader),
epoch)
writer.add_scalar('f1es',
f1_score / len(train_loader),
epoch)
# 每20个epoch保存一次模型断点
if (epoch != 0 and epoch % 5 == 0):
checkpoint = {
"net": net.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch
#'scheduler':scheduler.state_dict()
}
if not os.path.isdir("/home/sgyj/code/FrequecyTransformer/checkpoint"):
os.mkdir("/home/sgyj/code/FrequecyTransformer/checkpoint")
torch.save(checkpoint, '/home/sgyj/code/FrequecyTransformer/checkpoint/ckpt_best_%s.pth' % (str(epoch)))
# 每20个epoch保存一次模型
if (epoch != 0 and epoch % 20 == 0):
torch.save(net.state_dict(), "/home/sgyj/code/FrequecyTransformer/tem/FrequecyTransformer-copy_epoch_%d.pth" % (epoch))
st = time.time()
torch.save(net.state_dict(), '/home/sgyj/code/FrequecyTransformer/tem/FrequecyTransformer-copy_final.pth')
'''
# 绘图
x = np.arange(len(losses))
plt.plot(x, losses, label="train")
# plt.plot(x, losses_val, label="val")
plt.title("train losses")
plt.grid()
plt.legend()
plt.savefig("losses.jpg")
plt.clf()
plt.plot(x, precises, label="train")
# plt.plot(x, precises_val, label="val")
plt.title("train acc")
plt.grid()
plt.legend()
plt.savefig("acc.jpg")
plt.clf()
plt.plot(x, recalles, label="train")
# plt.plot(x, recalles_val, label="val")
plt.title("train recall")
plt.grid()
plt.legend()
plt.savefig("recall.jpg")
plt.clf()
plt.plot(x, f1es, label="train")
# plt.plot(x, f1es_val, label="val")
plt.title("train f1")
plt.grid()
plt.legend()
plt.savefig("f1_score.jpg")
plt.clf()
'''
if __name__ == "__main__":
train()