mnist手写数字集制作tfrecords数据格式

一、tfrecords文件

tfrecords是一种二进制文件,可先将图片与标签制作成该格式的文件,使用tfrecords进行数据读取,会提高内存利用率,将不同输入文件统一起来。

二、mnist数据集

MNIST数据集是一个手写体数字集合,可到此处下载,数据包括四部分:训练图片集、训练标签集、测试图片集、测试标签集。该数据集的训练集中有55000张图片,验证集中5000有张图片,测试集中是10000张图片。
mnist手写数字集制作tfrecords数据格式_第1张图片

三、数据制作与读取

1、生成文件

文件生成的过程:

  • 新建一个writer
  • for循环遍历每张图和标签
  • 把每张图和标签封装到example中
  • 将example序列化

具体代码如下:

def generate_tfRecord():
    isExists=os.path.exists(data_path)   ##判断保存路径是否存在
    if not isExists:
        os.makedirs(data_path)
        print('路径创建成功')
    else:
        print('路径已存在')
    write_tfRecord(tfRecord_train, image_train_path, label_train_path)   ##使用自定义函数将训练集生成名叫tfRecord_train的tfrecords文件
    write_tfRecord(tfRecord_test, image_test_path, label_test_path)   ##同理训练集

def write_tfRecord(tfRecordName, image_path, label_path):
    writer=tf.python_io.TFRecordWriter(tfRecordName)   ##创建一个writer
    num_pic=0   ##计数器
    f=open(label_path,'r')  ##以读的形式打开标签文件
    contents = f.readlines()   ##读取整个文件内容
    f.close()
    for content in contents:
        value=content.split()   ##以空格分隔每行的内容,分割后组成列表value
        img_path=image_path+value[0]
        img=Image.open(img_path)  ##打开图片
        img_raw=img.tobytes()  ##将图片转换为二进制数据
        labels=[0]*10
        labels[int(value[1])]=1   ##将labels所对应的标签为赋值为1

        example=tf.train.Example(features=tf.train.Features(feature={                         ##创建一个example,用一个features进行封装
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),    ##在img_raw放入二进制图片
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))}))               ##在labels放入图片对应的标签
        writer.write(example.SerializeToString())    ##将example进行序列化
        num_pic+=1
        print('图片数为:',num_pic)
    writer.close()
    print('tfREcord文件写入成功')

其中标签文件的格式为:
mnist手写数字集制作tfrecords数据格式_第2张图片

2、读取文件

文件读取的过程:

  • 新建一个reader
  • 解序列化example读取图片和标签
  • 将图片和标签转化为网络需要的格式

具体代码如下:

###实现了批获取训练集或测试集的图片和标签
def get_tfRecord(num, isTrain=True):  ##参数num表示一次读取多少组
    if isTrain:
        tfRecord_path=tfRecord_train
    else:
        tfRecord_path=tfRecord_test
    img,label=read_tfRecord(tfRecord_path)
    img_batch, label_batch= tf.train.shuffle_batch([img, label],
                                                   batch_size=num,
                                                   capacity=1000,
                                                   min_after_dequeue=700,
                                                   num_threads=2)

def read_tfRecord(tfRecord_path):
    filmname_queue=tf.train.string_input_producer([tfRecord_path])
    reader=tf.TFRecordReader()   ##新建一个reader
    _,serialized_example = reader.read(filmname_queue)  ##将读出的每一个样本保存到serialized_example中进行解序列化
    features=tf.parse_single_example(serialized_example,features={     ##将图片和标签的键值要和制作数据集是的键值相同
        'img_raw':tf.FixedLenFeature([],tf.string),
        'label':tf.FixedLenFeature([10],tf.int64)})
    img=tf.decode_raw(features['img_raw'],tf.uint8)   ##将img_raw字符串转化为8位无符号整型
    img.set_shape([784])
    img=tf.cast(img,tf.float32)*(1./255)   ##转化为浮点数形式
    label=tf.cast(features['label'],tf.float32)
    return img, label

你可能感兴趣的:(数据)