很多同学(针对新手)在训练mnist数据的时候,根据书本上的内容都可以很好很快的编辑并跑出来,但是一旦换成自己的文件夹,就很头疼,毕竟mnist里面一个read_data解决你所有的输入问题,然而在现实中,该read_data是要自己编辑的,本文主要针对非ont_hot数据,如何利用tensorflow搭起网络并跑通自己的数据,话不多说,直接上代码。
python版本:2.7
tensorflow 版本:1.1.0
#!/usr/bin/env python2
“”"
Created on Thu Jan 25 11:28:55 2018
@author:huangxd
“”"
“”"
vision:python3
author:huangxd
“”"
import os
import math
import numpy as np
import tensorflow as tf
#生成图片路径和标签list
#train_dir=‘C:/Users/hxd/Desktop/tensorflow_study/Alexnet_dr’
zeroclass = []
label_zeroclass = []
oneclass = []
label_oneclass = []
twoclass = []
label_twoclass = []
threeclass = []
label_threeclass = []
fourclass = []
label_fourclass = []
fiveclass = []
label_fiveclass = []
#s1 获取路径下所有图片名和路径,存放到对应列表并贴标签
def get_files(file_dir,ratio):
for file in os.listdir(file_dir+’/0’):
zeroclass.append(file_dir +’/0’+’/’+ file)
label_zeroclass.append(0)
for file in os.listdir(file_dir+’/1’):
oneclass.append(file_dir +’/1’+’/’+file)
label_oneclass.append(1)
for file in os.listdir(file_dir+’/2’):
twoclass.append(file_dir +’/2’+’/’+ file)
label_twoclass.append(2)
for file in os.listdir(file_dir+’/3’):
threeclass.append(file_dir +’/3’+’/’+file)
label_threeclass.append(3)
for file in os.listdir(file_dir+’/4’):
fourclass.append(file_dir +’/4’+’/’+file)
label_fourclass.append(4)
for file in os.listdir(file_dir+’/5’):
fiveclass.append(file_dir +’/5’+’/’+file)
label_fiveclass.append(5)
#s2 对生成图片路径和标签list打乱处理(img和label)
image_list=np.hstack((zeroclass, oneclass, twoclass, threeclass, fourclass, fiveclass))
label_list=np.hstack((label_zeroclass, label_oneclass, label_twoclass, label_threeclass, label_fourclass, label_fiveclass))
#shuffle打乱
temp = np.array([image_list, label_list])
temp = temp.transpose()
np.random.shuffle(temp)
#将所有的img和lab转换成list
all_image_list=list(temp[:,0])
all_label_list=list(temp[:,1])
#将所得List分为2部分,一部分train,一部分val,ratio是验证集比例
n_sample = len(all_label_list)
n_val = int(math.ceil(n_sample*ratio)) #验证样本数
n_train = n_sample - n_val #训练样本数
tra_images = all_image_list[0:n_train]
tra_labels = all_label_list[0:n_train]
tra_labels = [int(float(i)) for i in tra_labels]
val_images = all_image_list[n_train:]
val_labels = all_label_list[n_train:]
val_labels = [int(float(i)) for i in val_labels]
return tra_images,tra_labels,val_images,val_labels
#生成batch
#s1:将上面的list传入get_batch(),转换类型,产生输入队列queue因为img和lab
#是分开的,所以使用tf.train.slice_input_producer(),然后用tf.read_file()从队列中读取图像
def get_batch(image,label,image_W,image_H,batch_size,capacity):
#转换类型
image=tf.cast(image,tf.string)
label=tf.cast(label,tf.int32)
#入队
input_queue=tf.train.slice_input_producer([image,label])
label=input_queue[1]
image_contents=tf.read_file(input_queue[0]) #读取图像
#s2图像解码,且必须是同一类型
image=tf.image.decode_jpeg(image_contents,channels=3)
#s3预处理,主要包括旋转,缩放,裁剪,归一化
image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
image = tf.image.per_image_standardization(image)
#s4生成batch
image_batch, label_batch = tf.train.batch([image, label],
batch_size= batch_size,
num_threads= 32,
capacity = capacity)
#重新排列label,行数为[batch_size]
label_batch = tf.reshape(label_batch, [batch_size])
#image_batch = tf.cast(image_batch, tf.float32)
return image_batch, label_batch
该数据生成的是bool型,非one_hot编码,系统自带的mnist编码是one_hot编码,大家可以先去了解下这块东西
转:https://blog.csdn.net/qq_36631272/article/details/79173035