Tensorflow---使用Tensorflow进行自定义的训练

一、代码中的数据集可以通过以下代码进行下载

(train_image, train_label), (test_image, test_label) = tf.keras.datasets.mnist.load_data()

二、代码运行环境

Tensorflow-gpu==2.4.0

Python==3.7

三、数据集的构建如下所示

import tensorflow as tf
import os

# 环境变量的配置
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'


# 数据的加载
def make_dataset():
    (train_image, train_label), (test_image, test_label) = tf.keras.datasets.mnist.load_data()

    train_image = tf.expand_dims(train_image, -1)
    train_image = tf.cast(train_image / 255, tf.float32)
    train_label = tf.cast(train_label, tf.int64)

    test_image = tf.expand_dims(test_image, -1)
    test_image = tf.cast(test_image / 255, tf.float32)
    test_label = tf.cast(test_label, tf.int64)

    train_dataset = tf.data.Dataset.from_tensor_slices((train_image, train_label))
    train_dataset = train_dataset.shuffle(10000).batch(32)

    test_dataset = tf.data.Dataset.from_tensor_slices((test_image, test_label))
    test_dataset = test_dataset.batch(32)

    return train_dataset, test_dataset


if __name__ == '__main__':
    train_data, test_data = make_dataset()
    print(train_data)
    print(test_data)

四、模型的构建如下所示

import tensorflow as tf
import os

# 环境变量的配置
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'


# 模型的构建
def make_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(10)
    ])
    return model


if __name__ == '__main__':
    mol = make_model()
    mol.summary()

五、自定义的训练过程如下所示

import tensorflow as tf
import os
from data_loader import make_dataset
from model import make_model
import tqdm

# 环境变量的配置
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

# 数据的加载
train_dataset, test_dataset = make_dataset()

# 模型的构建
model = make_model()

# 模型的配置
optimizer = tf.keras.optimizers.Adam()
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss_metric = tf.keras.metrics.Mean('train_acc')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')

test_loss_metric = tf.keras.metrics.Mean('test_acc')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')


# loss函数的构建
def loss(mol, x, y):
    y_ = mol(x)
    return loss_func(y, y_)


# 构造一个批次的训练过程
def train_step(mol, images, labels):
    with tf.GradientTape() as t:
        pred = mol(images)
        loss_step = loss_func(labels, pred)
    grads = t.gradient(loss_step, mol.trainable_variables)
    optimizer.apply_gradients(zip(grads, mol.trainable_variables))
    train_loss_metric(loss_step)
    train_accuracy(labels, pred)


# 构造一个批次的验证过程
def test_step(mol, images, labels):
    pred = mol(images)
    loss_step = loss_func(labels, pred)
    test_loss_metric(loss_step)
    test_accuracy(labels, pred)


# 定义训练
def train():
    for epoch in range(100):
        tqdm_train = tqdm.tqdm(enumerate(train_dataset), total=len(train_dataset))
        for (batch, (images, labels)) in tqdm_train:
            train_step(model, images, labels)
            tqdm_train.set_description_str('Epoch{}'.format(epoch))
            tqdm_train.set_postfix_str(
                'train_loss is {:.14f} train_accuracy is {:.14f}'.format(train_loss_metric.result(),
                                                                         train_accuracy.result()))

        tqdm_test = tqdm.tqdm(enumerate(test_dataset), total=len(test_dataset))
        for (batch, (images, labels)) in tqdm_test:
            test_step(model, images, labels)
            tqdm_test.set_description_str('Epoch{}'.format(epoch))
            tqdm_test.set_postfix_str('test_loss is {:.14f} test_accuracy is {:.14f}'.format(test_loss_metric.result(),
                                                                                             test_accuracy.result()))
        train_loss_metric.reset_states()
        train_accuracy.reset_states()
        test_loss_metric.reset_states()
        test_accuracy.reset_states()
        tqdm_train.close()
        tqdm_test.close()
        print('\n')


if __name__ == '__main__':
    train()
    model.save(r'model_data/my_train.h5')

六、训练的过程输出展示

Tensorflow---使用Tensorflow进行自定义的训练_第1张图片

你可能感兴趣的:(Tensorflow,tensorflow,深度学习,机器学习)