tensorflow2获取预训练模型的特定层并基于其重构网络结构

该部分实现在tensorflow2下,基于训练好的网络的某些特定层勾心构建网络结构。从inputs到我们需要的特定层输出之间的参数固定,仅训练特定层输出和最终输出之间的网络参数。

input1 = Input((28,28,1), dtype=tf.float64)
input2 = Input((28,28,1), dtype=tf.float64)
input_ = concatenate([input1, input2])
# (28,28,2)
net = Conv2D(16, 3, padding='same', activation='relu')(input_)
# (28,28,16)
net = MaxPool2D()(net)
# (14,14,16)
net = Conv2D(32, 3, padding='same', activation='relu')(net)
# (14,14,32)
net = MaxPool2D()(net)
# (7,7,32)
net = Conv2D(64, 3, activation='relu')(net)
# (5, 5, 64)
net = Conv2D(64, 3, activation='relu')(net)
# (3, 3, 64)
net = Conv2D(128, 3, activation='relu')(net)
# (1, 1, 128)
# upsampling
net = UpSampling2D((4, 4))(net)
# (4, 4, 128)
net = Conv2D(64, 3, padding='same', activation='relu')(net)
# (4, 4, 64)
net = UpSampling2D((4, 4))(net)
# (16, 16, 64)
net = Conv2D(32, 3, padding='same', activation='relu')(net)
# (16, 16, 32)
net = UpSampling2D()(net)
# (32, 32, 32)
net = Conv2D(16, 5, activation='relu',name='final_feature_conv7')(net)
# (28, 28, 16)
output1 = Conv2D(1, 1, activation=tf.keras.activations.tanh,name='output1')(net)
output2 = Conv2D(1, 1, activation='sigmoid', name='output2')(net)
# model 为训练好的网络,我们将其参数固定
model = tf.keras.Model(inputs=[input1, input2], outputs=[output1, output2])
model.trainable = False
model.summary()
# 我们需要的是model中的(28,28,16)的输出特征,该层的name我们已经修改为'final_feature_conv7',
# 使用下面这句代码就可以根据name找到特定层的输出特征
final_feature = model.get_layer('final_feature_conv7').output
final_feature = Conv2D(16, 3, padding='same', activation='relu')(final_feature)
new_output1 = Conv2D(1, 1, activation=tf.keras.activations.tanh,name='output1')(final_feature)
new_output2 = Conv2D(1, 1, activation='sigmoid', name='output2')(final_feature)

new_model = tf.keras.Model(inputs=model.input, outputs=[new_output1, new_output2])
new_model.summary()

该样例的完整代码如下:

# 尝试获取预训练模型的特定层,并基于这些层构建模型

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, UpSampling2D, Input, BatchNormalization, concatenate, MaxPool2D


def load_data():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
    x_train, x_test = x_train/255.*2-1, x_test/255.*2-1
    x_train = tf.expand_dims(x_train, -1)
    x_test = tf.expand_dims(x_test, -1)
    y_train = tf.expand_dims(y_train, -1)
    y_test = tf.expand_dims(y_test, -1)
    return (x_train, y_train), (x_test, y_test)

input1 = Input((28,28,1), dtype=tf.float64)
input2 = Input((28,28,1), dtype=tf.float64)
input_ = concatenate([input1, input2])
# (28,28,2)
net = Conv2D(16, 3, padding='same', activation='relu')(input_)
# (28,28,16)
net = MaxPool2D()(net)
# (14,14,16)
net = Conv2D(32, 3, padding='same', activation='relu')(net)
# (14,14,32)
net = MaxPool2D()(net)
# (7,7,32)
net = Conv2D(64, 3, activation='relu')(net)
# (5, 5, 64)
net = Conv2D(64, 3, activation='relu')(net)
# (3, 3, 64)
net = Conv2D(128, 3, activation='relu')(net)
# (1, 1, 128)
# upsampling
net = UpSampling2D((4, 4))(net)
# (4, 4, 128)
net = Conv2D(64, 3, padding='same', activation='relu')(net)
# (4, 4, 64)
net = UpSampling2D((4, 4))(net)
# (16, 16, 64)
net = Conv2D(32, 3, padding='same', activation='relu')(net)
# (16, 16, 32)
net = UpSampling2D()(net)
# (32, 32, 32)
net = Conv2D(16, 5, activation='relu',name='final_feature_conv7')(net)
# (28, 28, 16)
output1 = Conv2D(1, 1, activation=tf.keras.activations.tanh,name='output1')(net)
output2 = Conv2D(1, 1, activation='sigmoid', name='output2')(net)

