【深度学习 走进tensorflow2.0】一个demo快速了解tensorflow2.0

下面我们用 tensorflow2.0 进行手写数字识别,看看tensorflow2.0 如何写代码。

简单版:使用sequential 模型构建模型:

# -*- coding: utf-8 -*-

import tensorflow as tf


mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0


model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])


model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test,  y_test, verbose=2)

专业版,使用 Keras 模型子类化(model subclassing) API 构建 tf.keras 模型:

# -*- coding: utf-8 -*-


from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model


mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.conv1 = Conv2D(32, 3, activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10, activation='softmax')

  def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)

model = MyModel()


loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

optimizer = tf.keras.optimizers.Adam()



train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')


@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)





@tf.function
def test_step(images, labels):
  predictions = model(images)
  t_loss = loss_object(labels, predictions)

  test_loss(t_loss)
  test_accuracy(labels, predictions)


# 训练100个 epochs
EPOCHS = 100

for epoch in range(EPOCHS):
  for images, labels in train_ds:
    train_step(images, labels)

  for test_images, test_labels in test_ds:
    test_step(test_images, test_labels)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
  print (template.format(epoch+1,
                         train_loss.result(),
                         train_accuracy.result()*100,
                         test_loss.result(),
                         test_accuracy.result()*100))


专业版在GPU服务器上训练100个轮次 测试集准确率最高达到98.41%

