最近总会有人提出画图的需求,然后自己又得从网上找matplotlib画图的代码,又重新写一遍,索性自己写一个博客,自己找自己写过的资料,以后留着自己用哈哈,我的代码如下:
import numpy as np
import matplotlib.pyplot as plt
def data_read(dir_path):
with open(dir_path, "r") as f:
raw_data = f.read()
data = raw_data.split()[-1]
return np.asfarray(data, float)
if __name__ == "__main__":
train_loss_path = r"results/gmt_v1.txt"
y_train_loss = data_read(train_loss_path)
x_train_loss = range(len(y_train_loss))
plt.figure()
# 去除顶部和右边框框
ax=plt.axes()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel('Round') # x轴标签
plt.ylabel('Accuracy') # y轴标签
# 默认颜色,如果想更改颜色,可以增加参数color='red',这是红色。
plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")
plt.legend()
plt.title('Loss curve')
plt.savefig('acc_round.png')
plt.show()
有人会说,数据的格式是啥,示例数据如下:
1 2.3258324 0.1107
2 2.311135 0.1125
3 2.2763004 0.13019999999999998
4 2.2557983 0.15910000000000002
5 2.1366177 0.20629999999999998
6 2.0213234 0.2783
7 1.960824 0.33849999999999997
8 1.8013808 0.3797
9 1.7840961 0.40110000000000007
10 1.7455825 0.4301
11 1.6700351 0.4473999999999999
.......
其实就是我的一些实验数据,取的最后一列,最后一列是accuracy。把数据换成loss,就是loss曲线图了。