TensorFlow 中生成Batch数据

TensorFlow 中生成Batch数据

Reference: https://www.tensorflow.org/programmers_guide/datasets#simple_batching

1. 利用sklearn

def batch(self,STEPS=10000, BATCH=100):
    ss = ShuffleSplit(n_splits=STEPS, train_size=BATCH)
    ss.get_n_splits(self.features, self.labels)
    for step, (idx, _) in enumerate(ss.split(self.features, self.labels), start=1):
        yield self.features[idx],self.labels[idx]

2.利用sklearn里面的Dataset

def generate_batch(self):
    features_placeholder = tf.placeholder(self.features.dtype, self.features.shape)
    labels_placeholder = tf.placeholder(self.labels.dtype, self.labels.shape)
    dataset = tf.data.Dataset.from_tensor_slices((self.features, self.labels))
    dataset = dataset.repeat(100)
    batched_dataset = dataset.batch(100)
    iterator = batched_dataset.make_initializable_iterator()
    batch_xs, batch_ys = iterator.get_next()
    return iterator.initializer,batch_xs, batch_ys

调用的时候

batch = generate_batch()
with tf.Session() as sess:
    batch_xs, batch_ys = sess.run(batch)

3. Feedable的Dataset接口

    def generate_batch(self):
        features_placeholder = tf.placeholder(self.features.dtype, self.features.shape)
        labels_placeholder = tf.placeholder(self.labels.dtype, self.labels.shape)
        #dataset = tf.data.Dataset.from_tensor_slices((self.features, self.labels))
        dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
        dataset = dataset.repeat(100)
        batched_dataset = dataset.batch(100)

        iterator = batched_dataset.make_initializable_iterator()
        self.sess = tf.Session()
        sess.run(iterator.initializer, feed_dict={features_placeholder: self.features, labels_placeholder: self.labels})

        batch_xs, batch_ys = iterator.get_next()
        return iterator.initializer,batch_xs, batch_ys

你可能感兴趣的:(Tensorflow)