```
#!/usr/bin/python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import numpy as np
import tensorflow as tf
from utils import *
image_path = "./ck/cohn-kanade-images/"
label_path = "./ck/Emotion_labels/"
Imglist, ImgLabellist, Labellist = GetImg_Label_list(image_path, label_path) #得到图片和标签相应的列表
def parse_data(filename,label):
'''
导入数据,进行预处理,输出两张图像,
分别是输入图像和标签
Args:
filaneme, 图片的路径
Returns:
处理后图像,标签
'''
# 读取图像
image = tf.read_file(filename)
# 解码图片
image = tf.image.decode_image(image)
# 数据预处理,或者数据增强,这一步根据需要自由发挥
image = tf.image.crop_to_bounding_box(image, 0, 0, 64, 64)
# 数据增强,随机水平翻转图像
image = tf.image.random_flip_left_right(image)
# 图像归一化
image = tf.cast(image, tf.float32) / 255.0
return image, label
def train_generator(batchsize, shuffle=True):
with tf.Session() as sess:
# 创建数据库
train_dataset = tf.data.Dataset().from_tensor_slices((Imglist, Labellist))
# 预处理数据
train_dataset = train_dataset.map(parse_data)
# 设置 batch size
train_dataset = train_dataset.batch(batchsize)
# 无限重复数据
train_dataset = train_dataset.repeat()
# 洗牌,打乱
if shuffle:
train_dataset = train_dataset.shuffle(buffer_size=4)
# 创建迭代器
train_iterator = train_dataset.make_initializable_iterator()
sess.run(train_iterator.initializer)
train_batch = train_iterator.get_next()
# 开始生成数据
while True:
try:
x_batch, y_batch = sess.run(train_batch)
yield (x_batch, y_batch)
except:
# 如果没有 train_dataset = train_dataset.repeat()
# 数据遍历完就到end了,就会抛出异常
train_iterator = train_dataset.make_initializable_iterator()
sess.run(train_iterator.initializer)
train_batch = train_iterator.get_next()
x_batch, y_batch = sess.run(train_batch)
yield (x_batch, y_batch)
#检查结果是否正确
x_batch = train_generator(16)
for i in range(5):
x,y = next(x_batch)
print(x,y) # 结果为一个batch 为16的数据包括图片和对应的标签
```