跟踪并保存训练过程中的损失函数
以CornerNet为例,他的源码并没有使用损失函数可视化的功能,有时候需要查看损失函数的变化趋势来确定超参。那么此时就需要手动去记录并储存损失函数值
在train.py中,关于损失函数的部分如下所示:
with stdout_to_tqdm() as save_stdout:
for iteration in tqdm(range(start_iter + 1, max_iteration + 1), file=save_stdout, ncols=80):
training = pinned_training_queue.get(block=True)
training_loss = nnet.train(**training)
if display and iteration % display == 0:
print("training loss at iteration {}: {}".format(iteration, training_loss.item()))
del training_loss
# if val_iter and validation_db.db_inds.size and iteration % val_iter == 0:
# nnet.eval_mode()
# validation = pinned_validation_queue.get(block=True)
# validation_loss = nnet.validate(**validation)
# print("validation loss at iteration {}: {}".format(iteration, validation_loss.item()))
# nnet.train_mode()
if iteration % snapshot == 0:
nnet.save_params(iteration)
if iteration % stepsize == 0:
learning_rate /= decay_rate
nnet.set_lr(learning_rate)
我们需要在其中添加如下的代码:
loss = training_loss.cpu()
loss_ = str(loss.data.numpy())
with open('./loss.txt', 'a') as f:
f.write(str(iteration))
f.write(' ')
f.write(loss_)
if iteration < max_iteration:
f.write(' \r\n')
添加后如下所示:
with stdout_to_tqdm() as save_stdout:
for iteration in tqdm(range(start_iter + 1, max_iteration + 1), file=save_stdout, ncols=80):
training = pinned_training_queue.get(block=True)
training_loss = nnet.train(**training)
loss = training_loss.cpu()
loss_ = str(loss.data.numpy())
with open('./loss.txt', 'a') as f:
f.write(str(iteration))
f.write(' ')
f.write(loss_)
if iteration < max_iteration:
f.write(' \r\n')
if display and iteration % display == 0:
print("training loss at iteration {}: {}".format(iteration, training_loss.item()))
del training_loss
# if val_iter and validation_db.db_inds.size and iteration % val_iter == 0:
# nnet.eval_mode()
# validation = pinned_validation_queue.get(block=True)
# validation_loss = nnet.validate(**validation)
# print("validation loss at iteration {}: {}".format(iteration, validation_loss.item()))
# nnet.train_mode()
if iteration % snapshot == 0:
nnet.save_params(iteration)
if iteration % stepsize == 0:
learning_rate /= decay_rate
nnet.set_lr(learning_rate)
解释一下代码:
由于深度学习中loss的计算值都是储存在cuda中的variable变量,是一种特殊的变量,主要是用于后续backpropogation自动计算grad的一种变量,所以要出存下来首先应该把cuda中的变量提取到cpu中,使用.cpu()函数
而后需要把variable变量变成一个普通的tensor变量,使用.data函数
然后把tensor变量转为numpy变量,使用.numpy()
由于f.write()中只能是string类型的变量,所以要使用str()函数
参照博客:Python读写txt文本文件, 由于open中的参数为’w‘的时候,是抹去之前的内容重新写人,会导致最后只有最后一次训练的数据,因此需要使用参数a
这里推荐使用txt文件,方便后续处理
绘制损失函数曲线
直接上代码:
"""
Note: The code is used to show the change trende via the whole training procession.
First: You need to mark all the loss of every iteration
Second: You need to write these data into a txt file with the format like:
......
iter loss
iter loss
......
Third: the path is the txt file path of your loss
"""
import matplotlib.pyplot as plt
def read_txt(path):
with open(path, 'r') as f:
lines = f.readlines()
splitlines = [x.strip().split(' ') for x in lines]
return splitlines
# Referenced from Tensorboard(a smooth_loss function:https://blog.csdn.net/charel_chen/article/details/80364841)
def smooth_loss(path, weight=0.85):
iter = []
loss = []
data = read_txt(path)
for value in data:
iter.append(int(value[0]))
loss.append(int(float(value[1])))
# Note a str like '3.552' can not be changed to int type directly
# You need to change it to float first, can then you can change the float type ton int type
last = loss[0]
smoothed = []
for point in loss:
smoothed_val = last * weight + (1 - weight) * point
smoothed.append(smoothed_val)
last = smoothed_val
return iter, smoothed
if __name__ == "__main__":
path = './loss.txt'
loss = []
iter = []
iter, loss = smooth_loss(path)
plt.plot(iter, loss, linewidth=2)
plt.title("Loss-iters", fontsize=24)
plt.xlabel("iters", fontsize=14)
plt.ylabel("loss", fontsize=14)
plt.tick_params(axis='both', labelsize=14)
plt.savefig('./loss_func.png')
plt.show()
这里主要借鉴了Tensorboard中的计算方法:tensorboard 平滑损失曲线代码
这里需要普及一下基础知识:当batch_size比较小的时候,损失函数特别波动,此时需要有一种计算方式来削弱这种波动,来显示总体的变化趋势.Tensorboard中采用的算法就是函数smooth_loss所示.
参考资料
https://blog.csdn.net/charel_chen/article/details/80364841
https://www.cnblogs.com/hackpig/p/8215786.html