import tensorflow as tf
l = tf.placeholder(tf.int32, [2,4])
l_expand_0 = tf.expand_dims(l, axis=0)
l_expand_1 = tf.expand_dims(l, axis=1)
with tf.Session() as sess:
le_0 = sess.run(l_expand_0, feed_dict={l:[[1,2,3,4],[5,6,7,8]]})
print('le_0 shape:', le_0.shape)
print(le_0)
le_1 = sess.run(l_expand_1, feed_dict={l:[[1,2,3,4],[5,6,7,8]]})
print('-------------------')
print('le_1 shape:', le_1.shape)
print(tf.shape(le_1))
print(le_1)
124逗号前一个中括号,214逗号前两个中括号
tf.expand_dims经常与tf.tile联合使用,想复制哪个维度就先扩展哪个维度。
import tensorflow as tf
l = tf.placeholder(tf.int32, [2,4])
# l_split = tf.split(l, 2, axis=0)
# l_split_1 = tf.split(l, 2, axis=1)
# l_expand_0 = tf.expand_dims(l, axis=0)
l_expand_1 = tf.expand_dims(l, axis=1)
l_tile_0 = tf.tile(l_expand_1, [2, 1, 1])
l_tile_1 = tf.tile(l_expand_1, [1, 2, 1])
with tf.Session() as sess:
le_0 = sess.run(l_expand_1, feed_dict={l:[[1,2,3,4],[5,6,7,8]]})
# print('le_0 shape:', le_0.shape)
print(le_0)
# le_1 = sess.run(l_split_1, feed_dict={l:[[1,2,3,4],[5,6,7,8]]})
print('-------------------')
lt_0 = sess.run(l_tile_0, feed_dict={l:[[1,2,3,4],[5,6,7,8]]})
print(lt_0)
print('-------------------')
lt_1 = sess.run(l_tile_1, feed_dict={l:[[1,2,3,4],[5,6,7,8]]})
print(lt_1)