CIFAR-10共有60000张图片,共有10个分类,每个分类大概6000
张,训练集共有50000张图片,测试集共有10000张图片,训练集每个分类约有
5000张图片,测试集每个分类约有1000张图片,训练集分5个批次,每个批次约
有10000张图片,测试集只有一个批次,该批次有10000张图片
import tensorflow as tf
import problem_unittests as tests
import helper #这是CIFAR10使用的帮助文件,见下
import numpy as np
import matplotlib.pyplot as plt #导入画笔
pickle模块是能够让我们直接在文件中存储几乎任何Python对象的高级工具
import pickle
探索数据集
cifar10_dataset_folder_path =\
"./cifar-10-python/cifar-10-batches-py"
tests.test_folder_path(cifar10_dataset_folder_path)
batch_id = 2 #批次
sample_id = 20 #批次中样本id
helper.display_stats(cifar10_dataset_folder_path,batch_id,sample_id)
标准化
def normalize(x):
"""
:param x:图片数据,图片的shape=(32,32,3)
:return: 返回归一化像素后的图片(像素取值范围:[0,1])(数组)
归一化处理图片数据,将其缩放到[0,1],shape不变,仍未(32,32,3)
"""
result = 0 + (x -x.min())*1.0/(x.max()-x.min())
return result
标签进行one-hot编码
def one_hot_encode(x):
"""
:param x:样本标签列表
:return: 返回one-hot编码后的样本标签列表,是一个numpy数组
"""
encode = []
for value in x:
list = np.zeros([10])
list[value] = 1
encode.append(list)
return np.array(encode)
预处理训练数据集,验证数据集,测试数据集
#45000张训练图片,5000张验证图片调超参的
#10000张测试图片
helper.preprocess_and_save_data(
cifar10_dataset_folder_path,normalize,one_hot_encode)
#第一个checkpoint,将预处理的数据保存到本地
#读取已经保存的验证集
file = open("./preprocess_sets/preprocess_validation.p",mode="rb")
valid_features,valid_labels = pickle.load(file)
# shape = (5000,32,32,3)
#shape = (5000,10)
#注:对placeholder命名可以用于加载保存的placeholder数据
#Tensorflow的None表示的是形状可以是动态大小
图片输入
def neural_net_image_input(image_shape):
"""
:param images_shape: 图片的形状4D:(batch_size,height,width,depth)
:return: Tensor of image input 输入图片的张量
"""
layer_0 = tf.placeholder(
tf.float32,
shape= [None,image_shape[0],image_shape[1],image_shape[2]],
name="x")
return layer_0
标签输入
def neural_net_label_input(n_classes):
"""
:param n_classes:样本标签类别个数
:return: 返回一个样本标签的Tensor
"""
label = tf.placeholder(tf.float32,
shape=[None,n_classes],name="y")
return label
dropout 留存率
def neural_net_keep_prob_input():
"""
:return:返回一个keep_prob的Tensor ,也就是keep_prob的
placeholder
"""
keep_prob = tf.placeholder(tf.float32,name="keep_prob")
return keep_prob
def conv2d_maxpool(
x_tensor,conv_kernels_num,conv_ksize,conv_strides,
pool_ksize,pool_strides,padding="SAME",std=0.1,activation=tf.nn.relu):
"""
先实现卷积再实现最大池化
:param x_tensor: Tensor输入
:param conv_kernels_num:卷积核个数,卷积输出个数/通道数
:param conv_ksize:卷积核二维窗口大小
:param conv_strides:卷积步长
:param pool_ksize:最大池化二维(每一通道)窗口大小
:param pool_strides:最大池化步长
:param std:标准差
:return:返回卷积最大池化后的x_tensor
"""
#过滤器由多个卷积核组成,每层卷积核有多个,而过滤器只有一个
filter_weights = tf.Variable(
tf.truncated_normal(shape= [conv_ksize[0],
conv_ksize[1],
x_tensor.get_shape().as_list()[3],
conv_kernels_num],stddev=std))
biases = tf.Variable(tf.zeros(conv_kernels_num))
conv_layer = tf.nn.conv2d(x_tensor,filter_weights,
strides=[1,conv_strides[0],conv_strides[1],1],
padding=padding)
conv_layer = tf.nn.bias_add(conv_layer,biases)
maxpool_logits = tf.nn.max_pool(conv_layer,
ksize=[1,pool_ksize[0],pool_ksize[1],1],
strides=[1,pool_strides[0],pool_strides[1],1],
padding=padding)
if activation == tf.nn.relu :
return tf.nn.relu(maxpool_logits)
else:
return maxpool_logits
def conv2d_avgpool(
x_tensor,conv_kernels_num,conv_ksize,conv_strides,
pool_ksize,pool_strides,padding="SAME",std=1,activation=tf.nn.relu):
"""
先实现卷积再实现最大池化
:param x_tensor: Tensor输入
:param conv_kernels_num:卷积核个数,卷积输出个数/通道数
:param conv_ksize:卷积核二维窗口大小
:param conv_strides:卷积步长
:param pool_ksize:最大池化二维(每一通道)窗口大小
:param pool_strides:最大池化步长
:param std:标准差
:return:返回卷积最大池化后的x_tensor
"""
#过滤器由多个卷积核组成,每层卷积核有多个,而过滤器只有一个
filter_weights = tf.Variable(
tf.truncated_normal(shape= [conv_ksize[0],
conv_ksize[1],
x_tensor.get_shape().as_list()[3],
conv_kernels_num],stddev=std))
biases = tf.Variable(tf.zeros(conv_kernels_num))
conv_layer = tf.nn.conv2d(x_tensor,filter_weights,
strides=[1,conv_strides[0],conv_strides[1],1],
padding=padding)
conv_layer = tf.nn.bias_add(conv_layer,biases)
avgpool_logits = tf.nn.avg_pool(conv_layer,
ksize=[1,pool_ksize[0],pool_ksize[1],1],
strides=[1,pool_strides[0],pool_strides[1],1],
padding=padding)
if activation == tf.nn.relu :
return tf.nn.relu(avgpool_logits)
else:
return avgpool_logits
#将x_tensor从4维变为2维,shape为(m,n)
#m为样本数目,n为特征数目(扁平化图片维度)
def flatten(x_tensor):
feature_dism = (x_tensor.get_shape().as_list()[1]*
x_tensor.get_shape().as_list()[2]*
x_tensor.get_shape().as_list()[3])
flat_x_tensor =tf.reshape( x_tensor,[-1,feature_dism])
return flat_x_tensor
def fully_connect(x_tensor,num_outputs,activation=tf.nn.relu):
"""
:param x_tensor:输入Tensor
:param num_outputs: 输出神经元个数
:return:
"""
full_weights = tf.Variable(tf.truncated_normal(
shape=[x_tensor.get_shape().as_list()[1],num_outputs],stddev=0.1))
full_biases = tf.Variable(tf.zeros(num_outputs))
logits = tf.add(tf.matmul(x_tensor,full_weights),full_biases)
if activation == tf.nn.relu :
return tf.nn.relu(logits)
else:
return logits
def output(x_tensor,num_outputs,activation=None):
"""
:param x_tensor:输入的2D Tensor
:param num_outputs:分类类别数目
:return:
"""
out_weights = tf.Variable(tf.truncated_normal(
shape = [x_tensor.get_shape().as_list()[1],num_outputs],stddev=0.1))
out_biases = tf.Variable(tf.zeros([num_outputs]))
out_logits = tf.add(tf.matmul(x_tensor,out_weights),out_biases)
if activation == tf.nn.relu:
return tf.nn.relu(out_logits)
else:
return out_logits
def conv_net(x_tensor,keep_prob):
#shape = (None,32,32,3)
conv_layer_1 = conv2d_maxpool(x_tensor,32,(3,3),[1,1],
pool_ksize=(2,2),
pool_strides=(2,2),std=0.1)
# conv_layer_1 = tf.nn.dropout(conv_layer_1, keep_prob=keep_prob)
#shape = (None,16,16,32)
conv_layer_2 = conv2d_maxpool(conv_layer_1,64,(3,3),[1,1],
pool_ksize=(2,2),
pool_strides=(2,2),std=0.1)
# conv_layer_2 = tf.nn.dropout(conv_layer_2, keep_prob=keep_prob)
#shape = (None,8,8,64)
conv_layer_3 = conv2d_maxpool(conv_layer_2,64,(3,3),[1,1],
pool_ksize=[2,2],
pool_strides=[2,2],std=0.1)
# conv_layer_3 = tf.nn.dropout(conv_layer_3, keep_prob=keep_prob)
conv_layer_3 = flatten(conv_layer_3)
#shape = (None,4096)
full_conn_layer_1 = fully_connect(conv_layer_3,num_outputs=1024)
full_conn_layer_1 = tf.nn.dropout(full_conn_layer_1, keep_prob=keep_prob)
#shape = (None,1024)
output_logits = output(full_conn_layer_1,10)
#shape = (None,10)
return output_logits
x = neural_net_image_input([32,32,3])
y = neural_net_label_input(10)
keep_prob = neural_net_keep_prob_input()
epochs = 20
batch_size = 256
batch_num = 5
drop_keep = 0.8
logits = conv_net(x,keep_prob=drop_keep)
logits = tf.identity(logits,name="logits")
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=logits,labels=y))
optimizer = tf.train.AdamOptimizer().minimize(loss)
correct_pred = tf.equal(tf.argmax(logits,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(
tf.cast(correct_pred,tf.float32),name="accuracy")
init_op = tf.global_variables_initializer()
save_model_path ="./models/image_classification.ckpt"
print("开始训练..........")
with tf.Session() as sess:
sess.run(init_op)
for epoch_i in range(epochs):
for batch_id in range(batch_num):
for batch_features,batch_labels in helper. load_preprocess_training_batch(batch_id+1,batch_size):
_ = sess.run(optimizer, feed_dict={x: batch_features,y: batch_labels, keep_prob: drop_keep})
train_loss = sess.run(loss,feed_dict={x:batch_features, y:batch_labels,keep_prob:1.})
valid_accuracy = sess.run(accuracy,feed_dict={x:valid_features[:batch_size], y:valid_labels[:batch_size],keep_prob:1.})
train_accuracy = sess.run(accuracy,feed_dict={x:batch_features,y:batch_labels,keep_prob:1.})
print("Epoch:{:<4}--Training loss:{:<4}--Training accuracy:{:<4}--Validation accuracy:{:<4}".
format(epoch_i,train_loss,train_accuracy,valid_accuracy))
saver = tf.train.Saver()
save_path = saver.save(sess,save_model_path)
import tensorflow as tf
import pickle
import helper
import random
import numpy as np
batch_size = 256
n_samples =4
try:
if batch_size:
pass
except NameError:
batch_size = 64
load_model_path ="./models/image_classification.ckpt"
n_samples = 4
#前三个概率预测中有就算正确
top_n_predictions = 3
test_file = open("./preprocess_sets/preprocess_test.p",mode="rb")
def test_model():
#取测试特征样本,测试样本标签
test_features,test_labels = pickle.load(test_file)
loaded_graph = tf.Graph()
with tf.Session(graph=loaded_graph) as sess:
sess.run(tf.global_variables_initializer())
#读取图模型meta
load_model = tf.train.import_meta_graph(load_model_path + ".meta")
#读取变量(权重,偏置项),它会去查看checkpoint最新的模型名字
load_model.restore(sess,load_model_path)
#从已经读取模型中获取Tensor
loaded_x = loaded_graph.get_tensor_by_name("x:0")
loaded_y = loaded_graph.get_tensor_by_name("y:0")
loaded_keep_prob = loaded_graph.get_tensor_by_name("keep_prob:0")
loaded_logits = loaded_graph.get_tensor_by_name("logits:0")
loaded_accuracy = loaded_graph.get_tensor_by_name("accuracy:0")
#获取每个batch的准确率,再求平均值,这样可以节约内存
test_batch_acc_total = 0 #所有样本准确率之和
test_batch_count = 0 #所有样本数
for batch_features,batch_labels in helper.batch_features_labels(test_features,test_labels,batch_size):
test_batch_acc_total += sess.run(loaded_accuracy,
feed_dict={loaded_x:batch_features,loaded_y:batch_labels,loaded_keep_prob:1})
test_batch_count += 1
print("test accuracy:{:<3}".format(test_batch_acc_total/test_batch_count))
# random_test_features, random_test_labels = tuple(zip(*random.sample(list(zip(test_features, test_labels)), n_samples)))
# random_test_predictions = sess.run(
# tf.nn.top_k(tf.nn.softmax(loaded_logits), top_n_predictions),
# feed_dict={loaded_x: random_test_features, loaded_y: random_test_labels, loaded_keep_prob: 1.0})
# helper.display_image_predictions(random_test_features, random_test_labels, random_test_predictions)
#随机打印一个例子
random_test_features,random_test_labels = tuple(zip(*random.sample(list(zip(test_features,test_labels)),n_samples)))
#shape = (4,3)
random_test_predictions = sess.run(tf.nn.top_k(tf.nn.softmax(loaded_logits),k=top_n_predictions),
feed_dict={loaded_x:random_test_features,loaded_y:random_test_labels,loaded_keep_prob:1.})
print(random_test_predictions.indices)
print(random_test_predictions.values)
# print(np.shape(random_test_predictions))
helper.display_image_predictions(random_test_features,random_test_labels,random_test_predictions)
test_model()
import pickle
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelBinarizer
从文件中加载标签名
def _load_label_names():
"""
Load the label names from file
"""
return ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
加载cifar10的批次样本
def load_cfar10_batch(cifar10_dataset_folder_path, batch_id):
"""
Load a batch of the dataset
"""
with open(cifar10_dataset_folder_path + '/data_batch_' + str(batch_id), mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
labels = batch['labels']
return features, labels
显示数据集的状态信息
def display_stats(cifar10_dataset_folder_path, batch_id, sample_id):
"""
Display Stats of the the dataset
"""
batch_ids = list(range(1, 6))
if batch_id not in batch_ids:
print('Batch Id out of Range. Possible Batch Ids: {}'.format(batch_ids))
return None
features, labels = load_cfar10_batch(cifar10_dataset_folder_path, batch_id)
if not (0 <= sample_id < len(features)):
print('{} samples in batch {}. {} is out of range.'.format(len(features), batch_id, sample_id))
return None
print('\nStats of batch {}:'.format(batch_id))
print('Samples: {}'.format(len(features)))
print('Label Counts: {}'.format(dict(zip(*np.unique(labels, return_counts=True)))))
print('First 20 Labels: {}'.format(labels[:20]))
sample_image = features[sample_id]
sample_label = labels[sample_id]
label_names = _load_label_names()
print('\nExample of Image {}:'.format(sample_id))
print('Image - Min Value: {} Max Value: {}'.format(sample_image.min(), sample_image.max()))
print('Image - Shape: {}'.format(sample_image.shape))
print('Label - Label Id: {} Name: {}'.format(sample_label, label_names[sample_label]))
plt.axis('off')
plt.imshow(sample_image)
plt.show()
预处理数据并保存到文件
def _preprocess_and_save(normalize, one_hot_encode, features, labels, filename):
"""
Preprocess data and save it to file
"""
features = normalize(features)
labels = one_hot_encode(labels)
# pickle模块是能够让我们直接在文件中存储几乎任何Python对象的高级工具
#存
pickle.dump((features, labels), open(filename, 'wb'))
预处理训练集、验证集数据
def preprocess_and_save_data(cifar10_dataset_folder_path, normalize, one_hot_encode):
"""
Preprocess Training and Validation Data
"""
n_batches = 5
valid_features = []
valid_labels = []
for batch_i in range(1, n_batches + 1):
features, labels = load_cfar10_batch(cifar10_dataset_folder_path, batch_i)
validation_count = int(len(features) * 0.1)
# Prprocess and save a batch of training data
_preprocess_and_save(
normalize,
one_hot_encode,
features[:-validation_count],
labels[:-validation_count],
'./preprocess_sets/preprocess_batch_' + str(batch_i) + '.p')
# Use a portion of training batch for validation
valid_features.extend(features[-validation_count:])
valid_labels.extend(labels[-validation_count:])
# Preprocess and Save all validation data
_preprocess_and_save(
normalize,
one_hot_encode,
np.array(valid_features),
np.array(valid_labels),
'./preprocess_sets/preprocess_validation.p')
with open(cifar10_dataset_folder_path + '/test_batch', mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
# load the test data
test_features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
test_labels = batch['labels']
# Preprocess and Save all test data
_preprocess_and_save(
normalize,
one_hot_encode,
np.array(test_features),
np.array(test_labels),
'./preprocess_sets/preprocess_test.p')
分割批次样本的特征和标签
def batch_features_labels(features, labels, batch_size):
"""
Split features and labels into batches
"""
for start in range(0, len(features), batch_size):
end = min(start + batch_size, len(features))
yield features[start:end], labels[start:end]
加载预处理训练集数据
def load_preprocess_training_batch(batch_id, batch_size):
"""
Load the Preprocessed Training data and return them in batches of or less
"""
filename = './preprocess_sets/preprocess_batch_' + str(batch_id) + '.p'
features, labels = pickle.load(open(filename, mode='rb'))
# Return the training data in batches of size or less
return batch_features_labels(features, labels, batch_size)
图片预测展示
def display_image_predictions(features, labels, predictions):
n_classes = 10
label_names = _load_label_names()
label_binarizer = LabelBinarizer()
label_binarizer.fit(range(n_classes))
label_ids = label_binarizer.inverse_transform(np.array(labels))
print("labels:",labels)
print("label_id",label_ids)
fig, axies = plt.subplots(nrows=4, ncols=2)
fig.tight_layout()
fig.suptitle('Softmax Predictions', fontsize=20, y=1.1)
n_predictions = 3
margin = 0.05
ind = np.arange(n_predictions)
width = (1. - 2. * margin) / n_predictions
#indices对应的标签数组
for image_i, (feature, label_id, pred_indicies, pred_values) in enumerate(zip(features, label_ids, predictions.indices, predictions.values)):
pred_names = [label_names[pred_i] for pred_i in pred_indicies]
correct_name = label_names[label_id]
axies[image_i][0].imshow(feature)
axies[image_i][0].set_title(correct_name)
axies[image_i][0].set_axis_off()
# barh()表示绘制水平方向的条形图,基本使用方法为:barh(y, width, height=0.8, align='center')
axies[image_i][1].barh(ind + margin, pred_values[::-1], width)
axies[image_i][1].set_yticks(ind + margin)
axies[image_i][1].set_yticklabels(pred_names[::-1])
axies[image_i][1].set_xticks([0, 0.5, 1.0])
plt.show()
import os
import numpy as np
import tensorflow as tf
import random
from unittest.mock import MagicMock
def _print_success_message():
print('Tests Passed')
检测文件路径
def test_folder_path(cifar10_dataset_folder_path):
assert cifar10_dataset_folder_path is not None,\
'Cifar-10 data folder not set.'
assert cifar10_dataset_folder_path[-1] != '/',\
'The "/" shouldn\'t be added to the end of the path.'
assert os.path.exists(cifar10_dataset_folder_path),\
'Path not found.'
assert os.path.isdir(cifar10_dataset_folder_path),\
'{} is not a folder.'.format(os.path.basename(cifar10_dataset_folder_path))
train_files = [cifar10_dataset_folder_path + '/data_batch_' + str(batch_id) for batch_id in range(1, 6)]
other_files = [cifar10_dataset_folder_path + '/batches.meta', cifar10_dataset_folder_path + '/test_batch']
missing_files = [path for path in train_files + other_files if not os.path.exists(path)]
assert not missing_files,\
'Missing files in directory: {}'.format(missing_files)
print('All files found!')
测试是否归一化
def test_normalize(normalize):
test_shape = (np.random.choice(range(1000)), 32, 32, 3)
test_numbers = np.random.choice(range(256), test_shape)
normalize_out = normalize(test_numbers)
assert type(normalize_out).__module__ == np.__name__,\
'Not Numpy Object'
assert normalize_out.shape == test_shape,\
'Incorrect Shape. {} shape found'.format(normalize_out.shape)
assert normalize_out.max() <= 1 and normalize_out.min() >= 0,\
'Incorect Range. {} to {} found'.format(normalize_out.min(), normalize_out.max())
_print_success_message()
测试是否哑编码
def test_one_hot_encode(one_hot_encode):
test_shape = np.random.choice(range(1000))
test_numbers = np.random.choice(range(10), test_shape)
one_hot_out = one_hot_encode(test_numbers)
assert type(one_hot_out).__module__ == np.__name__,\
'Not Numpy Object'
assert one_hot_out.shape == (test_shape, 10),\
'Incorrect Shape. {} shape found'.format(one_hot_out.shape)
n_encode_tests = 5
test_pairs = list(zip(test_numbers, one_hot_out))
test_indices = np.random.choice(len(test_numbers), n_encode_tests)
labels = [test_pairs[test_i][0] for test_i in test_indices]
enc_labels = np.array([test_pairs[test_i][1] for test_i in test_indices])
new_enc_labels = one_hot_encode(labels)
assert np.array_equal(enc_labels, new_enc_labels),\
'Encodings returned different results for the same numbers.\n' \
'For the first call it returned:\n' \
'{}\n' \
'For the second call it returned\n' \
'{}\n' \
'Make sure you save the map of labels to encodings outside of the function.'.format(enc_labels, new_enc_labels)
_print_success_message()
检测图片样本输入
def test_nn_image_inputs(neural_net_image_input):
image_shape = (32, 32, 3)
nn_inputs_out_x = neural_net_image_input(image_shape)
assert nn_inputs_out_x.get_shape().as_list() == [None, image_shape[0], image_shape[1], image_shape[2]],\
'Incorrect Image Shape. Found {} shape'.format(nn_inputs_out_x.get_shape().as_list())
assert nn_inputs_out_x.op.type == 'Placeholder',\
'Incorrect Image Type. Found {} type'.format(nn_inputs_out_x.op.type)
assert nn_inputs_out_x.name == 'x:0', \
'Incorrect Name. Found {}'.format(nn_inputs_out_x.name)
print('Image Input Tests Passed.')
检测样本标签输入
def test_nn_label_inputs(neural_net_label_input):
n_classes = 10
nn_inputs_out_y = neural_net_label_input(n_classes)
assert nn_inputs_out_y.get_shape().as_list() == [None, n_classes],\
'Incorrect Label Shape. Found {} shape'.format(nn_inputs_out_y.get_shape().as_list())
assert nn_inputs_out_y.op.type == 'Placeholder',\
'Incorrect Label Type. Found {} type'.format(nn_inputs_out_y.op.type)
assert nn_inputs_out_y.name == 'y:0', \
'Incorrect Name. Found {}'.format(nn_inputs_out_y.name)
print('Label Input Tests Passed.')
检测dropout留存率输入
def test_nn_keep_prob_inputs(neural_net_keep_prob_input):
nn_inputs_out_k = neural_net_keep_prob_input()
assert nn_inputs_out_k.get_shape().ndims is None,\
'Too many dimensions found for keep prob. Found {} dimensions. It should be a scalar (0-Dimension Tensor).'.format(nn_inputs_out_k.get_shape().ndims)
assert nn_inputs_out_k.op.type == 'Placeholder',\
'Incorrect keep prob Type. Found {} type'.format(nn_inputs_out_k.op.type)
assert nn_inputs_out_k.name == 'keep_prob:0', \
'Incorrect Name. Found {}'.format(nn_inputs_out_k.name)
print('Keep Prob Tests Passed.')
检测卷积最大池化层
def test_con_pool(conv2d_maxpool):
test_x = tf.placeholder(tf.float32, [None, 32, 32, 5])
test_num_outputs = 10
test_con_k = (2, 2)
test_con_s = (4, 4)
test_pool_k = (2, 2)
test_pool_s = (2, 2)
conv2d_maxpool_out = conv2d_maxpool(test_x, test_num_outputs, test_con_k, test_con_s, test_pool_k, test_pool_s)
assert conv2d_maxpool_out.get_shape().as_list() == [None, 4, 4, 10],\
'Incorrect Shape. Found {} shape'.format(conv2d_maxpool_out.get_shape().as_list())
_print_success_message()
检测flatten层
def test_flatten(flatten):
test_x = tf.placeholder(tf.float32, [None, 10, 30, 6])
flat_out = flatten(test_x)
assert flat_out.get_shape().as_list() == [None, 10*30*6],\
'Incorrect Shape. Found {} shape'.format(flat_out.get_shape().as_list())
_print_success_message()
检测全连接层
def test_fully_conn(fully_conn):
test_x = tf.placeholder(tf.float32, [None, 128])
test_num_outputs = 40
fc_out = fully_conn(test_x, test_num_outputs)
assert fc_out.get_shape().as_list() == [None, 40],\
'Incorrect Shape. Found {} shape'.format(fc_out.get_shape().as_list())
_print_success_message()
检测输出层
def test_output(output):
test_x = tf.placeholder(tf.float32, [None, 128])
test_num_outputs = 40
output_out = output(test_x, test_num_outputs)
assert output_out.get_shape().as_list() == [None, 40],\
'Incorrect Shape. Found {} shape'.format(output_out.get_shape().as_list())
_print_success_message()
检测卷积模型
def test_conv_net(conv_net):
test_x = tf.placeholder(tf.float32, [None, 32, 32, 3])
test_k = tf.placeholder(tf.float32)
logits_out = conv_net(test_x, test_k)
assert logits_out.get_shape().as_list() == [None, 10],\
'Incorrect Model Output. Found {}'.format(logits_out.get_shape().as_list())
print('Neural Network Built!')
检测训练神经网络
def test_train_nn(train_neural_network):
mock_session = tf.Session()
test_x = np.random.rand(128, 32, 32, 3)
test_y = np.random.rand(128, 10)
test_k = np.random.rand(1)
test_optimizer = tf.train.AdamOptimizer()
mock_session.run = MagicMock()
train_neural_network(mock_session, test_optimizer, test_k, test_x, test_y)
assert mock_session.run.called, 'Session not used'
_print_success_message()