1、导入相关python包
import numpy as np
import gzip
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import tqdm
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
2、数据读取与显示
def data_read(filename1,filename2,num_images):
IMAGE_SIZE = 28
NUM_CHANNELS = 1
PIXEL_DEPTH = 255
#加载图像:
with gzip.open(filename1) as bytestream:
#每个像素存储在文件中的大小为16bits
bytestream.read(16)
buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS)
#.astype(np.float32)方法修改数组数据类型
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
#像素值[0, 255]被调整到[-0.5, 0.5]
data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH
#调整为4维张量[image index, y, x, channels]
data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)
#加载标签:
with gzip.open(filename2) as bytestream:
#每个标签存储在文件中的大小为8bits
bytestream.read(8)
buf = bytestream.read(1 * num_images)
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
return data,labels
train_num = 60000
test_num = 10000
batch_size = 100
train_data,train_labels = data_read('./minist_data/train-images-idx3-ubyte.gz','./minist_data/train-labels-idx1-ubyte.gz',60000)
test_data,test_labels = data_read('./minist_data/t10k-images-idx3-ubyte.gz','./minist_data/t10k-labels-idx1-ubyte.gz',10000)
print('训练集参数:',train_data.shape,train_labels.shape)
print('测试集参数:',test_data.shape,test_labels.shape)
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(np.squeeze(train_data[i]), cmap=plt.cm.binary)
plt.xlabel(train_labels[i])
plt.show()
3、搭建网络,tf.get_variable用于获取或者创建变量,当with tf.variable_scope(name,reuse)中reuse=False时为在命名空间创建变量,如果该变量已经存在则会报错,为True时将会在命名空间中获取变量,如果没有获取到会报错。
class minist_inference(object):
def __init__(self,input_image,trainable,regularizer=None):
self.num_class = 10
self.output = self.build_network(input_image,trainable,regularizer)
def convolution(self,name,input_tensor,filter_shape,strides,padding,activate=True,bn=True):
with tf.variable_scope(name):
weights = tf.get_variable("weight", filter_shape,initializer=tf.truncated_normal_initializer(stddev=0.1))
input_tensor = tf.nn.conv2d(input_tensor, weights, strides, padding)
if bn == True:
input_tensor = tf.layers.batch_normalization(input_tensor, beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(),
moving_mean_initializer=tf.zeros_initializer(),
moving_variance_initializer=tf.ones_initializer(), training=True)
else:
biases = tf.get_variable("bias", filter_shape[-1], initializer=tf.constant_initializer(0.0))
input_tensor = tf.nn.bias_add(input_tensor, biases)
if activate == True:input_tensor = tf.nn.relu(input_tensor)
print(input_tensor.name,input_tensor.shape)
return input_tensor
def pool(self,name,input_tensor,ksize,strides,padding,flatten=False):
with tf.name_scope(name):
input_tensor = tf.nn.max_pool(input_tensor, ksize, strides, padding)
print(input_tensor.name,input_tensor.shape)
return input_tensor
def flatten(self,name,input_tensor):
with tf.name_scope(name):
input_tensor_shape = input_tensor.get_shape().as_list()
nodes = input_tensor_shape[1] * input_tensor_shape[2] * input_tensor_shape[3]
input_tensor = tf.reshape(input_tensor, [-1, nodes])
print(input_tensor.name,input_tensor.shape)
return input_tensor
def full_connect(self,name,input_tensor,size,regularizer,trainable,softmax):
with tf.variable_scope(name):
weights = tf.get_variable("weight", [input_tensor.get_shape().as_list()[-1], size],
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None: tf.add_to_collection('losses', regularizer(weights))
biases = tf.get_variable("bias", [size], initializer=tf.constant_initializer(0.1))
input_tensor = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
if trainable == True: input_tensor = tf.nn.dropout(input_tensor, 0.5)
if softmax == True: input_tensor = tf.nn.softmax(input_tensor)
print(input_tensor.name,input_tensor.shape)
return input_tensor
def build_network(self,input_image,trainable,regularizer=None):
print(input_image.name)
input_image = self.convolution('conv1',input_image,[3,3,1,32],[1,1,1,1],'SAME')
input_image = self.pool('pool1',input_image,[1,2,2,1],[1,2,2,1],'VALID')
input_image = self.convolution('conv2',input_image,[3,3,32,64],[1,1,1,1],'SAME')
input_image = self.pool('pool2',input_image,[1,2,2,1],[1,2,2,1],'VALID',True)
input_image = self.flatten('flatten1',input_image)
input_image = self.full_connect('fc1',input_image,512,regularizer,trainable,False)
input_image = self.full_connect('fc2',input_image,10,regularizer,False,True)
return input_image
4、数据准备,定义具有迭代器方法的类,注意在__init__函数中,一定要将数据随机打乱,不然网络很难训练。
class dataset(object):
def __init__(self,datas,labels,batch_size):
ix = np.arange(datas.shape[0])
np.random.shuffle(ix)
self.labels = train_labels.take(ix,0)
self.datas = train_data.take(ix,0)
self.batch_size = batch_size
self.batch_num = int(np.ceil(datas.shape[0] / batch_size))
self.res = datas.shape[0] % batch_size
self.batch_count = 0
def __iter__(self):
return self
def __next__(self):
with tf.device('/cpu:0'):
if(self.res!=0):
batch_image = self.datas[self.datas.shape[0]-self.res:,:]
batch_label = self.labels[self.labels.shape[0]-self.res:]
self.res = 0
return batch_image,batch_label
else:
if self.batch_count < self.batch_num:
batch_image = self.datas[self.batch_count*self.batch_size:(1+self.batch_count)*self.batch_size,:]
batch_label = self.labels[self.batch_count*self.batch_size:(1+self.batch_count)*self.batch_size]
self.batch_count += 1
return batch_image,batch_label
else:
self.batch_count = 0
raise StopIteration
5、模型训练与保存
#重置默认的计算图
tf.reset_default_graph()
#设定占位符
x = tf.placeholder(tf.float32, [None,28,28,1])
training = tf.placeholder(tf.bool)
y_ = tf.placeholder(tf.int64, [None])
#设定正则化函数
regularizer = tf.contrib.layers.l2_regularizer(0.001)
#前向计算
model = minist_inference(x,training,regularizer)
y = model.output
#定义损失函数:交叉熵损失与正则化损失
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=y_)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
#定义优化器
Optimizer = tf.train.AdamOptimizer().minimize(loss)
#计算网络输出精度
accuracy = tf.equal(tf.argmax(y,1),y_)
accuracy = tf.reduce_mean(tf.cast(accuracy,tf.float32))
#定义网络保存类
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(5):
aver_train_acc = []
aver_train_loss = []
train_dataset = dataset(train_data,train_labels,batch_size)
pbar = tqdm(train_dataset)
for traindata in pbar:
_, loss_value, train_acc = sess.run([Optimizer, loss, accuracy], feed_dict={x: traindata[0], training: False, y_: traindata[1]})
aver_train_acc.append(train_acc)
aver_train_loss.append(loss_value)
test_acc = sess.run(accuracy, feed_dict={x: test_data, training: False, y_: test_labels})
print('\n the iter is {},the aver-loss is {},the aver-train_acc is {},the test_acc is {}'
.format(i,np.mean(aver_train_loss),np.mean(aver_train_acc),test_acc))
saver.save(sess,'./minist.ckpt')
6、模型加载与测试
def plot_image(i, predictions_array, true_label, img):
predictions_array, true_label, img = predictions_array, true_label[i], img[i]
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(np.squeeze(img), cmap=plt.cm.binary)
predicted_label = np.argmax(predictions_array)
if predicted_label == true_label:
color = 'blue'
else:
color = 'red'
plt.xlabel("{} {:2.0f}% ({})".format(predicted_label,
100*np.max(predictions_array),
true_label),
color=color)
def plot_value_array(i, predictions_array, true_label):
predictions_array, true_label = predictions_array, true_label[i]
plt.grid(False)
plt.xticks(range(10))
plt.yticks([])
thisplot = plt.bar(range(10), predictions_array, color="#777777")
plt.ylim([0, 1])
predicted_label = np.argmax(predictions_array)
thisplot[predicted_label].set_color('red')
thisplot[true_label].set_color('blue')
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None,28,28,1])
model = minist_inference(x,trainable=False,regularizer=None)
y = model.output
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,'./minist.ckpt')
test_out = sess.run(y,feed_dict={x:test_data})
num_rows = 7
num_cols = 7
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
plt.subplot(num_rows, 2*num_cols, 2*i+1)
plot_image(i, test_out[i], test_labels, test_data)
plt.subplot(num_rows, 2*num_cols, 2*i+2)
plot_value_array(i, test_out[i], test_labels)
plt.tight_layout()
plt.show()
7、查看ckpt文件中的网络信息——网络参数与网络节点变量
reader=pywrap_tensorflow.NewCheckpointReader('./minist.ckpt')
var_to_shape_map=reader.get_variable_to_shape_map()
param_dict={}
for key in var_to_shape_map:
print('param_name',key)
ckpt_data=np.array(reader.get_tensor(key))
param_dict[key] = ckpt_data
meta_graph = tf.train.import_meta_graph("./minist.ckpt.meta")
with tf.Session()as sess:
meta_graph.restore(sess,"./minist.ckpt")
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print('tensor_name',tensor_name)
8、保存为pb文件
#重置默认的计算图
tf.reset_default_graph()
#导入计算图
saver = tf.train.import_meta_graph('./minist.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,'./minist.ckpt')
graph_def = tf.get_default_graph().as_graph_def()
#注意['Placeholder','fc2/Softmax']为输入节点与输出节点
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Placeholder','fc2/Softmax'])
with tf.gfile.GFile("./minst.pb", "wb") as f:
f.write(output_graph_def.SerializeToString())
9、利用pb模型进行测试
with tf.Session() as sess:
model_filename = "./minst.pb"
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
#result[0]对应输入占位符,result[1]对应输出结果
result = tf.import_graph_def(graph_def, return_elements=["Placeholder:0","fc2/Softmax:0"])
test_out = sess.run(result[1],feed_dict={result[0]:test_data})
num_rows = 7
num_cols = 7
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
plt.subplot(num_rows, 2*num_cols, 2*i+1)
plot_image(i, test_out[i], test_labels, test_data)
plt.subplot(num_rows, 2*num_cols, 2*i+2)
plot_value_array(i, test_out[i], test_labels)
plt.tight_layout()
plt.show()