tf.keras.datasets
里面放了一些小的numpy数据集,主要用于做一些测试
放了哪些数据集呢?一共有7个,分别是:
http://lib.stat.cmu.edu/datasets/boston
每条数据包含房屋的13种属性,以及房子的均价(k$)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data() #可以加载数据集
返回的x_train和x_test都是(num_samples,13)的numpy矩阵
y_train, y_test都是(num_sample,)的numpy矩阵
tf.keras.datasets.boston_housing.load_data()有3个参数
分别是path, test_split, seed。默认值为
path=‘boston_housing.npz’
test_split=0.2
seed=113
path是boston_housing在本地的缓存地址,默认相对路径是~/.keras/datasets,如果这个地址没有缓存过boston_housing的数据,会自动下载
test_split用于指定用作测试数据的比例
seed是一个随机种子,可以在拆分训练和测试集之前,打乱数据的顺序
https://www.cs.toronto.edu/~kriz/cifar.html
cifar10和cifar100都是图片分类数据集,cifar10有10个类别,cifar100有100个类别
cifar10的每张图片是8bit(数值范围0~255) 32x32的彩图,训练数据有5w张,测试数据有1w张
(x_train,y_train), (x_test,y_test) = tf.keras.datasets.cifar10.load_data() # 加载数据
x_train的形状(50000,32,32,3)
y_train的形状(50000,1)
x_test的形状(10000,32,32,3)
y_test的形状(10000,1)
10个类别用0,1,2,3,4,5,6,7,8,9这10个数字表达
0:airplane, 1:automobile, 2:bird, 3:cat, 4:deer, 5:dog, 6:frog, 7:horse, 8:ship, 9:truck
cifar100的加载有2种可选,一种是加载细粒度的100个分类,一种是粗粒度的20个分类
默认的是100分类,可以使用label_mode设置。label_mode为‘fine’,加载100个细分类, label_mode为‘coarse’,加载20个粗分类。
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.cifar100.load_data(label_mode='fine')
fashion_mnnist里面是一些与服装有关的图片,每张图片都是8bit 28x28的灰图,分为10个类别。训练数据有6w张,测试数据有1w张。
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.fashion_mnist.load_data()
10个类别用0~9的10个数字表示
0:T-shirt, 1:Trouser, 2:Pullover, 3:Dress, 4:Coat, 5:Sandal, 6:shirt, 7:Sneaker, 8:Bag, 9:Ankle boot
mnist是个手写数字数据集
http://yann.lecun.com/exdb/mnist/
每张图片是8bit 28x28的灰图,训练数据有6w张,测试数据有1w张,10个数字(0~9)
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
imdb里面是一些影评数据,用于做情感分类
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.imdb.load_data()
x_train, x_test都是列表类型,里面一行代表一句影评
影评的单词使用索引值表示
索引值是通过计算单词在这个数据集中的出现次数排名计算的。
x_test,y_test是列表类型,里面每个值为1或0,用于表示影评是正面的/负面的。
imdb的load_data函数有很多参数
path=‘imdb.npz’, # 表示缓存路径
num_words=None, # 为int时,表示仅使用最常出现的num_words个单词进行排序,为None表示使用全部单词排序
skip_top=0, # 去除前N个最常出现的
maxlen=None, # 每句话的最大长度,超过则截断。None表示无截断
seed=113, # 随机打乱顺序的种子
start_char=1, # 索引开始的数值
oov_char=2, # 句子中未出现的单词用oov_char表示其索引
index_from=3, # 单词的索引大于等于index_from,及从3开始排号
imdb还有一个函数,get_word_index,用于获取单词和索引的字典,key是单词,value是索引值
index_dict = tf.keras.datasets.imdb.get_word_index()
reuters是个主题分类数据集,里面是一些新闻,涵盖46个主题
每个新闻里面的单词都是索引表示
索引是用单词在数据集中出现的次数排名表示的
reuters的load_data函数和imdb的一样,有一堆参数
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.reuters.load_data()
tf.keras.datasets.reuters.get_word_index()可以获得单词和索引的字典。