Epoch 1, Loss: 0.13505776226520538, Accuracy: 95.9316635131836, Test Loss: 0.06438148766756058, Test Accuracy: 97.79999542236328
Epoch 2, Loss: 0.08899391442537308, Accuracy: 97.30667114257812, Test Loss: 0.06306955963373184, Test Accuracy: 97.94499969482422
Epoch 3, Loss: 0.06712741404771805, Accuracy: 97.94611358642578, Test Loss: 0.06150639429688454, Test Accuracy: 98.08333587646484
Epoch 4, Loss: 0.05401107668876648, Accuracy: 98.33916473388672, Test Loss: 0.061927955597639084, Test Accuracy: 98.09500122070312
Epoch 5, Loss: 0.04508327320218086, Accuracy: 98.60566711425781, Test Loss: 0.06134289130568504, Test Accuracy: 98.1520004272461
Epoch 6, Loss: 0.03891943395137787, Accuracy: 98.79055786132812, Test Loss: 0.062376100569963455, Test Accuracy: 98.18500518798828
Epoch 7, Loss: 0.03424203023314476, Accuracy: 98.93238067626953, Test Loss: 0.06449465453624725, Test Accuracy: 98.2028579711914
Epoch 8, Loss: 0.030599264428019524, Accuracy: 99.04541015625, Test Loss: 0.0659203752875328, Test Accuracy: 98.23124694824219
Epoch 9, Loss: 0.027583660557866096, Accuracy: 99.1392593383789, Test Loss: 0.06723382323980331, Test Accuracy: 98.24333190917969
Epoch 10, Loss: 0.02518264576792717, Accuracy: 99.2135009765625, Test Loss: 0.06946582347154617, Test Accuracy: 98.24699401855469
Epoch 11, Loss: 0.023218179121613503, Accuracy: 99.27484893798828, Test Loss: 0.07024320960044861, Test Accuracy: 98.26818084716797
Epoch 12, Loss: 0.021557582542300224, Accuracy: 99.32666778564453, Test Loss: 0.0733482837677002, Test Accuracy: 98.25666809082031
Epoch 13, Loss: 0.020063508301973343, Accuracy: 99.37256622314453, Test Loss: 0.07517433911561966, Test Accuracy: 98.26461791992188
Epoch 14, Loss: 0.01887793280184269, Accuracy: 99.4088134765625, Test Loss: 0.07549845427274704, Test Accuracy: 98.29357147216797
Epoch 15, Loss: 0.017702585086226463, Accuracy: 99.44489288330078, Test Loss: 0.07679435610771179, Test Accuracy: 98.30133056640625
Epoch 16, Loss: 0.016694325953722, Accuracy: 99.47708129882812, Test Loss: 0.08155844360589981, Test Accuracy: 98.27874755859375
Epoch 17, Loss: 0.015871752053499222, Accuracy: 99.50254821777344, Test Loss: 0.08181449770927429, Test Accuracy: 98.29999542236328
Epoch 18, Loss: 0.015113672241568565, Accuracy: 99.52693939208984, Test Loss: 0.08238118141889572, Test Accuracy: 98.31222534179688
Epoch 19, Loss: 0.014416887424886227, Accuracy: 99.54833221435547, Test Loss: 0.08342304825782776, Test Accuracy: 98.31631469726562
Epoch 20, Loss: 0.013733655214309692, Accuracy: 99.569580078125, Test Loss: 0.08398226648569107, Test Accuracy: 98.32849884033203
Epoch 21, Loss: 0.013174858875572681, Accuracy: 99.5873794555664, Test Loss: 0.08482220023870468, Test Accuracy: 98.336669921875
Epoch 22, Loss: 0.012655692175030708, Accuracy: 99.60317993164062, Test Loss: 0.086238332092762, Test Accuracy: 98.34136199951172
Epoch 23, Loss: 0.012189731933176517, Accuracy: 99.6182632446289, Test Loss: 0.08778291195631027, Test Accuracy: 98.3365249633789
Epoch 24, Loss: 0.011765073984861374, Accuracy: 99.63096618652344, Test Loss: 0.08869967609643936, Test Accuracy: 98.34500122070312
Epoch 25, Loss: 0.011337540112435818, Accuracy: 99.64412689208984, Test Loss: 0.08982877433300018, Test Accuracy: 98.34880065917969
Epoch 26, Loss: 0.010948083363473415, Accuracy: 99.65692138671875, Test Loss: 0.09020914882421494, Test Accuracy: 98.35730743408203
Epoch 27, Loss: 0.01057758741080761, Accuracy: 99.6683349609375, Test Loss: 0.09300722181797028, Test Accuracy: 98.35296630859375
Epoch 28, Loss: 0.010247496888041496, Accuracy: 99.6786880493164, Test Loss: 0.09502799063920975, Test Accuracy: 98.35107421875
Epoch 29, Loss: 0.00994042307138443, Accuracy: 99.68868255615234, Test Loss: 0.09682592749595642, Test Accuracy: 98.34862518310547
Epoch 30, Loss: 0.009655582718551159, Accuracy: 99.69766998291016, Test Loss: 0.09846261888742447, Test Accuracy: 98.34766387939453
Epoch 31, Loss: 0.00935621839016676, Accuracy: 99.70709991455078, Test Loss: 0.09916603565216064, Test Accuracy: 98.35354614257812
Epoch 32, Loss: 0.009064075537025928, Accuracy: 99.71625518798828, Test Loss: 0.09976999461650848, Test Accuracy: 98.36062622070312
Epoch 33, Loss: 0.008789409883320332, Accuracy: 99.72484588623047, Test Loss: 0.1003497764468193, Test Accuracy: 98.36727142333984
Epoch 34, Loss: 0.00853089801967144, Accuracy: 99.73294067382812, Test Loss: 0.10092400014400482, Test Accuracy: 98.37382507324219
Epoch 35, Loss: 0.008287157863378525, Accuracy: 99.74057006835938, Test Loss: 0.10151258856058121, Test Accuracy: 98.38114166259766
Epoch 36, Loss: 0.008056959137320518, Accuracy: 99.7477798461914, Test Loss: 0.10213596373796463, Test Accuracy: 98.3883285522461
Epoch 37, Loss: 0.007839202880859375, Accuracy: 99.75459289550781, Test Loss: 0.10280793905258179, Test Accuracy: 98.39486694335938
Epoch 38, Loss: 0.007632908411324024, Accuracy: 99.76104736328125, Test Loss: 0.1035393550992012, Test Accuracy: 98.40157318115234
Epoch 39, Loss: 0.007437192834913731, Accuracy: 99.76718139648438, Test Loss: 0.10434520989656448, Test Accuracy: 98.40846252441406
Epoch 40, Loss: 0.007251263130456209, Accuracy: 99.77300262451172, Test Loss: 0.10522928833961487, Test Accuracy: 98.41474914550781
Epoch 41, Loss: 0.007074403110891581, Accuracy: 99.7785415649414, Test Loss: 0.1061784029006958, Test Accuracy: 98.42073059082031
Epoch 42, Loss: 0.0069059645757079124, Accuracy: 99.78380584716797, Test Loss: 0.10718268156051636, Test Accuracy: 98.42666625976562
Epoch 43, Loss: 0.006745360791683197, Accuracy: 99.7888412475586, Test Loss: 0.10827546566724777, Test Accuracy: 98.43162536621094
Epoch 44, Loss: 0.006592057179659605, Accuracy: 99.79364013671875, Test Loss: 0.10937657207250595, Test Accuracy: 98.4361343383789
Epoch 45, Loss: 0.0064455671235919, Accuracy: 99.7982177734375, Test Loss: 0.11045938730239868, Test Accuracy: 98.4408950805664
Epoch 46, Loss: 0.006305445916950703, Accuracy: 99.80260467529297, Test Loss: 0.11154481023550034, Test Accuracy: 98.44499969482422
Epoch 47, Loss: 0.006171287503093481, Accuracy: 99.80680847167969, Test Loss: 0.11257602274417877, Test Accuracy: 98.44893646240234
Epoch 48, Loss: 0.006042719352990389, Accuracy: 99.81082916259766, Test Loss: 0.11355891078710556, Test Accuracy: 98.45291137695312
Epoch 49, Loss: 0.00591939827427268, Accuracy: 99.814697265625, Test Loss: 0.11449698358774185, Test Accuracy: 98.45653533935547
Epoch 50, Loss: 0.00580101041123271, Accuracy: 99.81840515136719, Test Loss: 0.11539533734321594, Test Accuracy: 98.45999908447266
Epoch 51, Loss: 0.005687265191227198, Accuracy: 99.82196044921875, Test Loss: 0.11626259237527847, Test Accuracy: 98.46333312988281
Epoch 52, Loss: 0.005577894859015942, Accuracy: 99.82537841796875, Test Loss: 0.11711153388023376, Test Accuracy: 98.4669189453125
Epoch 53, Loss: 0.005472651217132807, Accuracy: 99.82868194580078, Test Loss: 0.11791962385177612, Test Accuracy: 98.47000122070312
Epoch 54, Loss: 0.005371306091547012, Accuracy: 99.83184814453125, Test Loss: 0.11869216710329056, Test Accuracy: 98.47333526611328
Epoch 55, Loss: 0.005273645743727684, Accuracy: 99.83490753173828, Test Loss: 0.11943681538105011, Test Accuracy: 98.47672271728516
Epoch 56, Loss: 0.005179473664611578, Accuracy: 99.83786010742188, Test Loss: 0.12015651166439056, Test Accuracy: 98.47999572753906
Epoch 57, Loss: 0.00508860545232892, Accuracy: 99.8406982421875, Test Loss: 0.12085559964179993, Test Accuracy: 98.48298645019531
Epoch 58, Loss: 0.005000871140509844, Accuracy: 99.84344482421875, Test Loss: 0.12151146680116653, Test Accuracy: 98.48603820800781
Epoch 59, Loss: 0.0049161105416715145, Accuracy: 99.84609985351562, Test Loss: 0.12214148789644241, Test Accuracy: 98.48898315429688
Epoch 60, Loss: 0.0048341755755245686, Accuracy: 99.84866333007812, Test Loss: 0.12275853753089905, Test Accuracy: 98.49183654785156
Epoch 61, Loss: 0.004754926543682814, Accuracy: 99.85115051269531, Test Loss: 0.12335605174303055, Test Accuracy: 98.49459075927734
Epoch 62, Loss: 0.004678233992308378, Accuracy: 99.85354614257812, Test Loss: 0.1239520013332367, Test Accuracy: 98.49694061279297
Epoch 63, Loss: 0.004603976383805275, Accuracy: 99.85587310791016, Test Loss: 0.1245238184928894, Test Accuracy: 98.4993667602539
Epoch 64, Loss: 0.004532039165496826, Accuracy: 99.85812377929688, Test Loss: 0.1250881403684616, Test Accuracy: 98.5015640258789
Epoch 65, Loss: 0.004462315700948238, Accuracy: 99.86031341552734, Test Loss: 0.1256319284439087, Test Accuracy: 98.50384521484375
Epoch 66, Loss: 0.004394704941660166, Accuracy: 99.8624267578125, Test Loss: 0.12615543603897095, Test Accuracy: 98.50606536865234
Epoch 67, Loss: 0.004329112358391285, Accuracy: 99.8644790649414, Test Loss: 0.1266685575246811, Test Accuracy: 98.50820922851562
Epoch 68, Loss: 0.00426544900983572, Accuracy: 99.86647033691406, Test Loss: 0.12717460095882416, Test Accuracy: 98.51029205322266
Epoch 69, Loss: 0.0042036306113004684, Accuracy: 99.868408203125, Test Loss: 0.12767845392227173, Test Accuracy: 98.51232147216797
Epoch 70, Loss: 0.004143578931689262, Accuracy: 99.87028503417969, Test Loss: 0.1281762272119522, Test Accuracy: 98.51414489746094
Epoch 71, Loss: 0.004085218533873558, Accuracy: 99.87211608886719, Test Loss: 0.12866199016571045, Test Accuracy: 98.51605987548828
Epoch 72, Loss: 0.004028479568660259, Accuracy: 99.87388610839844, Test Loss: 0.129144087433815, Test Accuracy: 98.51763916015625
Epoch 73, Loss: 0.003973294980823994, Accuracy: 99.87561798095703, Test Loss: 0.12962999939918518, Test Accuracy: 98.51917266845703
Epoch 74, Loss: 0.0039196014404296875, Accuracy: 99.8772964477539, Test Loss: 0.13010157644748688, Test Accuracy: 98.52067565917969
Epoch 75, Loss: 0.00386734027415514, Accuracy: 99.8789291381836, Test Loss: 0.13056932389736176, Test Accuracy: 98.52213287353516
Epoch 76, Loss: 0.003816454205662012, Accuracy: 99.88053131103516, Test Loss: 0.13102789223194122, Test Accuracy: 98.52355194091797
Epoch 77, Loss: 0.0037668899167329073, Accuracy: 99.882080078125, Test Loss: 0.13148051500320435, Test Accuracy: 98.5248031616211
Epoch 78, Loss: 0.0037185964174568653, Accuracy: 99.88359069824219, Test Loss: 0.13192768394947052, Test Accuracy: 98.52615356445312
Epoch 79, Loss: 0.00367152551189065, Accuracy: 99.88506317138672, Test Loss: 0.13237474858760834, Test Accuracy: 98.5274658203125
Epoch 80, Loss: 0.0036256315652281046, Accuracy: 99.8864974975586, Test Loss: 0.1328185796737671, Test Accuracy: 98.52862548828125
Epoch 81, Loss: 0.003580870572477579, Accuracy: 99.88790130615234, Test Loss: 0.13325545191764832, Test Accuracy: 98.52974700927734
Epoch 82, Loss: 0.0035372015554457903, Accuracy: 99.88926696777344, Test Loss: 0.13368935883045197, Test Accuracy: 98.53121948242188
Epoch 83, Loss: 0.0034945847000926733, Accuracy: 99.8906021118164, Test Loss: 0.1341179609298706, Test Accuracy: 98.53228759765625
Epoch 84, Loss: 0.0034529822878539562, Accuracy: 99.89190673828125, Test Loss: 0.1345466524362564, Test Accuracy: 98.53333282470703
Epoch 85, Loss: 0.0034123591613024473, Accuracy: 99.89317321777344, Test Loss: 0.1349654644727707, Test Accuracy: 98.53446960449219
Epoch 86, Loss: 0.0033726803958415985, Accuracy: 99.89441680908203, Test Loss: 0.13537940382957458, Test Accuracy: 98.53557586669922
Epoch 87, Loss: 0.0033339140936732292, Accuracy: 99.8956298828125, Test Loss: 0.13580235838890076, Test Accuracy: 98.53655242919922
Epoch 88, Loss: 0.0032960285898298025, Accuracy: 99.89682006835938, Test Loss: 0.1362161934375763, Test Accuracy: 98.53772735595703
Epoch 89, Loss: 0.003258994547650218, Accuracy: 99.8979721069336, Test Loss: 0.1366264969110489, Test Accuracy: 98.53887939453125
Epoch 90, Loss: 0.00322278356179595, Accuracy: 99.89911651611328, Test Loss: 0.1370314061641693, Test Accuracy: 98.54000091552734
Epoch 91, Loss: 0.003187368391081691, Accuracy: 99.90022277832031, Test Loss: 0.13743369281291962, Test Accuracy: 98.5409927368164
Epoch 92, Loss: 0.0031527229584753513, Accuracy: 99.90130615234375, Test Loss: 0.13783754408359528, Test Accuracy: 98.54206085205078
Epoch 93, Loss: 0.003118822816759348, Accuracy: 99.9023666381836, Test Loss: 0.13824236392974854, Test Accuracy: 98.54312133789062
Epoch 94, Loss: 0.0030856437515467405, Accuracy: 99.90340423583984, Test Loss: 0.1386358141899109, Test Accuracy: 98.54415130615234
Epoch 95, Loss: 0.0030531634110957384, Accuracy: 99.90442657470703, Test Loss: 0.139034241437912, Test Accuracy: 98.54495239257812
Epoch 96, Loss: 0.0030213596764951944, Accuracy: 99.9054183959961, Test Loss: 0.13942478597164154, Test Accuracy: 98.54582977294922
Epoch 97, Loss: 0.0029902115929871798, Accuracy: 99.90638732910156, Test Loss: 0.13981282711029053, Test Accuracy: 98.54669952392578
Epoch 98, Loss: 0.00295969913713634, Accuracy: 99.90734100341797, Test Loss: 0.14019285142421722, Test Accuracy: 98.5472412109375
Epoch 99, Loss: 0.002929803216829896, Accuracy: 99.90827941894531, Test Loss: 0.14057938754558563, Test Accuracy: 98.5478744506836
Epoch 100, Loss: 0.002900505205616355, Accuracy: 99.90919494628906, Test Loss: 0.14096887409687042, Test Accuracy: 98.54859924316406


Subclassing API 提供了由运行定义的高级研究接口。为您的模型创建一个类,然后以命令方式编写前向传播。您可以轻松编写自定义层、激活函数和训练循环。

你可能感兴趣的:(【深度学习 走进tensorflow2.0】一个demo快速了解tensorflow2.0)