cifar10+tfrecords

import os
import tensorflow as tf
import cv2
import pickle
import matplotlib.pyplot as plt
%matplotlib inline

#note:download cifar10 dataset first

def load_data(path):
    with open(path,'rb') as f:
        #data = pickle.load(f,encoding='bytes')
        #x = data[b'data']
        #y = data[b'labels']
        # or 
        data = pickle.load(f,encoding='latin1')
        x = data['data']
        y = data['labels']
       #x = np.reshape(x,[-1,3,32,32])
        x = np.reshape(x,[-1,3*32*32]).astype('float32')/255.
        y = np.array(y).astype('int64')
        return x,y
    
def load_train(root='cifar10/data_batch_'):
    xs = []
    ys = []
    for i in range(1,6):
        x,y = load_data(root+str(i))
        xs.append(x)
        ys.append(y)
    train_x = np.concatenate(xs)
    train_y = np.concatenate(ys)
    return train_x,train_y

def load_test():
    return load_data('cifar10/test_batch')

def create_tfRecords(name='train'):
    if name=='train':
        x,y = load_train()
        writer = tf.python_io.TFRecordWriter('tfRecords/train.tfrecords')
    else:
        x,y = load_test()
        writer = tf.python_io.TFRecordWriter('tfRecords/test.tfrecords')
    for i in range(y.shape[0]):
        img,label = x[i].tobytes(),y[i].tobytes() #to byte
        example = tf.train.Example(features=tf.train.Features(feature={
            'img':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])),
            'label':tf.train.Feature(bytes_list=tf.train.BytesList(value=[label]))
        }))
        writer.write(example.SerializeToString())
#create tfrecords
create_tfRecords()
create_tfRecords('test')


#read tfrecords
def _parse(example):
    features = {
        'img':tf.FixedLenFeature((),tf.string),
        'label':tf.FixedLenFeature((),tf.string)
    }
    parse_example = tf.parse_single_example(example,features)
    img = parse_example['img'] #byte
    label = parse_example['label']#byte
    img = tf.decode_raw(img,out_type=tf.float32)
    img = tf.reshape(img,[3,32,32])
    img = tf.transpose(img,[1,2,0])
    label = tf.decode_raw(label,tf.int64)
    label = tf.reshape(label,[]) #scalar
    return img,label
    
def load_tfRecords(name='train'):
    file = 'tfRecords/'+name+'.tfrecords'
    ds = tf.data.TFRecordDataset(file)
    ds = ds.map(_parse)
    ds = ds.shuffle(1024)
    ds = ds.batch(32)
    ds = ds.repeat()
    it = ds.make_one_shot_iterator()
    next_data = it.get_next()
    return next_data

train_next = load_tfRecords()
#test 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    x,y = sess.run(train_next)
    plt.imshow(x[0])
    plt.show()

你可能感兴趣的:(tf)