TensorFlow 移除所有尺度为1的维度 tf.squeeze 的基本用法及实例代码

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

 

二、官方说明

从张量的形状中移除所有尺寸为1的维数。(弃用参数)

https://tensorflow.google.cn/api_docs/python/tf/squeeze

tf.squeeze(
    input,
    axis=None,
    name=None,
    squeeze_dims=None
)

参数:

input:要缩减维度的张量

axis:可选整型列表,默认为 [ ],如果指定了给参数,值域列表中指定的维度会被移除。维度所以从 0 开始,范围是 [- rank(input), rank(input)]。不能移除尺度不为 1 的维度,否则会报错!

name:可选参数,设置操作的名称

squeeze_dims:被移除的关键字参数,通过 axis 替代

 

返回:

包含输入 input 中的数据,但移除了所有尺度为 1 的维度的张量,和输入 input 的数据类型相同

 

三、实例

(1)尺度缩减的错误方式

>>> raw_tensor = tf.constant(value=[[[1,2,3],[4,5,6]]])
>>> raw_tensor

>>> squeezed_tensor = tf.squeeze(input=raw_tensor, axis=[1])
Traceback (most recent call last):
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 1628, in _create_c_op
    c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_4' (op: 'Squeeze') with input shapes: [1,2,3].

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "", line 1, in 
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\util\deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\ops\array_ops.py", line 2573, in squeeze
    return gen_array_ops.squeeze(input, axis, name)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\ops\gen_array_ops.py", line 10108, in squeeze
    "Squeeze", input=input, squeeze_dims=axis, name=name)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\util\deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 3274, in create_op
    op_def=op_def)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 1792, in __init__
    control_input_ops)
  File "C:\Users\WJW\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\framework\ops.py", line 1631, in _create_c_op
    raise ValueError(str(e))
ValueError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_4' (op: 'Squeeze') with input shapes: [1,2,3].

ValueError: Can not squeeze dim[1], expected a dimension of 1, got 2 for 'Squeeze_4' (op: 'Squeeze') with input shapes: [1,2,3].

即不能移除尺度不为 1 的维度

 

(2)尺度缩减的正确方式

>>> import tensorflow as tf


# 向量 (1, 3) --> 标量(3,)
# 移除尺度为 1 的第一个维度
>>> 
>>> raw_tensor = tf.constant(value=[[1,2,3]])
>>> squeezed_tensor = tf.squeeze(input=raw_tensor)
>>> squeezed_tensor



# 矩阵 (1, 3, 3) --> 标量(3, 3)
# 移除尺度为 1 的第一个维度
>>> raw_tensor = tf.constant(value=[[[1,2,3],[4,5,6],[7,8,9]]])
>>> raw_tensor

>>> squeezed_tensor = tf.squeeze(input=raw_tensor)
>>> squeezed_tensor



# 矩阵 (1, 1, 3) --> 标量(3,)
# 移除尺度为 1 的前两个维度
>>> raw_tensor = tf.constant(value=[[[1,2,3]]])
>>> raw_tensor

>>> squeezed_tensor = tf.squeeze(input=raw_tensor)
>>> squeezed_tensor



# 通过参数 axis 指定的一个要的尺度为 1 的维度
>>> raw_tensor = tf.constant(value=[[[1,2,3]]])
>>> raw_tensor

>>> squeezed_tensor = tf.squeeze(input=raw_tensor, axis=[1])
>>> squeezed_tensor



# 通过参数 axis 指定的多个要的尺度为 1 的维度
>>> raw_tensor = tf.constant(value=[[[1,2,3]]])
>>> raw_tensor

>>> squeezed_tensor = tf.squeeze(input=raw_tensor, axis=[0,1])
>>> squeezed_tensor

你可能感兴趣的:(TensorFlow基础)