import os
import numpy as np
import tensorflow as tf
from tensorflow_vgg import vgg16
from tensorflow_vgg import utils
data_dir = 'flower_photos/'
contents = os.listdir(data_dir)
classes = [each for each in contents if os.path.isdir(data_dir + each)] #os.path.isdir()判断某一路径是否为目录
# 首先设置计算batch的值,如果运算平台的内存越大,这个值可以设置得越高
batch_size = 10
# 用codes_list来存储特征值
codes_list = []
# 用labels来存储花的类别
labels = []
# batch数组用来临时存储图片数据
batch = []
codes = None
with tf.Session() as sess:
# 构建VGG16模型对象
vgg = vgg16.Vgg16()
input_ = tf.placeholder(tf.float32, [None, 224, 224, 3]) #None表示不定
with tf.name_scope("content_vgg"): #主要目的是更加方便的管理参数命名
# 载入VGG16模型
vgg.build(input_)
# 对每个不同种类的花分别用VGG16计算特征值
for each in classes:
print("Starting {} images".format(each))
class_path = data_dir + each
files = os.listdir(class_path)
for ii, file in enumerate(files, 1):
# 载入图片并放入batch数组中
img = utils.load_image(os.path.join(class_path, file))
batch.append(img.reshape((1, 224, 224, 3)))
labels.append(each)
# 如果图片数量到了batch_size则开始具体的运算
if ii % batch_size == 0 or ii == len(files):
images = np.concatenate(batch)
feed_dict = {input_: images} #feed_dict给使用的placeholder创建出来的tensor赋值
# 计算特征值
codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict)
# 将结果放入到codes数组中
if codes is None:
codes = codes_batch
else:
codes = np.concatenate((codes, codes_batch))
# 清空数组准备下一个batch的计算
batch = []
print('{} images processed'.format(ii))
with open('codes', 'w') as f:
codes.tofile(f) #tofile()将数组中的数据以二进制格式写进文件
import csv
with open('labels', 'w') as f:
writer = csv.writer(f, delimiter='\n') #默认的情况下, 读和写使用逗号做分隔符(delimiter),用双引号作为引用符(quotechar),当遇到特殊情况是,可以根据需要手动指定字符
writer.writerow(labels)
from sklearn.preprocessing import LabelBinarizer
lb = LabelBinarizer()
lb.fit(labels) #等价于?lb.fit_transform(labels)
labels_vecs = lb.transform(labels)
from sklearn.model_selection import StratifiedShuffleSplit
ss = StratifiedShuffleSplit(n_splits=1, test_size=0.2) #1组,测试占20%
train_idx, val_idx = next(ss.split(codes, labels)) #分别将codes,labels按照ss的标准分割成80%的train_idx,和20%的val_idx
half_val_len = int(len(val_idx)/2) #20%的val_idx进一步分割成1:1
val_idx, test_idx = val_idx[:half_val_len], val_idx[half_val_len:]
train_x, train_y = codes[train_idx], labels_vecs[train_idx]
val_x, val_y = codes[val_idx], labels_vecs[val_idx]
test_x, test_y = codes[test_idx], labels_vecs[test_idx]
print("Train shapes (x, y):", train_x.shape, train_y.shape)
print("Validation shapes (x, y):", val_x.shape, val_y.shape)
print("Test shapes (x, y):", test_x.shape, test_y.shape)
input_ = tf.placeholder(tf.float32,shape = [None,codes.shape[1]])
labels_ = tf.placeholder(tf.int64,shape = [None,labels_vecs.shape[1]])
fc = tf.contrib.layers.fully_connected(inputs_,256)
logits = tf.contrib.layers.fully_connected(fc,labels_vecs.shape[1],activation_fn = None)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels = labels_,logits = logits)
cost = tf.reduce_mean(cross_entropy)
optimizer = tf.train.AdamOptimizer().minimize(cost)
predicted = tf.nn.softmax(logits)
correct_pred = tf.equal(tf.argmax(predicted,1),tf.argmax(labels_,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
def get_batches(x,y,n_batches = 10):
batch_size = len(x) // n_batches
for ii in range(0,n_batches * batch_size,batch_size):
if ii != (n_batches - 1) * batch_size:
x,y = x[ii:ii + batch_size],y[ii : batch_size]
else:
x,y = x[ii:],y[ii:]
yield x,y
epochs = 20
iteration = 0
saver = tf.train.Saver()
with tf.Sessioon() as sess:
sess.run(tf.global_variables_initializer())
for e in range(epochs):
for x,y in get_batches(train_x,train_y):
feed = {inputs_:x,
labels_:y}
loss,_ = sess.run([cost,optimizer],feed_dict = feed)
print('Epoch:{}/{}'.format(e + 1,epochs),
'Iteration:{}'.format(iteration),
'Training loss: {;.5f}'.format(loss))
iteration += 1
if iteration % 5 == 0:
feed = {input_:val_x,
labels_;val_y}
val_acc = sess.run(accuracy,feed_dict = feed)
print('Epoch:{}/{}'.format(e, epochs),
'Iteration:{}'.format(iteration),
'Validation Acc: {;.4f}'.format(val_acc))
saver.save(sess,'checkpoints/flowers.ckpt')
with tf.Session() as sess:
saver.restore(sess,tf.train.latest_checkpoint('checkpoints'))
feed = {inputs_:test_x,
labels_:test_y}
test_acc = sess.run(accuracy,feed_dict = feed)
print('Test accuracy : {:.4f}'.format(test_acc))