在训练mnist数据集的过程中,我们会采用在线的学习方法,利用next_batch功能来不断地获取新的数据集进行训练。关于next_batch的功能以及其返回的数据格式学习一下
我们通过next_batch来获取下一个数据集来对我们的参数进行调整,用法为
#其中的n代表返回多少个训练数据集和对应的标签
batch_n=mnist_data.train.next_batch(n)
输入数据
import numpy as np
import tensorflow as tf
import input_data
#导入mnist数据(以one_hot的格式)
mnist_data=input_data.read_data_sets("MNIST_data/",one_hot=True)
mnist=mnist_data
#遍历两次
for i in range(2):
#每次返回两个训练集的数据
batch=mnist_data.train.next_batch(2)
#输出它的内容
print (batch)
#输出它的类型
print (type(batch))
返回结果
(array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32), array([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]]))
<type 'tuple'>
(array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32), array([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]]))
<type 'tuple'>
可以看到数据返回的是一个元组,元组的第一个元素为一个阵列,2行,784列,第二个元素为预测的标签,为两行十列,是one_hot的数据格式