tf.slice() / tf.strided_slice() intro

Slice in tensorflow is very similar to numpy, which can be found in TensorFlow API:

tf.slice

tf.slice(
input_, # input tensor
begin, # begin location
size, # output tensor size
name=None # name of operation
)

a equals b since tf.slice stride is [1,1,1] by default.

import tensorflow as tf
input = [[[1,1,1],[2,2,2],[3,3,3]],
         [[4,4,4],[5,5,5],[6,6,6]],
         [[7,7,7],[8,8,8],[9,9,9]]]

sess = tf.Session()

a = tf.slice(input, [0,1,0], [2,2,2])
print(sess.run(a))
print("\n")

b = tf.strided_slice(input, [0,1,0], [2,3,2], [1,1,1])
print(sess.run(b))
[[[2 2]
  [3 3]]

 [[5 5]
  [6 6]]]


[[[2 2]
  [3 3]]

 [[5 5]
  [6 6]]]

tf.strided_slice

tf.strided_slice(
input_, # input tensor
begin, # begin location
end, # end location, ATTENTION: not included!
strides=None, # strides of slice
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0,
var=None,
name=None
)
tf.strided_slice 1-dim operation: tf.strided_slice(input_data, [begin_element], [end_element(not included), stride_step]
import tensorflow as tf
input_data = [1,2,3,4,5,6,7,8,9]
a = tf.strided_slice(input_data, [0], [4], [1])
b = tf.strided_slice(input_data, [0], [-1], [1])
c = tf.strided_slice(input_data, [0], [-1], [2])
d = tf.strided_slice(input_data, [0], [-1])
e = tf.strided_slice(input_data, [-1], [-2], [-1])
a begin location at input_data[0], end location at input_data[4] with stride step 1 
[1 2 3 4]


b begin location at input_data[0], end location at last with stride step 1 
[1 2 3 4 5 6 7 8]


c begin location at input_data[0], end location at last with stride step 2 
[1 3 5 7]


d begin location at input_data[0], end location at last with stride step 1 by default 
[1 2 3 4 5 6 7 8]


e begin location at input_data[last element], end location at input_data[second to last] with stride step -1
[9]
tf.strided_slice higher dim operation:
import tensorflow as tf
import numpy as np
input_data = np.arange(60).reshape(3, 4, 5)

slice_np_data = input_data[1:2, 0:2, 0:2]
print(slice_np_data)
>>>
[[[20 21]
  [25 26]]]

slice_tensor = tf.strided_slice(input_data,[1,0,0],[2,2,2])
with tf.Session() as sess:
    print(sess.run(slice_tensor))
>>>
[[[20 21]
  [25 26]]]
#3 tf.strided_slice case in seq2seq model:

In seq2seq model we slice target data into decoder network

ending = tf.strided_slice(targets, [0, 0], [batch_size, -1], [1, 1])
dec_input = tf.concat([tf.fill([batch_size, 1], target_letter_to_int['']), ending], 1)

Supposing batch_size is n. tf.strided_slice(targets, [0, 0], [batch_size, -1], [1, 1]) keeps every batch and drop the last integer per batch. ending = targets[0:n, 0,-1]

你可能感兴趣的:(tf.slice() / tf.strided_slice() intro)