TensorFlow笔记(一)tensorflow加载数据的三种方式

最近在看TF2.0的内容,顺便把以前的内容也做下笔记,以便查阅。所有程序在不注明的情况下,默认使用tensorflow1.14版本。

数据加载是训练模型的第一步,合理的数据加载方式虽然不会对模型效果有促进作用,但是会大大加快训练过程。TensorFlow中常用的数据加载方式有四种:

  • 内存对象数据集,在学习阶段最常见的数据加载方式,在session中直接用字典变量feed_dict给变量喂数据,这种方式适用于数据量比较少的情况下。
  • TFRecord数据集,用tfRecord向模型喂数据,适用于大量数据集情况。
  • Dataset数据集,通过高级API(tf.data)给模型喂数据,也是TensorFlow高级版本比较推荐的方式,在开发中也比较建议使用这种方式。
  • tf.keras等高级接口,TensorFlow中的高级封装框架,如keras、slim等,会有自己的数据集接口,对于这些高级API比较熟悉的同学,可以使用这些API固有的接口,尤其是tf.kears,建议大家多多熟悉。

1、内存对象数据集

内存对象数据集比较常见,在学习的时候,常常会直接在内存中模拟数据,然后在sess.run中喂数据。这种方法简单直接高效,也比较容易理解,但是在大数据量情况下因为数据都在内存中则不适用。这种方式资料较多,就不做过多介绍了。这里介绍两种改进方法。

(1)第一种改进方法是结合python生成器(yield)和多线程(生产者消费者模型),或者只使用yield加载数据,可以实现每次加载一个batch的数据训练模型。

(2)利用TF中的队列API实现(1)中的功能。具体来说就是使用一个线程源源不断的将硬盘中的数据文件名读入到一个内存队列中,另一个线程负责计算任务,所需数据直接从内存队列中获取。先介绍几个需要用到的API:

tf.train.slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
                         capacity=32, shared_name=None, name=None)
##tf.train.slice_input_producer,从tensor_list中抽取tensor,准备放入队列中
  • tensor_list:包含一系列tensor的列表,表中tensor的第一维度的值必须相等,即个数必须相等,有多少个图像,就应该有多少个对应的标签。使用时常把训练数据和标签放在一个列表中。
  • num_epochs: 可选参数,是一个整数值,代表迭代的次数,如果设置 num_epochs=None,生成器可以无限次遍历tensor列表,如果设置为 num_epochs=N,生成器只能遍历tensor列表N次,遍历结束以后会抛出tf.errors.OutOfRangeError异常。
  • shuffle: bool类型,设置是否打乱样本的顺序。一般情况下,如果shuffle=True,生成的样本顺序就被打乱了,在批处理的时候不需要再次打乱样本,使用 tf.train.batch函数就可以了;如果shuffle=False,就需要在批处理时候使用 tf.train.shuffle_batch函数打乱样本。
  • seed: 可选的整数,是生成随机数的种子,在第三个参数设置为shuffle=True的情况下才有用。
  • capacity:设置tensor列表的容量。
  • shared_name:可选参数,如果设置一个‘shared_name’,则在不同的上下文环境(Session)中可以通过这个名字共享生成的tensor。
  • name:可选,设置操作的名称。
tf.train.batch(tensors, batch_size, num_threads=1, capacity=32,
          enqueue_many=False, shapes=None, dynamic_pad=False,
          allow_smaller_final_batch=False, shared_name=None, name=None)
  • tensors:tensor序列或tensor字典,可以是含有单个样本的序列,在使用的时候可以把tf.train.slice_input_producer的返回结果传赋给该值
  • batch_size: 生成的batch的大小;
  • num_threads:执行tensor入队操作的线程数量,可以设置使用多个线程同时并行执行,提高运行效率,但也不是数量越多越好;
  • capacity: 定义生成的tensor序列的最大容量;
  • enqueue_many: 定义第一个传入参数tensors是多个tensor组成的序列,还是单个tensor;
  • shapes: 可选参数,默认是推测出的传入的tensor的形状;
  • dynamic_pad: 定义是否允许输入的tensors具有不同的形状,设置为True,会把输入的具有不同形状的tensor归一化到相同的形状;
  • allow_smaller_final_batch: 设置为True,表示在tensor队列中剩下的tensor数量不够一个batch_size的情况下,允许最后一个batch的数量少于batch_size, 设置为False,则不管什么情况下,生成的batch都拥有batch_size个样本;
  • shared_name: 可选参数,设置生成的tensor序列在不同的Session中的共享名称;
  • name: 操作的名称;