model = tf.keras.Model(inputs=[input1, input2], outputs=[output1, output2])
model.trainable = False
model.summary()

final_feature = model.get_layer('final_feature_conv7').output
final_feature = Conv2D(16, 3, padding='same', activation='relu')(final_feature)
new_output1 = Conv2D(1, 1, activation=tf.keras.activations.tanh,name='output1')(final_feature)
new_output2 = Conv2D(1, 1, activation='sigmoid', name='output2')(final_feature)

new_model = tf.keras.Model(inputs=model.input, outputs=[new_output1, new_output2])
new_model.summary()

epochs = 10
batch_size = 32
optimizer = tf.keras.optimizers.Adam()
# accuracy = tf.keras.metrics.Accuracy()
mean_loss = tf.metrics.Mean()
step = tf.Variable(1, name='global_step')
ckpt = tf.train.Checkpoint(step=step, optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(ckpt, './checkpoints1', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print(f'Restored from {manager.latest_checkpoint}')
else:
    print('Initializing from scratch')


def train_step(features, labels):
    with tf.GradientTape() as t:
        is_six = (labels == 6 * tf.ones_like(labels))
        input2 = []
        for i in range(batch_size):
            if is_six[i]:
                input2.append(tf.ones_like(features[0]))
            else:
                input2.append((tf.zeros_like(features[0])))
        input2 = tf.convert_to_tensor(input2)
        # print(features.shape, input2.shape)
        [output1, output2] = new_model([features, input2])
        loss1 = tf.reduce_mean(tf.keras.losses.MSE(features, output1))
        loss2 = tf.reduce_mean(tf.keras.losses.MSE(output2, input2))
        loss = tf.cast(loss1,dtype=tf.float32) + tf.cast(loss2, dtype=tf.float32)
    grad = t.gradient(loss, new_model.trainable_variables)
    optimizer.apply_gradients(zip(grad, new_model.trainable_variables))
    step.assign_add(1)
    mean_loss.update_state(loss)
    return loss

def train():
    (x_train, y_train), (x_test, y_test) = load_data()

    nr_batches_train = len(x_train)//batch_size
    train_summary_writer = tf.summary.create_file_writer('./log/train')
    with train_summary_writer.as_default():
        for epoch in range(epochs):
            for t in range(nr_batches_train):
                start_from = t*batch_size
                to = (t+1)*batch_size
                features, labels = x_train[start_from:to], y_train[start_from:to]
                loss_value = train_step(features, labels)

                if t%100 == 0:
                    print('step:{}, loss:{}'.format(step.numpy(), loss_value))
                    saved_path = manager.save()
                    print(f'Checkpoint saved:{saved_path}')
                    tf.summary.image('train_set', features, max_outputs=3, step=step.numpy())
                    # tf.summary.scalar('accracy', accuracy.result(), step=step.numpy())
                    tf.summary.scalar('loss', mean_loss.result(), step=step.numpy())
                    # accuracy.reset_states()
                    mean_loss.reset_states()
            print('epoch:{}end'.format(epoch))
            if mean_loss.result() < 0.5:
                break


if __name__ == '__main__':
    train()

你可能感兴趣的:(tensorflow,tensorflow,深度学习,keras)