tensorflow.squeeze用来删除张量中大小为1的维度;
关于张量的维度,只需要数中括号‘[’的个数,比如[[1 2 3]]就是二维张量,打印出张量,必然包括shape(1,3),更复杂的例子如下:
示例代码1
x = tf.constant([[[[1, 2, 3], [1, 2, 3], [1, 2, 3]]]])
print("x.shape:",x.shape)
输出结果如下:
维度具体判断方法就是:从左往右数中括号‘[’,数到第n个中括号,找出第n个“中括号对”,然后进入这个“中括号对”,假设这个“中括号对”中包含m个“中括号对”,或者这个“中括号对”中包含k个数字,那么最终打印出shape的时候,第n位的数值就是m或k。
理解了维度,再理解squeeze就好办了,比如下面这个例子:
x = tf.constant([[[[1, 2, 3], [1, 2, 3], [1, 2, 3]]]])
y = tf.squeeze(x);
print("x:", x);
print("x.shape:", x.shape)
print("y:", y)
print("y.shape", y.shape);
输出结果如下:
当我们去除所有大小为1的维度时,前2个中括号被去除了,原因很简单,因为前2个中括号对应的每个“中括号对”中,只含有1个“中括号对”。
squeeze也可以指定删除某个或某几个大小为1的维度:
x = tf.constant([[[[1, 2, 3], [1, 2, 3], [1, 2, 3]]]])
y = tf.squeeze(x,x[0]);
print("x:", x);
print("x.shape:", x.shape)
print("y:", y)
print("y.shape", y.shape);
可以看到,只有第一个中括号被删除了,因为在squeeze函数中增加了第二个参数[0],它表示删除第一个中括号。
squeeze还可以加第3个参数,表示操作的名称,不过现在暂时不知道这个参数有什么作用:
x = tf.constant([[[[1, 2, 3], [1, 2, 3], [1, 2, 3]]]])
y = tf.squeeze(x, [0], "haha");
print("x:", x);
print("x.shape:", x.shape)
print("y:", y)
print("y.shape", y.shape);