深度学习实战(二) 模型重用及冻结层

Transfer learning.

Create a new DNN that reuses all the pretrained hidden layers of the previous model, freezes them, and replaces the softmax output layer with a fresh new one.

Train this new DNN on digits 5 to 9, using only 100 images per digit, and time how long it takes. Despite this small number of examples, can you achieve high precision?

import tensorflow as tf
import numpy as np
from datetime import datetime
import os
import time

def shuffle_batch(X, y, batch_size):
    rnd_idx = np.random.permutation(len(X))
    n_batches = len(X) // batch_size
    for batch_idx in np.array_split(rnd_idx, n_batches):
        X_batch, y_batch = X[batch_idx], y[batch_idx]
        yield X_batch, y_batch

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype(np.float32).reshape(-1, 28*28) / 255.0
X_test = X_test.astype(np.float32).reshape(-1, 28*28) / 255.0
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
X_valid, X_train = X_train[:5000], X_train[5000:]
y_valid, y_train = y_train[:5000], y_train[5000:]

X_train = X_train[y_train > 4]
y_train = y_train[y_train > 4]-5
X_valid = X_valid[y_valid > 4]
y_valid = y_valid[y_valid > 4]-5
X_test = X_test[y_test > 4]
y_test = y_test[y_test > 4]-5

rnd_idx = np.random.permutation(len(X_train))
rnd_idx = rnd_idx[0:100]
X_train = X_train[rnd_idx,:]
y_train = y_train[rnd_idx]

old_final_model_path = "./my_logreg_model"
old_final_model_meta = "./my_logreg_model.meta"
new_final_model_path = "./my_new_logreg_model"

learning_rate = 0.01

saver = tf.train.import_meta_graph(old_final_model_meta)

X = tf.get_default_graph().get_tensor_by_name("X:0")
y = tf.get_default_graph().get_tensor_by_name("y:0")
loss = tf.get_default_graph().get_tensor_by_name("loss/loss:0")
accuracy = tf.get_default_graph().get_tensor_by_name("eval/accuracy:0")

#Y_proba = tf.get_default_graph().get_tensor_by_name("dnn/Y_proba:0")
#logits = Y_proba.op.inputs[0]
#correct = tf.nn.in_top_k(logits, y, 1)
#accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")

with tf.name_scope("new_train"):
    output_layer_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="logits")
    optimizer = tf.train.AdamOptimizer(learning_rate, name="Adam2")
    training_op = optimizer.minimize(loss, var_list=output_layer_vars)    
    
n_epochs = 1000
batch_size = 20

max_checks_without_progress = 20
checks_without_progress = 0
best_loss_val = np.infty

init = tf.global_variables_initializer()
new_saver = tf.train.Saver()

with tf.Session() as sess:
    init.run()
    saver.restore(sess, old_final_model_path)
    for var in output_layer_vars:
        var.initializer.run()

    t0 = time.time()
    
    for epoch in range(n_epochs):   
        for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size):
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})                
        acc_batch = accuracy.eval(feed_dict={X: X_batch, y: y_batch}) 
        acc_val, loss_val = sess.run([accuracy, loss], feed_dict={X: X_valid, y: y_valid})
              
        if loss_val < best_loss_val:
            save_path = new_saver.save(sess, new_final_model_path)
            best_loss_val = loss_val
            checks_without_progress = 0
        else:
            checks_without_progress += 1
            if checks_without_progress > max_checks_without_progress:
                print("Early stopping!")
                break
                
        print("{}\tValidation loss: {:.6f}\tBest loss: {:.6f}\tAccuracy: {:.2f}%".format(
            epoch, loss_val, best_loss_val, acc_val * 100))
        
        #print(epoch, "Batch accuracy:", acc_batch, "Val accuracy:", acc_val)

    t1 = time.time()
    print("Total training time: {:.1f}s".format(t1 - t0))
    
with tf.Session() as sess:
    new_saver.restore(sess, new_final_model_path)
    acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
    print("Final test accuracy: {:.2f}%".format(acc_test * 100))

log info:

INFO:tensorflow:Restoring parameters from ./my_logreg_model
0	Validation loss: 1.484102	Best loss: 1.484102	Accuracy: 45.33%
1	Validation loss: 1.316709	Best loss: 1.316709	Accuracy: 60.20%
2	Validation loss: 1.220808	Best loss: 1.220808	Accuracy: 58.72%
3	Validation loss: 1.159953	Best loss: 1.159953	Accuracy: 60.24%
4	Validation loss: 1.108175	Best loss: 1.108175	Accuracy: 62.16%
5	Validation loss: 1.072390	Best loss: 1.072390	Accuracy: 63.47%
6	Validation loss: 1.042931	Best loss: 1.042931	Accuracy: 65.23%
7	Validation loss: 1.026151	Best loss: 1.026151	Accuracy: 65.40%
8	Validation loss: 1.018076	Best loss: 1.018076	Accuracy: 65.56%
9	Validation loss: 1.007680	Best loss: 1.007680	Accuracy: 66.13%
10	Validation loss: 1.000437	Best loss: 1.000437	Accuracy: 66.58%
11	Validation loss: 0.992715	Best loss: 0.992715	Accuracy: 66.67%
12	Validation loss: 0.988555	Best loss: 0.988555	Accuracy: 66.91%
13	Validation loss: 0.985457	Best loss: 0.985457	Accuracy: 66.58%
14	Validation loss: 0.983999	Best loss: 0.983999	Accuracy: 66.46%
15	Validation loss: 0.985433	Best loss: 0.983999	Accuracy: 66.83%
16	Validation loss: 0.981804	Best loss: 0.981804	Accuracy: 66.87%
17	Validation loss: 0.978994	Best loss: 0.978994	Accuracy: 66.95%
18	Validation loss: 0.979878	Best loss: 0.978994	Accuracy: 67.32%
19	Validation loss: 0.975456	Best loss: 0.975456	Accuracy: 67.32%
20	Validation loss: 0.978571	Best loss: 0.975456	Accuracy: 67.16%
21	Validation loss: 0.980386	Best loss: 0.975456	Accuracy: 67.20%
22	Validation loss: 0.979259	Best loss: 0.975456	Accuracy: 66.95%
23	Validation loss: 0.975193	Best loss: 0.975193	Accuracy: 67.24%
24	Validation loss: 0.974504	Best loss: 0.974504	Accuracy: 67.28%
25	Validation loss: 0.975535	Best loss: 0.974504	Accuracy: 67.36%
26	Validation loss: 0.975207	Best loss: 0.974504	Accuracy: 67.32%
27	Validation loss: 0.977013	Best loss: 0.974504	Accuracy: 67.44%
28	Validation loss: 0.978441	Best loss: 0.974504	Accuracy: 67.16%
29	Validation loss: 0.978795	Best loss: 0.974504	Accuracy: 67.08%
30	Validation loss: 0.977175	Best loss: 0.974504	Accuracy: 67.28%
31	Validation loss: 0.975567	Best loss: 0.974504	Accuracy: 67.57%
32	Validation loss: 0.980125	Best loss: 0.974504	Accuracy: 67.12%
33	Validation loss: 0.979631	Best loss: 0.974504	Accuracy: 67.12%
34	Validation loss: 0.982773	Best loss: 0.974504	Accuracy: 67.32%
35	Validation loss: 0.984605	Best loss: 0.974504	Accuracy: 67.28%
36	Validation loss: 0.987710	Best loss: 0.974504	Accuracy: 67.16%
37	Validation loss: 0.984601	Best loss: 0.974504	Accuracy: 67.04%
38	Validation loss: 0.986800	Best loss: 0.974504	Accuracy: 67.08%
39	Validation loss: 0.985071	Best loss: 0.974504	Accuracy: 67.04%
40	Validation loss: 0.986665	Best loss: 0.974504	Accuracy: 66.83%
41	Validation loss: 0.988729	Best loss: 0.974504	Accuracy: 67.36%
42	Validation loss: 0.987109	Best loss: 0.974504	Accuracy: 67.16%
43	Validation loss: 0.987439	Best loss: 0.974504	Accuracy: 67.28%
44	Validation loss: 0.989915	Best loss: 0.974504	Accuracy: 67.32%
Early stopping!
Total training time: 5.8s
INFO:tensorflow:Restoring parameters from ./my_new_logreg_model
Final test accuracy: 68.65%

 

你可能感兴趣的:(深度学习实战(二) 模型重用及冻结层)