Tenosorflow基础学习---------Tensorflow训练自己的数据集

一般我们获得的数据集并非是提前处理好的二进制的格式文件,所以我们需要将数据集进行处理,当然我们这里说的数据集类似于猫狗大战那样的,并不是MNIST和CIFAR-10那样拿来就可以直接用的,而且提前分类和标签的数据集,只不过给的是大量的图片,一般都是比赛提供方给的数据集,而对于这样的数据集当然不可能整张输入和读取,这样不仅的数据不仅数据量大,需要大量的内存消耗,而且时间也是相当的慢,于是在tensorflow中提供了一种专门用于tensorflow的数据集的格式转换。

目录

1.第一步:数据集的加工

2.第二步:图片数据集转化为Tensorflow专用格式

2.1附加读取数据


1.第一步:数据集的加工

数据集中的数据并不是按照规格大小处理,对于不同的的图片,其规格尺寸都不尽相同,因此在数据集提交之前需要对数据集进行处理。对于图像的处理我这里用的是opencv,opencv的功能真的方便,强力推荐大家!!

最简单的方式就是把数据裁剪成规定大小。比如输入模型中的图片大小为[227,227],因此我们这里将图片裁剪成[227,227]的尺寸,给出代码示例:

import cv2
import os
def rebulid(dir):
    #walk(top, topdown=True, onerror=None, followlinks=False)
    #top-是你所要遍历的目录的地址,返回的是一个三元组(root,dirs,files)
    #root所指的是当前正在遍历的这个文件夹的本身的地址
    #dirs是一个list,内容是该文件夹中所有的目录的名字(不包括子目录)
    #files同样是list,内容是该文件夹中所有的文件(不包含子目录)
    #topdown--可选,为True,则优先遍历top目录,否则优先遍历top的子目录(默认为开启)。
    # 如果topdown参数为True,walk会遍历top文件夹中每一个子目录
    #onerror--可选,需要一个callable对象,当walk需要异常时,会调用
    #followlinks--可选,如果为True,则会遍历目录下的快捷方式实际所指的目录
    for root,dirs,files in os.walk(dir):
        for file in files:
            filepath = os.path.join(root,file)
            try:
                image = cv2.imread(filepath)
                dim = (227,227)
                resized = cv2.resize(image,dim)
                path = "C:\\cat_and_dog\\dog-r\\"+file
                cv2.imwrite(path,resized)
            except:
                print(filepath)
                os.remove(filepath)
        cv2.waitKey(0)   #退出

在这里导入的是图片集的根目录,os对数据集所在的文件夹进行读取,之后的一个for循环重建了图片数据所在的路径(filepath),在图片被重构后重新写入给定的位置(path)。

2.第二步:图片数据集转化为Tensorflow专用格式

对于数据集来说,最好的办法就是将其转换成Tensorflow专用的数据格式,即TFRecord格式。

将裁剪后的图片的位置进行读取,之后根据文件名称的不同将处于不同文件夹中的图片标签设置为0或者1,如果有更多分类的话可以依据这个格式设置更多的标签类,之后使用创建的数组对所读取的文件位置和标签进行保存,而Numpy对数组的调整重构了存储有对应文件位置和文件标签的矩阵,并返回。

def get_file(file_dir):
    images=[]
    temp = []
    for root,sub_folders,files in os.walk(file_dir):
        #图片目录
        for name in files:
            images.append(os.path.join(root,name))
        #get 10 sub-folders:
        for name in sub_folders:
            temp.append(os.path.join(root,name))
        print(files)
    #根据文件夹名分配多个标签
    labels = []
    for one_folder in temp:
        n_img = len(os.listdir(one_folder))
        #split('\\')[-1]以\\分割字符串,保留最后一段
        letter = one_folder.split('\\')[-1]
        if letter == 'cat':
            #n_img*[0]和np.zeros[n_img]一样
            labels = np.append(labels,n_img*[0])
        else:
            #n_img*[1]和np.ones[n_img]一样
            labels = np.append(labels,n_img*[1])
    temp = np.array([images,labels])
    #转置
    temp = temp.transpose()
    #shuffle 随机排列
    np.random.shuffle(temp)

    image_list = list(temp[:,0])
    label_list = list(temp[:,1])
    label_list = [int(float(i) for i in label_list)]

    return image_list,label_list

在获取图片数据文件位置和图片标签之后,即可通过相应的程序对其进行读取,并生成专门用的TFRecord格式的数据集

首先是转换格式的定义,这里需要将数据转换为相应的格式。

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def convert_to_tfrecord(images_list,lables_list,save_dir,name):
    filename = os.path.join(save_dir,name+'.tfrecords')
    n_samples = len(lables_list)
    writer = tf.python_io.TFRecordWriter(filename)
    print('\nTransform start.......')
    for i in np.arange(0,n_samples):
        try:
            image = io.imread(images_list[i])
            image_raw = image.tostring()
            label = int(label[i])
            example = tf.train.Example(features=tf.train.Feature(feature={
                'label':int64_feature(label),
                'image_raw':bytes_feature(image_raw)}))
        except IOError as e:
            print('Could not read:',images[i])
        writer.close()
        print('Transform done!')

convert_to_tfrecord(images_list,labels_list,save_dir,name)函数中需要4个参数,其中images_list和labels_list是上一段代码段获取的图片位置和对应标签的列表。save_dir是存储路径,如果希望生成的TFRecord文件存储在当前目录下,直接使用空的双引号""即可。最后是生成的文件名,这里只需填写名称就会自动生成以".tfrecords"格式结尾的数据集。

2.1附加读取数据

当生成完数据集后,在神经网络使用数据集进行训练时,需要一个方法将数据从数据急中取出,下面代码段完成了数据读取功能。

def read_and_decode(tfrecords_file,batch_size):
    filename_queue = tf.train.string_input_producer([tfrecords_file])
    reader = tf.TFRecordReader()
    _,serialized_example = reader.read(filename_queue)
    img_features = tf.parse_single_example(
        serialized_example,
        features={
            'label':tf.FixedLenFeature([],tf.int64),
            'image_raw':tf.FixedLenFeature([],tf.string),
        }
    )
    image = tf.decode_raw(img_features['image_raw'],tf.uint8)
    image = tf.reshape(image,[227,227,3])
    lable = tf.cast(img_features['label'],tf.int32)
    image_batch,lable_batch = tf.train.shuffle_batch([image,label],
                                                     batch_size=batch_size,
                                                     min_after_dequeue=100,
                                                     num_threads=64,
                                                     capacity=200)
    return image_batch,tf.reshape(lable_batch,[batch_size])

 

你可能感兴趣的:(tensorflow,tensorflow,TFRecord)