tensorflow深度学习网络的feature map保存为图片

保存feature map就是对你的代码有两个地方修改下就好了。

第一个:

在你的网络搭建的地方,你想保存哪一层的网络,就在那一层后添加代码:

#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
      inputs=x,
      filters=32,
      kernel_size=[5, 5],
      padding="same",
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
tf.add_to_collection('activations', conv1)

这里的第一个参数名称你可以随便改,第二个参数就是层变脸的名字。

第二个地方:

在你的训练epoch循环中的batch循环里,添加代码:

for epoch in range(n_epoch):
    start_time = time.time()
    
    #training
    train_loss, train_acc, n_batch = 0, 0, 0
    for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
        _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
        train_loss += err; train_acc += ac; n_batch += 1
        
        visualize_layers = ['conv1']
        conv_out = sess.run(tf.get_collection('activations'), feed_dict={x: x_train_a, y_: y_train_a})
        for i, layer in enumerate(visualize_layers):
            plot_dir='/Users/zhangyiming/Desktop/cnn/'#要保存的路径
            if not os.path.exists(plot_dir+layer):#如果路径不存在,则创建文件夹
                os.mkdir(plot_dir+layer)
            for j in range(conv_out[i].shape[3]):#保存为图片
                feature_map.plot_conv_output(conv_out[i], plot_dir + layer, str(j), filters_all=False, filters=[j])

上面前一部分就是原始的训练过程,我们在训练后添加了的,就是保存了你利用tf.add_to_collection()函数添加的变量。

注意后面的代码第一行,如果你在第一部分添加了其他层,例如conv2层,那这里也要添加

        visualize_layers = ['conv1','conv2']
最后的feature_map.plot_conv_output函数:
import tensorflow as tf
import h5py
import scipy.misc
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os
import utils

def plot_conv_output(conv_img, plot_dir, name, filters_all=True, filters=[0]):
    w_min = np.min(conv_img)
    w_max = np.max(conv_img)

    # get number of convolutional filters
    if filters_all:
        num_filters = conv_img.shape[3]
        filters = range(conv_img.shape[3])
    else:
        num_filters = len(filters)

    # get number of grid rows and columns
    grid_r, grid_c = utils.get_grid_dim(num_filters)

    # create figure and axes
    fig, axes = plt.subplots(min([grid_r, grid_c]),
                             max([grid_r, grid_c]))

    # iterate filters
    if num_filters == 1:
        img = conv_img[0, :, :, filters[0]]
        axes.imshow(img, vmin=w_min, vmax=w_max, interpolation='bicubic', cmap=cm.hot)
        # remove any labels from the axes
        axes.set_xticks([])
        axes.set_yticks([])
    else:
        for l, ax in enumerate(axes.flat):
            # get a single image
            img = conv_img[0, :, :, filters[l]]
            # put it on the grid
            ax.imshow(img, vmin=w_min, vmax=w_max, interpolation='bicubic', cmap=cm.hot)
            # remove any labels from the axes
            ax.set_xticks([])
            ax.set_yticks([])
    # save figure
    print(os.path.join(plot_dir, '{}.png'.format(name)))
    plt.savefig(os.path.join(plot_dir, '{}.png'.format(name)), bbox_inches='tight')



    

其中用到的utils.py:

import math
import os
import errno
import shutil


def get_grid_dim(x):
    """
    Transforms x into product of two integers
    :param x: int
    :return: two ints
    """
    factors = prime_powers(x)
    if len(factors) % 2 == 0:
        i = int(len(factors) / 2)
        return factors[i], factors[i - 1]

    i = len(factors) // 2
    return factors[i], factors[i]


def prime_powers(n):
    """
    Compute the factors of a positive integer
    Algorithm from https://rosettacode.org/wiki/Factors_of_an_integer#Python
    :param n: int
    :return: set
    """
    factors = set()
    for x in range(1, int(math.sqrt(n)) + 1):
        if n % x == 0:
            factors.add(int(x))
            factors.add(int(n // x))
    return sorted(factors)


def empty_dir(path):
    """
    Delete all files and folders in a directory
    :param path: string, path to directory
    :return: nothing
    """
    for the_file in os.listdir(path):
        file_path = os.path.join(path, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print ('Warning: {}'.format(e))


def create_dir(path):
    """
    Creates a directory
    :param path: string
    :return: nothing
    """
    try:
        os.makedirs(path)
    except OSError as exc:
        if exc.errno != errno.EEXIST:
            raise


def prepare_dir(path, empty=False):
    """
    Creates a directory if it soes not exist
    :param path: string, path to desired directory
    :param empty: boolean, delete all directory content if it exists
    :return: nothing
    """
    if not os.path.exists(path):
        create_dir(path)

    if empty:
        empty_dir(path)


你可能感兴趣的:(tensorflow深度学习网络的feature map保存为图片)