保存feature map就是对你的代码有两个地方修改下就好了。
第一个:
在你的网络搭建的地方,你想保存哪一层的网络,就在那一层后添加代码:
#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
inputs=x,
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
tf.add_to_collection('activations', conv1)
这里的第一个参数名称你可以随便改,第二个参数就是层变脸的名字。
第二个地方:
在你的训练epoch循环中的batch循环里,添加代码:
for epoch in range(n_epoch):
start_time = time.time()
#training
train_loss, train_acc, n_batch = 0, 0, 0
for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
_,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
train_loss += err; train_acc += ac; n_batch += 1
visualize_layers = ['conv1']
conv_out = sess.run(tf.get_collection('activations'), feed_dict={x: x_train_a, y_: y_train_a})
for i, layer in enumerate(visualize_layers):
plot_dir='/Users/zhangyiming/Desktop/cnn/'#要保存的路径
if not os.path.exists(plot_dir+layer):#如果路径不存在,则创建文件夹
os.mkdir(plot_dir+layer)
for j in range(conv_out[i].shape[3]):#保存为图片
feature_map.plot_conv_output(conv_out[i], plot_dir + layer, str(j), filters_all=False, filters=[j])
上面前一部分就是原始的训练过程,我们在训练后添加了的,就是保存了你利用tf.add_to_collection()函数添加的变量。
注意后面的代码第一行,如果你在第一部分添加了其他层,例如conv2层,那这里也要添加
visualize_layers = ['conv1','conv2']
import tensorflow as tf
import h5py
import scipy.misc
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os
import utils
def plot_conv_output(conv_img, plot_dir, name, filters_all=True, filters=[0]):
w_min = np.min(conv_img)
w_max = np.max(conv_img)
# get number of convolutional filters
if filters_all:
num_filters = conv_img.shape[3]
filters = range(conv_img.shape[3])
else:
num_filters = len(filters)
# get number of grid rows and columns
grid_r, grid_c = utils.get_grid_dim(num_filters)
# create figure and axes
fig, axes = plt.subplots(min([grid_r, grid_c]),
max([grid_r, grid_c]))
# iterate filters
if num_filters == 1:
img = conv_img[0, :, :, filters[0]]
axes.imshow(img, vmin=w_min, vmax=w_max, interpolation='bicubic', cmap=cm.hot)
# remove any labels from the axes
axes.set_xticks([])
axes.set_yticks([])
else:
for l, ax in enumerate(axes.flat):
# get a single image
img = conv_img[0, :, :, filters[l]]
# put it on the grid
ax.imshow(img, vmin=w_min, vmax=w_max, interpolation='bicubic', cmap=cm.hot)
# remove any labels from the axes
ax.set_xticks([])
ax.set_yticks([])
# save figure
print(os.path.join(plot_dir, '{}.png'.format(name)))
plt.savefig(os.path.join(plot_dir, '{}.png'.format(name)), bbox_inches='tight')
其中用到的utils.py:
import math
import os
import errno
import shutil
def get_grid_dim(x):
"""
Transforms x into product of two integers
:param x: int
:return: two ints
"""
factors = prime_powers(x)
if len(factors) % 2 == 0:
i = int(len(factors) / 2)
return factors[i], factors[i - 1]
i = len(factors) // 2
return factors[i], factors[i]
def prime_powers(n):
"""
Compute the factors of a positive integer
Algorithm from https://rosettacode.org/wiki/Factors_of_an_integer#Python
:param n: int
:return: set
"""
factors = set()
for x in range(1, int(math.sqrt(n)) + 1):
if n % x == 0:
factors.add(int(x))
factors.add(int(n // x))
return sorted(factors)
def empty_dir(path):
"""
Delete all files and folders in a directory
:param path: string, path to directory
:return: nothing
"""
for the_file in os.listdir(path):
file_path = os.path.join(path, the_file)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print ('Warning: {}'.format(e))
def create_dir(path):
"""
Creates a directory
:param path: string
:return: nothing
"""
try:
os.makedirs(path)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
def prepare_dir(path, empty=False):
"""
Creates a directory if it soes not exist
:param path: string, path to desired directory
:param empty: boolean, delete all directory content if it exists
:return: nothing
"""
if not os.path.exists(path):
create_dir(path)
if empty:
empty_dir(path)