Tensorflow 2.0 教程持续更新 :https://blog.csdn.net/qq_31456593/article/details/88606284
完整tensorflow2.0教程代码请看tensorflow2.0:中文教程tensorflow2_tutorials_chinese(欢迎star)
入门教程:
TensorFlow 2.0 教程- Keras 快速入门
TensorFlow 2.0 教程-keras 函数api
TensorFlow 2.0 教程-使用keras训练模型
TensorFlow 2.0 教程-用keras构建自己的网络层
TensorFlow 2.0 教程-keras模型保存和序列化
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = tf.expand_dims(x_train.astype('float32'), -1) / 255.0
x_test = tf.expand_dims(x_test.astype('float32'),-1) / 255.0
print(x_train.shape, ' ', y_train.shape)
print(x_test.shape, ' ', y_test.shape)
(60000, 28, 28, 1) (60000,)
(10000, 28, 28, 1) (10000,)
inputs = layers.Input(shape=(x_train.shape[1], x_train.shape[2], x_train.shape[3]), name='inputs')
print(inputs.shape)
code = layers.Conv2D(16, (3,3), activation='relu', padding='same')(inputs)
code = layers.MaxPool2D((2,2), padding='same')(code)
print(code.shape)
decoded = layers.Conv2D(16, (3,3), activation='relu', padding='same')(code)
decoded = layers.UpSampling2D((2,2))(decoded)
print(decoded.shape)
outputs = layers.Conv2D(1, (3,3), activation='sigmoid', padding='same')(decoded)
print(outputs.shape)
auto_encoder = keras.Model(inputs, outputs)
(None, 28, 28, 1)
(None, 14, 14, 16)
(None, 28, 28, 16)
(None, 28, 28, 1)
auto_encoder.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy())
keras.utils.plot_model(auto_encoder, show_shapes=True)
early_stop = keras.callbacks.EarlyStopping(patience=2, monitor='loss')
auto_encoder.fit(x_train,x_train, batch_size=64, epochs=1, validation_split=0.1,validation_freq=10,
callbacks=[early_stop])
Train on 54000 samples, validate on 6000 samples
54000/54000 [==============================] - 31s 572us/sample - loss: 0.1007
import matplotlib.pyplot as plt
decoded = auto_encoder.predict(x_test)
n = 5
plt.figure(figsize=(10, 4))
for i in range(n):
# display original
ax = plt.subplot(2, n, i+1)
plt.imshow(tf.reshape(x_test[i+1],(28, 28)))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + n+1)
plt.imshow(tf.reshape(decoded[i+1],(28, 28)))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()