tf.train.start_queue_runners(sess=None, coord=None, daemon=True, start=True,
                        collection=ops.GraphKeys.QUEUE_RUNNERS)
  • sess:使用的session,默认是默认session
  • coord:线程协调器
  • daemon:默认为True,表示是否把线程标记为守护,设为True,表示不会阻塞程序退出
  • start:默认为True,如果设为False,表示只创建线程,不启动线程
  • collection:指定获取的队列运行集合

启动队列之前,还需要通过 tf.train.Coordinator(clean_stop_exception_types=None) 类建立一个线程协调器,用来管理之后在Session中启动的所有线程,并将其传给tf.train.start_queue_runners的coord参数。操作示例如下:

##读取图片路径和标签,改为自己的数据即可
def load_sample(sample_dir):
    lfilenames = []
    labelsnames = []
    for (dirpath, dirnames, filenames) in os.walk(sample_dir):
        for filename in filenames:
            #print(dirnames)
            filename_path = os.sep.join([dirpath, filename])
            lfilenames.append(filename_path) 
            labelsnames.append( dirpath.split('\\')[-1] )

    lab= list(sorted(set(labelsnames)))
    labdict=dict( zip( lab  ,list(range(len(lab)))  ))
    labels = [labdict[i] for i in labelsnames]

    return shuffle(np.asarray( lfilenames),np.asarray( labels))

##返回batch
def get_batches(image,label,input_w,input_h,channels,batch_size):
    queue = tf.train.slice_input_producer([image,label])  #使用tf.train.slice_input_producer实现一个输入的队列
    label = queue[1]                                        #从输入队列里读取标签

    image_c = tf.read_file(queue[0])                        #从输入队列里读取image路径
    image = tf.image.decode_bmp(image_c,channels)           #按照路径读取图片

    image = tf.image.resize_image_with_crop_or_pad(image,input_w,input_h) #修改图片大小
    image = tf.image.per_image_standardization(image) #图像标准化处理,(x - mean) / adjusted_stddev
    image_batch,label_batch = tf.train.batch([image,label],#调用tf.train.batch函数生成批次数据
               batch_size = batch_size,
               num_threads = 64)

    images_batch = tf.cast(image_batch,tf.float32)   #将数据类型转换为float32
    labels_batch = tf.reshape(label_batch,[batch_size])#修改标签的形状shape

    return images_batch,labels_batch


batch_size = 16
image_batches,label_batches = get_batches(image,label,28,28,1,batch_size)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)  #初始化

    coord = tf.train.Coordinator()          #创建一个线程协调器,开启列队
    threads = tf.train.start_queue_runners(sess = sess,coord = coord)
    try:
        for step in np.arange(10):
            if coord.should_stop():
                break
            images,label = sess.run([image_batches,label_batches]) #注入数据
            print(label) 

    except tf.errors.OutOfRangeError:
        print("Done!!!")
    finally:
        coord.request_stop()

    coord.join(threads)                             #关闭列队

2、TFRecord数据集

TFRecord也是一种非常好用的读取数据的方法,并且是一种非常高效的数据持久化方法,尤其是对于需要预处理的数据。。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。一次处理,永久使用。另外一种类似的方法是使用 lmdb 库处理数据。

制作tfrecords文件:

def makeTFRec(filenames,labels): 
    #定义函数生成TFRecord,filenames是数据路径列表,labels是标签列表
    writer= tf.python_io.TFRecordWriter("mydata.tfrecords") #通过tf.python_io.TFRecordWriter 写入到TFRecords文件
    for i in tqdm( range(0,len(labels) ) ):
        img=Image.open(filenames[i])
        img = img.resize((256, 256))
        img_raw=img.tobytes()#将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
                #存放图片的标签label
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]])),
                #存放具体的图片
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            })) #example对象对label和image数据进行封装

        writer.write(example.SerializeToString())  #序列化为字符串
    writer.close()  #数据集制作完成

