识花模型代码理解

import os
import numpy as np
import tensorflow as tf

from tensorflow_vgg import vgg16
from tensorflow_vgg import utils

data_dir = 'flower_photos/'
contents = os.listdir(data_dir)
classes = [each for each in contents if os.path.isdir(data_dir + each)] #os.path.isdir()判断某一路径是否为目录
# 首先设置计算batch的值,如果运算平台的内存越大,这个值可以设置得越高
batch_size = 10
# 用codes_list来存储特征值
codes_list = []
# 用labels来存储花的类别
labels = []
# batch数组用来临时存储图片数据
batch = []

codes = None

with tf.Session() as sess:
    # 构建VGG16模型对象
    vgg = vgg16.Vgg16()
    input_ = tf.placeholder(tf.float32, [None, 224, 224, 3]) #None表示不定
    with tf.name_scope("content_vgg"):  #主要目的是更加方便的管理参数命名
        # 载入VGG16模型
        vgg.build(input_)
    
    # 对每个不同种类的花分别用VGG16计算特征值
    for each in classes:
        print("Starting {} images".format(each))
        class_path = data_dir + each
        files = os.listdir(class_path)
        for ii, file in enumerate(files, 1):
            # 载入图片并放入batch数组中
            img = utils.load_image(os.path.join(class_path, file))
            batch.append(img.reshape((1, 224, 224, 3)))
            labels.append(each)
            
            # 如果图片数量到了batch_size则开始具体的运算
            if ii % batch_size == 0 or ii == len(files):
                images = np.concatenate(batch)

                feed_dict = {input_: images} #feed_dict给使用的placeholder创建出来的tensor赋值
                # 计算特征值
                codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict)
                
                # 将结果放入到codes数组中
                if codes is None:
                    codes = codes_batch
                else:
                    codes = np.concatenate((codes, codes_batch))
                
                # 清空数组准备下一个batch的计算
                batch = []
                print('{} images processed'.format(ii))

with open('codes', 'w') as f:
    codes.tofile(f) #tofile()将数组中的数据以二进制格式写进文件
    
import csv
with open('labels', 'w') as f:
    writer = csv.writer(f, delimiter='\n') #默认的情况下, 读和写使用逗号做分隔符(delimiter),用双引号作为引用符(quotechar),当遇到特殊情况是,可以根据需要手动指定字符
    writer.writerow(labels)
from sklearn.preprocessing import LabelBinarizer

lb = LabelBinarizer()
lb.fit(labels) #等价于?lb.fit_transform(labels)


labels_vecs = lb.transform(labels) 
from sklearn.model_selection import StratifiedShuffleSplit

ss = StratifiedShuffleSplit(n_splits=1, test_size=0.2) #1组,测试占20%

train_idx, val_idx = next(ss.split(codes, labels)) #分别将codes,labels按照ss的标准分割成80%的train_idx,和20%的val_idx

half_val_len = int(len(val_idx)/2) #20%的val_idx进一步分割成1:1
val_idx, test_idx = val_idx[:half_val_len], val_idx[half_val_len:]

train_x, train_y = codes[train_idx], labels_vecs[train_idx]
val_x, val_y = codes[val_idx], labels_vecs[val_idx]
test_x, test_y = codes[test_idx], labels_vecs[test_idx]

print("Train shapes (x, y):", train_x.shape, train_y.shape)
print("Validation shapes (x, y):", val_x.shape, val_y.shape)
print("Test shapes (x, y):", test_x.shape, test_y.shape)

input_ = tf.placeholder(tf.float32,shape = [None,codes.shape[1]])
labels_ = tf.placeholder(tf.int64,shape = [None,labels_vecs.shape[1]])
fc = tf.contrib.layers.fully_connected(inputs_,256)
logits = tf.contrib.layers.fully_connected(fc,labels_vecs.shape[1],activation_fn = None)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels = labels_,logits = logits)
cost = tf.reduce_mean(cross_entropy)
optimizer = tf.train.AdamOptimizer().minimize(cost)
predicted = tf.nn.softmax(logits)

correct_pred = tf.equal(tf.argmax(predicted,1),tf.argmax(labels_,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))

def get_batches(x,y,n_batches = 10):
    batch_size = len(x)  // n_batches

    for ii in range(0,n_batches * batch_size,batch_size):
        if ii != (n_batches - 1) * batch_size:
            x,y = x[ii:ii + batch_size],y[ii : batch_size]

        else:
            x,y = x[ii:],y[ii:]
        yield x,y

epochs = 20
iteration = 0
saver = tf.train.Saver()
with tf.Sessioon() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for x,y in get_batches(train_x,train_y):
            feed = {inputs_:x,
                    labels_:y}
            loss,_ = sess.run([cost,optimizer],feed_dict = feed)
            print('Epoch:{}/{}'.format(e + 1,epochs),
                  'Iteration:{}'.format(iteration),
                  'Training loss: {;.5f}'.format(loss))
            iteration += 1
            if iteration % 5 == 0:
                feed = {input_:val_x,
                        labels_;val_y}
                val_acc = sess.run(accuracy,feed_dict = feed)
                print('Epoch:{}/{}'.format(e, epochs),
                      'Iteration:{}'.format(iteration),
                      'Validation Acc: {;.4f}'.format(val_acc))

saver.save(sess,'checkpoints/flowers.ckpt')


with tf.Session() as sess:
    saver.restore(sess,tf.train.latest_checkpoint('checkpoints'))

    feed = {inputs_:test_x,
            labels_:test_y}
    test_acc = sess.run(accuracy,feed_dict = feed)
    print('Test accuracy : {:.4f}'.format(test_acc))




你可能感兴趣的:(迁移学习)