【tensorflow2.0】高阶api--主要为tf.keras.models提供的模型的类接口

下面的范例使用TensorFlow的高阶API实现线性回归模型。

TensorFlow的高阶API主要为tf.keras.models提供的模型的类接口。

使用Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型。

此处分别演示使用Sequential按层顺序构建模型以及继承Model基类构建自定义模型。

一,使用Sequential按层顺序构建模型【面向新手】

import tensorflow as tf
from tensorflow.keras import models,layers,optimizers
 
# 样本数量
n = 800
 
# 生成测试用数据集
X = tf.random.uniform([n,2],minval=-10,maxval=10) 
w0 = tf.constant([[2.0],[-1.0]])
b0 = tf.constant(3.0)
 
Y = X@w0 + b0 + tf.random.normal([n,1],mean = 0.0,stddev= 2.0)  # @表示矩阵乘法,增加正态扰动
tf.keras.backend.clear_session()
 
linear = models.Sequential()
linear.add(layers.Dense(1,input_shape =(2,)))
linear.summary()


### 使用fit方法进行训练
 
linear.compile(optimizer="adam",loss="mse",metrics=["mae"])
linear.fit(X,Y,batch_size = 20,epochs = 200)  
 
tf.print("w = ",linear.layers[0].kernel)
tf.print("b = ",linear.layers[0].bias)

