利用数据集读取数据有三个基本步骤:
例:
import tensorflow as tf
def parser(record):
features = tf.parse_single_example(
record,
features={
'feat1':tf.FixedLenFeature([],tf.int64),
'feat2':tf.FixedLenFeature([],tf.int64)
}
)
return features['feat1'],features['feat2']
#数据集可以是一个tensor,或者文本文件
#若是tensor,则使用tf.data.from_tensor_slices(input_data)
#若是文本文件,则使用tf.data.TextLineDataset(input_files)
input_files = ['file1','file2']
dataset = tf.data.TFRecordDataset(input_files)
#由于tfrecords读取出来的是二进制数据,需要对每个数据进行解析,得到想要的格式
#这里使用映射函数对每个数据进行解析
dataset = dataset.map(parser)
#通过一个迭代器获取数据
iterator = dataset.make_one_shot_iterator()
feat1,feat2 = iterator.get_next()
with tf.Session() as sess:
for i in range(10):
print(sess.run([feat1,feat2]))
若需要动态输入数据,可以使用make_initializable_iterator()
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
#由于tfrecords读取出来的是二进制数据,需要对每个数据进行解析,得到想要的格式
#这里使用映射函数对每个数据进行解析
dataset = dataset.map(parser)
#通过一个迭代器获取数据
iterator = dataset.make_initializable_iterator()
feat1,feat2 = iterator.get_next()
with tf.Session() as sess:
#注意要先对迭代器初始化
sess.run(iterator.initializer,feed_dict={input_files:['file1','file2']})
#由于不知道数据集大小,这里使用while循环,当全部数据访问完毕时,则抛出错误
while True:
try:
sess.run([feat1,feat2])
except:
break
一些高层操作:
dataset.map(func)
dataset.shuffle(buffer_size)
dataset.batch(batch_size)
dataset.repeat(N)
具体用法见实例:
import tensorflow as tf
train_files = tf.train.match_filenames_once('tfrecords/train_file-*')
test_file = tf.train.match_filenames_once('tfrecords/test_file-*')
def preprocess_for_train(image,height,width,bbox):
pass
def build_net(input):pass
def calc_loss(logit,label):pass
def parse(record):
features = tf.parse_single_example(
record,
features={
'image':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([],tf.int64),
'height':tf.FixedLenFeature([],tf.int64),
'width':tf.FixedLenFeature([],tf.int64),
'channels':tf.FixedLenFeature([],tf.int64)
}
)
image_data = tf.decode_raw(features['image'],tf.uint8)
image_data.set_shape([features['height'],features['width'],features['channels']])
label = features['label']
return image_data,label
if __name__ == '__main__':
image_size = 299
batch_size = 100
shuffle_buffer = 10000
#读取数据集
dataset = tf.data.TFRecordDataset(train_files)
#将tfrecord转化成image,label的格式
dataset = dataset.map(parse)
#对image进行预处理
dataset = dataset.map(
lambda image,label:(preprocess_for_train(image,image_size,image_size,None),label)
)
#将数据集的顺序打乱,shuffle_buffer指定了队列中最少的元素个数
dataset = dataset.shuffle(shuffle_buffer)
#指定每次从迭代器中读出的数据个数,默认为1
dataset = dataset.batch(batch_size)
num_epoches = 10
#将数据集中的数据重复num_epoches次,由于之前使用了shuffle,因此每个副本的顺序都不一定相同
dataset = dataset.repeat(num_epoches)
iterator = dataset.make_initializable_iterator()
image_batch,label_batch = iterator.get_next()
learning_rate = 0.01
#构建网络,得到结果
logit = build_net(image_batch)
#结算损失
loss = calc_loss(logit,label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
test_dataset = tf.data.TFRecordDataset(test_file)
test_dataset = test_dataset.map(parse).map(
lambda image,label:(tf.image.resize_images(image,[image_size,image_size]),label)
)
test_dataset = test_dataset.batch(batch_size)
test_iterator = test_dataset.make_initializable_iterator()
test_image_batch,test_label_batch = test_dataset.get_next()
test_logit = build_net(test_image_batch)
prediction = tf.argmax(test_logit,1)
with tf.Session() as sess:
#使用filename_match_once函数要初始化local_variables
#使用迭代器要初始化iterator.initializer
#global_variables一般都会初始化
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer(),iterator.initializer])
while True:
try:
sess.run(train_step)
except:
break
sess.run(test_iterator.initializer)
test_results = []
test_labels = []
while True:
try:
pred,label = sess.run([prediction,test_label_batch])
test_results.extend(pred)
test_labels.extend(label)
except:
break
correct = [float(y==y_) for (y,y_) in zip(test_results,test_labels)]
acc = sum(correct)/len(correct)
print(acc)