Lambda 层

keras中的Lambda 层中的使用

'''
from keras.layers import Reshape, Permute, RepeatVector, Dense,Lambda,Embedding,Add
from keras.layers import Multiply
from keras.models import Sequential,Input,Model
import numpy as np
import keras.backend as K
a = Input(shape=(2,))
b = Input(shape=(2,))

def minus(inputs):
x,y = inputs
return K.mean(x-y,axis=1)

cha = Lambda(minus,name='minus')([a,b])

model = Model(inputs=[a,b],outputs=[cha])

print(model.summary())

v0 = np.array([5,2])
v1 = np.array([8,4])
v2 = np.array([3,2])
print(model.predict([v0.reshape(1,2), v1.reshape(1,2)]))
print(model.predict([v0.reshape(1,2), v2.reshape(1,2)]))
print(model.predict([np.array([v0, v0]), np.array([v1, v2])]))
'''

你可能感兴趣的:(Lambda 层)