一 实例描述
通过import cifar10_input来导入CIFAR数据集。
二 代码
import cifar10_input
import tensorflow as tf
import pylab
#取数据
batch_size = 12
data_dir = '/tmp/cifar10_data/cifar-10-batches-bin'
'''
cifar10_input.inputs是专门获取数据的函数,返回数据集和对应的标签,但是cifar10_input.inputs会将图片裁剪好,由原来的32*32*3变成了24*24*3。
该函数默认是使用测试数据集,如果使用训练数据集,可以将第一个参数传入eval_data=False
另外再将batch_size和dir传入,就可以得到dir下面的batch_size个数据。
注意:这里获得的图片并不是原始图片,是经过了两次变换,首先将32*32尺寸裁剪成24尺寸,然后又进行了一次图片标准化(减去均值像素,并除以像素方差)。这样做得好处是,使所有的输入都在一个有效的数据分布之内,便于特征的分类处理,会使梯度下降算法的收敛更快。
'''
images_test, labels_test = cifar10_input.inputs(eval_data = True, data_dir = data_dir, batch_size = batch_size)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
tf.train.start_queue_runners()
image_batch, label_batch = sess.run([images_test, labels_test])
print("__\n",image_batch[0])
print("__\n",label_batch[0])
pylab.imshow(image_batch[0])
pylab.show()
三 运行结果
[[[ 1.01835191 1.2548697 1.43225801]
[ 1.31399918 1.51109731 1.64906597]
[ 1.57022679 1.72790527 1.84616411]
...,
[ 1.84616411 1.84616411 1.84616411]
[ 1.92500341 1.92500341 1.90529358]
[ 2.14181137 2.14181137 2.12210155]]
[[ 1.05777156 1.2548697 1.35341883]
[ 1.29428935 1.47167766 1.49138749]
[ 1.60964632 1.7476151 1.72790527]
...,
[ 1.80674446 1.80674446 1.72790527]
[ 1.7476151 1.72790527 1.60964632]
[ 2.04326224 2.00384259 1.86587393]]
[[ 1.05777156 1.11690104 1.15632069]
[ 1.27457952 1.29428935 1.2548697 ]
[ 1.6687758 1.64906597 1.57022679]
...,
[ 1.68848562 1.70819545 1.62935615]
[ 1.55051696 1.51109731 1.35341883]
[ 1.92500341 1.80674446 1.58993649]]
...,
[[-0.69640189 -0.262786 0.21024954]
[-0.75553137 -0.34162524 0.11170048]
[-0.97233933 -0.53872341 -0.10510748]
...,
[-1.80015147 -1.5636338 -1.36653554]
[-2.11550856 -1.87899077 -1.68189263]
[-2.15492821 -1.91841042 -1.72131228]]
[[-1.62276316 -1.24827671 -0.89350003]
[-1.72131228 -1.34682584 -1.01175892]
[-1.81986129 -1.44537485 -1.09059823]
...,
[-2.07608891 -1.81986129 -1.64247298]
[-2.17463803 -1.91841042 -1.74102211]
[-2.15492821 -1.91841042 -1.72131228]]
[[-2.09579873 -1.74102211 -1.52421415]
[-2.07608891 -1.68189263 -1.50450432]
[-2.05637908 -1.66218281 -1.4847945 ]
...,
[-2.17463803 -1.91841042 -1.74102211]
[-2.19434786 -1.93812025 -1.76073194]
[-2.17463803 -1.91841042 -1.74102211]]]
__
8
四 运行说明
上面结果输出的是图片像素数据和标签数据。可以看到,读取的数据都是经过标准化处理的(变成了均值为0,方差为1的数据分布),所以输出的图片就是乱的。