参考极客学院TensorFlow官方文档中文版:
http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html
import tensorflow.examples.tutorials.mnist.input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
运行上面的代码,并不能成功获得数据集,会报错,如下图(报错可能不同):
下面是改正后的代码,可能有不少冗余的,但是懒得细看了,也不一定看得懂,如下:
from __future__ import print_function
from IPython.display import Image
import base64
import os
from six.moves.urllib.request import urlretrieve
import tensorflow.examples.tutorials.mnist.input_data as input_data
from __future__ import absolute_import
from __future__ import division
import gzip
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
WORK_DIRECTORY = "/tmp/mnist" #路径
def maybe_download(filename):
"""A helper to download the data files if not present."""
if not os.path.exists(WORK_DIRECTORY):
os.mkdir(WORK_DIRECTORY)
filepath = os.path.join(WORK_DIRECTORY, filename)
if not os.path.exists(filepath):
filepath, _ = urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
else:
print('Already downloaded', filename)
# print(filepath) 如果路径不是很清楚可以打印出来看
return filepath
train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
mnist = input_data.read_data_sets("/tmp/mnist/", one_hot=True)
# /tmp/mnist/与上文路径统一,有时需要建立子文件夹,子文件夹建立时名字可以随便起,但有时候只能与上面4个.gz之一同名,比如train-images-idx3-ubyte.gz,也就是:/tmp/mnist/train-images-idx3-ubyte.gz
# print(mnist) 确认已经成功导入数据集
尤其注意这两个路径问题,可能本来的路径已经成功了,但是下一次又不行,更改一下路径并且统一一下。就能成功运行并获得一个数据集Datasets。
边学边记录,有问题可以一起交流。