keras中model.summary()输出的output shape为multiple解决办法

当在tensorflow中使用类定义来编写网络结构时,使用model.summary()方法输出的output shape 有可能为multiple,以下提供三种解决办法。

原文链接:python - model.summary() can't print output shape while using subclass model - Stack Overflow

解决方式1:

第一步:在__init__中添加一个input_shape参数

第二步:添加一个input_layer,

self.input_layer = tf.keras.layers.Input(input_shape)

第三步:调用call方法

self.out = self.call(self.input_layer) 

示例:

class MyModel(tf.keras.Model):
    
    def __init__(self,input_shape=(32,32,1), **kwargs):
        super(MyModel, self).__init__(**kwargs) 
        self.input_layer = tf.keras.layers.Input(input_shape)
        self.dense10 = tf.keras.layers.Dense(10, activation=tf.keras.activations.softmax)    
        self.dense20 = tf.keras.layers.Dense(20, activation=tf.keras.activations.softmax)
        self.out = self.call(self.input_layer)    
    
    def call(self, inputs):
        x =  self.dense10(inputs)
        y_pred =  self.dense20(x)
     
        return y_pred

model = MyModel()
model(x_test[:99])
print('x_test[:99].shape:',x_test[:10].shape)
model.summary()

解决方式2:

在使用model.summary方法前,调用.build()和.call()方法。

import tensorflow as tf
from tensorflow.keras import Input, layers, Model

class subclass(Model):
    def __init__(self):
        super(subclass, self).__init__()
        self.conv = layers.Conv2D(28, 3, strides=1)

    def call(self, x):
        return self.conv(x)

if __name__ == '__main__':
    model = subclass()
    model.build(input_shape=(None, 24, 24, 3))

    # Adding this call to the call() method solves it all
    model.call(Input(shape=(24, 24, 3)))

    # And the summary() outputs all the information
    model.summary()

解决方法3:

在类中添加一个自定义方法,调用call()方法。

class subclass(Model):
    def __init__(self):
        ...
    def call(self, x):
        ...

    def model(self):
        x = Input(shape=(24, 24, 3))
        return Model(inputs=[x], outputs=self.call(x))



if __name__ == '__main__':
    sub = subclass()
    sub.model().summary()

你可能感兴趣的:(keras,python,机器学习)