Tensorflow2.0数据集下载

  • 通过tf.keras.datasets 下载数据集

import tensorflow as tf
fashion_mnist  = tf.keras.datasets.fashion_mnist
#返回四个numpy数组
(train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()
'''
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
'''
print(train_images.shape)
print(train_labels.shape)
print(test_images.shape)
print(test_labels.shape)

tf.keras.datasets下包含了以下数据集,可以直接下载使用

y“

boston_housing   
cifar10
cifar100
fashion_mnist
imdb
mnist
reuters
”
  • 下载数据之后再构建数据集tf.keras.utils.get_file

使用tf.keras.utils.get_file , 从网络中下载你想要的数据,然后创建一个文件夹用于存放数据。

import tensorflow as tf
import pathlib
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root_orig)
print(data_root)

 然后通过预处理之后再通过tf.data.Dataset构建数据集。

ds = tf.data.Dataset.from_tensor_slices()  #对指定tensor进行切片,构成数据集

可以是features的数据集,也可以是(features,labels)数据集

输入:张量tensor   第一维须相同

输出:datasets

  • 要处理文件中的行,请使用tf.data.TextLineDataset
dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
  • 要处理以该TFRecord格式编写的记录,请使用TFRecordDataset
dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
  • 要创建与模式匹配的所有文件的数据集,请使用 tf.data.Dataset.list_files
dataset = tf.data.Dataset.list_files("/path/*.txt")  # doctest: +SKIP
  • 通过tfds下载数据集

 tf.keras.datasets只包含了几种常用的数据集,还有更多的数据集可以在tensorflow_datasets中找到

import tensorflow_datasets as tfds
#列出所有的数据集
tfds.list_builders()

#['abstract_reasoning', 'aeslc', 'aflw2k3d', 'ai2_arc', 'amazon_us_reviews', 'anli', 'arc', 'bair_robot_pushing_small', 'beans', 'big_patent', 'bigearthnet', 'billsum', 'binarized_mnist', 'binary_alpha_digits', 'blimp', 'c4', 'caltech101', 'caltech_birds2010', 'caltech_birds2011', 'cars196', 'cassava', 'cats_vs_dogs', 'celeb_a', 'celeb_a_hq', 'cfq', 'chexpert', 'cifar10', 'cifar100', 'cifar10_1', 'cifar10_corrupted', 'citrus_leaves', 'cityscapes', 'civil_comments', 'clevr', 'clinc_oos', 'cmaterdb', 'cnn_dailymail', 'coco', 'coil100', 'colorectal_histology', 'colorectal_histology_large', 'common_voice', 'cos_e', 'cosmos_qa', 'covid19sum', 'crema_d', 'curated_breast_imaging_ddsm', 'cycle_gan', 'deep_weeds', 'definite_pronoun_resolution', 'dementiabank', 'diabetic_retinopathy_detection', 'div2k', 'dmlab', 'downsampled_imagenet', 'dsprites', 'dtd', 'duke_ultrasound', 'emnist', 'eraser_multi_rc', 'esnli', 'eurosat', 'fashion_mnist', 'flic', 'flores', 'food101', 'forest_fires', 'fuss', 'gap', 'geirhos_conflict_stimuli', 'german_credit_numeric', 'gigaword', 'glue', 'groove', 'higgs', 'horses_or_humans', 'i_naturalist2017', 'imagenet2012', 'imagenet2012_corrupted', 'imagenet2012_real', 'imagenet2012_subset', 'imagenet_a', 'imagenet_resized', 'imagenet_v2', 'imagenette', 'imagewang', 'imdb_reviews', 'irc_disentanglement', 'iris', 'kitti', 'kmnist', 'lfw', 'librispeech', 'librispeech_lm', 'libritts', 'ljspeech', 'lm1b', 'lost_and_found', 'lsun', 'malaria', 'math_dataset', 'mctaco', 'mnist', 'mnist_corrupted', 'movie_lens', 'movie_rationales', 'moving_mnist', 'multi_news', 'multi_nli', 'multi_nli_mismatch', 'natural_questions', 'newsroom', 'nsynth', 'nyu_depth_v2', 'omniglot', 'open_images_challenge2019_detection', 'open_images_v4', 'openbookqa', 'opinion_abstracts', 'opinosis', 'opus', 'oxford_flowers102', 'oxford_iiit_pet', 'para_crawl', 'patch_camelyon', 'pet_finder', 'pg19', 'places365_small', 'plant_leaves', 'plant_village', 'plantae_k', 'qa4mre', 'quickdraw_bitmap', 'reddit', 'reddit_disentanglement', 'reddit_tifu', 'resisc45', 'robonet', 'rock_paper_scissors', 'rock_you', 'samsum', 'savee', 'scan', 'scene_parse150', 'scicite', 'scientific_papers', 'shapes3d', 'smallnorb', 'snli', 'so2sat', 'speech_commands', 'squad', 'stanford_dogs', 'stanford_online_products', 'starcraft_video', 'stl10', 'sun397', 'super_glue', 'svhn_cropped', 'ted_hrlr_translate', 'ted_multi_translate', 'tedlium', 'tf_flowers', 'the300w_lp', 'tiny_shakespeare', 'titanic', 'trivia_qa', 'uc_merced', 'ucf101', 'vctk', 'vgg_face2', 'visual_domain_decathlon', 'voc', 'voxceleb', 'voxforge', 'waymo_open_dataset', 'web_questions', 'wider_face', 'wiki40b', 'wikihow', 'wikipedia', 'wikipedia_toxicity_subtypes', 'winogrande', 'wmt14_translate', 'wmt15_translate', 'wmt16_translate', 'wmt17_translate', 'wmt18_translate', 'wmt19_translate', 'wmt_t2t_translate', 'wmt_translate', 'wordnet', 'xnli', 'xsum', 'yelp_polarity_reviews']

 tfds.load 方法,载入所需的数据集

tfds.load 方法返回一个 tf.data.Dataset 对象。部分重要的参数如下:

  • name:string类型,数据集的名称
  • as_supervised :若为 True,则根据数据集的特性返回为 (input, label) 格式,否则返回所有特征的字典。

  • split:指定返回数据集的特定部分,若无则返回整个数据集。一般有 tfds.Split.TRAIN (训练集)和 tfds.Split.TEST (测试集)选项。还可以对数据集切片后返回

  • download:布尔值,是否进行数据下载,如果数据准备好了,后续的 load 命令便不会重新下载,可以重复使用准备好的数据。你可以通过指定 data_dir= (默认是 ~/tensorflow_datasets/) 来自定义数据保存/加载的路径。

  • with_info:布尔值,是否返回数据集相关信息,包含数据集信息(版本,特征,拆分,num_examples等)的元组

  • try_gcs:布尔值,如果为True,则tfds.load将在本地构建数据集之前查看该数据集是否存在于公共GCS存储桶中。

(train_ds, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

 返回值: tf.data.Dataset 对象ds 

metadata:包含数据集信息(版本,特征,拆分,num_examples等)的元组

此文纯属搬砖

你可能感兴趣的:(python,深度学习,tensorflow)