一、代码中的数据集可以通过以下代码进行下载
(train_image, train_label), (test_image, test_label) = tf.keras.datasets.mnist.load_data()
二、代码运行环境
Tensorflow-gpu==2.4.0
Python==3.7
三、数据集的构建如下所示
import tensorflow as tf
import os
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
def make_dataset():
(train_image, train_label), (test_image, test_label) = tf.keras.datasets.mnist.load_data()
train_image = tf.expand_dims(train_image, -1)
train_image = tf.cast(train_image / 255, tf.float32)
train_label = tf.cast(train_label, tf.int64)
test_image = tf.expand_dims(test_image, -1)
test_image = tf.cast(test_image / 255, tf.float32)
test_label = tf.cast(test_label, tf.int64)
train_dataset = tf.data.Dataset.from_tensor_slices((train_image, train_label))
train_dataset = train_dataset.shuffle(10000).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices((test_image, test_label))
test_dataset = test_dataset.batch(32)
return train_dataset, test_dataset
if __name__ == '__main__':
train_data, test_data = make_dataset()
print(train_data)
print(test_data)
四、模型的构建如下所示
import tensorflow as tf
import os
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
def make_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10)
])
return model
if __name__ == '__main__':
mol = make_model()
mol.summary()
五、自定义的训练过程如下所示
import tensorflow as tf
import os
from data_loader import make_dataset
from model import make_model
import tqdm
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
train_dataset, test_dataset = make_dataset()
model = make_model()
optimizer = tf.keras.optimizers.Adam()
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_loss_metric = tf.keras.metrics.Mean('train_acc')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
test_loss_metric = tf.keras.metrics.Mean('test_acc')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')
def loss(mol, x, y):
y_ = mol(x)
return loss_func(y, y_)
def train_step(mol, images, labels):
with tf.GradientTape() as t:
pred = mol(images)
loss_step = loss_func(labels, pred)
grads = t.gradient(loss_step, mol.trainable_variables)
optimizer.apply_gradients(zip(grads, mol.trainable_variables))
train_loss_metric(loss_step)
train_accuracy(labels, pred)
def test_step(mol, images, labels):
pred = mol(images)
loss_step = loss_func(labels, pred)
test_loss_metric(loss_step)
test_accuracy(labels, pred)
def train():
for epoch in range(100):
tqdm_train = tqdm.tqdm(enumerate(train_dataset), total=len(train_dataset))
for (batch, (images, labels)) in tqdm_train:
train_step(model, images, labels)
tqdm_train.set_description_str('Epoch{}'.format(epoch))
tqdm_train.set_postfix_str(
'train_loss is {:.14f} train_accuracy is {:.14f}'.format(train_loss_metric.result(),
train_accuracy.result()))
tqdm_test = tqdm.tqdm(enumerate(test_dataset), total=len(test_dataset))
for (batch, (images, labels)) in tqdm_test:
test_step(model, images, labels)
tqdm_test.set_description_str('Epoch{}'.format(epoch))
tqdm_test.set_postfix_str('test_loss is {:.14f} test_accuracy is {:.14f}'.format(test_loss_metric.result(),
test_accuracy.result()))
train_loss_metric.reset_states()
train_accuracy.reset_states()
test_loss_metric.reset_states()
test_accuracy.reset_states()
tqdm_train.close()
tqdm_test.close()
print('\n')
if __name__ == '__main__':
train()
model.save(r'model_data/my_train.h5')
六、训练的过程输出展示