目的:统计神经网络模型的参数分布情况
使用seaborn
给我们提供的distplot
函数来绘制,即调用sns.distplot()
,并传入相关参数即可,这里我们用np.random.normal()
函数来生成10000
个均值为0
,方差为1
的数据,并将其传入sns.distplot()
from matplotlib import pyplot as plt
from scipy.stats import norm
import seaborn as sns
import numpy as np
w = np.random.normal(0, 1, 10000)
with plt.style.context(['science', 'no-latex']): # 如果没有安装SciencePlot库,可以把这行去掉
sns.distplot(w, bins=100, fit=norm)
plt.title("honest parameters")
plt.savefig('./distribution.svg', format='svg', dpi=300)
plt.show()
可以看到已经成功画出来啦!
这时候如果我们想对比诚实模型和恶意模型的参数分布,并给其添加对应的标签,label="honest"
, label="attacker"
,代码如下:
from matplotlib import pyplot as plt
from scipy.stats import norm
import seaborn as sns
import numpy as np
w = np.random.normal(0, 1, 10000)
w_bad = np.random.normal(1, 2, 10000)
with plt.style.context(['science', 'no-latex']): # 如果没有安装SciencePlot库,可以把这行去掉
sns.distplot(w, bins=100, fit=norm, label="honest")
sns.distplot(w_bad, bins=100, fit=norm, label="attacker")
plt.savefig('./distribution.svg', format='svg', dpi=300)
plt.show()
可以看到诚实模型和恶意模型的参数分布已经绘制出来,但似乎设置的标签不起作用label="honest"
, label="attacker"
,在图中无法显示。
解决方法:在plt.show()
之前调用plt.legend()
即可解决此问题
from matplotlib import pyplot as plt
from scipy.stats import norm
import seaborn as sns
import numpy as np
w = np.random.normal(0, 1, 10000)
w_bad = np.random.normal(1, 2, 10000)
with plt.style.context(['science', 'no-latex']): # 如果没有安装SciencePlot库,可以把这行去掉
sns.distplot(w, bins=100, fit=norm, label="honest")
sns.distplot(w_bad, bins=100, fit=norm, label="attacker")
plt.savefig('./distribution.svg', format='svg', dpi=300)
plt.legend() # 这行没有则无法显示标签
plt.show()