结果:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 1)                 3         
=================================================================
Total params: 3
Trainable params: 3
Non-trainable params: 0
_________________________________________________________________
Epoch 1/200
40/40 [==============================] - 0s 908us/step - loss: 195.5055 - mae: 11.7040
Epoch 2/200
40/40 [==============================] - 0s 870us/step - loss: 188.2559 - mae: 11.4891
Epoch 3/200
40/40 [==============================] - 0s 820us/step - loss: 181.3084 - mae: 11.2766
Epoch 4/200
40/40 [==============================] - 0s 859us/step - loss: 174.4538 - mae: 11.0680
Epoch 5/200
40/40 [==============================] - 0s 886us/step - loss: 167.8749 - mae: 10.8582
Epoch 6/200
40/40 [==============================] - 0s 912us/step - loss: 161.5035 - mae: 10.6533
Epoch 7/200
40/40 [==============================] - 0s 916us/step - loss: 155.3012 - mae: 10.4504
Epoch 8/200
40/40 [==============================] - 0s 839us/step - loss: 149.3520 - mae: 10.2490
Epoch 9/200
40/40 [==============================] - 0s 977us/step - loss: 143.5773 - mae: 10.0487
Epoch 10/200
40/40 [==============================] - 0s 951us/step - loss: 137.9654 - mae: 9.8543
Epoch 11/200
40/40 [==============================] - 0s 964us/step - loss: 132.5708 - mae: 9.6616
Epoch 12/200
40/40 [==============================] - 0s 876us/step - loss: 127.3686 - mae: 9.4716
Epoch 13/200
40/40 [==============================] - 0s 885us/step - loss: 122.3309 - mae: 9.2796
Epoch 14/200
40/40 [==============================] - 0s 901us/step - loss: 117.4739 - mae: 9.0935
Epoch 15/200
40/40 [==============================] - 0s 919us/step - loss: 112.7674 - mae: 8.9095
Epoch 16/200
40/40 [==============================] - 0s 1ms/step - loss: 108.2400 - mae: 8.7304
Epoch 17/200
40/40 [==============================] - 0s 1ms/step - loss: 103.8868 - mae: 8.5522
Epoch 18/200
40/40 [==============================] - 0s 955us/step - loss: 99.6424 - mae: 8.3771
Epoch 19/200
40/40 [==============================] - 0s 951us/step - loss: 95.6005 - mae: 8.2044
Epoch 20/200
40/40 [==============================] - 0s 939us/step - loss: 91.7217 - mae: 8.0324
Epoch 21/200
40/40 [==============================] - 0s 1ms/step - loss: 87.9180 - mae: 7.8633
Epoch 22/200
40/40 [==============================] - 0s 1ms/step - loss: 84.2936 - mae: 7.6975
Epoch 23/200
40/40 [==============================] - 0s 1ms/step - loss: 80.7858 - mae: 7.5372
Epoch 24/200
40/40 [==============================] - 0s 891us/step - loss: 77.4177 - mae: 7.3785
Epoch 25/200
40/40 [==============================] - 0s 902us/step - loss: 74.1665 - mae: 7.2210
Epoch 26/200
40/40 [==============================] - 0s 876us/step - loss: 71.0455 - mae: 7.0657
Epoch 27/200
40/40 [==============================] - 0s 892us/step - loss: 68.0396 - mae: 6.9119
Epoch 28/200
40/40 [==============================] - 0s 898us/step - loss: 65.1385 - mae: 6.7610
Epoch 29/200
40/40 [==============================] - 0s 944us/step - loss: 62.3531 - mae: 6.6115
Epoch 30/200
40/40 [==============================] - 0s 1ms/step - loss: 59.6815 - mae: 6.4647
Epoch 31/200
40/40 [==============================] - 0s 1ms/step - loss: 57.0783 - mae: 6.3193
Epoch 32/200
40/40 [==============================] - 0s 978us/step - loss: 54.6050 - mae: 6.1775
Epoch 33/200
40/40 [==============================] - 0s 940us/step - loss: 52.2259 - mae: 6.0359
Epoch 34/200
40/40 [==============================] - 0s 966us/step - loss: 49.9196 - mae: 5.8980
Epoch 35/200
40/40 [==============================] - 0s 964us/step - loss: 47.7187 - mae: 5.7628
Epoch 36/200
40/40 [==============================] - 0s 1ms/step - loss: 45.6023 - mae: 5.6286
Epoch 37/200
40/40 [==============================] - 0s 953us/step - loss: 43.5680 - mae: 5.4965
Epoch 38/200
40/40 [==============================] - 0s 978us/step - loss: 41.6182 - mae: 5.3673
Epoch 39/200
40/40 [==============================] - 0s 1ms/step - loss: 39.7323 - mae: 5.2402
Epoch 40/200
40/40 [==============================] - 0s 976us/step - loss: 37.9372 - mae: 5.1159
Epoch 41/200
40/40 [==============================] - 0s 989us/step - loss: 36.2184 - mae: 4.9935
Epoch 42/200
40/40 [==============================] - 0s 964us/step - loss: 34.5556 - mae: 4.8724
Epoch 43/200
40/40 [==============================] - 0s 978us/step - loss: 32.9704 - mae: 4.7550
Epoch 44/200
40/40 [==============================] - 0s 954us/step - loss: 31.4466 - mae: 4.6392
Epoch 45/200
40/40 [==============================] - 0s 1ms/step - loss: 29.9887 - mae: 4.5273
Epoch 46/200
40/40 [==============================] - 0s 1ms/step - loss: 28.5938 - mae: 4.4169
Epoch 47/200
40/40 [==============================] - 0s 944us/step - loss: 27.2567 - mae: 4.3116
Epoch 48/200
40/40 [==============================] - 0s 874us/step - loss: 25.9801 - mae: 4.2037
Epoch 49/200
40/40 [==============================] - 0s 875us/step - loss: 24.7709 - mae: 4.1004
Epoch 50/200
40/40 [==============================] - 0s 843us/step - loss: 23.5911 - mae: 3.9987
Epoch 51/200
40/40 [==============================] - 0s 880us/step - loss: 22.4801 - mae: 3.8986
Epoch 52/200
40/40 [==============================] - 0s 862us/step - loss: 21.4129 - mae: 3.8020
Epoch 53/200
40/40 [==============================] - 0s 930us/step - loss: 20.4039 - mae: 3.7072
Epoch 54/200
40/40 [==============================] - 0s 921us/step - loss: 19.4387 - mae: 3.6129
Epoch 55/200
40/40 [==============================] - 0s 929us/step - loss: 18.5113 - mae: 3.5211
Epoch 56/200
40/40 [==============================] - 0s 958us/step - loss: 17.6301 - mae: 3.4325
Epoch 57/200
40/40 [==============================] - 0s 857us/step - loss: 16.7977 - mae: 3.3455
Epoch 58/200
40/40 [==============================] - 0s 924us/step - loss: 16.0002 - mae: 3.2620
Epoch 59/200
40/40 [==============================] - 0s 906us/step - loss: 15.2526 - mae: 3.1796
Epoch 60/200
40/40 [==============================] - 0s 989us/step - loss: 14.5282 - mae: 3.1000
Epoch 61/200
40/40 [==============================] - 0s 1ms/step - loss: 13.8489 - mae: 3.0228
Epoch 62/200
40/40 [==============================] - 0s 957us/step - loss: 13.2086 - mae: 2.9496
Epoch 63/200
40/40 [==============================] - 0s 1ms/step - loss: 12.5944 - mae: 2.8770
Epoch 64/200
40/40 [==============================] - 0s 1ms/step - loss: 12.0144 - mae: 2.8087
Epoch 65/200
40/40 [==============================] - 0s 939us/step - loss: 11.4699 - mae: 2.7409
Epoch 66/200
40/40 [==============================] - 0s 950us/step - loss: 10.9486 - mae: 2.6764
Epoch 67/200
40/40 [==============================] - 0s 922us/step - loss: 10.4627 - mae: 2.6140
Epoch 68/200
40/40 [==============================] - 0s 937us/step - loss: 10.0007 - mae: 2.5530
Epoch 69/200
40/40 [==============================] - 0s 1ms/step - loss: 9.5686 - mae: 2.4958
Epoch 70/200
40/40 [==============================] - 0s 926us/step - loss: 9.1566 - mae: 2.4412
Epoch 71/200
40/40 [==============================] - 0s 990us/step - loss: 8.7749 - mae: 2.3897
Epoch 72/200
40/40 [==============================] - 0s 1ms/step - loss: 8.4119 - mae: 2.3410
Epoch 73/200
40/40 [==============================] - 0s 1ms/step - loss: 8.0721 - mae: 2.2930
Epoch 74/200
40/40 [==============================] - 0s 996us/step - loss: 7.7548 - mae: 2.2490
Epoch 75/200
40/40 [==============================] - 0s 1ms/step - loss: 7.4565 - mae: 2.2054
Epoch 76/200
40/40 [==============================] - 0s 1ms/step - loss: 7.1764 - mae: 2.1642
Epoch 77/200
40/40 [==============================] - 0s 987us/step - loss: 6.9172 - mae: 2.1252
Epoch 78/200
40/40 [==============================] - 0s 1ms/step - loss: 6.6718 - mae: 2.0881
Epoch 79/200
40/40 [==============================] - 0s 1ms/step - loss: 6.4435 - mae: 2.0517
Epoch 80/200
40/40 [==============================] - 0s 1ms/step - loss: 6.2325 - mae: 2.0181
Epoch 81/200
40/40 [==============================] - 0s 946us/step - loss: 6.0333 - mae: 1.9845
Epoch 82/200
40/40 [==============================] - 0s 934us/step - loss: 5.8515 - mae: 1.9533
Epoch 83/200
40/40 [==============================] - 0s 922us/step - loss: 5.6774 - mae: 1.9230
Epoch 84/200
40/40 [==============================] - 0s 941us/step - loss: 5.5195 - mae: 1.8950
Epoch 85/200
40/40 [==============================] - 0s 1ms/step - loss: 5.3701 - mae: 1.8676
Epoch 86/200
40/40 [==============================] - 0s 1ms/step - loss: 5.2337 - mae: 1.8420
Epoch 87/200
40/40 [==============================] - 0s 1ms/step - loss: 5.1067 - mae: 1.8188
Epoch 88/200
40/40 [==============================] - 0s 894us/step - loss: 4.9888 - mae: 1.7968
Epoch 89/200
40/40 [==============================] - 0s 909us/step - loss: 4.8797 - mae: 1.7761
Epoch 90/200
40/40 [==============================] - 0s 876us/step - loss: 4.7784 - mae: 1.7572
Epoch 91/200
40/40 [==============================] - 0s 872us/step - loss: 4.6857 - mae: 1.7381
Epoch 92/200
40/40 [==============================] - 0s 866us/step - loss: 4.5981 - mae: 1.7221
Epoch 93/200
40/40 [==============================] - 0s 928us/step - loss: 4.5178 - mae: 1.7055
Epoch 94/200
40/40 [==============================] - 0s 868us/step - loss: 4.4441 - mae: 1.6920
Epoch 95/200
40/40 [==============================] - 0s 931us/step - loss: 4.3759 - mae: 1.6776
Epoch 96/200
40/40 [==============================] - 0s 963us/step - loss: 4.3143 - mae: 1.6650
Epoch 97/200
40/40 [==============================] - 0s 971us/step - loss: 4.2540 - mae: 1.6532
Epoch 98/200
40/40 [==============================] - 0s 914us/step - loss: 4.2015 - mae: 1.6427
Epoch 99/200
40/40 [==============================] - 0s 874us/step - loss: 4.1508 - mae: 1.6330
Epoch 100/200
40/40 [==============================] - 0s 897us/step - loss: 4.1059 - mae: 1.6243
Epoch 101/200
40/40 [==============================] - 0s 884us/step - loss: 4.0636 - mae: 1.6162
Epoch 102/200
40/40 [==============================] - 0s 971us/step - loss: 4.0239 - mae: 1.6081
Epoch 103/200
40/40 [==============================] - 0s 918us/step - loss: 3.9885 - mae: 1.6012
Epoch 104/200
40/40 [==============================] - 0s 990us/step - loss: 3.9542 - mae: 1.5946
Epoch 105/200
40/40 [==============================] - 0s 919us/step - loss: 3.9245 - mae: 1.5892
Epoch 106/200
40/40 [==============================] - 0s 872us/step - loss: 3.8949 - mae: 1.5834
Epoch 107/200
40/40 [==============================] - 0s 879us/step - loss: 3.8686 - mae: 1.5779
Epoch 108/200
40/40 [==============================] - 0s 872us/step - loss: 3.8441 - mae: 1.5735
Epoch 109/200
40/40 [==============================] - 0s 1ms/step - loss: 3.8221 - mae: 1.5693
Epoch 110/200
40/40 [==============================] - 0s 941us/step - loss: 3.7991 - mae: 1.5651
Epoch 111/200
40/40 [==============================] - 0s 958us/step - loss: 3.7793 - mae: 1.5617
Epoch 112/200
40/40 [==============================] - 0s 888us/step - loss: 3.7607 - mae: 1.5583
Epoch 113/200
40/40 [==============================] - 0s 834us/step - loss: 3.7446 - mae: 1.5555
Epoch 114/200
40/40 [==============================] - 0s 872us/step - loss: 3.7285 - mae: 1.5529
Epoch 115/200
40/40 [==============================] - 0s 878us/step - loss: 3.7146 - mae: 1.5499
Epoch 116/200
40/40 [==============================] - 0s 944us/step - loss: 3.7016 - mae: 1.5476
Epoch 117/200
40/40 [==============================] - 0s 949us/step - loss: 3.6883 - mae: 1.5449
Epoch 118/200
40/40 [==============================] - 0s 939us/step - loss: 3.6753 - mae: 1.5428
Epoch 119/200
40/40 [==============================] - 0s 859us/step - loss: 3.6651 - mae: 1.5408
Epoch 120/200
40/40 [==============================] - 0s 876us/step - loss: 3.6544 - mae: 1.5387
Epoch 121/200
40/40 [==============================] - 0s 860us/step - loss: 3.6459 - mae: 1.5371
Epoch 122/200
40/40 [==============================] - 0s 938us/step - loss: 3.6357 - mae: 1.5357
Epoch 123/200
40/40 [==============================] - 0s 918us/step - loss: 3.6284 - mae: 1.5345
Epoch 124/200
40/40 [==============================] - 0s 890us/step - loss: 3.6212 - mae: 1.5334
Epoch 125/200
40/40 [==============================] - 0s 853us/step - loss: 3.6131 - mae: 1.5318
Epoch 126/200
40/40 [==============================] - 0s 856us/step - loss: 3.6067 - mae: 1.5307
Epoch 127/200
40/40 [==============================] - 0s 1ms/step - loss: 3.6014 - mae: 1.5297
Epoch 128/200
40/40 [==============================] - 0s 990us/step - loss: 3.5953 - mae: 1.5289
Epoch 129/200
40/40 [==============================] - 0s 955us/step - loss: 3.5898 - mae: 1.5278
Epoch 130/200
40/40 [==============================] - 0s 929us/step - loss: 3.5857 - mae: 1.5270
Epoch 131/200
40/40 [==============================] - 0s 878us/step - loss: 3.5823 - mae: 1.5267
Epoch 132/200
40/40 [==============================] - 0s 925us/step - loss: 3.5767 - mae: 1.5255
Epoch 133/200
40/40 [==============================] - 0s 1ms/step - loss: 3.5735 - mae: 1.5246
Epoch 134/200
40/40 [==============================] - 0s 950us/step - loss: 3.5699 - mae: 1.5239
Epoch 135/200
40/40 [==============================] - 0s 855us/step - loss: 3.5664 - mae: 1.5233
Epoch 136/200
40/40 [==============================] - 0s 869us/step - loss: 3.5637 - mae: 1.5228
Epoch 137/200
40/40 [==============================] - 0s 920us/step - loss: 3.5611 - mae: 1.5224
Epoch 138/200
40/40 [==============================] - 0s 946us/step - loss: 3.5586 - mae: 1.5218
Epoch 139/200
40/40 [==============================] - 0s 864us/step - loss: 3.5570 - mae: 1.5216
Epoch 140/200
40/40 [==============================] - 0s 1ms/step - loss: 3.5544 - mae: 1.5208
Epoch 141/200
40/40 [==============================] - 0s 990us/step - loss: 3.5522 - mae: 1.5206
Epoch 142/200
40/40 [==============================] - 0s 914us/step - loss: 3.5508 - mae: 1.5200
Epoch 143/200
40/40 [==============================] - 0s 865us/step - loss: 3.5494 - mae: 1.5197
Epoch 144/200
40/40 [==============================] - 0s 867us/step - loss: 3.5487 - mae: 1.5194
Epoch 145/200
40/40 [==============================] - 0s 848us/step - loss: 3.5473 - mae: 1.5194
Epoch 146/200
40/40 [==============================] - 0s 920us/step - loss: 3.5453 - mae: 1.5188
Epoch 147/200
40/40 [==============================] - 0s 954us/step - loss: 3.5445 - mae: 1.5186
Epoch 148/200
40/40 [==============================] - 0s 958us/step - loss: 3.5443 - mae: 1.5188
Epoch 149/200
40/40 [==============================] - 0s 929us/step - loss: 3.5430 - mae: 1.5181
Epoch 150/200
40/40 [==============================] - 0s 919us/step - loss: 3.5430 - mae: 1.5186
Epoch 151/200
40/40 [==============================] - 0s 875us/step - loss: 3.5409 - mae: 1.5176
Epoch 152/200
40/40 [==============================] - 0s 931us/step - loss: 3.5425 - mae: 1.5177
Epoch 153/200
40/40 [==============================] - 0s 957us/step - loss: 3.5403 - mae: 1.5175
Epoch 154/200
40/40 [==============================] - 0s 967us/step - loss: 3.5403 - mae: 1.5172
Epoch 155/200
40/40 [==============================] - 0s 873us/step - loss: 3.5425 - mae: 1.5177
Epoch 156/200
40/40 [==============================] - 0s 905us/step - loss: 3.5402 - mae: 1.5173
Epoch 157/200
40/40 [==============================] - 0s 1ms/step - loss: 3.5395 - mae: 1.5172
Epoch 158/200
40/40 [==============================] - 0s 876us/step - loss: 3.5385 - mae: 1.5169
Epoch 159/200
40/40 [==============================] - 0s 877us/step - loss: 3.5383 - mae: 1.5167
Epoch 160/200
40/40 [==============================] - 0s 847us/step - loss: 3.5385 - mae: 1.5167
Epoch 161/200
40/40 [==============================] - 0s 846us/step - loss: 3.5375 - mae: 1.5165
Epoch 162/200
40/40 [==============================] - 0s 947us/step - loss: 3.5377 - mae: 1.5166
Epoch 163/200
40/40 [==============================] - 0s 986us/step - loss: 3.5371 - mae: 1.5165
Epoch 164/200
40/40 [==============================] - 0s 869us/step - loss: 3.5380 - mae: 1.5167
Epoch 165/200
40/40 [==============================] - 0s 875us/step - loss: 3.5402 - mae: 1.5169
Epoch 166/200
40/40 [==============================] - 0s 913us/step - loss: 3.5390 - mae: 1.5170
Epoch 167/200
40/40 [==============================] - 0s 926us/step - loss: 3.5389 - mae: 1.5163
Epoch 168/200
40/40 [==============================] - 0s 853us/step - loss: 3.5379 - mae: 1.5160
Epoch 169/200
40/40 [==============================] - 0s 925us/step - loss: 3.5380 - mae: 1.5159
Epoch 170/200
40/40 [==============================] - 0s 935us/step - loss: 3.5376 - mae: 1.5167
Epoch 171/200
40/40 [==============================] - 0s 873us/step - loss: 3.5371 - mae: 1.5164
Epoch 172/200
40/40 [==============================] - 0s 847us/step - loss: 3.5376 - mae: 1.5165
Epoch 173/200
40/40 [==============================] - 0s 874us/step - loss: 3.5383 - mae: 1.5167
Epoch 174/200
40/40 [==============================] - 0s 930us/step - loss: 3.5362 - mae: 1.5162
Epoch 175/200
40/40 [==============================] - 0s 960us/step - loss: 3.5386 - mae: 1.5165
Epoch 176/200
40/40 [==============================] - 0s 968us/step - loss: 3.5376 - mae: 1.5166
Epoch 177/200
40/40 [==============================] - 0s 986us/step - loss: 3.5373 - mae: 1.5164
Epoch 178/200
40/40 [==============================] - 0s 907us/step - loss: 3.5395 - mae: 1.5166
Epoch 179/200
40/40 [==============================] - 0s 911us/step - loss: 3.5375 - mae: 1.5161
Epoch 180/200
40/40 [==============================] - 0s 1ms/step - loss: 3.5377 - mae: 1.5165
Epoch 181/200
40/40 [==============================] - 0s 1ms/step - loss: 3.5367 - mae: 1.5164
Epoch 182/200
40/40 [==============================] - 0s 890us/step - loss: 3.5380 - mae: 1.5164
Epoch 183/200
40/40 [==============================] - 0s 926us/step - loss: 3.5373 - mae: 1.5167
Epoch 184/200
40/40 [==============================] - 0s 931us/step - loss: 3.5389 - mae: 1.5168
Epoch 185/200
40/40 [==============================] - 0s 839us/step - loss: 3.5371 - mae: 1.5158
Epoch 186/200
40/40 [==============================] - 0s 892us/step - loss: 3.5383 - mae: 1.5159
Epoch 187/200
40/40 [==============================] - 0s 915us/step - loss: 3.5371 - mae: 1.5163
Epoch 188/200
40/40 [==============================] - 0s 992us/step - loss: 3.5384 - mae: 1.5170
Epoch 189/200
40/40 [==============================] - 0s 913us/step - loss: 3.5376 - mae: 1.5160
Epoch 190/200
40/40 [==============================] - 0s 970us/step - loss: 3.5386 - mae: 1.5166
Epoch 191/200
40/40 [==============================] - 0s 954us/step - loss: 3.5398 - mae: 1.5163
Epoch 192/200
40/40 [==============================] - 0s 906us/step - loss: 3.5370 - mae: 1.5163
Epoch 193/200
40/40 [==============================] - 0s 892us/step - loss: 3.5371 - mae: 1.5166
Epoch 194/200
40/40 [==============================] - 0s 1ms/step - loss: 3.5389 - mae: 1.5167
Epoch 195/200
40/40 [==============================] - 0s 976us/step - loss: 3.5376 - mae: 1.5170
Epoch 196/200
40/40 [==============================] - 0s 925us/step - loss: 3.5371 - mae: 1.5164
Epoch 197/200
40/40 [==============================] - 0s 995us/step - loss: 3.5368 - mae: 1.5161
Epoch 198/200
40/40 [==============================] - 0s 957us/step - loss: 3.5380 - mae: 1.5161
Epoch 199/200
40/40 [==============================] - 0s 923us/step - loss: 3.5391 - mae: 1.5162
Epoch 200/200
40/40 [==============================] - 0s 899us/step - loss: 3.5368 - mae: 1.5160
w =  [[2.00381827]
 [-0.98936516]]
