一、环境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
cudnn64_7.dll
Python 3.6.3
Windows 10
二、官方说明
从张量的形状中移除所有尺寸为1的维数。(弃用参数)
https://tensorflow.google.cn/api_docs/python/tf/squeeze
tf.squeeze(
input,
axis=None,
name=None,
squeeze_dims=None
)
参数:
input:要缩减维度的张量
axis:可选整型列表,默认为 [ ],如果指定了给参数,值域列表中指定的维度会被移除。维度所以从 0 开始,范围是 [- rank(input), rank(input)]。不能移除尺度不为 1 的维度,否则会报错!
name:可选参数,设置操作的名称
squeeze_dims:被移除的关键字参数,通过 axis 替代
返回:
包含输入 input 中的数据,但移除了所有尺度为 1 的维度的张量,和输入 input 的数据类型相同
三、实例
(1)尺度缩减的错误方式
>>> raw_tensor = tf.constant(value=[[[1,2,3],[4,5,6]]])
>>> raw_tensor
>>> squeezed_tensor = tf.squeeze(input=raw_tensor, axis=[1])
Traceback (most recent call last):
File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 1628, in _create_c_op
c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_4' (op: 'Squeeze') with input shapes: [1,2,3].
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "", line 1, in
File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\util\deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\ops\array_ops.py", line 2573, in squeeze
return gen_array_ops.squeeze(input, axis, name)
File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\ops\gen_array_ops.py", line 10108, in squeeze
"Squeeze", input=input, squeeze_dims=axis, name=name)
File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\util\deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 3274, in create_op
op_def=op_def)
File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 1792, in __init__
control_input_ops)
File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 1631, in _create_c_op
raise ValueError(str(e))
ValueError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_4' (op: 'Squeeze') with input shapes: [1,2,3].
ValueError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_4' (op: 'Squeeze') with input shapes: [1,2,3].
即不能移除尺度不为 1 的维度
(2)尺度缩减的正确方式
>>> import tensorflow as tf
# 向量 (1, 3) --> 标量(3,)
# 移除尺度为 1 的第一个维度
>>>
>>> raw_tensor = tf.constant(value=[[1,2,3]])
>>> squeezed_tensor = tf.squeeze(input=raw_tensor)
>>> squeezed_tensor
# 矩阵 (1, 3, 3) --> 标量(3, 3)
# 移除尺度为 1 的第一个维度
>>> raw_tensor = tf.constant(value=[[[1,2,3],[4,5,6],[7,8,9]]])
>>> raw_tensor
>>> squeezed_tensor = tf.squeeze(input=raw_tensor)
>>> squeezed_tensor
# 矩阵 (1, 1, 3) --> 标量(3,)
# 移除尺度为 1 的前两个维度
>>> raw_tensor = tf.constant(value=[[[1,2,3]]])
>>> raw_tensor
>>> squeezed_tensor = tf.squeeze(input=raw_tensor)
>>> squeezed_tensor
# 通过参数 axis 指定的一个要的尺度为 1 的维度
>>> raw_tensor = tf.constant(value=[[[1,2,3]]])
>>> raw_tensor
>>> squeezed_tensor = tf.squeeze(input=raw_tensor, axis=[1])
>>> squeezed_tensor
# 通过参数 axis 指定的多个要的尺度为 1 的维度
>>> raw_tensor = tf.constant(value=[[[1,2,3]]])
>>> raw_tensor
>>> squeezed_tensor = tf.squeeze(input=raw_tensor, axis=[0,1])
>>> squeezed_tensor