关于CNN基本概念知识,建议先阅读以下大神链接的讲解:
《吊炸天的CNNs,这是我见过最详尽的图解!(上)》
《吊炸天的CNNs,这是我见过最详尽的图解!(下)》
借用上面链接中的cnn结构图:
下面用keras框架库获取minst 分类的中间特征图:
预备工作,获取minst数据:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)
第一步 查看前20张minst原始图, size:(28, 28):
origin_0-10 imageX_Show = X_test.reshape(X_test.shape[0], 28, 28)
X_Show *= 255
print "img shape:{}".format(X_Show[0].shape)
test_img = X_Show[0]
for item in X_Show[1:20]:
test_img = np.append(test_img, item, axis=1)
cv2.imshow("test1", test_img)
第二步 查看经过第一次卷积后的特征图各取前32个特征图, size:(26, 26):
layer0后的feature mapslayer2后的feature maps相关代码:
layer_1 = K.function([model.layers[0].input], [model.layers[0].output])#con out 26*26 layer_2 = K.function([model.layers[1].input], [model.layers[1].output])#active out 26*26 layer_3 = K.function([model.layers[2].input], [model.layers[2].output])#con out 24*24 layer_4 = K.function([model.layers[3].input], [model.layers[3].output])#active out 24*24 layer_5 = K.function([model.layers[4].input], [model.layers[4].output])#pooling out 12*12 layer_6 = K.function([model.layers[5].input], [model.layers[5].output])#platten layer_7 = K.function([model.layers[6].input], [model.layers[6].output])#fcnn layer_8 = K.function([model.layers[7].input], [model.layers[7].output])#drop layer_9 = K.function([model.layers[8].input], [model.layers[8].output])#fcnn # layer_5 = K.function([model.layers[4].input], [model.layers[5].output]) f1 = layer_1([X_test[0:20]])[0] f2 = layer_2([f1])[0] f3 = layer_3([f2])[0] f4 = layer_4([f3])[0] f5 = layer_5([f4])[0] f6 = layer_6([f5])[0] f7 = layer_7([f6])[0] f8 = layer_8([f7])[0] f9 = layer_9([f8])[0]#f1 = f3 f1 *= 255 test_img_total = [] for v in range(20): test_img2 = [] for _ in range(32): item = f1[v][:, :, _] if len(test_img2) == 0: test_img2 = item else: test_img2 = np.append(test_img2, item, axis=1) if len(test_img_total) == 0: test_img_total = test_img2 else: test_img_total = np.append(test_img_total, test_img2, axis=0) cv2.imwrite("layer1.png", test_img_total) cv2.imshow("test2", test_img_total) 第三步 打开#f1 = f3,可以看到layer2卷积后的feature map, size:(24, 24):
以此类推:layer4后的feature maps,因为这层是2*2 polling,所以size缩小了一倍(12, 12):
layer4后的feature maps
layer5后的特征图被摊平了,size:12*12*32 = 4608,12*12是feature map大小, 32为depth, 即feature map的个数(层节点数)
layer6是全连接,128个节点,所以输出size大小就是128,如果是N张图片一起训练或者推断的话这层输出大小就是n*128
这里也可以以图片形式输出来,但是已经看不出数字的特征了,以下是N=20的输出:
20张图片218个节点全链接后大小为20*128
layer7是dropout, size不变
layer8是全连接,10个节点,所以输出size大小就是10,如果是N张图片一起训练或者推断的话这层输出大小就是N*10,这里的输出已经是推断结果了:
20张图片推断结果数据图
大家可以看出,每一行白点代表一张图片的推断结果,第一行白色点位置在0-9的位置为7,所以第一张图片识别为7,20行白色点位置分别为:7、2、1、0、4、1、4、9、5、9、0、6、9、0、1、5、9、7、8、4跟原始图可以做对比:
输入图
其实倒数第二个数字识别错了,为看了下label标注的为3,被识别为8, 情有可原吧。
这里很多人会想,为什么白点就代表推断数字呢,因为标注y中图片分类7的表示是通过[0,0,0,0,0,0,0,1,0,0],最后一层输出的结果其实都是0-1范围,可以理解为这张图片是数字0-9的概率,最后把像素*255了,便于观看结果和原始图的匹配关系。
完整代码:
import numpy as np
from keras import backend as K
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Flatten, Dense, Dropout
from keras.utils import np_utils
from keras.datasets import mnist
from matplotlib import pyplot as plt
import cv2
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(28, 28, 1)))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.6))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
if False:
model.fit(X_train, Y_train,
batch_size=32, nb_epoch=1, validation_split=0.3)
score = model.evaluate(X_test, Y_test)
print(score)
model.save_weights("mnist_wight.wf")
else:
model.load_weights("mnist_wight.wf")
from keras.utils import plot_model
plot_model(model, to_file='model.png')
# serialize model to JSON
# model_json = model.to_json()
yaml_string = model.to_yaml()
with open("model.yaml", "w") as file:
file.write(yaml_string)
X_Show = X_test.reshape(X_test.shape[0], 28, 28)
X_Show *= 255
print "img shape:{}".format(X_Show[0].shape)
test_img = X_Show[0]
for item in X_Show[1:20]:
test_img = np.append(test_img, item, axis=1)
cv2.imwrite("origin.png", test_img)
cv2.imshow("test1", test_img)
ret = model.predict(X_test[0:20])
layer_1 = K.function([model.layers[0].input], [model.layers[0].output])#con out 26*26
layer_2 = K.function([model.layers[1].input], [model.layers[1].output])#active out 26*26
layer_3 = K.function([model.layers[2].input], [model.layers[2].output])#con out 24*24
layer_4 = K.function([model.layers[3].input], [model.layers[3].output])#active out 24*24
layer_5 = K.function([model.layers[4].input], [model.layers[4].output])#pooling out 12*12
layer_6 = K.function([model.layers[5].input], [model.layers[5].output])#platten
layer_7 = K.function([model.layers[6].input], [model.layers[6].output])#fcnn
layer_8 = K.function([model.layers[7].input], [model.layers[7].output])#drop
layer_9 = K.function([model.layers[8].input], [model.layers[8].output])#fcnn
# layer_5 = K.function([model.layers[4].input], [model.layers[5].output])
f1 = layer_1([X_test[0:20]])[0]
f2 = layer_2([f1])[0]
f3 = layer_3([f2])[0]
f4 = layer_4([f3])[0]
f5 = layer_5([f4])[0]
f6 = layer_6([f5])[0]
f7 = layer_7([f6])[0]
f8 = layer_8([f7])[0]
f9 = layer_9([f8])[0]
print "f1.size:{}".format(len(f1))
print "f1.item.shape:{}".format(f1[0].shape)
print "f2.size:{}".format(len(f2))
print "f2.item.shape:{}".format(f2[0].shape)
print "f3.size:{}".format(len(f3))
print "f3.item.shape:{}".format(f3[0].shape)
print "f4.size:{}".format(len(f4))
print "f4.item.shape:{}".format(f4[0].shape)
print "f5.size:{}".format(len(f5))
print "f5.item.shape:{}".format(f5[0].shape)
print "f6.size:{}".format(len(f6))
print "f6.item.shape:{}".format(f6.shape)
print "f7.size:{}".format(len(f7))
print "f7.item.shape:{}".format(f7.shape)
print "f8.size:{}".format(len(f8))
print "f8.item.shape:{}".format(f8.shape)
print "f9.size:{}".format(len(f9))
print "f9.item.shape:{}".format(f9.shape)
print "f9:{}".format(np.around(f9, decimals=3))
cv2.imwrite("layer_f7.png", f7*255)
cv2.imwrite("layer_f9.png", f9*255)
f1 = f5
f1 *= 255.
test_img_total = []
for v in range(20):
test_img2 = []
for _ in range(32):
item = f1[v][:, :, _]
if len(test_img2) == 0:
test_img2 = item
else:
test_img2 = np.append(test_img2, item, axis=1)
if len(test_img_total) == 0:
test_img_total = test_img2
else:
test_img_total = np.append(test_img_total, test_img2, axis=0)
cv2.imwrite("layer5.png", test_img_total)
cv2.imshow("test2", test_img_total)
label_true = Y_test[0:20]
print "ret:{}".format(np.around(ret, decimals=3))
print "label:{}".format(label_true)
print "label_diff:{}".format(label_true!= ret)
cv2.waitKey(0)
# raw_input('Press Enter to exit...')