Xception是inception处于极端假设的一种网络结构。当卷积层试图在三维空间(两个空间维度和一个通道维度)进行卷积过程时,一个卷积核需要同时绘制跨通道相关性和空间相关性。
前面分享的inception模块的思想就是将这一卷积过程分解成一系列相互独立的操作,使其更为便捷有效。典型的inception模块假设通道相关性和空间相关性的绘制有效脱钩,而Xception的思想则是inception模块思想的一种极端情况,即卷积神经网络的特征图中的跨通道相关性和空间相关性的绘制可以完全脱钩。
Xception实现迁移学习也是基于微调的方式,和InceptionV3实现迁移学习一样,在获取基于imageNet预训练完毕的Xception模型后,用自己搭建的全连接层(包括输出层)代替xception模型的全连接层和输出层,进而得到一个新的网络模型,固定新网络模型的部分参数,使其不参与训练,基于mnist数据集训练余下未固定的参数。
代码实现:
from keras.applications.xception import Xception
from keras.datasets import mnist
from keras.utils import np_utils
from keras.layers import Dense,GlobalAveragePooling2D,Dropout,Input,UpSampling3D
from keras.models import Model
from matplotlib import pyplot as plt
import numpy as np
(X_train,Y_train),(X_test,Y_test)=mnist.load_data()
X_test1=X_test
Y_test1=Y_test
X_train=X_train.reshape(-1,28,28,1).astype("float32")/255.0
X_test=X_test.reshape(-1,28,28,1).astype("float32")/255.0
Y_test=np_utils.to_categorical(Y_test,10)
Y_train=np_utils.to_categorical(Y_train,10)
#搭建xception模型
#weight="imagenet",xcception权重使用基于imagenet训练获得的权重,include_to=false代表不包含顶层的全连接层
base_model=Xception(weights="imagenet",include_top=False)
input_xception=Input(shape=(28,28,1),dtype="float32",name="xception imput")
#对数据进行上采样,沿着数据的3个维度分别重复size[0],size[1],size[2]
x=UpSampling3D(size=(3,3,3),data_format="channels_last")(input_xception)
#将数据送入网络
x=base_model(x)
#此时模型没有全连接层,需要自己搭建全连接层
#通过GlobalAveragePooling2D对每张二维特征图进行全局平均池化,输出对应一维数值
x=GlobalAveragePooling2D()(x)
x=Dense(1024,activation="relu")(x)
x=Dropout(0.5)(x)
pre=Dense(10,activation="softmax")(x)
#调用Model,定义一个新的模型Xception_model
xception_model=Model(inputs=input_xception,outputs=pre)
#查看每一层的名称和对应的层数
for i,layer in enumerate(base_model.layers):
print(i,layer.name)
#固定base_model中前36层的参数,使其不参与训练
for layer in base_model.layers[:36]:
layer.trainable=False
#查看模型的摘要
xception_model.summary()
#编译
xception_model.compile(
loss="categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"]
)
#训练
training=xception_model.fit(
X_train,
Y_train,
epochs=5,
batch_size=64,
validation_split=0.2,
verbose=1
)
test=xception_model.evaluate(X_test,Y_test)
print("误差:",test[0])
print("准确值:",test[1])
#画出训练集和验证集的随着时期的变化曲线
def plot_history(training_history,train,validation):
plt.plot(training.history[train],linestyle="-",color="b")
plt.plot(training.history[validation],linestyle="--",color="r")
plt.title("xception_model accuracy")
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.legend(["train","validation"],loc="lower right")
plt.show()
plot_history(training,"accuracy","val_accuracy")
def plot_history1(training_history,train,validation):
plt.plot(training.history[train],linestyle="-",color="b")
plt.plot(training.history[validation],linestyle="--",color="r")
plt.title("xception_model accuracy")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.legend(["train","validation"],loc="upper right")
plt.show()
plot_history1(training,"loss","val_loss")
#预测值
prediction=xception_model.predict(X_test)
#打印图片
def plot_image(image):
fig=plt.gcf()
fig.set_size_inches(2,2)
plt.imshow(image,cmap="binary")
plt.show()
def result(i):
plot_image(X_test1[i])
print("真实值:",Y_test1[i])
print("预测值:",np.argmax(prediction[i]))
result(0)
result(1)