我的模型有两组数据输入,所以没法直接使用组装层来实现,使用keras.Model 由于实在是不熟悉tf 与keras,遇到了种种困难包括但不限于网络中的训练参数怎么添加、多输入的处理、Model.build的定义失败、add_weight不会用、网络summary没有参数,最后发现了一个简单的方法,真是结了燃眉之急
class MYM(keras.Model):
def __init__(self):
ipt1 = keras.Input(shape=(13,13,7),name="view")
ipt2 = keras.Input(shape=(34),name="feature")
x = layers.Conv2D(7,kernel_size=3)(ipt1)
x = layers.Conv2D(1,kernel_size=3)(x)
x = layers.Flatten()(x)
#print(x,ipt2)
x =tf.concat([x,ipt2],axis=-1)
x = layers.Dense(128)(x)
#print(x)
x = layers.Dense(21)(x)
out = layers.Softmax(axis=-1)(x)
super(MYM,self).__init__(inputs=[ipt1,ipt2],outputs=out)
summary
net = MYM()
net.summary()
Model: "mym_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
view (InputLayer) [(None, 13, 13, 7)] 0
__________________________________________________________________________________________________
conv2d_53 (Conv2D) (None, 11, 11, 7) 448 view[0][0]
__________________________________________________________________________________________________
conv2d_54 (Conv2D) (None, 9, 9, 1) 64 conv2d_53[0][0]
__________________________________________________________________________________________________
flatten_15 (Flatten) (None, 81) 0 conv2d_54[0][0]
__________________________________________________________________________________________________
feature (InputLayer) [(None, 34)] 0
__________________________________________________________________________________________________
tf_op_layer_concat_8 (TensorFlo [(None, 115)] 0 flatten_15[0][0]
feature[0][0]
__________________________________________________________________________________________________
dense_44 (Dense) (None, 128) 14848 tf_op_layer_concat_8[0][0]
__________________________________________________________________________________________________
dense_45 (Dense) (None, 21) 2709 dense_44[0][0]
__________________________________________________________________________________________________
softmax_36 (Softmax) (None, 21) 0 dense_45[0][0]
==================================================================================================
Total params: 18,069
Trainable params: 18,069
Non-trainable params: 0
扔一个batch试试
x = np.random.rand(13,13,7)
ix=np.array([x,x])
x2 = np.random.rand(34)
ix2=np.array([x2,x2])
print(ix.shape,ix2.shape)
(2, 13, 13, 7) (2, 34)
y=net(inputs=[ix,ix2])
print(y.shape)
tf.argmax(y,axis=-1)
(2, 21)
好了真是难死我了,继续去添砖加瓦,以后要好好学习,提前接触这些流行的东东