Keras函数式API是定义复杂模型(如多输出模型、有向无环图,或具有共享层模型的方法)
from keras.layers import Input, Dense
from keras.models import Model
# 返回一个张量
inputs = Input(shape=(784,))
# 层的实例是可调用的,它以张量为参数,并且返回一个张量
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)
# 创建一个包含输入层和三个全连接层的模型
model = Model(inputs=inputs, outputs=predictions)
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels) # 开始训练
利用函数式 API ,可以重用训练好的模型:可将任何模型看作是一个层,然后通过传递一个张量来调用它。且在调用模型时,不仅重用了模型的结构,还重用了它的权重。
函数式 API 的另一个用途是使用共享网络层的模型。
若只有一个输入:
a = Input(shape=(280, 256))
lstm = LSTM(32)
encoded_a = lstm(a)
assert lstm.output == encoded_a
有多个输入时可通过如下方法解决:
assert lstm.get_output_at(0) == encoded_a
assert lstm.get_output_at(1) == encoded_b
input_shape 和 output_shape 也是如此:只要该层有一个节点,或者所有节点都具有相同的输入/输出尺寸,那么他们的输入/输出尺寸就能被很好定义,并且由函数唯一返回。但是如果一个层应用于两个尺寸的输入/输出,就要通过所属节点的索引来获取他们。
compile(optimizer, loss=None, metrics=None, loss_weights=None,
sample_weight_mode=None, weighted_metrics=None, target_tensors=None)
fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,
validation_split=0.0, validation_data=None, shuffle=True,
class_weight=None, sample_weight=None, initial_epoch=0,
steps_per_epoch=None, validation_steps=None)
evaluate(x=None, y=None, batch_size=None, verbose=1, sample_weight=None,
steps=None)
predict(x, batch_size=None, verbose=0, steps=None)
train_on_batch(x, y, sample_weight=None, class_weight=None)
test_on_batch(x, y, sample_weight=None)
predict_on_batch(x)
fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1,
callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1,
use_multiprocessing=False, shuffle=True, initial_epoch=0)
evaluate_generator(generator, steps=None, max_queue_size=10, workers=1,
use_multiprocessing=False, verbose=0)
predict_generator(generator, steps=None, max_queue_size=10, workers=1,
use_multiprocessing=False, verbose=0)
get_layer(self, name=None, index=None)