TensorFlow入门4 -- 鸢尾(Iris)花分类,导入并解析训练数据集

参考:《深度学习图像识别技术--基于TensorFlow Object Detection API 和 OpenVINO》

问题:假设你的生物学家,要对鸢尾(Iris)花分类。Iris有300多类,这里仅仅对Iris setosa,Iris virginica,Iris versicolor 这三类进行识别,如下图所示

本文范例程序下载地址:IrisClassifier.py

TensorFlow入门4 -- 鸢尾(Iris)花分类,导入并解析训练数据集_第1张图片

方法有很多种,比如,基于CNN的深度学习,直接学习图像。这里采用更加简单的方法,通过  sepals(花萼)和 petals(花瓣)的长度和宽度数据,进行模型训练和分类,这样更加适合初学者。

收集和构架数据集要花很多时间,幸运的是,已经有现成的Iris flower data set,which contains a set of 150 records under 5 attributes - Petal Length , Petal Width , Sepal Length , Sepal width and Class 如下图所示


TensorFlow入门4 -- 鸢尾(Iris)花分类,导入并解析训练数据集_第2张图片

基于这样的数据集(DataSet),可以让我们更加专注于学习机器学习的算法,而不需要花大量时间准备数据

第一步:下载训练数据集

我们需要把dataset文件下载到本地,然后把它转化为Python可以使用的数据结构。范例代码如下:


TensorFlow入门4 -- 鸢尾(Iris)花分类,导入并解析训练数据集_第3张图片

打开文件:C:\Users\tf\.keras\datasets\iris_training.csv


TensorFlow入门4 -- 鸢尾(Iris)花分类,导入并解析训练数据集_第4张图片

可以看到有120行数据,跟Iris data set wiki里面说的不大一样,不过没有关系,不影响训练。

前四列是Features,分别是:Petal Length , Petal Width , Sepal Length , Sepal width

第五列是label,分别用整型数来代表花的种类,对机器来说,用整型数比用字符串更加方便,但我们要知道整型数和花种类之间的映射:

0: Iris setosa

1: Iris versicolor

2: Iris virginica

第二步:解析(Parse)数据集

下载到本地的数据集iris_training.csv 是一个 CSV格式的文本文件, TensorFlow模型还不能直接使用。我们需要把feature和label的值按照TensorFlow模型的数据输入要求,重新格式化。

创建一个函数 parse_csv

输入参数是:iris_training.csv文件的一行(line),

功能是:把 前四个 feature 值合并成为一个List,并reshape成为一个 single tensor;把最后一个 label 变量reshape成为一个single tensor.

返回值: features 和 label tensors

如下所示:


TensorFlow入门4 -- 鸢尾(Iris)花分类,导入并解析训练数据集_第5张图片

tf.decode_csv函数功能是:Convert CSV records to tensors. Each column maps to one tensor.

tf.reshape(tensor, shape,name=None)函数的功能是:Given tensor, this operation returns a tensor that has the same values as tensor with shape shape

第三步:创建训练 tf.data.Dataset

TensorFlow's Dataset API 用于feeding data into a model,它负责读取data,并将data转换为适合模型训练的格式

代码如下所示:

TensorFlow入门4 -- 鸢尾(Iris)花分类,导入并解析训练数据集_第6张图片

执行结果如下所示:

TensorFlow入门4 -- 鸢尾(Iris)花分类,导入并解析训练数据集_第7张图片

你可能感兴趣的:(TensorFlow入门4 -- 鸢尾(Iris)花分类,导入并解析训练数据集)