pytorch用多层感知机实现鸢尾花3分类(亲测可用)

pytorch用多层感知机实现鸢尾花3分类(亲测可用)

泪目了,家人们
我终于能交出点东西了
这是上课的要求,不能直接用库,不能用sklearn函数,必须用多层感知机!而且要3分类,太难了。


鸢尾花分类是人工智能界的Hello World。各种人工智能的书籍,往往都会从鸢尾花的分类开始。下面我们将使用鸢尾花分类作为例子,来共同学习人工智能的若干基本概念。这里的人工智能,特指机器学习。

iris数据集的中文名是安德森鸢尾花卉数据集,含有5个key,分别是DESCT,target_name(分类名称,即四个特征值的名称),target(分类,有150个数值,有(0,1,2)三种取值,分别代表三个种类),feature_names(特征名称,三个种类的名称),data(四个特征值,花萼的长、宽,花瓣的长、宽)。 iris包含150个样本,对应数据集的每行数据。每行数据包含每个样本的四个特征和样本的类别信息,所以iris数据集是一个150行5列的二维表。通俗地说,iris数据集是用来给花做分类的数据集,每个样本包含了花萼长度、花萼宽度、花瓣长度、花瓣宽度四个特征(前4列),我们需要建立一个分类器,分类器可以通过样本的四个特征来判断样本属于山鸢尾、变色鸢尾还是维吉尼亚鸢尾(这三个名词都是花的品种)。

鸢尾花主要有三个品种,setosa,versicolor,virginnica(山鸢尾、变色鸢尾和维吉尼亚鸢尾)。在进行分类时,主要依据是花瓣的长度(Petal Length)、宽度(Petal Width),花萼的长度(Sepal Length)和宽度(Sepal Width)(均以厘米做单位)。

首先,我们从鸢尾花数据集中,提取一部分,来作为训练数据(训练集),让机器学会如何辨识。然后,我们把剩下的一部分数据作为测试数据(测试集),让机器来识别,并判断识别的精准度。

pytorch用多层感知机实现鸢尾花3分类(亲测可用)_第1张图片

调用load_iris函数来加载数据:

from sklearn.datasets import load_iris
iris_dataset = load_iris()

下面是代码,而且是三分类,不是二分类
完整代码链接

完整代码亲测有效,没用你来打我!

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

import torch.nn as nn
import torch
import torch.utils.data as Data








你可能感兴趣的:(深度学习相关,pytorch,分类,深度学习)