tf.slice()回顾

之前用到了tf.slice,当时会用,现在又忘了,故重温一下,直接看示例

import tensorflow as tf
    
    t = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])
    sess = tf.Session()
    data = tf.slice(t, [1,0,0], [1,1,3])
    print(sess.run(data))
    data = tf.slice(t, [1, 0, 0], [2, 2, 3])
    print(sess.run(data))

tf.slice(inputs, begin, size, name)包含几个参数,一般只需要设置前三个即可。
inputs是你要切割的tensor,begin是切割的起始位置,size是要切出来的大小。
对应上面的示例,我的理解是对于三维矩阵,第一维表示batch,第二位表示行,第三维表示列(index都是从0开始)。
具体点t是一个三维矩阵,begin是[1,0,0],也就是在从第2个batch中第1行第1列的位置开始切割,切割的大小是多少呢?再看size是[1,1,3],三个数字分别表示一整个batch、一行、三列,由于起始位置已经固定,所以切割出来的应该是第2个batch中第一行中的三列,正是[3,3,3]。

你可能感兴趣的:(tf.slice()回顾)