b =  [2.9572618]

二,继承Model基类构建自定义模型【面向专家】

import tensorflow as tf
from tensorflow.keras import models,layers,optimizers,losses,metrics
 
 
# 打印时间分割线
@tf.function
def printbar():
    ts = tf.timestamp()
    today_ts = ts%(24*60*60)
 
    hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
    minite = tf.cast((today_ts%3600)//60,tf.int32)
    second = tf.cast(tf.floor(today_ts%60),tf.int32)
 
    def timeformat(m):
        if tf.strings.length(tf.strings.format("{}",m))==1:
            return(tf.strings.format("0{}",m))
        else:
            return(tf.strings.format("{}",m))
 
    timestring = tf.strings.join([timeformat(hour),timeformat(minite),
                timeformat(second)],separator = ":")
    tf.print("=========="*8,end = "")
    tf.print(timestring)
 
# 样本数量
n = 800
 
# 生成测试用数据集
X = tf.random.uniform([n,2],minval=-10,maxval=10) 
w0 = tf.constant([[2.0],[-1.0]])
b0 = tf.constant(3.0)
 
Y = X@w0 + b0 + tf.random.normal([n,1],mean = 0.0,stddev= 2.0)  # @表示矩阵乘法,增加正态扰动
 
ds_train = tf.data.Dataset.from_tensor_slices((X[0:n*3//4,:],Y[0:n*3//4,:])) \
     .shuffle(buffer_size = 1000).batch(20) \
     .prefetch(tf.data.experimental.AUTOTUNE) \
     .cache()
 
ds_valid = tf.data.Dataset.from_tensor_slices((X[n*3//4:,:],Y[n*3//4:,:])) \
     .shuffle(buffer_size = 1000).batch(20) \
     .prefetch(tf.data.experimental.AUTOTUNE) \
     .cache()
 
tf.keras.backend.clear_session()
 
class MyModel(models.Model):
    def __init__(self):
        super(MyModel, self).__init__()
 
    def build(self,input_shape):
        self.dense1 = layers.Dense(1)   
        super(MyModel,self).build(input_shape)
 
    def call(self, x):
        y = self.dense1(x)
        return(y)
 
model = MyModel()
model.build(input_shape =(None,2))
model.summary()
 


### 自定义训练循环(专家教程)
 
 
optimizer = optimizers.Adam()
loss_func = losses.MeanSquaredError()
 
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_metric = tf.keras.metrics.MeanAbsoluteError(name='train_mae')
 
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_metric = tf.keras.metrics.MeanAbsoluteError(name='valid_mae')
 
 
@tf.function
def train_step(model, features, labels):
    with tf.GradientTape() as tape:
        predictions = model(features)
        loss = loss_func(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
 
    train_loss.update_state(loss)
    train_metric.update_state(labels, predictions)
 
@tf.function
def valid_step(model, features, labels):
    predictions = model(features)
    batch_loss = loss_func(labels, predictions)
    valid_loss.update_state(batch_loss)
    valid_metric.update_state(labels, predictions)
 
 
@tf.function
def train_model(model,ds_train,ds_valid,epochs):
    for epoch in tf.range(1,epochs+1):
        for features, labels in ds_train:
            train_step(model,features,labels)
 
        for features, labels in ds_valid:
            valid_step(model,features,labels)
 
        logs = 'Epoch={},Loss:{},MAE:{},Valid Loss:{},Valid MAE:{}'
 
        if  epoch%100 ==0:
            printbar()
            tf.print(tf.strings.format(logs,
            (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
            tf.print("w=",model.layers[0].kernel)
            tf.print("b=",model.layers[0].bias)
            tf.print("")
 
        train_loss.reset_states()
        valid_loss.reset_states()
        train_metric.reset_states()
        valid_metric.reset_states()
 
train_model(model,ds_train,ds_valid,400)

结果:

Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  3         
=================================================================
Total params: 3
Trainable params: 3
Non-trainable params: 0
_________________________________________________________________
================================================================================15:40:27
Epoch=100,Loss:7.5666852,MAE:2.1710279,Valid Loss:6.50372219,Valid MAE:2.06310129
w= [[1.78483891]
 [-0.941808105]]
b= [1.89865637]

================================================================================15:40:34
Epoch=200,Loss:4.18288374,MAE:1.6310848,Valid Loss:3.79517508,Valid MAE:1.53697133
w= [[2.02300119]
 [-0.992656231]]
b= [2.88763976]

================================================================================15:40:42
Epoch=300,Loss:4.17580175,MAE:1.62464666,Valid Loss:3.80199885,Valid MAE:1.53819764
w= [[2.02173]
 [-0.992035568]]
b= [2.97494888]

================================================================================15:40:49
Epoch=400,Loss:4.17601919,MAE:1.6246767,Valid Loss:3.80182695,Valid MAE:1.53820801
w= [[2.02159858]
 [-0.992003262]]
b= [2.97537684]

 

参考:

开源电子书地址:https://lyhue1991.github.io/eat_tensorflow2_in_30_days/

GitHub 项目地址:https://github.com/lyhue1991/eat_tensorflow2_in_30_days

你可能感兴趣的:(【tensorflow2.0】高阶api--主要为tf.keras.models提供的模型的类接口)