DL with python(4)——基于Keras的二层神经网络鸢尾花分类

本文涉及到的是中国大学慕课《人工智能实践:Tensorflow笔记》第三讲的内容,运用keras框架搭建神经网络模型,对一个传统数据集——鸢尾花数据集进行分类。

实现的功能和第一讲中的神经网络相同,详情见
DL with python(1)——基于TensorFlow的二层神经网络鸢尾花分类
具体函数的功能简介见DL with python(5)——Keras中的神经网络函数(简单)

tf.keras 搭建神经网络“六步法”

tf.keras 是 tensorflow2 引入的高封装度的框架,可以用于快速搭建神经网络模型,keras 为支持快速实验而生,能够把想法迅速转换为结果,是深度学习框架之中最终易上手的一个,它提供了一致而简洁的 API,能够极大地减少一般应用下的工作量,提高代码地封装程度和复用性。通过keras搭建神经网络只需要六个步骤,如下:

第一步:import 相关模块,如 import tensorflow as tf。
第二步:指定输入网络的训练集和测试集,如指定训练集的输入 x_train 和标签y_train,测试集的输入 x_test 和标签 y_test。
第三步:逐层搭建网络结构,model = tf.keras.models.Sequential()。
第四步:在 model.compile()中配置训练方法,选择训练时使用的优化器、损失函数和最终评价指标。
第五步:在 model.fit()中执行训练过程,告知训练集和测试集的输入值和标签、每个 batch 的大小(batchsize)和数据集的迭代次数(epoch)。
第六步 :使用 model.summary()打印网络结构,统计参数数目。

tf.keras六步代码实现

下面给出keras的代码,分为六个步骤

# 第一步,导入相关模块
import tensorflow as tf
from sklearn import datasets
import numpy as np
tf.compat.v1.enable_eager_execution()

# 第二步,导入数据集
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
# 打乱数据集
np.random.seed(116)         # 设置随机数种子
np.random.shuffle(x_train)  # 打乱x
np.random.seed(116)         # 设置与打乱x相同的随机数种子
np.random.shuffle(y_train)  # 打乱y
tf.compat.v2.random.set_seed(116)

# 第三步,搭建网络结构
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
]) # 依次为神经元个数,激活函数,正则化方法

# 第四步,配置训练方法
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),      # SGD优化器,学习率0.1
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), # 损失函数,输出概率分布->False
              metrics=['sparse_categorical_accuracy']) # 评价指标,根据实际标签和网络输出来选择

# 第五步,执行训练,依次为训练集样本,训练集标签,小批量大小,训练轮次,测试集占比20%,训练集循环20轮次进行一次测试
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)

# 第六步,打印网络结构和参数统计
model.summary()

使用class搭建神经网络

使用 Sequential 可以快速搭建网络结构,但是如果网络包含跳连等其他复杂网络结构,Sequential 就无法表示了。这就需要使用 class 来声明网络结构。
实现代码如下,除了第三步,其他部分和前面相同。

# 第一步,导入相关模块
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as np
tf.compat.v1.enable_eager_execution()

# 第二步,导入数据集
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
# 打乱数据集
np.random.seed(116)         # 设置随机数种子
np.random.shuffle(x_train)  # 打乱x
np.random.seed(116)         # 设置与打乱x相同的随机数种子
np.random.shuffle(y_train)  # 打乱y
tf.compat.v2.random.set_seed(116)

# 第三步,使用class定义网络模型
class IrisModel(Model):
    def __init__(self):
        super(IrisModel, self).__init__()
        self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2()) # 定义全连接层

    def call(self, x):
        y = self.d1(x) # 调用全连接层d1,实现x到y的传播
        return y

model = IrisModel() # 实例化model

# 第四步,配置训练方法
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),      # SGD优化器,学习率0.1
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), # 损失函数,输出概率分布->False
              metrics=['sparse_categorical_accuracy']) # 评价指标,根据实际标签和网络输出来选择

# 第五步,执行训练,依次为训练集样本,训练集标签,小批量大小,训练轮次,测试集占比20%,训练集循环20轮次进行一次测试
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)

# 第六步,打印网络结构和参数统计
model.summary()

运行结果

两个代码的运行结果相同,首先是按照训练轮次输出参数

......

Epoch 499/500

 32/120 [=======>......................] - ETA: 0s - loss: 0.2851 - sparse_categorical_accuracy: 0.9688
120/120 [==============================] - 0s 58us/sample - loss: 0.3382 - sparse_categorical_accuracy: 0.9500
Epoch 500/500

 32/120 [=======>......................] - ETA: 0s - loss: 0.6113 - sparse_categorical_accuracy: 0.6875
120/120 [==============================] - 0s 66us/sample - loss: 0.5049 - sparse_categorical_accuracy: 0.7917 - val_loss: 0.3395 - val_sparse_categorical_accuracy: 1.0000

然后打印网络结构和参数统计(包括可训练参数和不可训练参数)

Model: "iris_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                multiple                  15        
=================================================================
Total params: 15
Trainable params: 15
Non-trainable params: 0

你可能感兴趣的:(python学习)