tensorflow 读取自己的数据集用于训练

```

#!/usr/bin/python

# -*- coding: utf-8 -*-

from __future__ import division

import os

import numpy as np

import tensorflow as tf

from utils import *

image_path = "./ck/cohn-kanade-images/"

label_path = "./ck/Emotion_labels/"

Imglist, ImgLabellist, Labellist = GetImg_Label_list(image_path, label_path)  #得到图片和标签相应的列表


def parse_data(filename,label):

    '''

    导入数据,进行预处理,输出两张图像,

    分别是输入图像和标签

    Args:

        filaneme, 图片的路径

    Returns:

        处理后图像,标签

    '''

    # 读取图像

    image = tf.read_file(filename)

    # 解码图片

    image = tf.image.decode_image(image)

    # 数据预处理,或者数据增强,这一步根据需要自由发挥

    image = tf.image.crop_to_bounding_box(image, 0, 0, 64, 64)

    # 数据增强,随机水平翻转图像

    image = tf.image.random_flip_left_right(image)

    # 图像归一化

    image = tf.cast(image, tf.float32) / 255.0

    return image, label

def train_generator(batchsize, shuffle=True):

    with tf.Session() as sess:

        # 创建数据库

        train_dataset = tf.data.Dataset().from_tensor_slices((Imglist, Labellist))

        # 预处理数据

        train_dataset = train_dataset.map(parse_data)

        # 设置 batch size

        train_dataset = train_dataset.batch(batchsize)

        # 无限重复数据

        train_dataset = train_dataset.repeat()

        # 洗牌,打乱

        if shuffle:

            train_dataset = train_dataset.shuffle(buffer_size=4)

        # 创建迭代器

        train_iterator = train_dataset.make_initializable_iterator()

        sess.run(train_iterator.initializer)

        train_batch = train_iterator.get_next()

        # 开始生成数据

        while True:

            try:

                x_batch, y_batch = sess.run(train_batch)

                yield (x_batch, y_batch)

            except:

                # 如果没有  train_dataset = train_dataset.repeat()

                # 数据遍历完就到end了,就会抛出异常

                train_iterator = train_dataset.make_initializable_iterator()

                sess.run(train_iterator.initializer)

                train_batch = train_iterator.get_next()

                x_batch, y_batch = sess.run(train_batch)

                yield (x_batch, y_batch)

#检查结果是否正确

x_batch = train_generator(16)

for i in range(5):

    x,y = next(x_batch)

    print(x,y)    # 结果为一个batch 为16的数据包括图片和对应的标签

```

你可能感兴趣的:(tensorflow 读取自己的数据集用于训练)