# Tensorflow MNIST 数据集使用
import tensorflow as tf
tf.__version__
'2.1.0'
import numpy as np
import matplotlib.pyplot as plt
tensorflow2.0的数据集集成到keras高级接口之中,使用如下代码一般都能下载
mint=tf.keras.datasets.mnist
(x_,y_),(x_1,y_1)=mint.load_data()
from tensorflow.examples.tutorials.mnist import input_data
print ("packs loaded")
packs loaded
from tensorflow.examples.tutorials.mnist import input_data
print ("Download and Extract MNIST dataset")
mnist = input_data.read_data_sets('data/', one_hot=True)
# 会下载数据集失败
print ("Download and Extract MNIST dataset")
mnist = input_data.read_data_sets('data/', one_hot=True)
Download and Extract MNIST dataset
Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From d:\progra~2\python\virtua~1\py37_x64\lib\site-packages\tensorflow_core\examples\tutorials\mnist\input_data.py:328: _DataSet.__init__ (from tensorflow.examples.tutorials.mnist.input_data) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/_DataSet.py from tensorflow/models.
# 查看数据数量
print (" tpye of 'mnist' is %s" % (type(mnist)))
print (" number of trian data is %d" % (mnist.train.num_examples))
print (" number of test data is %d" % (mnist.test.num_examples))
tpye of 'mnist' is
number of trian data is 55000
number of test data is 10000
# What does the data of MNIST look like?
print ("What does the data of MNIST look like?")
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
# 查看数据类型
print("type of 'trainlabel' is %s" % type(trainlabel))
print("type of 'testimg' is %s" % type(testimg))
print("type of 'testlabel' is %s" % type(testlabel))
print()
# 查看数据形状
print("shape of 'trainimg' is %s" % (trainimg.shape, ))
print("shape of 'trainlabel' is %s" % (trainlabel.shape, ))
print("shape of 'testimg' is %s" % (testimg.shape, ))
print("shape of 'testlabel' is %s" % (testlabel.shape, ))
What does the data of MNIST look like?
type of 'trainlabel' is
type of 'testimg' is
type of 'testlabel' is
shape of 'trainimg' is (55000, 784)
shape of 'trainlabel' is (55000, 10)
shape of 'testimg' is (10000, 784)
shape of 'testlabel' is (10000, 10)
# How does the training data look like?
print ("How does the training data look like?")
nsample = 5
randidx = np.random.randint(trainimg.shape[0], size=nsample)
for i in randidx:
curr_img = np.reshape(trainimg[i, :], (28, 28)) # 28 by 28 matrix
curr_label = np.argmax(trainlabel[i, :] ) # Label
plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
plt.title("" + str(i) + "th Training Data "
+ "Label is " + str(curr_label))
print ("" + str(i) + "th Training Data "
+ "Label is " + str(curr_label))
plt.show()
How does the training data look like?
54259th Training Data Label is 4
33047th Training Data Label is 4
52715th Training Data Label is 1
15223th Training Data Label is 7
18188th Training Data Label is 4
# Batch Learning?
print ("Batch Learning? ")
batch_size = 100
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
print ("type of 'batch_xs' is %s" % (type(batch_xs)))
print ("type of 'batch_ys' is %s" % (type(batch_ys)))
print ("shape of 'batch_xs' is %s" % (batch_xs.shape,))
print ("shape of 'batch_ys' is %s" % (batch_ys.shape,))
Batch Learning?
type of 'batch_xs' is
type of 'batch_ys' is
shape of 'batch_xs' is (100, 784)
shape of 'batch_ys' is (100, 10)