TensorFlow2.0 分批读取数据集、训练

目录

头文件

一、处理数据集(dogs vs cats)

二、自定义构建模型

三、训练模型

实验结果


 

头文件

import tensorflow as tf
import os

一、处理数据集(dogs vs cats)

data_dir = "D:/dataset/cats_and_dogs_filtered"
train_cat_dir = data_dir + "/train/cats/"
train_dog_dir = data_dir + "/train/dogs/"

test_cat_dir = data_dir + "/validation/cats/"
test_dog_dir = data_dir + "/validation/dogs/"

train_cat_filename = tf.constant([train_cat_dir + filename for filename in os.listdir(train_cat_dir)])
train_dog_filename = tf.constant([train_dog_dir + filename for filename in os.listdir(train_dog_dir)])
train_filename = tf.concat([train_cat_filename, train_dog_filename], axis=-1)

train_labels = tf.concat([
    tf.zeros(train_cat_filename.shape, dtype=tf.float32),
    tf.ones(train_dog_filename.shape, dtype=tf.float32)
], axis=-1)

test_cat_filename = tf.constant([test_cat_dir + filename for filename in os.listdir(test_cat_dir)])
test_dog_filename = tf.constant([test_dog_dir + filename for filename in os.listdir(test_dog_dir)])
test_filename = tf.concat([test_cat_filename, test_dog_filename], axis=-1)

test_labels = tf.concat([
    tf.zeros(test_cat_filename.shape, dtype=tf.float32),
    tf.ones(test_dog_filename.shape, dtype=tf.float32)
], axis=-1)

def _decode_and_resize(filename, label):
    # tf.print(filename)
    image_string = tf.io.read_file(filename)
    image_decode = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize(image_decode, [256, 256])/255.0
    return image_resized, label

batch_size = 16
train_dataset = tf.data.Dataset.from_tensor_slices((train_filename, train_labels))
train_dataset = train_dataset.map(
    map_func=_decode_and_resize,
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)

train_dataset = train_dataset.shuffle(buffer_size=20000).batch(batch_size)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((test_filename, test_labels))
test_dataset = test_dataset.map(
    map_func=_decode_and_resize,
).batch(batch_size)

二、自定义构建模型

class CNNModel(tf.keras.models.Model):
    def __init__(self):
        super(CNNModel, self).__init__()

        self.conv1 = tf.keras.layers.Conv2D(12, 3, activation='relu')
        self.maxpool1 = tf.keras.layers.MaxPooling2D()
        self.conv2 = tf.keras.layers.Conv2D(12, 5, activation='relu')
        self.maxpool2 = tf.keras.layers.MaxPooling2D()
        self.flatten = tf.keras.layers.Flatten()
        self.d1 = tf.keras.layers.Dense(64, activation='relu')
        self.d2 = tf.keras.layers.Dense(2, activation='softmax')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.flatten(x)
        x = self.d1(x)
        x = self.d2(x)
        return x

三、训练模型

def train_CNNModel():
    model = CNNModel()
    loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.Adam(0.001)

    train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_acc')
    test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_acc')


    @tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            logits = model(images)
            loss = loss_obj(labels, logits)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        train_acc(labels, logits)

    @tf.function
    def test_step(images, labels):
        logits = model(images)
        test_acc(labels, logits)

    Epochs = 5
    for epoch in range(Epochs):
        train_acc.reset_states()
        test_acc.reset_states()

        for images, labels in train_dataset:

            train_step(images, labels)

        for images, labels in test_dataset:
            test_step(images, labels)

        tmp = 'Epoch {}, Acc {}, Test Acc {}'
        print (tmp.format(epoch+1,
                          train_acc.result()*100,
                          test_acc.result()*100))

实验结果

Epoch 1, Acc 51.0, Test Acc 54.20000076293945
Epoch 2, Acc 56.69999694824219, Test Acc 57.75
Epoch 3, Acc 67.8499984741211, Test Acc 59.766666412353516
Epoch 4, Acc 81.30000305175781, Test Acc 60.60000228881836
Epoch 5, Acc 92.5999984741211, Test Acc 61.18000030517578

 

你可能感兴趣的:(TensorFlow)