Transfer learning.
Create a new DNN that reuses all the pretrained hidden layers of the previous model, freezes them, and replaces the softmax output layer with a fresh new one.
Train this new DNN on digits 5 to 9, using only 100 images per digit, and time how long it takes. Despite this small number of examples, can you achieve high precision?
import tensorflow as tf
import numpy as np
from datetime import datetime
import os
import time
def shuffle_batch(X, y, batch_size):
rnd_idx = np.random.permutation(len(X))
n_batches = len(X) // batch_size
for batch_idx in np.array_split(rnd_idx, n_batches):
X_batch, y_batch = X[batch_idx], y[batch_idx]
yield X_batch, y_batch
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype(np.float32).reshape(-1, 28*28) / 255.0
X_test = X_test.astype(np.float32).reshape(-1, 28*28) / 255.0
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
X_valid, X_train = X_train[:5000], X_train[5000:]
y_valid, y_train = y_train[:5000], y_train[5000:]
X_train = X_train[y_train > 4]
y_train = y_train[y_train > 4]-5
X_valid = X_valid[y_valid > 4]
y_valid = y_valid[y_valid > 4]-5
X_test = X_test[y_test > 4]
y_test = y_test[y_test > 4]-5
rnd_idx = np.random.permutation(len(X_train))
rnd_idx = rnd_idx[0:100]
X_train = X_train[rnd_idx,:]
y_train = y_train[rnd_idx]
old_final_model_path = "./my_logreg_model"
old_final_model_meta = "./my_logreg_model.meta"
new_final_model_path = "./my_new_logreg_model"
learning_rate = 0.01
saver = tf.train.import_meta_graph(old_final_model_meta)
X = tf.get_default_graph().get_tensor_by_name("X:0")
y = tf.get_default_graph().get_tensor_by_name("y:0")
loss = tf.get_default_graph().get_tensor_by_name("loss/loss:0")
accuracy = tf.get_default_graph().get_tensor_by_name("eval/accuracy:0")
#Y_proba = tf.get_default_graph().get_tensor_by_name("dnn/Y_proba:0")
#logits = Y_proba.op.inputs[0]
#correct = tf.nn.in_top_k(logits, y, 1)
#accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")
with tf.name_scope("new_train"):
output_layer_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="logits")
optimizer = tf.train.AdamOptimizer(learning_rate, name="Adam2")
training_op = optimizer.minimize(loss, var_list=output_layer_vars)
n_epochs = 1000
batch_size = 20
max_checks_without_progress = 20
checks_without_progress = 0
best_loss_val = np.infty
init = tf.global_variables_initializer()
new_saver = tf.train.Saver()
with tf.Session() as sess:
init.run()
saver.restore(sess, old_final_model_path)
for var in output_layer_vars:
var.initializer.run()
t0 = time.time()
for epoch in range(n_epochs):
for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size):
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
acc_batch = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
acc_val, loss_val = sess.run([accuracy, loss], feed_dict={X: X_valid, y: y_valid})
if loss_val < best_loss_val:
save_path = new_saver.save(sess, new_final_model_path)
best_loss_val = loss_val
checks_without_progress = 0
else:
checks_without_progress += 1
if checks_without_progress > max_checks_without_progress:
print("Early stopping!")
break
print("{}\tValidation loss: {:.6f}\tBest loss: {:.6f}\tAccuracy: {:.2f}%".format(
epoch, loss_val, best_loss_val, acc_val * 100))
#print(epoch, "Batch accuracy:", acc_batch, "Val accuracy:", acc_val)
t1 = time.time()
print("Total training time: {:.1f}s".format(t1 - t0))
with tf.Session() as sess:
new_saver.restore(sess, new_final_model_path)
acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
print("Final test accuracy: {:.2f}%".format(acc_test * 100))
log info:
INFO:tensorflow:Restoring parameters from ./my_logreg_model
0 Validation loss: 1.484102 Best loss: 1.484102 Accuracy: 45.33%
1 Validation loss: 1.316709 Best loss: 1.316709 Accuracy: 60.20%
2 Validation loss: 1.220808 Best loss: 1.220808 Accuracy: 58.72%
3 Validation loss: 1.159953 Best loss: 1.159953 Accuracy: 60.24%
4 Validation loss: 1.108175 Best loss: 1.108175 Accuracy: 62.16%
5 Validation loss: 1.072390 Best loss: 1.072390 Accuracy: 63.47%
6 Validation loss: 1.042931 Best loss: 1.042931 Accuracy: 65.23%
7 Validation loss: 1.026151 Best loss: 1.026151 Accuracy: 65.40%
8 Validation loss: 1.018076 Best loss: 1.018076 Accuracy: 65.56%
9 Validation loss: 1.007680 Best loss: 1.007680 Accuracy: 66.13%
10 Validation loss: 1.000437 Best loss: 1.000437 Accuracy: 66.58%
11 Validation loss: 0.992715 Best loss: 0.992715 Accuracy: 66.67%
12 Validation loss: 0.988555 Best loss: 0.988555 Accuracy: 66.91%
13 Validation loss: 0.985457 Best loss: 0.985457 Accuracy: 66.58%
14 Validation loss: 0.983999 Best loss: 0.983999 Accuracy: 66.46%
15 Validation loss: 0.985433 Best loss: 0.983999 Accuracy: 66.83%
16 Validation loss: 0.981804 Best loss: 0.981804 Accuracy: 66.87%
17 Validation loss: 0.978994 Best loss: 0.978994 Accuracy: 66.95%
18 Validation loss: 0.979878 Best loss: 0.978994 Accuracy: 67.32%
19 Validation loss: 0.975456 Best loss: 0.975456 Accuracy: 67.32%
20 Validation loss: 0.978571 Best loss: 0.975456 Accuracy: 67.16%
21 Validation loss: 0.980386 Best loss: 0.975456 Accuracy: 67.20%
22 Validation loss: 0.979259 Best loss: 0.975456 Accuracy: 66.95%
23 Validation loss: 0.975193 Best loss: 0.975193 Accuracy: 67.24%
24 Validation loss: 0.974504 Best loss: 0.974504 Accuracy: 67.28%
25 Validation loss: 0.975535 Best loss: 0.974504 Accuracy: 67.36%
26 Validation loss: 0.975207 Best loss: 0.974504 Accuracy: 67.32%
27 Validation loss: 0.977013 Best loss: 0.974504 Accuracy: 67.44%
28 Validation loss: 0.978441 Best loss: 0.974504 Accuracy: 67.16%
29 Validation loss: 0.978795 Best loss: 0.974504 Accuracy: 67.08%
30 Validation loss: 0.977175 Best loss: 0.974504 Accuracy: 67.28%
31 Validation loss: 0.975567 Best loss: 0.974504 Accuracy: 67.57%
32 Validation loss: 0.980125 Best loss: 0.974504 Accuracy: 67.12%
33 Validation loss: 0.979631 Best loss: 0.974504 Accuracy: 67.12%
34 Validation loss: 0.982773 Best loss: 0.974504 Accuracy: 67.32%
35 Validation loss: 0.984605 Best loss: 0.974504 Accuracy: 67.28%
36 Validation loss: 0.987710 Best loss: 0.974504 Accuracy: 67.16%
37 Validation loss: 0.984601 Best loss: 0.974504 Accuracy: 67.04%
38 Validation loss: 0.986800 Best loss: 0.974504 Accuracy: 67.08%
39 Validation loss: 0.985071 Best loss: 0.974504 Accuracy: 67.04%
40 Validation loss: 0.986665 Best loss: 0.974504 Accuracy: 66.83%
41 Validation loss: 0.988729 Best loss: 0.974504 Accuracy: 67.36%
42 Validation loss: 0.987109 Best loss: 0.974504 Accuracy: 67.16%
43 Validation loss: 0.987439 Best loss: 0.974504 Accuracy: 67.28%
44 Validation loss: 0.989915 Best loss: 0.974504 Accuracy: 67.32%
Early stopping!
Total training time: 5.8s
INFO:tensorflow:Restoring parameters from ./my_new_logreg_model
Final test accuracy: 68.65%