代码:
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model,datasets,losses,optimizers
from tensorflow.keras.layers import Dense,Flatten,Conv2D,MaxPooling2D
(x, y), (x_val, y_val) = datasets.mnist.load_data()
x, x_val = x / 255.0, x_val / 255.0 #此时x.shape=[50000,28,28]
x = x[...,tf.newaxis]
x_val = x_val[...,tf.newaxis] #此x.shape=[50000,28,28,1] 满足tensorflow输入size
train_ds = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(32)
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(6,3,activation='relu')
self.maxp2 = MaxPooling2D(2,2)
self.conv3 = Conv2D(16,3,activation='relu')
self.maxp4 = MaxPooling2D(2,2)
self.flatten = Flatten()
self.d5 = Dense(120,activation='relu')
self.d6 = Dense(84,activation='relu')
self.d7 = Dense(10,activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.maxp2(x)
x = self.conv3(x)
x = self.maxp4(x)
x = self.flatten(x)
x = self.d5(x)
x = self.d6(x)
return self.d7(x)
model = MyModel()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
EPOCHS = 10
for epoch in range(EPOCHS):
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
for imgs, labels in train_ds:
train_step(imgs,labels)
for test_images, test_labels in test_ds:
test_step(test_images,test_labels)
template = 'Epoch:{}, Loss:{}, Accuracy:{}, Test Loss:{}, Test Accuracy:{}'
print(template.format(epoch+1,
train_loss.result(),
train_accuracy.result()*100,
test_loss.result(),
test_accuracy.result()*100))
取前5个epoch:
Epoch:1, Loss:0.1898973435163498, Accuracy:94.21666717529297, Test Loss:0.06720136106014252, Test Accuracy:97.91999816894531
Epoch:2, Loss:0.06402933597564697, Accuracy:98.04666900634766, Test Loss:0.05430004000663757, Test Accuracy:98.22000122070312
Epoch:3, Loss:0.0453227162361145, Accuracy:98.58833312988281, Test Loss:0.038181234151124954, Test Accuracy:98.77999877929688
Epoch:4, Loss:0.03610095754265785, Accuracy:98.84500122070312, Test Loss:0.04592617228627205, Test Accuracy:98.48999786376953
Epoch:5, Loss:0.02789519727230072, Accuracy:99.12333679199219, Test Loss:0.0402412936091423, Test Accuracy:98.83999633789062
结果:
Epoch:1, Loss:1.5339884757995605, Accuracy:44.29399871826172, Test Loss:1.3234540224075317, Test Accuracy:52.96999740600586
Epoch:2, Loss:1.2463908195495605, Accuracy:55.544002532958984, Test Loss:1.239940881729126, Test Accuracy:55.48999786376953
Epoch:3, Loss:1.1332910060882568, Accuracy:59.91200256347656, Test Loss:1.1412276029586792, Test Accuracy:59.59000015258789
Epoch:4, Loss:1.0569039583206177, Accuracy:62.53799819946289, Test Loss:1.1044120788574219, Test Accuracy:61.38999938964844
Epoch:5, Loss:0.9972842931747437, Accuracy:64.98600006103516, Test Loss:1.0524744987487793, Test Accuracy:62.94000244140625
Epoch:5, Loss:0.9932153224945068, Accuracy:64.76399993896484, Test Loss:1.076656699180603, Test Accuracy:62.279998779296875
Epoch:6, Loss:0.9340294003486633, Accuracy:67.07599639892578, Test Loss:1.1172698736190796, Test Accuracy:61.290000915527344
Epoch:7, Loss:0.882939875125885, Accuracy:68.83599853515625, Test Loss:1.0803672075271606, Test Accuracy:63.27000427246094
Epoch:8, Loss:0.8316821455955505, Accuracy:70.61599731445312, Test Loss:1.0501688718795776, Test Accuracy:64.17000579833984
Epoch:9, Loss:0.7863742709159851, Accuracy:72.25599670410156, Test Loss:1.0738744735717773, Test Accuracy:64.06999969482422
Epoch:10, Loss:0.7494518756866455, Accuracy:73.4939956665039, Test Loss:1.0971062183380127, Test Accuracy:63.160003662109375
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32,3,activation='relu')
self.maxp2 = MaxPooling2D(2,2)
self.conv3 = Conv2D(64,3,activation='relu')
self.maxp4 = MaxPooling2D(2,2)
self.conv5 = Conv2D(64,3,activation='relu')
self.flatten = Flatten()
self.d6 = Dense(120,activation='relu')
self.d7 = Dense(84,activation='relu')
self.d8 = Dense(10,activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.maxp2(x)
x = self.conv3(x)
x = self.maxp4(x)
x = self.conv5(x)
x = self.flatten(x)
x = self.d6(x)
x = self.d7(x)
return self.d8(x)
model = MyModel()
结果:
Epoch:1, Loss:1.5444934368133545, Accuracy:43.354000091552734, Test Loss:1.2567822933197021, Test Accuracy:54.43000030517578
Epoch:2, Loss:1.1718486547470093, Accuracy:58.37799835205078, Test Loss:1.071388602256775, Test Accuracy:62.30000305175781
Epoch:3, Loss:1.002740740776062, Accuracy:64.55199432373047, Test Loss:1.0042154788970947, Test Accuracy:64.68000030517578
Epoch:4, Loss:0.8955680131912231, Accuracy:68.29000091552734, Test Loss:0.9319255948066711, Test Accuracy:67.40999603271484
Epoch:5, Loss:0.8162031769752502, Accuracy:71.41000366210938, Test Loss:0.9057866930961609, Test Accuracy:68.9000015258789
Epoch:6, Loss:0.7521154880523682, Accuracy:73.58399963378906, Test Loss:0.9005085825920105, Test Accuracy:69.4000015258789
Epoch:7, Loss:0.6932022571563721, Accuracy:75.59000396728516, Test Loss:0.910836398601532, Test Accuracy:69.37000274658203
Epoch:8, Loss:0.632886528968811, Accuracy:77.69599914550781, Test Loss:0.9727368950843811, Test Accuracy:68.4800033569336
Epoch:9, Loss:0.5897338390350342, Accuracy:79.12999725341797, Test Loss:0.8926523923873901, Test Accuracy:71.17000579833984
Epoch:10, Loss:0.5396286249160767, Accuracy:80.79199981689453, Test Loss:0.9478877782821655, Test Accuracy:70.36000061035156