最近开始使用tensorflow自己搭建网络的框架,遇到的第一个问题就是怎么样批量的将文件中的训练图片读进来,并且可以每次抓取一部分图像数据来进行训练。经过了几天的学习,终于在最后完成了批量图片读入的实验。这次学习主要参考的博客文章是这一系列的。tensorflow入门之猫狗大战
先上整体的代码:
# -*- coding: utf-8 -*-
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
"""
Created on Fri Jul 13 20:03:46 2018
@author: Lenovo
"""
N_CLASSES = 2
IMG_W = 208 # resize图像,太大的话训练时间久
IMG_H = 208
BATCH_SIZE = 1
CAPACITY = 2000
MAX_STEP = 10000 # 一般大于10K
learning_rate = 0.0001 # 一般小于0.0001
left_dir = 'E:/Study/研究生文档/密集匹配程序/train/image_2'
right_dir = 'E:/Study/研究生文档/密集匹配程序/train/image_3'
groundtruth_dir = 'E:/Study/研究生文档/密集匹配程序/train/disp_noc_0'
tf.reset_default_graph()#这一句话非常重要,如果没有这句话,就会出现重复定义变量的错误
left=[]
right=[]
groundtruth=[]
left_label=[]
right_label=[]
groundtruth_label=[]
def get_imagelist(path):#将指定文件夹中的图片的地址保存为列表返回
image=[]#用于存储图片文件地址列表
label=[]#用于存储图片的标签
for file in os.listdir(left_dir):#os.listdir(file_dir+'/cat')#这个神器可以遍历一个文件夹中的文件
image.append(left_dir + '/'+file)
label.append(0)
image_list=np.hstack((image))
label_list=np.hstack((label))
#shuffle打乱
temp = np.array([image_list, label_list])
temp = temp.transpose()
np.random.shuffle(temp)
#将所有的img和lab转换成list
all_image_list=list(temp[:,0])
all_label_list=list(temp[:,1])
all_label_list = [int(i) for i in all_label_list]#将list里面的每一个值从字符串形式变成int类型
return all_image_list,all_label_list
def get_batch(image_list,label_list,IMG_W,IMG_H,BATCH_SIZE,CAPACITY):#从输入的图片、标签列表中随机抓取一个batch的数据,然后生成需要的图片格式返回
image=tf.cast(image_list,tf.string)#这个函数将一般的string转为tf.string,因为只有tf格式的数据才能参与tensorflow的运算,这也可以说是tensorflow的一个不好的地方
label=tf.cast(label_list,tf.int32)
#入队
input_queue=tf.train.slice_input_producer([image,label])#将label和image建立一个队列
label=input_queue[1]
image_contents=tf.read_file(input_queue[0]) #读取图像,tensorflow的这个read_file函数很强大,可以一次性读取很多的图片,而不是一张一张的读取图片。但是这时候这些数据需要进行解码,不是真正的图片数据
#s2图像解码,且必须是同一类型
image=tf.image.decode_png(image_contents,channels=3)#这可以说是tensorflow的一个bug吧,支持的图片格式太少了
#s3预处理,主要包括旋转,缩放,裁剪,归一化
image = tf.image.resize_image_with_crop_or_pad(image, IMG_W, IMG_H)
image = tf.image.per_image_standardization(image)
#s4生成batch
image_batch, label_batch = tf.train.batch([image, label], batch_size= BATCH_SIZE,num_threads= 32, capacity = CAPACITY)
#重新排列label,行数为[batch_size]
label_batch = tf.reshape(label_batch, [BATCH_SIZE])
image_batch = tf.cast(image_batch, tf.float32)
return image_batch,label_batch#用这种方法来获得批量的样本其实并不好,这种方法参与的数据类型只能是tensor,不能中间输出来检查检查。另外,这种方式支持的格式太少了,如pfm格式就无法支持。
def showimgs(image_batch,label_batch):#测试获取的batch中的图片能否正常显示
with tf.Session() as sess:
i = 0#如果不利用i这个条件来控制的话,会将所有的图片显示出来
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop() and i<1:
img, label = sess.run([image_batch, label_batch])
# just test one batch
for j in np.arange(BATCH_SIZE):
print('label: %d' %label[j])
plt.imshow(img[j,:,:,:])
plt.show()
i=i+1
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
coord.join(threads)
leftimglist,leftlabellist=get_imagelist(left_dir)
left_image_batch,left_label_batch=get_batch(leftimglist,leftlabellist,IMG_W,IMG_H,BATCH_SIZE,CAPACITY)
showimgs(left_image_batch,left_label_batch)
rightimglist,rightlabellist=get_imagelist(right_dir)
right_image_batch,right_label_batch=get_batch(rightimglist,rightlabellist,IMG_W,IMG_H,BATCH_SIZE,CAPACITY)
showimgs(right_image_batch,right_label_batch)
groundtruthimglist,groundtruthlabellist=get_imagelist(groundtruth_dir)
groundtruth_image_batch,groundtruth_label_batch=get_batch(groundtruthimglist,groundtruthlabellist,IMG_W,IMG_H,BATCH_SIZE,CAPACITY)
showimgs(groundtruth_image_batch,groundtruth_label_batch)
其中有几个重要的部分需要说明一下,这个小程序中包含三个主要的函数,一个是
def get_imagelist(path)
return all_image_list,all_label_list
这个函数主要是将指定路径下的图片的相关地址都记录在一个列表当中,然后返回所有带有训练图片地址的列表。
其中:
for file in os.listdir(left_dir):#os.listdir(file_dir+'/cat')#这个神器可以遍历一个文件夹中的文件
image.append(left_dir + '/'+file)
label.append(0)
os.listdir()这个函数可以遍历一个文件夹中所有文件的路径。它的参数是文件夹所在的路径。然后通过image.append()这个语句将文件的路径添加到列表当中。这是常见的一种获取样本列表的方法。
def get_batch(image_list,label_list,IMG_W,IMG_H,BATCH_SIZE,CAPACITY)
这个函数的主要作用就是将文件列表分批分批的读取图片数据。
def showimgs(image_batch,label_batch):#测试获取的batch中的图片能否正常显示
这个函数的作用是将批量读取的样本图片文件进行显示。