tf.expand_dims详解

今天看见了一个代码里写得tf.expand_dims不是很理解,于是查看了一下官方文件,他是这样写的:

"""Returns a tensor with a length 1 axis inserted at index `axis`.

  Given a tensor `input`, this operation inserts a dimension of length 1 at the
  dimension index `axis` of `input`'s shape. The dimension index follows Python
  indexing rules: It's zero-based, a negative index it is counted backward
  from the end.

  This operation is useful to:

  * Add an outer "batch" dimension to a single element.
  * Align axes for broadcasting.
  * To add an inner vector length axis to a tensor of scalars.

  For example:

  If you have a single image of shape `[height, width, channels]`:

  >>> image = tf.zeros([10,10,3])

  You can add an outer `batch` axis by passing `axis=0`:

  >>> tf.expand_dims(image, axis=0).shape.as_list()
  [1, 10, 10, 3]

  The new axis location matches Python `list.insert(axis, 1)`:

  >>> tf.expand_dims(image, axis=1).shape.as_list()
  [10, 1, 10, 3]

  Following standard Python indexing rules, a negative `axis` counts from the
  end so `axis=-1` adds an inner most dimension:

  >>> tf.expand_dims(image, -1).shape.as_list()
  [10, 10, 3, 1]

  This operation requires that `axis` is a valid index for `input.shape`,
  following Python indexing rules:

 
  -1-tf.rank(input) <= axis <= tf.rank(input)
  
  """

其作用是在axis所规定的位置插入一个数字,用来增加Tensor的维度,axis遵循Python索引规则:它是从零开始的,而负索引,它是最后开始计数的。
举个例子:
比如给定一张图片,它的形状是【长,宽,通道数】,然后给定三种情况:

image = tf.zeros([10,10,3])

①:

tf.expand_dims(image, axis=0).shape

得到的结果是:
[1, 10, 10, 3]

②:

tf.expand_dims(image, axis=1).shape

得到的结果是:
[10, 1, 10, 3]
③:

tf.expand_dims(image, axis=-1).shape

得到的结果是:
[10, 10, 3,1]

axis的取值范围是:

-1-tf.rank(input) <= axis <= tf.rank(input)

你可能感兴趣的:(python中零星小知识,python,深度学习,tensorflow)