2021-04-18

用Tensorflow API:

tf.keras 搭建网络八股

六步法 提纲

  • import
  • train,test
  • model=tf.keras.models.Sequential # 前向传播,封装网络结构
  • model.compile # 配置训练方法,选择优化器,选择损失函数,选择评测指标
  • model.fit,#喂入标签,迭代数据集
  • model.summary #打印网络结构和参数统计
model.Sequential(神经元个数,activation="激活函数",kerenl_regularizer=哪种正则化)
model.compile(optimizer=优化器,loss=损失函数,metrics=["准确率"])
model.fit(训练时的输入特征,训练集的标签,
batch_size= ,eopchs= ,
validation_data=(测试集的输入特征,测试集的标签),
validation_split=从训练集划分多少比例给测试集,
validation_frep=多少epoc测试一次)

用class类封装一个神经网络结构:

class MyModel(Model):
	def_init_(self):
		super(MyModel,self)._init_() #  括号内与class 名字一致
		#定义网络结构块
	def call(self,x):
	 	# 调用网络结构块,实现前向传播
	 	return y
	 
class IrisModel(Model):
	def_init_(self):
		super(IrisModel,self)._init()
		self.dl=Dense(3,activation='softmax',kernel_reqularizer=tf.keras.reqularizers.12())
	def call(self,x):
		y=self.dl(x)
		return y
model=IrisModel() #  实例化出model


MNIST数据集:
 提供手写图片和标签。
 导入MINIST数据集:
```python
mnist=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()

FASHION数据集:
提供衣裤图片和标签
导入FASHION数据集:

fashion=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()

作为输入特征,输入神经网络时,将数据拉伸为一维数组:

tf.keras.layers.Flatten()

用Sequential实现手写数字识别训练代码如下:

import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()
x_train,x_test = x_train / 255.0,x_test / 255.0
model = tf.keras.models.Sequential([
	tf.keras.layers.Flatten().
	tf.keras.layers.Dense(128,actvation='relu'),
	tf.keras.layers.Dense(10,activation='softmax')
])
model.compile(optimizer='adam',loss=tf.keras.losses.SpareseCateggoricalCrossentropy(form_logits=False),metrics=['sparese_categorical_accuracy'])
model.fit(x_train,y_train,batch_size=32,epoch=5,validation_data=(x_test,y_test),validation_freq=1)
model.fit(x_train,y_train,
batch_size=32 ,eopchs=5,
validation_data=(x_test,y_test)#validation_split=从训练集划分多少比例给测试集,
validation_frep=1)
model.summary()

神经网络八股功能扩展:

  1. 自制数据集,解决本领域应用
  2. 数据增强,扩展数据集
  3. 断点续训,存取模型
  4. 参数提取,把参数存入文本
  5. acc/loss可视化,查看训练效果
  6. 应用程序,给图识物

你可能感兴趣的:(笔记,tensorflow)