目录
头文件
一、读取数据集(图片名)
二、将数据集图片、标签写入TFRecord
三、从TFRecord中读取数据集
四、构建模型
五、训练模型
实验结果
import tensorflow as tf
import os
data_dir = "D:/dataset/cats_and_dogs_filtered"
train_cat_dir = data_dir + "/train/cats/"
train_dog_dir = data_dir + "/train/dogs/"
train_tfrecord_file = data_dir + "/train/train.tfrecords"
test_cat_dir = data_dir + "/validation/cats/"
test_dog_dir = data_dir + "/validation/dogs/"
test_tfrecord_file = data_dir + "/validation/test.tfrecords"
train_cat_filenames = [train_cat_dir + filename for filename in os.listdir(train_cat_dir)]
train_dog_filenames = [train_dog_dir + filename for filename in os.listdir(train_dog_dir)]
train_filenames = train_cat_filenames + train_dog_filenames
train_labels = [0]*len(train_cat_filenames) + [1]*len(train_dog_filenames)
test_cat_filenames = [test_cat_dir + filename for filename in os.listdir(test_cat_dir)]
test_dog_filenames = [test_dog_dir + filename for filename in os.listdir(test_dog_dir)]
test_filenames = test_cat_filenames + test_dog_filenames
test_labels = [0]*len(test_cat_filenames) + [1]*len(test_dog_filenames)
def encoder(filenames, labels, tfrecord_file):
with tf.io.TFRecordWriter(tfrecord_file) as writer:
for filename, label in zip(filenames, labels):
image = open(filename, 'rb').read()
feature = {# 建立feature字典
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}
# 通过字典创建example
example = tf.train.Example(features=tf.train.Features(feature=feature))
# 将example序列化并写入字典
writer.write(example.SerializeToString())
encoder(train_filenames, train_labels, train_tfrecord_file)
encoder(test_filenames, test_labels, test_tfrecord_file)
def decoder(tfrecord_file, is_train_dataset=None):
dataset = tf.data.TFRecordDataset(tfrecord_file)
feature_discription = {
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)
}
def _parse_example(example_string): # 解码每一个example
feature_dic = tf.io.parse_single_example(example_string, feature_discription)
feature_dic['image'] = tf.io.decode_jpeg(feature_dic['image'])
feature_dic['image'] = tf.image.resize(feature_dic['image'], [256, 256])/255.0
return feature_dic['image'], feature_dic['label']
batch_size = 32
if is_train_dataset is not None:
dataset = dataset.map(_parse_example).shuffle(buffer_size=2000).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
else:
dataset = dataset.map(_parse_example)
dataset = dataset.batch(batch_size)
return dataset
train_data = decoder(train_tfrecord_file, 1)
test_data = decoder(test_tfrecord_file)
class CNNModel(tf.keras.models.Model):
def __init__(self):
super(CNNModel, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(12, 3, activation='relu')
self.maxpool1 = tf.keras.layers.MaxPooling2D()
self.conv2 = tf.keras.layers.Conv2D(12, 5, activation='relu')
self.maxpool2 = tf.keras.layers.MaxPooling2D()
self.flatten = tf.keras.layers.Flatten()
self.d1 = tf.keras.layers.Dense(64, activation='relu')
self.d2 = tf.keras.layers.Dense(2, activation='softmax')
def call(self, inputs):
x = self.conv1(inputs)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.flatten(x)
x = self.d1(x)
x = self.d2(x)
return x
def train_CNNModel():
model = CNNModel()
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(0.001)
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_acc')
test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_acc')
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
logits = model(images)
loss = loss_obj(labels, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_acc(labels, logits)
@tf.function
def test_step(images, labels):
logits = model(images)
test_acc(labels, logits)
Epochs = 5
for epoch in range(Epochs):
train_acc.reset_states()
test_acc.reset_states()
for images, labels in train_data:
train_step(images, labels)
for images, labels in test_data:
test_step(images, labels)
tmp = 'Epoch {}, Acc {}, Test Acc {}'
print(tmp.format(epoch + 1,
train_acc.result() * 100,
test_acc.result() * 100))
train_CNNModel()
Epoch 1, Acc 51.45000076293945, Test Acc 51.70000076293945
Epoch 2, Acc 60.650001525878906, Test Acc 58.099998474121094
Epoch 3, Acc 70.5, Test Acc 63.30000305175781
Epoch 4, Acc 78.05000305175781, Test Acc 69.30000305175781
Epoch 5, Acc 87.4000015258789, Test Acc 69.19999694824219