tensorflow 之Dataset数据集之批量数据

###生成批次数据
import tensorflow as tf

'''reapt()生成重复数据集  batch()将数据集按批次组合'''
file_name = ['img1','img2','img3','img4']
label = [1,2,3,4]
dataset =tf.data.Dataset.from_tensor_slices((file_name,label))
dataset1 = dataset.repeat().batch(3)
##定义一个迭代器迭代取批量数据
def getone(dataset):
    iterator = dataset.make_one_shot_iterator()  #生成一个迭代器
    one_element = iterator.get_next()            #迭代器取值
    return one_element
one_element1 = getone(dataset)
one_element2 = getone(dataset1)

##定义一个会话内调用的函数,用于显示批量数据
def showbatch(onebatch_element):
    for ii in range(3):
        datav = sess.run(onebatch_element)
        print('第%s批次'%ii,datav)
##开启会话,调用数据
with tf.Session() as sess:
    showbatch(one_element1)
    showbatch(one_element2)
    
'''
第0批次 (b'img1', 1)
第1批次 (b'img2', 2)
第2批次 (b'img3', 3)

第0批次 (array([b'img1', b'img2', b'img3'], dtype=object), array([1, 2, 3]))
第1批次 (array([b'img4', b'img1', b'img2'], dtype=object), array([4, 1, 2]))
第2批次 (array([b'img3', b'img4', b'img1'], dtype=object), array([3, 4, 1]))
'''

 

你可能感兴趣的:(tensorflow 之Dataset数据集之批量数据)