模型的训练主要有内置fit方法、内置tran_on_batch方法、自定义训练循环。
注:fit_generator方法在tf.keras中不推荐使用,其功能已经被fit包含。
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import *
#打印时间分割线
@tf.function
def printbar():
ts = tf.timestamp()
today_ts = ts%(24*60*60)
hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
minite = tf.cast((today_ts%3600)//60,tf.int32)
second = tf.cast(tf.floor(today_ts%60),tf.int32)
def timeformat(m):
if tf.strings.length(tf.strings.format("{}",m))==1:
return(tf.strings.format("0{}",m))
else:
return(tf.strings.format("{}",m))
timestring = tf.strings.join([timeformat(hour),timeformat(minite),
timeformat(second)],separator = ":")
tf.print("=========="*8,end = "")
tf.print(timestring)
MAX_LEN = 300
BATCH_SIZE = 32
(x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
MAX_WORDS = x_train.max()+1
CAT_NUM = y_train.max()+1
ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
.shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE).cache()
ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
.shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE).cache()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters.npz
2113536/2110848 [==============================] - 4s 2us/step
该方法功能非常强大, 支持对numpy array, tf.data.Dataset以及 Python generator数据进行训练。
并且可以通过设置回调函数实现对训练过程的复杂控制逻辑。
tf.keras.backend.clear_session()
def create_model():
model = models.Sequential()
model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Flatten())
model.add(layers.Dense(CAT_NUM,activation = "softmax"))
return(model)
def compile_model(model):
model.compile(optimizer=optimizers.Nadam(),
loss=losses.SparseCategoricalCrossentropy(),
metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
return(model)
model = create_model()
model.summary()
model = compile_model(model)
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, 300, 7) 216874
_________________________________________________________________
conv1d (Conv1D) (None, 296, 64) 2304
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 148, 64) 0
_________________________________________________________________
conv1d_1 (Conv1D) (None, 146, 32) 6176
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
_________________________________________________________________
flatten (Flatten) (None, 2336) 0
_________________________________________________________________
dense (Dense) (None, 46) 107502
=================================================================
Total params: 332,856
Trainable params: 332,856
Non-trainable params: 0
_________________________________________________________________
history = model.fit(ds_train,validation_data = ds_test,epochs = 10)
Train for 281 steps, validate for 71 steps
Epoch 1/10
281/281 [==============================] - 13s 46ms/step - loss: 2.0064 - sparse_categorical_accuracy: 0.4593 - sparse_top_k_categorical_accuracy: 0.7438 - val_loss: 1.6618 - val_sparse_categorical_accuracy: 0.5654 - val_sparse_top_k_categorical_accuracy: 0.7627
Epoch 2/10
281/281 [==============================] - 13s 46ms/step - loss: 1.4743 - sparse_categorical_accuracy: 0.6166 - sparse_top_k_categorical_accuracy: 0.7966 - val_loss: 1.5357 - val_sparse_categorical_accuracy: 0.5975 - val_sparse_top_k_categorical_accuracy: 0.7934
Epoch 3/10
281/281 [==============================] - 12s 42ms/step - loss: 1.2023 - sparse_categorical_accuracy: 0.6879 - sparse_top_k_categorical_accuracy: 0.8485 - val_loss: 1.5544 - val_sparse_categorical_accuracy: 0.6233 - val_sparse_top_k_categorical_accuracy: 0.8037
Epoch 4/10
281/281 [==============================] - 12s 41ms/step - loss: 0.9329 - sparse_categorical_accuracy: 0.7573 - sparse_top_k_categorical_accuracy: 0.9061 - val_loss: 1.7027 - val_sparse_categorical_accuracy: 0.6149 - val_sparse_top_k_categorical_accuracy: 0.8041
Epoch 5/10
281/281 [==============================] - 12s 43ms/step - loss: 0.6946 - sparse_categorical_accuracy: 0.8221 - sparse_top_k_categorical_accuracy: 0.9442 - val_loss: 1.9096 - val_sparse_categorical_accuracy: 0.6064 - val_sparse_top_k_categorical_accuracy: 0.8019
Epoch 6/10
281/281 [==============================] - 12s 42ms/step - loss: 0.5219 - sparse_categorical_accuracy: 0.8690 - sparse_top_k_categorical_accuracy: 0.9686 - val_loss: 2.1816 - val_sparse_categorical_accuracy: 0.6006 - val_sparse_top_k_categorical_accuracy: 0.7956
Epoch 7/10
281/281 [==============================] - 12s 42ms/step - loss: 0.4114 - sparse_categorical_accuracy: 0.8999 - sparse_top_k_categorical_accuracy: 0.9810 - val_loss: 2.4422 - val_sparse_categorical_accuracy: 0.5988 - val_sparse_top_k_categorical_accuracy: 0.7956
Epoch 8/10
281/281 [==============================] - 11s 39ms/step - loss: 0.3419 - sparse_categorical_accuracy: 0.9197 - sparse_top_k_categorical_accuracy: 0.9863 - val_loss: 2.6622 - val_sparse_categorical_accuracy: 0.6037 - val_sparse_top_k_categorical_accuracy: 0.7970
Epoch 9/10
281/281 [==============================] - 11s 39ms/step - loss: 0.2969 - sparse_categorical_accuracy: 0.9293 - sparse_top_k_categorical_accuracy: 0.9900 - val_loss: 2.8685 - val_sparse_categorical_accuracy: 0.6051 - val_sparse_top_k_categorical_accuracy: 0.8014
Epoch 10/10
281/281 [==============================] - 12s 41ms/step - loss: 0.2654 - sparse_categorical_accuracy: 0.9354 - sparse_top_k_categorical_accuracy: 0.9919 - val_loss: 3.0531 - val_sparse_categorical_accuracy: 0.6109 - val_sparse_top_k_categorical_accuracy: 0.8023
该内置方法相比较fit方法更加灵活,可以不通过回调函数而直接在批次层次上更加精细地控制训练的过程。
tf.keras.backend.clear_session()
def create_model():
model = models.Sequential()
model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Flatten())
model.add(layers.Dense(CAT_NUM,activation = "softmax"))
return(model)
def compile_model(model):
model.compile(optimizer=optimizers.Nadam(),
loss=losses.SparseCategoricalCrossentropy(),
metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
return(model)
model = create_model()
model.summary()
model = compile_model(model)
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, 300, 7) 216874
_________________________________________________________________
conv1d (Conv1D) (None, 296, 64) 2304
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 148, 64) 0
_________________________________________________________________
conv1d_1 (Conv1D) (None, 146, 32) 6176
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
_________________________________________________________________
flatten (Flatten) (None, 2336) 0
_________________________________________________________________
dense (Dense) (None, 46) 107502
=================================================================
Total params: 332,856
Trainable params: 332,856
Non-trainable params: 0
_________________________________________________________________
def train_model(model,ds_train,ds_valid,epoches):
for epoch in tf.range(1,epoches+1):
model.reset_metrics()
# 在后期降低学习率
if epoch == 5:
model.optimizer.lr.assign(model.optimizer.lr/2.0)
tf.print("Lowering optimizer Learning Rate...\n\n")
for x, y in ds_train:
train_result = model.train_on_batch(x, y)
for x, y in ds_valid:
valid_result = model.test_on_batch(x, y,reset_metrics=False)
if epoch%1 ==0:
printbar()
tf.print("epoch = ",epoch)
print("train:",dict(zip(model.metrics_names,train_result)))
print("valid:",dict(zip(model.metrics_names,valid_result)))
print("")
train_model(model,ds_train,ds_test,10)
================================================================================09:18:28
epoch = 1
train: {'loss': 1.8016982, 'sparse_categorical_accuracy': 0.54545456, 'sparse_top_k_categorical_accuracy': 0.6363636}
valid: {'loss': 1.6246095, 'sparse_categorical_accuracy': 0.5765806, 'sparse_top_k_categorical_accuracy': 0.76268923}
================================================================================09:18:38
epoch = 2
train: {'loss': 1.7089436, 'sparse_categorical_accuracy': 0.59090906, 'sparse_top_k_categorical_accuracy': 0.6363636}
valid: {'loss': 1.610321, 'sparse_categorical_accuracy': 0.6046305, 'sparse_top_k_categorical_accuracy': 0.7934105}
================================================================================09:18:49
epoch = 3
train: {'loss': 1.5006503, 'sparse_categorical_accuracy': 0.54545456, 'sparse_top_k_categorical_accuracy': 0.8181818}
valid: {'loss': 1.9378943, 'sparse_categorical_accuracy': 0.6251113, 'sparse_top_k_categorical_accuracy': 0.80142474}
================================================================================09:18:59
epoch = 4
train: {'loss': 1.2124836, 'sparse_categorical_accuracy': 0.6363636, 'sparse_top_k_categorical_accuracy': 0.8181818}
valid: {'loss': 2.4395564, 'sparse_categorical_accuracy': 0.6246661, 'sparse_top_k_categorical_accuracy': 0.8009795}
Lowering optimizer Learning Rate...
================================================================================09:19:08
epoch = 5
train: {'loss': 0.8375286, 'sparse_categorical_accuracy': 0.6818182, 'sparse_top_k_categorical_accuracy': 0.95454544}
valid: {'loss': 2.9165814, 'sparse_categorical_accuracy': 0.63178986, 'sparse_top_k_categorical_accuracy': 0.7987533}
================================================================================09:19:17
epoch = 6
train: {'loss': 0.66361845, 'sparse_categorical_accuracy': 0.6818182, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 3.168786, 'sparse_categorical_accuracy': 0.6246661, 'sparse_top_k_categorical_accuracy': 0.7956367}
================================================================================09:19:26
epoch = 7
train: {'loss': 0.50838065, 'sparse_categorical_accuracy': 0.77272725, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 3.3748772, 'sparse_categorical_accuracy': 0.626447, 'sparse_top_k_categorical_accuracy': 0.7987533}
================================================================================09:19:35
epoch = 8
train: {'loss': 0.40919036, 'sparse_categorical_accuracy': 0.77272725, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 3.5492792, 'sparse_categorical_accuracy': 0.62422085, 'sparse_top_k_categorical_accuracy': 0.8009795}
================================================================================09:19:44
epoch = 9
train: {'loss': 0.35043362, 'sparse_categorical_accuracy': 0.8636364, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 3.69528, 'sparse_categorical_accuracy': 0.6202137, 'sparse_top_k_categorical_accuracy': 0.8032057}
================================================================================09:19:54
epoch = 10
train: {'loss': 0.30596277, 'sparse_categorical_accuracy': 0.8636364, 'sparse_top_k_categorical_accuracy': 1.0}
valid: {'loss': 3.799032, 'sparse_categorical_accuracy': 0.6121995, 'sparse_top_k_categorical_accuracy': 0.8036509}
自定义训练循环无需编译模型,直接利用优化器根据损失函数反向传播迭代参数,拥有最高的灵活性。
tf.keras.backend.clear_session()
def create_model():
model = models.Sequential()
model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Flatten())
model.add(layers.Dense(CAT_NUM,activation = "softmax"))
return(model)
model = create_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, 300, 7) 216874
_________________________________________________________________
conv1d (Conv1D) (None, 296, 64) 2304
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 148, 64) 0
_________________________________________________________________
conv1d_1 (Conv1D) (None, 146, 32) 6176
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
_________________________________________________________________
flatten (Flatten) (None, 2336) 0
_________________________________________________________________
dense (Dense) (None, 46) 107502
=================================================================
Total params: 332,856
Trainable params: 332,856
Non-trainable params: 0
_________________________________________________________________
optimizer = optimizers.Nadam()
loss_func = losses.SparseCategoricalCrossentropy()
train_loss = metrics.Mean(name='train_loss')
train_metric = metrics.SparseCategoricalAccuracy(name='train_accuracy')
valid_loss = metrics.Mean(name='valid_loss')
valid_metric = metrics.SparseCategoricalAccuracy(name='valid_accuracy')
@tf.function
def train_step(model, features, labels):
with tf.GradientTape() as tape:
predictions = model(features,training = True)
loss = loss_func(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss.update_state(loss)
train_metric.update_state(labels, predictions)
@tf.function
def valid_step(model, features, labels):
predictions = model(features)
batch_loss = loss_func(labels, predictions)
valid_loss.update_state(batch_loss)
valid_metric.update_state(labels, predictions)
def train_model(model,ds_train,ds_valid,epochs):
for epoch in tf.range(1,epochs+1):
for features, labels in ds_train:
train_step(model,features,labels)
for features, labels in ds_valid:
valid_step(model,features,labels)
logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
if epoch%1 ==0:
printbar()
tf.print(tf.strings.format(logs,
(epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
tf.print("")
train_loss.reset_states()
valid_loss.reset_states()
train_metric.reset_states()
valid_metric.reset_states()
train_model(model,ds_train,ds_test,10)
================================================================================09:20:04
Epoch=1,Loss:2.03290606,Accuracy:0.454353154,Valid Loss:1.69233954,Valid Accuracy:0.558325887
================================================================================09:20:14
Epoch=2,Loss:1.49616826,Accuracy:0.613226473,Valid Loss:1.52313662,Valid Accuracy:0.604185224
================================================================================09:20:22
Epoch=3,Loss:1.22066784,Accuracy:0.680806041,Valid Loss:1.51799047,Valid Accuracy:0.627337515
================================================================================09:20:33
Epoch=4,Loss:0.945678711,Accuracy:0.749944329,Valid Loss:1.65234017,Valid Accuracy:0.627337515
================================================================================09:20:44
Epoch=5,Loss:0.678333282,Accuracy:0.822533965,Valid Loss:1.8622793,Valid Accuracy:0.621549428
================================================================================09:20:55
Epoch=6,Loss:0.483631164,Accuracy:0.882208884,Valid Loss:2.06073833,Valid Accuracy:0.623775601
================================================================================09:21:04
Epoch=7,Loss:0.371374488,Accuracy:0.912714303,Valid Loss:2.21256471,Valid Accuracy:0.629118443
================================================================================09:21:13
Epoch=8,Loss:0.305030555,Accuracy:0.927410364,Valid Loss:2.36870408,Valid Accuracy:0.63223511
================================================================================09:21:22
Epoch=9,Loss:0.262721539,Accuracy:0.936317086,Valid Loss:2.50547385,Valid Accuracy:0.630454123
================================================================================09:21:32
Epoch=10,Loss:0.234934,Accuracy:0.941104412,Valid Loss:2.62294984,Valid Accuracy:0.626892269