python matplotlib库画子图踩坑心得

看论文的时候看到一个很高大上的图
python matplotlib库画子图踩坑心得_第1张图片
看到这个图的第一感觉就是,视觉冲击力很强,虽然在很多方面表现其实不如准确率,损失等性能指标来的直接,但是这个图放到论文里还是会增添一分色彩。
下面的程序接上个博客只不过是加上了上面那个样子的画图方式

from keras import models
from keras import layers
import matplotlib.pyplot as plt
import random
import numpy as np
# 载入数据
from keras.datasets import boston_housing
(train_data, train_targets), (test_data, test_targets) = boston_housing.load_data()
print(train_data.shape)
# 数据预处理,去均值和除标准差得到平均值为0和标准差为1
mean = train_data.mean(axis=0)
train_data -= mean
std = train_data.std(axis=0)
train_data /= std
test_data -= mean
test_data /= std
# --------------------------------------------


# 建立模型子函数,三层神经网络
def build_model(dense_1, dense_2, dense_3):
    model = models.Sequential()
    model.add(layers.Dense(dense_1, activation='relu',input_shape=(train_data.shape[1],)))
    model.add(layers.Dense(dense_2, activation='relu'))
    model.add(layers.Dense(dense_3, activation='relu'))
    model.add(layers.Dense(1))
    model.compile(optimizer='rmsprop', loss='mse', metrics=['mae'])
    return model
# ---------------------------------------------

w_x = [16, 32, 64, 128, 256, 512]

A = []  # 储存每个粒子loss变化
fore_data = []  # 储存每个子网的预测数据
num_epochs = 100  # 设置迭代次数
for i in range(6):
    model = build_model(w_x[i], w_x[i], w_x[i])  # 建立模型
    history = model.fit(train_data, train_targets, epochs=num_epochs, batch_size=50, verbose=1)
    fore_d = model.predict(test_data)  # 通过predict函数输出网络的预测值
    fore_data.append(fore_d)
    # 显示网络loss
    A.append(history.history['loss'])

test_targets = test_targets.reshape((102, 1))
# 画左图
plt.axes([0.025, 0.15, 0.1, 0.7])
c = test_targets - fore_data[0]
a = plt.plot(c, range(len(c)), '-.', color='g')

plt.title('p1 loss')
plt.legend()
# 画右图
plt.axes([0.875, 0.15, 0.1, 0.7])
d = test_targets - fore_data[1]
plt.plot(d, range(len(d)), ':', color='r')
plt.title('p2 loss')
plt.legend()

plt.axes([0.15, 0.75, 0.7, 0.15])
a = test_targets - fore_data[0]
plt.stem(a)
plt.title('p3 loss')
plt.legend()

plt.axes([0.15, 0.1, 0.7, 0.15])
b = test_targets - fore_data[1]
plt.stem(b)
plt.title('p4 loss')
plt.legend()

plt.axes([0.15, 0.3, 0.7, 0.4])
plt.plot(test_targets, color='skyblue', label='real data')
plt.plot(fore_data[0], color='b', label='p1')
plt.plot(fore_data[1], color='r', label='p2')
plt.plot(fore_data[2], color='g', label='p3')
plt.plot(fore_data[3], color='c', label='p4')
# plt.plot(A[5], color='m', label='p6')
# plt.plot(A[6], color='y', label='p7')
# plt.plot(A[7], color='k', label='p8')
# plt.plot(A[8], color='peachpuff', label='p9')
# plt.plot(A[9], color='orange', label='p10')
plt.legend()
plt.show()

python matplotlib库画子图踩坑心得_第2张图片
运行效果就是上图
画上面的图踩了很多坑。

第一个坑:plot、stem、hist的使用

如果是使用plot函数,可以将横纵坐标转换,也就是将图竖起来,但是使用stem就不行。
因为进入stem代码,其实封装好的stem函数只接受两个参数的输入,所以ta不接受换成横向的或者线型换成别的,或者连变颜色都不行。

def stem(*args, **kwargs):
    ax = gca()
    # Deprecated: allow callers to override the hold state
    # by passing hold=True|False
    washold = ax._hold
    hold = kwargs.pop('hold', None)
    if hold is not None:
        ax._hold = hold
        from matplotlib.cbook import mplDeprecation
        warnings.warn("The 'hold' keyword argument is deprecated since 2.0.",
                      mplDeprecation)
    try:
        ret = ax.stem(*args, **kwargs)
    finally:
        ax._hold = washold

    return ret

然后hist函数,这个东西就不是用来干这个活的,不是说你想输出多大的数值,就能直接在图上显示,python matplotlib hist函数也就是统计学中的直方图,表示某个数值区间内有几个数。

第二个坑:怎么在一张图上显示大小不一样的子图

这个问题的解决方式应该很多,不局限于我的方法。我知道的还有一种适用plt.subplot的方法,本着先入为主的原则,只要这个方法能解决,就不再重新接受其他方法。
大家可以先运行下面的代码试一下效果

import numpy as np
import matplotlib.pyplot as plt

x1 = np.linspace(0.0, 5.0)
x2 = np.linspace(0.0, 2.0)

y1 = np.cos(2 * np.pi * x1) * np.exp(-x1)
y2 = np.cos(2 * np.pi * x2)

plt.axes([0.025, 0.15, 0.1, 0.7])  # [0.05, 0.05, 0.05, 0.9]第一个参数代表x轴的起点,
                                   # 第二个参数代表y轴的起点,第三个参数代表x轴的长度,第四个参数代表y轴的长度
plt.plot(y1, x1, 'yo-')

plt.axes([0.15, 0.1, 0.7, 0.1])
plt.plot(x2, y2, 'r.-')

plt.axes([0.15, 0.3, 0.7, 0.4])
plt.plot(x2, y2, 'r.-')

plt.axes([0.15, 0.8, 0.7, 0.1])
plt.plot(x2, y2, 'r.-')

plt.axes([0.875, 0.15, 0.1, 0.7])
plt.plot(y2, x2, 'r.-')

plt.show()

我使用的方法是plt.axes,axes的使用方法很简单,第一个参数代表的是x轴的起点,第二个参数代表的是y轴的起点,第三个参数代表x轴的长度,第四个参数代表y轴的长度。
上面代码的运行效果是
python matplotlib库画子图踩坑心得_第3张图片
从图的左下角起,是(0,0),也就是说,上面那个大图,有自己的坐标,那些子图会按照大图上的坐标确定自己的x轴,y轴起点和x轴,y轴的长度。
python matplotlib库画子图踩坑心得_第4张图片
然后,如果想把图像竖起来显示,直接将x和y输入数据对换。

其他小坑

诸如如何在一张图上显示多条线啊,横坐标间隔的修改啊,转换yticks的显示方向啊都是些很容易百度到的东西,不多说了。

你可能感兴趣的:(python,pycharm,matplotlib库,多图显示,plot函数)