TensorFlow的 各模块关系keras、nn、metrics、model、Sequential、data.Dataset、keras.datasets

TensorFlow下的API结构

    • 前言:
      • 一、tf 下面有三部分内容:模块、类、常用的函数
      • 二、其中像比较常用的`tf.keras`中
          • 1. `Model`母类中有针对训练的函数
          • 2. `tf.metrics`中测量三步走
      • 三、数据处理的`tf.data.Dataset`下的
      • 四、Tensorflow构建神经网络和全连接层常用的函数
        • 1.数据集操作
        • 2.搭建网络层
        • 3.计算误差
    • 推荐:Tensorflow的[龙良曲老师GitHub](https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book)
      • 本文参考:


前言:


在使用tensorflow的函数时,对它整体API的结构比较模糊,搜索了一遍之后官方文档解答了我的疑惑,以下为小总结,如有错误欢迎指正。

 

一、tf 下面有三部分内容:模块、类、常用的函数

|--- tf
     |---- 大模块
            tf.nn,神经网络模块
            tf.keras,高阶API
            tf.math,数学工具模块
            tf.losses,计算误差
            tf.data,数据模块
            tf.random
            tf.summary,展示模块信息
            tf.train,训练函数
            tf.contrib,实验性质常变动的函数模块
            ....

     |--- 类
        Dtype
        Variable
        ...

     |--- 常用的函数
       tf.argmax()
       tf.add()
       tf.constant()
	   tf.one_hot()
       tf.cast()
       tf.reduce_mean()
       tf.square()
       ...

二、其中像比较常用的tf.keras

tf,keras.datasets(下载数据集)
tf.keras.metrics(计算精度,评估性能)
tf.keras.layers


自定义层
tf.keras.layers.Layer
tf.keras.Squentital
tf.keras.Model

1. Model母类中有针对训练的函数

compile(),训练
network.compile(optimizer=optimizers.Adam(lr=0.01),
		loss=tf.losses.CategoricalCrossentropy(from_logits=True),
		metrics=['accuracy']
	)
	
fit(),训练期间测试
network.fit(db, epochs=5, validation_data=ds_val, validation_freq=2)

evaluate(),训练结束最终评估
network.evaluate(ds_val)

predict(),预测值
2. tf.metrics中测量三步走
1、生成测量器
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()

2、喂数据
loss_meter.update_state(loss)
acc_meter.update_state(y, pred)

3、取结果
loss_meter.result().numpy()
acc_meter.result().numpy()

4、一个迭代之后重置
loss_meter.reset_states()
acc_meter.reset_states()

三、数据处理的tf.data.Dataset下的

tf.data.Dataset.from_tensor_slices(),切分数据
tf.data.Dataset.shuffle(),打散数据
tf.data.Dataset.map(),预处理
tf.data.Dataset.batch(),分批
tf.data.Dataset.repeat(),重复迭代

四、Tensorflow构建神经网络和全连接层常用的函数


1.数据集操作

  • 数据加载,返回numpy类型的数据
(x_data,y_data),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
  • 数据预处理(类型转换。。)
  • 打散、分批

2.搭建网络层

  • 搭建线性神经网络层,获取输出
tf.keras.Squentital()
  • 对输出数据进行tf.nn.softmax(),将数值映射到0-1,并且和为1(这是和tf.softmax的区别)

3.计算误差

  • MSE:tf.losses.MSE()有时用MSE会出现梯度消失的情况,所以交叉熵也很好
  • 交叉熵:tf.losses.categorical_cossentropy(),交叉熵越小,说明信息量越大,不可知的东西多,既误差很大

推荐:Tensorflow的龙良曲老师GitHub


老师讲的很好,资料很全,能让自己学习更清晰

TensorFlow的 各模块关系keras、nn、metrics、model、Sequential、data.Dataset、keras.datasets_第1张图片
 

本文参考:

tensorflow的官方文档

一文读懂TensorFlow 2.0高阶API

TensorFlow 的常用模块介绍

你可能感兴趣的:(早期编程语言基础)