makeTFRec(filenames,labels)

makeTFRec的参数获取可以参考1中的load_sample函数。这里主要有3个API:

  1. tf.python_io.TFRecordWriter(path, options=None),根据path创建一个tfrecords文件,并返回一个TFRecordWriter实例去写入数据
  2. tf.train.Example。下面是Example协议块,我们可以看出tf_example可以写入的数据形式有三种,分别BytesListFloatList以及Int64List的类型(注意没有string)。
    message Example {
      Features features = 1;
    };
    
    message Features {
      map feature = 1;
    };
    
    message Feature {
      oneof kind {
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
      }
    };

     

  3. example.SerializeToString()序列化为字符串

读取tfrecords文件:

def read_and_decode(filenames,batch_size = 3):
    #根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer(filenames)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example, #取出包含image和label的feature对象
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })

    #tf.decode_raw可以将字符串解析成图像对应的像素数组
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image, [256,256,3])
    label = tf.cast(features['label'], tf.int32)

    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5     #归一化
    img_batch, label_batch = tf.train.batch([image, label],batch_size=batch_size, capacity=20) ##注意设置capacity大小            

    return img_batch, label_batch

TFRecordfilenames = ["mydata.tfrecords"]
image, label = read_and_decode(TFRecordfilenames)

with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)

    try:
        for i in range(5):
            example, examplelab = sess.run([image,label])#在会话中取出image和label
            ##这里的example, examplelab大小是batchsize

    except tf.errors.OutOfRangeError:
        print('Done Test -- epoch limit reached')
    finally:
        coord.request_stop()
        coord.join(threads)
        print("stop()")

3、Dataset数据集

tf.data.Dataset是TF比较推荐的数据处理接口。Dataset可以看作是相同类型“元素”的有序列表。在实际使用时,单个“元素”可以是向量,也可以是字符串、图片,甚至是tuple或者dict。tf.data.Dataset接口是通过创建Dataset对象来生成数据集的,并且Dataset对象可以做shuffle、map、iterate、zip、repeat、batch、flat_map、apply、filter等操作。使用demo如下:

def make_dataset(directory,batchsize):
    filenames,labels =load_sample(directory,shuffleflag=False) #载入文件名称与标签
    def _parseone(filename, label):                         #解析一个图片文件
        """ Reading and handle  image"""
        image_string = tf.read_file(filename)         #读取整个文件
        image_decoded = tf.image.decode_image(image_string)
        image_decoded = tf.cast(image_decoded,dtype=tf.float32)
        label = tf.cast(tf.reshape(label, []) ,dtype=tf.int32)#将label 转为张量
        return image_decoded, label

    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))#生成Dataset对象
    dataset = dataset.map(_parseone).shuffle(buffersize=1000).repeat().batch(batchsize) #批次划分数据集

    return dataset

path = "data"
dataset = make_dataset(path,32)
iterator = dataset.make_one_shot_iterator()	 #生成一个迭代器
one_element = iterator.get_next()			#从iterator里取出一个元素,实际大小是batchsize

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    try:
        for step in np.arange(1):
            value = sess.run(one_element)
            ##value就是图片和标签值

    except tf.errors.OutOfRangeError:           #捕获异常
        print("Done!!!")

load_sample函数可以参考1中的程序。

  • dataset.map(map_func,num_parallel_calls=None):通过map_func函数转换数据集元素,返回新数据集;num_parallel_calls表示并行线程数。
  • dataset.shuffle(buffer_size,seed=None,reshuffle_each_iteration=None):随机打乱顺序,buffer_size越大越混乱;seed随机种子;reshuffle_each_iteration是否每次迭代都随机乱序。
  • dataset.repeat(count=None):生成重复的数据集,count代表重复次数,默认无限次重复。
  • dataset.batch(batch_size,drop_remainder=False):批次取数据,batch_size批次大小;drop_remainder是否忽略批次组合后剩余的数据,默认为False,会把最后剩余的数据
  • dataset.filter(predicate):对整个数据集过滤,留下使函数predicate为True的数据。

 

参考资料

https://blog.csdn.net/dcrmg/article/details/79780331

https://blog.csdn.net/lyb3b3b/article/details/82910863

你可能感兴趣的:(深度学习)