[tensorflow2笔记五] 使用八股搭建神经网络

用Tensorflow API:tf.keras搭建网络八股,全连接网络

文章目录

    • 1. 六步法
    • 2. 用法
    • 3. mnist数据集
    • 4. fashion数据集

1. 六步法

import									# 模块
train, test								# 训练集和测试集

# 二选一
model = tf.keras.models.Sequential		# 逐层搭建网络结构
class MyModel(Model)	model = Model()	# 类搭建网络

model.compile	    					# 配置训练方法(优化器、损失函数、评价指标)
model.fit								# 训练
model.summary							# 打印出网络结构和参数统计

2. 用法

(1)Sequential可以认为是一个容器,封装了一个神经网络结构。

拉直层:tf.keras.layers.Flatten()
# 该层不含计算,只是形状转换,把输入特征拉直变成一维数组。

全连接层:tf.keras.layers.Dense(神经元个数,
								activation=”激活函数”,
								kernel_regularizer=哪种正则化)
# 激活函数(以字符串形式给出):relu, softmax, sigmoid, tanh
# 正则化:tf.keras.regularizers.l1()	tf.keras.regularizers.l2()

卷积层:tf.keras.layers.Conv2D(filters=卷积核个数,kernel_size=卷积核尺寸,strides= 卷积步长,padding = “valid” or “same”)
     
循环神经网络层LSTM层:tf.keras.layers.LSTM()

(2)model.compile(optimizer=优化器,loss=损失函数,metrics=[“准确率”])

Optimizer可选:‘sgd’	‘adagrad’		‘adadelta’ 	‘adam’

loss可选:‘mse’或者tf.keras.losses.MeanSquareError()	
		   ‘sparse_categorical_crossentropy’或者
		     tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
			# 其中,False:是否原始输出(没有经过softmax概率分布的输出)

Metics可选:‘accuracy’:y和y_都是数值
		    ‘categorical_accuracy’:y_和y都是独热码
			‘sparse_categorical_accuracy’:y_是数值,y是独热码(概率分布)

(3)model.fit()

model.fit(
			训练集的输入特征,训练集的标签,
			batch_size=		, epochs=		,
			validation_data=(测试集的输入特征,测试集的标签),
			validation_split=从训练集划分多少比例给测试集,
			validation_freq=多少次epoch测试一次
			)

(4)model.summary():打印网络结构参数

Sequential方法搭建网络,很简单,但无法写出一些带有跳连的非顺序网络结构。此时可以使用类class封装一个搭建网络结构。

Class MyModel(Model):  #Model继承keras的Model类
	def __init__(self):
		super(MyModel, self).__init__()
		#定义网络结构块
		self.d1 = Dense(3)  # 3个神经元的全连接
	def call(self, x):
		# 调用网络结构块,实现前向传播
		y = self.d1(x)
		return y
model = MyModel()

3. mnist数据集

(1)MNIST数据集介绍

6万张 2828像素点的0~9手写数字图片和标签,用于训练
1万张28
28像素点的0~9手写数字图片和标签,用于测试

(2)导入数据集

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

(3)搭建网络

sequential方法和class方法(详细见tensorflow实战演练)

4. fashion数据集

(1)fashion数据集介绍

6万张 28x28 像素点的衣裤图片和标签,用于训练
1万张 28x28 像素点的衣裤图片和标签,用于测试

一共有10个分类:T恤、裤子…

(2)导入数据集

# 1.加载数据
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()

(3)搭建网络

sequential方法和class方法(详细见tensorflow实战演练)

你可能感兴趣的:(TensorFlow2,神经网络,卷积神经网络)