tf.slice()

import tensorflow as tf

input = [[[1,1,1],[2,2,2]],
         [[3,3,3],[4,4,4]],
         [[5,5,5],[6,6,6]]]

x = tf.slice(input,[0,0,0],[1,2,3])

sess = tf.InteractiveSession()
print(sess.run(x))

>>>   [[[1 1 1]
       [[2 2 2]]]

首先来看tf.slice里的几个参数,

  • input代表输入的tensor,
  • [0,0,0]代表begin,起始值
  • [1,2,3]代表切的大小size。

要明白tf.slice是一个切片函数,那应该怎么切呢?

注意到tf.slice从begin开始切,
例如上面就是从[0,0,0],也就是第0行第0列第0维开始切,
然后size[1,2,3]表示切出1行2列3维的大小。
所以切出来了:
       [[[1 1 1]
       [[2 2 2]]]


倘若是下面的代码:

import tensorflow as tf

input = [[[1,1,1],[2,2,2]],
         [[3,3,3],[4,4,4]],
         [[5,5,5],[6,6,6]]]

x = tf.slice(input,[1,0,0],[2,1,3])

sess = tf.InteractiveSession()
print(sess.run(x))

>>>   [[[3 3 3]]
       [[5 5 5]]]

tf.slice会从第一行第0列第0维开始切,
并切出2行1列3维的大小。

你可能感兴趣的:(tf.slice())