import tensorflow as tf
注:本文使用2.1.0
版本说明print(tf.__version__) # 2.1.0
记住两个重要函数dir(),help()
即可逐步向下找到使用方法
# 快速使用小窗口
# 导入数据集
import tensorflow as tf
tf.keras.datasets.数据集
'''
数据集 有:
'boston_housing', 'cifar10', 'cifar100', 'fashion_mnist', 'imdb', 'mnist', 'reuters']
'''
keras内置的数据集可直接导入,联网会自动下载,数据集位置在tf.keras.datasets
ds = tf.keras.datasets
print(dir(ds))
'''输出
['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_sys',
'boston_housing', 'cifar10', 'cifar100', 'fashion_mnist', 'imdb', 'mnist', 'reuters']
'''
可见数据集有以下几种:
boston_housing:Boston housing price regression dataset.
cifar10:CIFAR10 small images classification dataset.
cifar100:CIFAR100 small images classification dataset.
fashion_mnist:Fashion-MNIST dataset.
imdb:IMDB sentiment classification dataset
mnist:MNIST handwritten digits dataset.
reuters:Reuters topic classification dataset.
直接获取即可,如mnist
数据集
minst = tf.keras.datasets.mnist
其他都是这样获取,执行后会从https://storage.googleapis.com/tensorflow/tf-keras-datasets/
下载,可能会失败,重复几次一般可下载下来。
同样使用dir()
函数可以看到数据集自带的函数,如print(dir(mnist))
['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_sys', 'load_data']
该数据集中有一个函数可用,load_data
print(help(minst.load_data))
输出为:
Help on function load_data in module tensorflow.python.keras.datasets.mnist:
load_data(path='mnist.npz')
Loads the MNIST dataset.
Arguments:
path: path where to cache the dataset locally
(relative to ~/.keras/datasets).
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
License:
Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
which is a derivative work from original NIST datasets.
MNIST dataset is made available under the terms of the
[Creative Commons Attribution-Share Alike 3.0 license.](
https://creativecommons.org/licenses/by-sa/3.0/)
None
根据说明则数据可以这样导出:
(x_train, y_train), (x_test, y_test) = minst.load_data()
还可以进一步看数据形状,以便后续处理
print("x_train shape",x_train.shape)
print("y_train shape",y_train.shape)
print("x_test shape",x_test.shape)
print("y_test shape",y_test.shape)
'''输出
x_train shape (60000, 28, 28)
y_train shape (60000,)
x_test shape (10000, 28, 28)
y_test shape (10000,)
'''
导航:tensorflow——平地起高楼