重点在于将数据集元素的再次划分。将数据集中的训练集分成训练集和验证集两部分。主要使用tf.split()函数。主用途是把一个张量分成几个子张量。
tf.split(
value,
num_or_size_splits,
axis=0
}
value为准备切分的张量
num_or_size_splits用来确定切割方式
axis指切割的维度
分割方式分为两种:
最终代码如下:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, Sequential, optimizers
# load data
(x, y), (x_test, y_test) = datasets.mnist.load_data()
# build datasets
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32)/255.
x = tf.reshape(x, [-1, 28*28])
y = tf.cast(y, dtype=tf.int64)
y = tf.one_hot(y, depth=10)
return x, y
batchsizie = 128
x_train, x_val = tf.split(x, num_or_size_splits=[50000, 10000],axis=0) #cut the data
y_train, y_val = tf.split(y, num_or_size_splits=[50000, 10000])
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(50000).batch(batchsizie).map(preprocess)
db_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
db_val = db_val.shuffle(10000).batch(batchsizie).map(preprocess)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsizie).map(preprocess)
#build network
network = Sequential([
layers.Dense(256, activation=tf.nn.relu), # [b, 784] to [b, 256]
layers.Dense(128, activation=tf.nn.relu), # [b, 256] to [b, 128]
layers.Dense(64, activation=tf.nn.relu), # [b, 128] to [b, 64]
layers.Dense(32, activation=tf.nn.relu), # [b, 64] to [b, 32]
layers.Dense(10) # [b, 32] to [b, 10]
])
network.build(input_shape=[None,28*28])
network.summary()
# train and text
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['acc'])
network.fit(db_train, epochs=10, validation_data=db_val, validation_freq=2)
network.evaluate(db_test)
最终调用测试集得到测试准确度
1/79 [..............................] - ETA: 0s - loss: 0.0861 - acc: 0.9766
23/79 [=======>......................] - ETA: 0s - loss: 0.2205 - acc: 0.9647
45/79 [================>.............] - ETA: 0s - loss: 0.2026 - acc: 0.9655
67/79 [========================>.....] - ETA: 0s - loss: 0.1686 - acc: 0.9703
79/79 [==============================] - 0s 2ms/step - loss: 0.1625 - acc: 0.9714