卷积层特征的可视化

前言与背景:

很多论文都提到可视化卷积层的特征图,如这篇SCI二区的IEEE论文,最早提出可视化卷积层工作的好像是2014年的一篇顶会。

卷积层特征的可视化_第1张图片卷积层特征的可视化_第2张图片

这篇论文有一部分的工作就是可视化卷积层的输出,它选择了可视化第一层卷积层和最后一层池化层,认为这样有助于理解与解释卷积神经网络。一般观点都认为卷积神经网络是先提取低层次特征(例如点、线),再提前高层次的抽象特征。

实现方法:

我用的深度学习框架是TensorFlow,关键的函数是tf.transpose,实现反卷积,具体如下:

注-这段代码是在模型已经训练完成的后面进行

            # imput image
            fig2, ax2 = plt.subplots(figsize=(299, 299))
            ax2.imshow(np.reshape(training_images[1], (299, 299, 3)))
            plt.xticks([])
            plt.yticks([])
            plt.show()

            # 第一层的卷积输出的特征图
            input_image = training_images[1:2]
            conv1_32 = sess.run(net_1, feed_dict={images: input_image})  # 第一层有32个卷积核,valid填充,149*149
            conv1_transpose = sess.run(tf.transpose(conv1_32, [3, 0, 1, 2]))  # 3,0,1,2是把格式变为[channels,batch, height, width]
            print(conv1_transpose.shape)  # (32, 1, 149, 149)
            fig3, ax3 = plt.subplots(nrows=8, ncols=4, figsize=(40, 80))
            conv1_index = 0
            for i in range(8):
                for k in range(4):
                    ax3[i][k].imshow(conv1_transpose[conv1_index][0])  # tensor的切片[row, column]
                    ax3[i][k].set_xticks([])
                    ax3[i][k].set_yticks([])
                    conv1_index = conv1_index + 1

            plt.subplots_adjust(wspace=0.1, hspace=0.1)
            save_file = RESULT_PATH + 'conv1.jpg'
            plt.savefig(save_file, format='jpg')

我是以inception_v4模型为演示模型,net_1是第一个卷积层的结果,如下图的深色处

卷积层特征的可视化_第3张图片

实验结果:

输入图像:这原本是1个70*70的大脑白质连接矩阵,通过插值算法变成了299*299,是为了满足实验模型的迁移学习要求。

卷积层特征的可视化_第4张图片

第一个卷积层的特征图,32个(因为有32个卷积核)                           第四个卷积层(并行),96+64=160

卷积层特征的可视化_第5张图片                               卷积层特征的可视化_第6张图片

在前面的卷积层还可以看出一些连接矩阵的模样,在后期就看不出来了

21卷积层(并行),96+96+96+96=384个

卷积层特征的可视化_第7张图片

 

一些小bug

(1)一开始我是定义net_1为全局变量,在代码的开头写了global net_1, 可是在主函数里反卷积还是找不到net_1这个变量,于是我改成了返回的方法,即让inception_v4_base这个函数多返回出net_1,就可以了。

(2)一开始我测试的时候是,犯了一个白痴问题,需要写成ax3[0][0]。

(3)最关键的问题,GUI的显示为了适应窗口大小,显示的图片不准确!需要保存出来,怪不得我之前怎么都无法消除子图的间隔。

例如这是在matplotlib的GUI界面里显示是左图,然后直接保存出来后就如右图

左图:卷积层特征的可视化_第8张图片                                     右图:      卷积层特征的可视化_第9张图片

 

你可能感兴趣的:(深度学习)