【tensorflow基础】读取mnist数据

转载于:MNIST手写数字数据集读取方法

TensorFlow的封装让使用MNIST数据集变得更加方便。MNIST数据集是NIST数据集的一个子集,它包含了60000张图片作为训练数据,10000张图片作为测试数据。在MNIST数据集中的每一张图片都代表了0~9中的一个数字。图片的大小都为28*28,且数字都会出现在图片的正中间。 代码如下:

import tensorflow as tf
import numpy as np
import pickle
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#经过上述代码,会自动下载mnist数据集

'''
(1)采用tensorflow自带的代码读取mnist数据集
'''
mnist = input_data.read_data_sets('./data/')  #读取
train_nums = mnist.train.num_examples             #训练数据的个数
validation_nums = mnist.validation.num_examples   #验证数据的个数
test_nums = mnist.test.num_examples               #测试数据的个数
print('MNIST数据集的个数')
print(' >>>train_nums=%d' % train_nums,'\n',
      '>>>validation_nums=%d'% validation_nums,'\n',
      '>>>test_nums=%d' % test_nums,'\n')

'''
(2)获得数据值
'''
train_data = mnist.train.images   #所有训练数据
val_data = mnist.validation.images  #验证数据,大小(5000,784)
test_data = mnist.test.images       #测试数据,大小(10000,784)
print('>>>训练集数据大小:',train_data.shape,'\n',
      '>>>一副图像的大小:',train_data[0].shape)

'''
(3)获取标签值label=[0,0,...,0,1],是一个1*10的向量
'''
train_labels = mnist.train.labels     #(55000,10)
val_labels = mnist.validation.labels  #(5000,10)
test_labels = mnist.test.labels       #(10000,10)
print('>>>训练集标签数组大小:',train_labels.shape,'\n',
      '>>>一副图像的标签大小:',train_labels[1].shape,'\n',
      '>>>一副图像的标签值:',train_labels[0])

'''
(4)批量获取数据和标签【使用next_batch(batch_size)】
'''
batch_size = 100    #每次批量训练100幅图像
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
print('使用mnist.train.next_batch(batch_size)批量读取样本\n')
print('>>>批量读取100个样本:数据集大小=',batch_xs.shape,'\n',
      '>>>批量读取100个样本:标签集大小=',batch_ys.shape)
#xs是图像数据(100,784);ys是标签(100,10)

'''
(5)显示图像
'''
plt.figure()
for i in range(10):
    im = train_data[i].reshape(28,28)
    im = batch_xs[i].reshape(28,28)
    plt.imshow(im,'gray')
    plt.pause(0.0000001)
plt.show()

 

你可能感兴趣的:(tensorflow,深度学习,tensorflow)