本文将讲解如何可视化caffe网络中的层的参数及数据,即只要输入的规格为(n, height, width)或(n, height, width, 3)都可以通过如下函数可视化。
def vis_square(data):
“”“Take an array of shape (n, height, width) or (n, height, width, 3)
and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)”“”
# normalize data for display
data = (data - data.min()) / (data.max() - data.min())
# force the number of filters to be square
n = int(np.ceil(np.sqrt(data.shape[0])))
padding = (((0, n ** 2 - data.shape[0]),
(0, 1), (0, 1)) # add some space between filters
+ ((0, 0),) * (data.ndim - 3)) # don't pad the last dimension (if there is one)
data = np.pad(data, padding, mode='constant', constant_values=1) # pad with ones (white)
# tile the filters into an image
data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
plt.imshow(data); plt.axis('off')
filters = net.params['conv1'][0].data
vis_square(filters.transpose(0, 2, 3, 1))`
这是官方给的例子,我们假设net.params[‘conv1’][0].data.shape为(96,3,11,11)。
filters.transpose(0, 2, 3, 1) #此时shape为(96,11,11,3)
data = (data - data.min()) / (data.max() - data.min()) #预处理,减去平均值,并除以值的变化范围
n = int(np.ceil(np.sqrt(data.shape[0]))) #n为10
padding = (((0, n * 2 - data.shape[0]),(0, 1), (0, 1))+ ((0, 0),) (data.ndim - 3)) #pad为填充函数,第一维扩展为100,第二第三维分别加1,这是为了留出图之间的空隙,第四维不变
data = np.pad(data, padding, mode=’constant’, constant_values=1) # 开始进行扩展,扩展的值为1,即全为白色,此时shape为(100,12,12,3)
data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1))) #shape由(100,12,12,3)变为(10,12,10,12,3)
data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:]) #shape由(10,12,10,12,3)变为(120,120,3)
plt.imshow(data); plt.axis(‘off’) #显示图像
接下来是全连接层的可视化。
feat = net.blobs[‘fc6’].data[0] #获取全连接层处理前的数据
plt.subplot(2, 1, 1) #指定将要绘制的图像的位置,在这里是两行一列中的第一行第一列的位置。
plt.plot(feat.flat) #绘制图像,feat.flat是numpy中的迭代器,如下图1
plt.subplot(2, 1, 2)
plt.hist(feat.flat[feat.flat > 0], bins=100) #绘制统计直方图,如下图2,统计每个区间内数据的个数