tf2 利用keras自定义多输入网络模型的一种方法

  我的模型有两组数据输入,所以没法直接使用组装层来实现,使用keras.Model 由于实在是不熟悉tf 与keras,遇到了种种困难包括但不限于网络中的训练参数怎么添加、多输入的处理、Model.build的定义失败、add_weight不会用、网络summary没有参数,最后发现了一个简单的方法,真是结了燃眉之急

class MYM(keras.Model):
    def __init__(self):
        ipt1 = keras.Input(shape=(13,13,7),name="view")
        ipt2 = keras.Input(shape=(34),name="feature")
        x = layers.Conv2D(7,kernel_size=3)(ipt1)
        x = layers.Conv2D(1,kernel_size=3)(x)
        x = layers.Flatten()(x)
        #print(x,ipt2)
        x =tf.concat([x,ipt2],axis=-1)
        x = layers.Dense(128)(x)
        #print(x)
        x = layers.Dense(21)(x) 
        out = layers.Softmax(axis=-1)(x)
        super(MYM,self).__init__(inputs=[ipt1,ipt2],outputs=out)

summary

net = MYM() 
net.summary() 
Model: "mym_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
view (InputLayer)               [(None, 13, 13, 7)]  0                                            
__________________________________________________________________________________________________
conv2d_53 (Conv2D)              (None, 11, 11, 7)    448         view[0][0]                       
__________________________________________________________________________________________________
conv2d_54 (Conv2D)              (None, 9, 9, 1)      64          conv2d_53[0][0]                  
__________________________________________________________________________________________________
flatten_15 (Flatten)            (None, 81)           0           conv2d_54[0][0]                  
__________________________________________________________________________________________________
feature (InputLayer)            [(None, 34)]         0                                            
__________________________________________________________________________________________________
tf_op_layer_concat_8 (TensorFlo [(None, 115)]        0           flatten_15[0][0]                 
                                                                 feature[0][0]                    
__________________________________________________________________________________________________
dense_44 (Dense)                (None, 128)          14848       tf_op_layer_concat_8[0][0]       
__________________________________________________________________________________________________
dense_45 (Dense)                (None, 21)           2709        dense_44[0][0]                   
__________________________________________________________________________________________________
softmax_36 (Softmax)            (None, 21)           0           dense_45[0][0]                   
==================================================================================================
Total params: 18,069
Trainable params: 18,069
Non-trainable params: 0

扔一个batch试试

x = np.random.rand(13,13,7)
ix=np.array([x,x])
x2 = np.random.rand(34)
ix2=np.array([x2,x2])
print(ix.shape,ix2.shape)

(2, 13, 13, 7) (2, 34)

y=net(inputs=[ix,ix2])
print(y.shape)
tf.argmax(y,axis=-1)
(2, 21)

 

好了真是难死我了,继续去添砖加瓦,以后要好好学习,提前接触这些流行的东东

你可能感兴趣的:(tensorflow,自定义网络,多输入)