# 尝试获取预训练模型的特定层,并基于这些层构建模型
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, UpSampling2D, Input, BatchNormalization, concatenate, MaxPool2D
def load_data():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train/255.*2-1, x_test/255.*2-1
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)
y_train = tf.expand_dims(y_train, -1)
y_test = tf.expand_dims(y_test, -1)
return (x_train, y_train), (x_test, y_test)
input1 = Input((28,28,1), dtype=tf.float64)
input2 = Input((28,28,1), dtype=tf.float64)
input_ = concatenate([input1, input2])
# (28,28,2)
net = Conv2D(16, 3, padding='same', activation='relu')(input_)
# (28,28,16)
net = MaxPool2D()(net)
# (14,14,16)
net = Conv2D(32, 3, padding='same', activation='relu')(net)
# (14,14,32)
net = MaxPool2D()(net)
# (7,7,32)
net = Conv2D(64, 3, activation='relu')(net)
# (5, 5, 64)
net = Conv2D(64, 3, activation='relu')(net)
# (3, 3, 64)
net = Conv2D(128, 3, activation='relu')(net)
# (1, 1, 128)
# upsampling
net = UpSampling2D((4, 4))(net)
# (4, 4, 128)
net = Conv2D(64, 3, padding='same', activation='relu')(net)
# (4, 4, 64)
net = UpSampling2D((4, 4))(net)
# (16, 16, 64)
net = Conv2D(32, 3, padding='same', activation='relu')(net)
# (16, 16, 32)
net = UpSampling2D()(net)
# (32, 32, 32)
net = Conv2D(16, 5, activation='relu',name='final_feature_conv7')(net)
# (28, 28, 16)
output1 = Conv2D(1, 1, activation=tf.keras.activations.tanh,name='output1')(net)
output2 = Conv2D(1, 1, activation='sigmoid', name='output2')(net)
model = tf.keras.Model(inputs=[input1, input2], outputs=[output1, output2])
model.trainable = False
final_feature = model.get_layer('final_feature_conv7').output
final_feature = Conv2D(16, 3, padding='same', activation='relu')(final_feature)
new_output1 = Conv2D(1, 1, activation=tf.keras.activations.tanh,name='output1')(final_feature)
new_output2 = Conv2D(1, 1, activation='sigmoid', name='output2')(final_feature)
new_model = tf.keras.Model(inputs=model.input, outputs=[new_output1, new_output2])
epochs = 10
batch_size = 32
optimizer = tf.keras.optimizers.Adam()
# accuracy = tf.keras.metrics.Accuracy()
mean_loss = tf.metrics.Mean()
step = tf.Variable(1, name='global_step')
ckpt = tf.train.Checkpoint(step=step, optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(ckpt, './checkpoints1', max_to_keep=3)
if manager.latest_checkpoint:
print(f'Restored from {manager.latest_checkpoint}')
print('Initializing from scratch')
def train_step(features, labels):
with tf.GradientTape() as t:
is_six = (labels == 6 * tf.ones_like(labels))
input2 = []
for i in range(batch_size):
if is_six[i]:
input2 = tf.convert_to_tensor(input2)
# print(features.shape, input2.shape)
[output1, output2] = new_model([features, input2])
loss1 = tf.reduce_mean(tf.keras.losses.MSE(features, output1))
loss2 = tf.reduce_mean(tf.keras.losses.MSE(output2, input2))
loss = tf.cast(loss1,dtype=tf.float32) + tf.cast(loss2, dtype=tf.float32)
grad = t.gradient(loss, new_model.trainable_variables)
optimizer.apply_gradients(zip(grad, new_model.trainable_variables))
return loss
def train():
(x_train, y_train), (x_test, y_test) = load_data()
nr_batches_train = len(x_train)//batch_size
train_summary_writer = tf.summary.create_file_writer('./log/train')
with train_summary_writer.as_default():
for epoch in range(epochs):
for t in range(nr_batches_train):
start_from = t*batch_size
to = (t+1)*batch_size
features, labels = x_train[start_from:to], y_train[start_from:to]
loss_value = train_step(features, labels)
if t%100 == 0:
print('step:{}, loss:{}'.format(step.numpy(), loss_value))
saved_path = manager.save()
print(f'Checkpoint saved:{saved_path}')
tf.summary.image('train_set', features, max_outputs=3, step=step.numpy())
# tf.summary.scalar('accracy', accuracy.result(), step=step.numpy())
tf.summary.scalar('loss', mean_loss.result(), step=step.numpy())
# accuracy.reset_states()
if mean_loss.result() < 0.5:
if __name__ == '__main__':