tf.keras.layers.TimeDistributed

函数原型:

tf.keras.layers.TimeDistributed(
    layer, **kwargs
)

除了batch_size以外,第一个维度被认为是时间维度,在进行卷积或其他操作的时候,batch_size和时间维度保持不变,对后面的维度进行处理,所以至少应该为3维。
比如(32, 10, 128, 128, 3),batch_size = 32, 包含10个时间步长的128*128的RGB图片。

inputs = tf.keras.Input(shape=(10, 128, 128, 3))
conv_2d_layer = tf.keras.layers.Conv2D(64, (3, 3))
outputs = tf.keras.layers.TimeDistributed(conv_2d_layer)(inputs)
outputs.shape

输出:(None, 10, 126, 126, 64)

inputs = tf.keras.Input(shape=(10, 16, 16, 3))
x = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten())(inputs)
print(outputs.shape)

输出:(None, 10, 126, 126, 64)

你可能感兴趣的:(tensorflow)