tf.gather()介绍

tf.gather(
    params, indices, validate_indices=None, axis=None, batch_dims=0, name=None
)

官网一共给了6个参数,只介绍其中重要的3个参数。

tf.gather(params,indices,axis=0)

给定一个4*3矩阵params=[[1,2,3],[4,5,6],[7,8,9],[10,11,12]] 

indices = [0,2]

运行tf.gather(params,indices,axis=0)

会得到[[1,2,3],[7,8,9]]

这是为何呢?先说axis,axis=0等于在params的行上面进行提取,注意params的shape,是(4,3),axis=0就是在(4,3)的第0个下标的数字进行提取。

indices=[0,2]:代表提取params的第0行和第2行。接下来上图。

tf.gather()介绍_第1张图片

 提出来的结果就是[[1,2,3],[7,8,9]]。

如果换成axis=1,自然就是在列上提取。第0列和第2列。tf.gather()介绍_第2张图片

结果就是:

[[ 1  3]
 [ 4  6]
 [ 7  9]
 [10 12]], shape=(4, 2), dtype=int32)。


接下来上难度,

params= tf.reshape(tf.range(210),[5,6,7])
indices = tf.reshape(tf.range(24),[2,3,4])
print(tf.gather(params, indices, axis=0))

params是这样:

tf.Tensor(
[[[  0   1   2   3   4   5   6]
  [  7   8   9  10  11  12  13]
  [ 14  15  16  17  18  19  20]
  [ 21  22  23  24  25  26  27]
  [ 28  29  30  31  32  33  34]
  [ 35  36  37  38  39  40  41]]

 [[ 42  43  44  45  46  47  48]
  [ 49  50  51  52  53  54  55]
  [ 56  57  58  59  60  61  62]
  [ 63  64  65  66  67  68  69]
  [ 70  71  72  73  74  75  76]
  [ 77  78  79  80  81  82  83]]

 [[ 84  85  86  87  88  89  90]
  [ 91  92  93  94  95  96  97]
  [ 98  99 100 101 102 103 104]
  [105 106 107 108 109 110 111]
  [112 113 114 115 116 117 118]
  [119 120 121 122 123 124 125]]

 [[126 127 128 129 130 131 132]
  [133 134 135 136 137 138 139]
  [140 141 142 143 144 145 146]
  [147 148 149 150 151 152 153]
  [154 155 156 157 158 159 160]
  [161 162 163 164 165 166 167]]

 [[168 169 170 171 172 173 174]
  [175 176 177 178 179 180 181]
  [182 183 184 185 186 187 188]
  [189 190 191 192 193 194 195]
  [196 197 198 199 200 201 202]
  [203 204 205 206 207 208 209]]], shape=(5, 6, 7), dtype=int32)

indices是这样:

tf.Tensor(
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]], shape=(2, 3, 4), dtype=int32)

那么再运行tf.gather(params,indices,axis=0)会是什么结果,会怎么切割,其实这已经不重要,重要的是会形成一个怎样的矩阵。

会形成一个shape为[2,3,4,6,7]的矩阵,其实就是拿indices的shape[2,3,4],去替换掉params的shape第0位,也就是替换掉5。为什么是第0位,因为axis=0。

如果axis=1,就是替换第1位,也就是6,会形成一个[5,2,3,4,7]的矩阵。

如果axis=2,就是替换第2位,也就是6,会形成一个[5,6,2,3,4]的矩阵。

你可能感兴趣的:(python,机器学习,深度学习)