工具系列:TensorFlow Decision Forests_(1)构建、训练和评估模型

文章目录

    • 1. 介绍
    • 2. 安装 TensorFlow Decision Forests
    • 3. 导入库
    • 4. 训练一个随机森林模型
      • 4.1 加载数据集并将其转换为tf.Dataset
      • 4.2 训练模型
      • 4.3 备注
    • 5. 评估模型
    • 6. 为TensorFlow Serving准备这个模型。
    • 7. 绘制模型
    • 8. 模型结构和特征重要性
    • 9. 模型自我评估
    • 10. 绘制训练日志
    • 11. 使用不同的学习算法重新训练模型
    • 12. 使用特征子集
    • 13. 超参数
    • 14. 特征预处理
    • 15. 训练回归模型

1. 介绍

决策森林(DF)是一类用于监督分类、回归和排序的机器学习算法。顾名思义,DF使用决策树作为构建块。如今,最流行的DF训练算法是随机森林和梯度提升决策树。

TensorFlow决策森林(TF-DF)是一个用于训练、评估、解释和推断决策森林模型的库。

在本教程中,您将学习如何:

  1. 在包含数值、分类和缺失特征的数据集上训练一个多类分类随机森林模型。
  2. 在测试数据集上评估模型。
  3. 准备模型以供TensorFlow Serving使用。
  4. 检查模型的整体结构和每个特征的重要性。
  5. 使用不同的学习算法(梯度提升决策树)重新训练模型。
  6. 使用不同的输入特征集。
  7. 更改模型的超参数。
  8. 预处理特征。
  9. 训练一个回归模型。

详细文档可在用户手册中找到。示例目录包含其他端到端示例。

2. 安装 TensorFlow Decision Forests

通过运行以下单元格来安装 TF-DF。

# 安装tensorflow_decision_forests库
!pip install tensorflow_decision_forests
Collecting tensorflow_decision_forests
  Obtaining dependency information for tensorflow_decision_forests from https://files.pythonhosted.org/packages/67/84/dc181dc6d4ec2692432bb168119e932a3175ffcfddcca41bc8a1a6d5a8b9/tensorflow_decision_forests-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading tensorflow_decision_forests-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)
Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.24.3)
Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.0.3)
Requirement already satisfied: tensorflow~=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.13.0)
Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.16.0)
Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.4.0)
Requirement already satisfied: wheel in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (0.40.0)
Collecting wurlitzer (from tensorflow_decision_forests)
  Using cached wurlitzer-3.0.3-py3-none-any.whl (7.3 kB)
Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.6.3)
Requirement already satisfied: flatbuffers>=23.1.21 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (23.5.26)
Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.4.0)
Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.2.0)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.57.0)
Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.9.0)
Requirement already satisfied: keras<2.14,>=2.13.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.13.1)
Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (16.0.6)
Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.3.0)
Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (23.1)
Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (3.20.3)
Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (68.1.0)
Requirement already satisfied: tensorboard<2.14,>=2.13 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.13.0)
Requirement already satisfied: tensorflow-estimator<2.14,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.13.0)
Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (2.3.0)
Requirement already satisfied: typing-extensions<4.6.0,>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (4.5.0)
Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (1.15.0)
Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow~=2.13.0->tensorflow_decision_forests) (0.33.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2023.3)
Requirement already satisfied: tzdata>=2022.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2023.3)
Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.22.0)
Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.0.0)
Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.4.4)
Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.31.0)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.7.1)
Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.3.7)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (5.3.1)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.3.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (4.9)
Requirement already satisfied: urllib3<2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.26.16)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (1.3.1)
Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (6.8.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.2.0)
Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.4)
Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2023.7.22)
Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (2.1.3)
Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.16.2)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (0.5.0)
Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.14,>=2.13->tensorflow~=2.13.0->tensorflow_decision_forests) (3.2.2)
Using cached tensorflow_decision_forests-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.8 MB)
Installing collected packages: wurlitzer, tensorflow_decision_forests
Successfully installed tensorflow_decision_forests-1.5.0 wurlitzer-3.0.3

Wurlitzer 是在 Colabs 中显示详细的训练日志所需的(当在模型构造函数中使用 verbose=2 时)。

!pip install wurlitzer
Requirement already satisfied: wurlitzer in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (3.0.3)

3. 导入库

# 导入所需的库

import tensorflow_decision_forests as tfdf  # 导入决策森林库
import os  # 导入操作系统库
import numpy as np  # 导入数值计算库
import pandas as pd  # 导入数据处理库
import tensorflow as tf  # 导入深度学习库
import math  # 导入数学库

在Colab中,隐藏的代码单元格会限制输出的高度。

#@title

# 导入所需的模块
from IPython.core.magic import register_line_magic
from IPython.display import Javascript
from IPython.display import display as ipy_display

# 定义一个魔术命令,用于设置单元格的最大高度
@register_line_magic
def set_cell_height(size):
  # 调用Javascript代码,设置单元格的最大高度
  ipy_display(
      Javascript("google.colab.output.setIframeHeight(0, true, {maxHeight: " +
                 str(size) + "})"))
# 检查 TensorFlow Decision Forests 的版本
# 打印出 TensorFlow Decision Forests 的版本号
print("Found TensorFlow Decision Forests v" + tfdf.__version__)
Found TensorFlow Decision Forests v1.5.0

4. 训练一个随机森林模型

在本节中,我们将训练、评估、分析和导出一个基于Palmer’s Penguins数据集的多类分类随机森林模型。

注意: 数据集被导出为一个未经预处理的csv文件:library(palmerpenguins); write.csv(penguins, file="penguins.csv", quote=F, row.names=F)

4.1 加载数据集并将其转换为tf.Dataset

这个数据集非常小(300个例子),并且以类似.csv的文件格式存储。因此,使用Pandas来加载它。

**注意:**Pandas很实用,因为你不需要输入特征的名称来加载它们。对于更大的数据集(>1M个例子),使用TensorFlow Dataset来读取文件可能更合适。

让我们将数据集组装成一个csv文件(即添加头部),然后加载它:

# 下载数据集
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv

# 将数据集加载到Pandas Dataframe中
dataset_df = pd.read_csv("/tmp/penguins.csv")

# 显示前3个样本
dataset_df.head(3)
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex year
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 male 2007
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 female 2007
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 female 2007

该数据集包含数字型(例如bill_depth_mm)、分类型(例如island)和缺失特征的混合。TF-DF原生支持所有这些特征类型(与基于NN的模型不同),因此不需要进行预处理,如独热编码、归一化或额外的is_present特征。

标签有点不同:Keras指标需要整数。标签(species)存储为字符串,因此让我们将其转换为整数。

# 将分类标签编码为整数

# 详细说明:
# 如果你的分类标签是字符串形式的,那么这个步骤是必要的,因为Keras期望的是整数分类标签。
# 当使用`pd_dataframe_to_tf_dataset`(见下文)时,可以跳过这一步。

# 标签列的名称
label = "species"

# 获取标签的唯一值,并转换为列表
classes = dataset_df[label].unique().tolist()

# 打印标签的类别
print(f"Label classes: {classes}")

# 使用类别的索引值替换数据集中的标签值
dataset_df[label] = dataset_df[label].map(classes.index)
Label classes: ['Adelie', 'Gentoo', 'Chinstrap']

接下来将数据集分为训练集和测试集:

# 将数据集分割为训练集和测试集。

def split_dataset(dataset, test_ratio=0.30):
  """将panda数据框分割成两部分。"""
  # 生成一个与数据集长度相同的随机数组,元素值小于测试比例的为True,大于等于测试比例的为False
  test_indices = np.random.rand(len(dataset)) < test_ratio
  # 返回测试集和训练集
  return dataset[~test_indices], dataset[test_indices]


# 调用split_dataset函数将数据集分割成训练集和测试集,并将返回的结果分别赋值给train_ds_pd和test_ds_pd
train_ds_pd, test_ds_pd = split_dataset(dataset_df)

# 打印训练集和测试集的样本数量
print("{}个样本用于训练,{}个样本用于测试。".format(
    len(train_ds_pd), len(test_ds_pd)))
239 examples in training, 105 examples for testing.

最后,将pandas数据帧(pd.Dataframe)转换为tensorflow数据集(tf.data.Dataset):

# 将Pandas DataFrame转换为TensorFlow数据集
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label)  # 将训练集的Pandas DataFrame转换为TensorFlow数据集,并指定标签列
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=label)  # 将测试集的Pandas DataFrame转换为TensorFlow数据集,并指定标签列

**注意事项:**请记住,如果需要,pd_dataframe_to_tf_dataset会将字符串标签转换为整数。

如果您想自己创建tf.data.Dataset,请记住以下几点:

  • 学习算法使用的是一个周期的数据集,且不进行洗牌。
  • 批次大小不会影响训练算法,但较小的值可能会减慢读取数据集的速度。

4.2 训练模型

# 设置单元格高度为300

# 指定模型为随机森林模型,并设置详细程度为2
model_1 = tfdf.keras.RandomForestModel(verbose=2)

# 训练模型
model_1.fit(train_ds)



Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


Use /tmpfs/tmp/tmpblfnf8hv as temporary training directory
Reading training dataset...
Training tensor examples:
Features: {'island': , 'bill_length_mm': , 'bill_depth_mm': , 'flipper_length_mm': , 'body_mass_g': , 'sex': , 'year': }
Label: Tensor("data_7:0", shape=(None,), dtype=int64)
Weights: None
Normalized tensor features:
 {'island': SemanticTensor(semantic=, tensor=), 'bill_length_mm': SemanticTensor(semantic=, tensor=), 'bill_depth_mm': SemanticTensor(semantic=, tensor=), 'flipper_length_mm': SemanticTensor(semantic=, tensor=), 'body_mass_g': SemanticTensor(semantic=, tensor=), 'sex': SemanticTensor(semantic=, tensor=), 'year': SemanticTensor(semantic=, tensor=)}
Training dataset read in 0:00:03.556705. Found 239 examples.
Training model...
Standard output detected as not visible to the user e.g. running in a notebook. Creating a training log redirection. If training gets stuck, try calling tfdf.keras.set_training_logs_redirection(False).


[INFO 23-08-16 11:05:20.8059 UTC kernel.cc:773] Start Yggdrasil model training
[INFO 23-08-16 11:05:20.8059 UTC kernel.cc:774] Collect training examples
[INFO 23-08-16 11:05:20.8060 UTC kernel.cc:787] Dataspec guide:
column_guides {
  column_name_pattern: "^__LABEL$"
  type: CATEGORICAL
  categorial {
    min_vocab_frequency: 0
    max_vocab_count: -1
  }
}
default_column_guide {
  categorial {
    max_vocab_count: 2000
  }
  discretized_numerical {
    maximum_num_bins: 255
  }
}
ignore_columns_without_guides: false
detect_numerical_as_discretized_numerical: false

[INFO 23-08-16 11:05:20.8063 UTC kernel.cc:393] Number of batches: 1
[INFO 23-08-16 11:05:20.8064 UTC kernel.cc:394] Number of examples: 239
[INFO 23-08-16 11:05:20.8064 UTC kernel.cc:794] Training dataset:
Number of records: 239
Number of columns: 8

Number of columns by type:
	NUMERICAL: 5 (62.5%)
	CATEGORICAL: 3 (37.5%)

Columns:

NUMERICAL: 5 (62.5%)
	1: "bill_depth_mm" NUMERICAL num-nas:1 (0.41841%) mean:17.0387 min:13.2 max:21.5 sd:1.97169
	2: "bill_length_mm" NUMERICAL num-nas:1 (0.41841%) mean:44.0025 min:32.1 max:55.9 sd:5.27172
	3: "body_mass_g" NUMERICAL num-nas:1 (0.41841%) mean:4230.57 min:2700 max:6300 sd:821.055
	4: "flipper_length_mm" NUMERICAL num-nas:1 (0.41841%) mean:201.176 min:172 max:231 sd:14.2924
	7: "year" NUMERICAL mean:2008.03 min:2007 max:2009 sd:0.807521

CATEGORICAL: 3 (37.5%)
	0: "__LABEL" CATEGORICAL integerized vocab-size:4 no-ood-item
	5: "island" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Biscoe" 121 (50.6276%)
	6: "sex" CATEGORICAL num-nas:7 (2.92887%) has-dict vocab-size:3 zero-ood-items most-frequent:"female" 120 (51.7241%)

Terminology:
	nas: Number of non-available (i.e. missing) values.
	ood: Out of dictionary.
	manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
	tokenized: The attribute value is obtained through tokenization.
	has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
	vocab-size: Number of unique values.

[INFO 23-08-16 11:05:20.8065 UTC kernel.cc:810] Configure learner
[INFO 23-08-16 11:05:20.8067 UTC kernel.cc:824] Training config:
learner: "RANDOM_FOREST"
features: "^bill_depth_mm$"
features: "^bill_length_mm$"
features: "^body_mass_g$"
features: "^flipper_length_mm$"
features: "^island$"
features: "^sex$"
features: "^year$"
label: "^__LABEL$"
task: CLASSIFICATION
random_seed: 123456
metadata {
  framework: "TF Keras"
}
pure_serving_model: false
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
  num_trees: 300
  decision_tree {
    max_depth: 16
    min_examples: 5
    in_split_min_examples_check: true
    keep_non_leaf_label_distribution: true
    num_candidate_attributes: 0
    missing_value_policy: GLOBAL_IMPUTATION
    allow_na_conditions: false
    categorical_set_greedy_forward {
      sampling: 0.1
      max_num_items: -1
      min_item_frequency: 1
    }
    growing_strategy_local {
    }
    categorical {
      cart {
      }
    }
    axis_aligned_split {
    }
    internal {
      sorting_strategy: PRESORTED
    }
    uplift {
      min_examples_in_treatment: 5
      split_score: KULLBACK_LEIBLER
    }
  }
  winner_take_all_inference: true
  compute_oob_performances: true
  compute_oob_variable_importances: false
  num_oob_variable_importances_permutations: 1
  bootstrap_training_dataset: true
  bootstrap_size_ratio: 1
  adapt_bootstrap_size_ratio_for_maximum_training_duration: false
  sampling_with_replacement: true
}

[INFO 23-08-16 11:05:20.8070 UTC kernel.cc:827] Deployment config:
cache_path: "/tmpfs/tmp/tmpblfnf8hv/working_cache"
num_threads: 32
try_resume_training: true

[INFO 23-08-16 11:05:20.8072 UTC kernel.cc:889] Train model
[INFO 23-08-16 11:05:20.8073 UTC random_forest.cc:416] Training random forest on 239 example(s) and 7 feature(s).
[INFO 23-08-16 11:05:20.8130 UTC random_forest.cc:802] Training of tree  1/300 (tree index:0) done accuracy:0.943182 logloss:2.04793
[INFO 23-08-16 11:05:20.8139 UTC random_forest.cc:802] Training of tree  11/300 (tree index:8) done accuracy:0.949367 logloss:0.383614
[INFO 23-08-16 11:05:20.8144 UTC random_forest.cc:802] Training of tree  21/300 (tree index:4) done accuracy:0.953975 logloss:0.386135
[INFO 23-08-16 11:05:20.8146 UTC random_forest.cc:802] Training of tree  35/300 (tree index:20) done accuracy:0.953975 logloss:0.249595
[INFO 23-08-16 11:05:20.8147 UTC random_forest.cc:802] Training of tree  50/300 (tree index:30) done accuracy:0.949791 logloss:0.249004
[INFO 23-08-16 11:05:20.8149 UTC random_forest.cc:802] Training of tree  62/300 (tree index:61) done accuracy:0.949791 logloss:0.247371
[INFO 23-08-16 11:05:20.8155 UTC random_forest.cc:802] Training of tree  73/300 (tree index:73) done accuracy:0.962343 logloss:0.246108
[INFO 23-08-16 11:05:20.8158 UTC random_forest.cc:802] Training of tree  83/300 (tree index:82) done accuracy:0.958159 logloss:0.240771
[INFO 23-08-16 11:05:20.8163 UTC random_forest.cc:802] Training of tree  96/300 (tree index:98) done accuracy:0.962343 logloss:0.0994905
[INFO 23-08-16 11:05:20.8166 UTC random_forest.cc:802] Training of tree  106/300 (tree index:105) done accuracy:0.966527 logloss:0.100095
[INFO 23-08-16 11:05:20.8170 UTC random_forest.cc:802] Training of tree  117/300 (tree index:117) done accuracy:0.962343 logloss:0.0959006
[INFO 23-08-16 11:05:20.8173 UTC random_forest.cc:802] Training of tree  127/300 (tree index:125) done accuracy:0.958159 logloss:0.0962165
[INFO 23-08-16 11:05:20.8177 UTC random_forest.cc:802] Training of tree  138/300 (tree index:137) done accuracy:0.958159 logloss:0.0927663
[INFO 23-08-16 11:05:20.8182 UTC random_forest.cc:802] Training of tree  148/300 (tree index:147) done accuracy:0.966527 logloss:0.0931921
[INFO 23-08-16 11:05:20.8187 UTC random_forest.cc:802] Training of tree  158/300 (tree index:157) done accuracy:0.966527 logloss:0.092117
[INFO 23-08-16 11:05:20.8190 UTC random_forest.cc:802] Training of tree  170/300 (tree index:170) done accuracy:0.966527 logloss:0.0926436
[INFO 23-08-16 11:05:20.8196 UTC random_forest.cc:802] Training of tree  180/300 (tree index:181) done accuracy:0.966527 logloss:0.0927239
[INFO 23-08-16 11:05:20.8200 UTC random_forest.cc:802] Training of tree  190/300 (tree index:187) done accuracy:0.966527 logloss:0.0942833
[INFO 23-08-16 11:05:20.8203 UTC random_forest.cc:802] Training of tree  200/300 (tree index:198) done accuracy:0.966527 logloss:0.0941766
[INFO 23-08-16 11:05:20.8208 UTC random_forest.cc:802] Training of tree  210/300 (tree index:208) done accuracy:0.962343 logloss:0.0938748
[INFO 23-08-16 11:05:20.8211 UTC random_forest.cc:802] Training of tree  220/300 (tree index:219) done accuracy:0.958159 logloss:0.0950461
[INFO 23-08-16 11:05:20.8214 UTC random_forest.cc:802] Training of tree  231/300 (tree index:231) done accuracy:0.953975 logloss:0.0951599
[INFO 23-08-16 11:05:20.8218 UTC random_forest.cc:802] Training of tree  241/300 (tree index:241) done accuracy:0.962343 logloss:0.0948531
[INFO 23-08-16 11:05:20.8221 UTC random_forest.cc:802] Training of tree  251/300 (tree index:250) done accuracy:0.962343 logloss:0.0942377
[INFO 23-08-16 11:05:20.8224 UTC random_forest.cc:802] Training of tree  262/300 (tree index:261) done accuracy:0.962343 logloss:0.0940229
[INFO 23-08-16 11:05:20.8228 UTC random_forest.cc:802] Training of tree  272/300 (tree index:276) done accuracy:0.958159 logloss:0.0934476
[INFO 23-08-16 11:05:20.8231 UTC random_forest.cc:802] Training of tree  282/300 (tree index:281) done accuracy:0.958159 logloss:0.0934649
[INFO 23-08-16 11:05:20.8234 UTC random_forest.cc:802] Training of tree  292/300 (tree index:292) done accuracy:0.958159 logloss:0.0943068
[INFO 23-08-16 11:05:20.8236 UTC random_forest.cc:802] Training of tree  300/300 (tree index:299) done accuracy:0.958159 logloss:0.0945677
[INFO 23-08-16 11:05:20.8250 UTC random_forest.cc:882] Final OOB metrics: accuracy:0.958159 logloss:0.0945677
[INFO 23-08-16 11:05:20.8261 UTC kernel.cc:926] Export model in log directory: /tmpfs/tmp/tmpblfnf8hv with prefix 3b862cbea45f4b2a
[INFO 23-08-16 11:05:20.8303 UTC kernel.cc:944] Save model in resources
[INFO 23-08-16 11:05:20.8335 UTC abstract_model.cc:849] Model self evaluation:
Number of predictions (without weights): 239
Number of predictions (with weights): 239
Task: CLASSIFICATION
Label: __LABEL

Accuracy: 0.958159  CI95[W][0.930062 0.977127]
LogLoss: : 0.0945677
ErrorRate: : 0.041841

Default Accuracy: : 0.422594
Default LogLoss: : 1.04864
Default ErrorRate: : 0.577406

Confusion Table:
truth\prediction
   0   1   2   3
0  0   0   0   0
1  0  98   0   3
2  0   1  91   0
3  0   4   2  40
Total: 239

One vs other classes:

[INFO 23-08-16 11:05:20.8441 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpblfnf8hv/model/ with prefix 3b862cbea45f4b2a
[INFO 23-08-16 11:05:20.8582 UTC decision_forest.cc:660] Model loaded with 300 root(s), 4336 node(s), and 7 input feature(s).
[INFO 23-08-16 11:05:20.8582 UTC abstract_model.cc:1311] Engine "RandomForestGeneric" built
[INFO 23-08-16 11:05:20.8582 UTC kernel.cc:1075] Use fast generic engine


Model trained in 0:00:00.059670
Compiling model...
WARNING:tensorflow:AutoGraph could not transform  and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert


WARNING:tensorflow:AutoGraph could not transform  and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert


WARNING: AutoGraph could not transform  and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
Model compiled.






4.3 备注

  • 没有指定输入特征。因此,除了标签之外,所有列都将被用作输入特征。模型使用的特征在训练日志和model.summary()中显示。
  • DFs原生支持数值、分类、分类集和缺失值特征。数值特征不需要进行归一化。分类字符串值不需要在字典中进行编码。
  • 没有指定训练超参数。因此将使用默认超参数。默认超参数在大多数情况下提供合理的结果。
  • fit之前对模型调用compile是可选的。编译可以用于提供额外的评估指标。
  • 训练算法不需要验证数据集。如果提供了验证数据集,它只会用于显示指标。
  • 调整RandomForestModelverbose参数以控制显示的训练日志的数量。设置verbose=0以隐藏大部分日志。设置verbose=2以显示所有日志。

注意: 分类集特征由一组分类值组成(而分类只是一个值)。更多详细信息和示例将在后面给出。

5. 评估模型

让我们在测试数据集上评估我们的模型。

# 编译模型
model_1.compile(metrics=["accuracy"])

# 评估模型
evaluation = model_1.evaluate(test_ds, return_dict=True)

print()

# 打印评估结果
for name, value in evaluation.items():
  print(f"{name}: {value:.4f}")
1/1 [==============================] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9619
1/1 [==============================] - 0s 295ms/step - loss: 0.0000e+00 - accuracy: 0.9619

loss: 0.0000
accuracy: 0.9619

备注: 测试准确率接近于训练日志中显示的袋外准确率。

有关更多评估方法,请参见下面的模型自我评估部分。

6. 为TensorFlow Serving准备这个模型。

导出模型为SavedModel格式,以便以后重复使用,例如TensorFlow Serving。

# 保存模型到指定路径
model_1.save("/tmp/my_saved_model")
INFO:tensorflow:Assets written to: /tmp/my_saved_model/assets


INFO:tensorflow:Assets written to: /tmp/my_saved_model/assets

7. 绘制模型

绘制决策树并跟随第一分支有助于了解决策森林。在某些情况下,绘制模型甚至可以用于调试。

由于它们训练的方式不同,某些模型比其他模型更有趣。由于训练过程中注入的噪声和树的深度,绘制随机森林的信息量较少,而绘制CART或梯度提升树的第一棵树更具信息量。

尽管如此,让我们绘制我们的随机森林模型的第一棵树:

# 使用model_plotter模块中的plot_model_in_colab函数绘制模型图
# 参数model_1为要绘制的模型
# 参数tree_idx为要绘制的决策树的索引,这里选择第一个决策树
# 参数max_depth为要绘制的决策树的最大深度,这里设置为3
tfdf.model_plotter.plot_model_in_colab(model_1, tree_idx=0, max_depth=3)

/**

  • Plotting of decision trees generated by TF-DF.
  • A tree is a recursive structure of node objects.
  • A node contains one or more of the following components:
    • A value: Representing the output of the node. If the node is not a leaf,
  •  the value is only present for analysis i.e. it is not used for
    
  •  predictions.
    
    • A condition : For non-leaf nodes, the condition (also known as split)
  •  defines a binary test to branch to the positive or negative child.
    
    • An explanation: Generally a plot showing the relation between the label
  •  and the condition to give insights about the effect of the condition.
    
    • Two children : For non-leaf nodes, the children nodes. The first
  •  children (i.e. "node.children[0]") is the negative children (drawn in
    
  •  red). The second children is the positive one (drawn in green).
    

*/

/**

  • Plots a single decision tree into a DOM element.
  • @param {!options} options Dictionary of configurations.
  • @param {!tree} raw_tree Recursive tree structure.
  • @param {string} canvas_id Id of the output dom element.
    */
    function display_tree(options, raw_tree, canvas_id) {
    console.log(options);

// Determine the node placement.
const tree_struct = d3.tree().nodeSize(
[options.node_y_offset, options.node_x_offset])(d3.hierarchy(raw_tree));

// Boundaries of the node placement.
let x_min = Infinity;
let x_max = -x_min;
let y_min = Infinity;
let y_max = -x_min;

tree_struct.each(d => {
if (d.x > x_max) x_max = d.x;
if (d.x < x_min) x_min = d.x;
if (d.y > y_max) y_max = d.y;
if (d.y < y_min) y_min = d.y;
});

// Size of the plot.
const width = y_max - y_min + options.node_x_size + options.margin * 2;
const height = x_max - x_min + options.node_y_size + options.margin * 2 +
options.node_y_offset - options.node_y_size;

const plot = d3.select(canvas_id);

// Tool tip
options.tooltip = plot.append(‘div’)
.attr(‘width’, 100)
.attr(‘height’, 100)
.style(‘padding’, ‘4px’)
.style(‘background’, ‘#fff’)
.style(‘box-shadow’, ‘4px 4px 0px rgba(0,0,0,0.1)’)
.style(‘border’, ‘1px solid black’)
.style(‘font-family’, ‘sans-serif’)
.style(‘font-size’, options.font_size)
.style(‘position’, ‘absolute’)
.style(‘z-index’, ‘10’)
.attr(‘pointer-events’, ‘none’)
.style(‘display’, ‘none’);

// Create canvas
const svg = plot.append(‘svg’).attr(‘width’, width).attr(‘height’, height);
const graph =
svg.style(‘overflow’, ‘visible’)
.append(‘g’)
.attr(‘font-family’, ‘sans-serif’)
.attr(‘font-size’, options.font_size)
.attr(
‘transform’,
() => translate(${options.margin},${ - x_min + options.node_y_offset / 2 + options.margin}));

// Plot bounding box.
if (options.show_plot_bounding_box) {
svg.append(‘rect’)
.attr(‘width’, width)
.attr(‘height’, height)
.attr(‘fill’, ‘none’)
.attr(‘stroke-width’, 1.0)
.attr(‘stroke’, ‘black’);
}

// Draw the edges.
display_edges(options, graph, tree_struct);

// Draw the nodes.
display_nodes(options, graph, tree_struct);
}

/**

  • Draw the nodes of the tree.
  • @param {!options} options Dictionary of configurations.
  • @param {!graph} graph D3 search handle containing the graph.
  • @param {!tree_struct} tree_struct Structure of the tree (node placement,
  • data, etc.).
    

*/
function display_nodes(options, graph, tree_struct) {
const nodes = graph.append(‘g’)
.selectAll(‘g’)
.data(tree_struct.descendants())
.join(‘g’)
.attr(‘transform’, d => translate(${d.y},${d.x}));

nodes.append(‘rect’)
.attr(‘x’, 0.5)
.attr(‘y’, 0.5)
.attr(‘width’, options.node_x_size)
.attr(‘height’, options.node_y_size)
.attr(‘stroke’, ‘lightgrey’)
.attr(‘stroke-width’, 1)
.attr(‘fill’, ‘white’)
.attr(‘y’, -options.node_y_size / 2);

// Brackets on the right of condition nodes without children.
non_leaf_node_without_children =
nodes.filter(node => node.data.condition != null && node.children == null)
.append(‘g’)
.attr(‘transform’, translate(${options.node_x_size},0));

non_leaf_node_without_children.append(‘path’)
.attr(‘d’, ‘M0,0 C 10,0 0,10 10,10’)
.attr(‘fill’, ‘none’)
.attr(‘stroke-width’, 1.0)
.attr(‘stroke’, ‘#F00’);

non_leaf_node_without_children.append(‘path’)
.attr(‘d’, ‘M0,0 C 10,0 0,-10 10,-10’)
.attr(‘fill’, ‘none’)
.attr(‘stroke-width’, 1.0)
.attr(‘stroke’, ‘#0F0’);

const node_content = nodes.append(‘g’).attr(
‘transform’,
translate(0,${options.node_padding - options.node_y_size / 2}));

node_content.append(node => create_node_element(options, node));
}

/**

  • Creates the D3 content for a single node.
  • @param {!options} options Dictionary of configurations.
  • @param {!node} node Node to draw.
  • @return {!d3} D3 content.
    */
    function create_node_element(options, node) {
    // Output accumulator.
    let output = {
    // Content to draw.
    content: d3.create(‘svg:g’),
    // Vertical offset to the next element to draw.
    vertical_offset: 0
    };

// Conditions.
if (node.data.condition != null) {
display_condition(options, node.data.condition, output);
}

// Values.
if (node.data.value != null) {
display_value(options, node.data.value, output);
}

// Explanations.
if (node.data.explanation != null) {
display_explanation(options, node.data.explanation, output);
}

return output.content.node();
}

/**

  • Adds a single line of text inside of a node.
  • @param {!options} options Dictionary of configurations.
  • @param {string} text Text to display.
  • @param {!output} output Output display accumulator.
    */
    function display_node_text(options, text, output) {
    output.content.append(‘text’)
    .attr(‘x’, options.node_padding)
    .attr(‘y’, output.vertical_offset)
    .attr(‘alignment-baseline’, ‘hanging’)
    .text(text);
    output.vertical_offset += 10;
    }

/**

  • Adds a single line of text inside of a node with a tooltip.
  • @param {!options} options Dictionary of configurations.
  • @param {string} text Text to display.
  • @param {string} tooltip Text in the Tooltip.
  • @param {!output} output Output display accumulator.
    */
    function display_node_text_with_tooltip(options, text, tooltip, output) {
    const item = output.content.append(‘text’)
    .attr(‘x’, options.node_padding)
    .attr(‘alignment-baseline’, ‘hanging’)
    .text(text);

add_tooltip(options, item, () => tooltip);
output.vertical_offset += 10;
}

/**

  • Adds a tooltip to a dom element.
  • @param {!options} options Dictionary of configurations.
  • @param {!dom} target Dom element to equip with a tooltip.
  • @param {!func} get_content Generates the html content of the tooltip.
    */
    function add_tooltip(options, target, get_content) {
    function show(d) {
    options.tooltip.style(‘display’, ‘block’);
    options.tooltip.html(get_content());
    }

function hide(d) {
options.tooltip.style(‘display’, ‘none’);
}

function move(d) {
options.tooltip.style(‘display’, ‘block’);
options.tooltip.style(‘left’, (d.pageX + 5) + ‘px’);
options.tooltip.style(‘top’, d.pageY + ‘px’);
}

target.on(‘mouseover’, show);
target.on(‘mouseout’, hide);
target.on(‘mousemove’, move);
}

/**

  • Adds a condition inside of a node.
  • @param {!options} options Dictionary of configurations.
  • @param {!condition} condition Condition to display.
  • @param {!output} output Output display accumulator.
    */
    function display_condition(options, condition, output) {
    threshold_format = d3.format(‘r’);

if (condition.type === ‘IS_MISSING’) {
display_node_text(options, ${condition.attribute} is missing, output);
return;
}

if (condition.type === ‘IS_TRUE’) {
display_node_text(options, ${condition.attribute} is true, output);
return;
}

if (condition.type === ‘NUMERICAL_IS_HIGHER_THAN’) {
format = d3.format(‘r’);
display_node_text(
options,
${condition.attribute} >= ${threshold_format(condition.threshold)},
output);
return;
}

if (condition.type === ‘CATEGORICAL_IS_IN’) {
display_node_text_with_tooltip(
options, ${condition.attribute} in [...],
${condition.attribute} in [${condition.mask}], output);
return;
}

if (condition.type === ‘CATEGORICAL_SET_CONTAINS’) {
display_node_text_with_tooltip(
options, ${condition.attribute} intersect [...],
${condition.attribute} intersect [${condition.mask}], output);
return;
}

if (condition.type === ‘NUMERICAL_SPARSE_OBLIQUE’) {
display_node_text_with_tooltip(
options, Sparse oblique split...,
[${condition.attributes}]*[${condition.weights}]>=${ threshold_format(condition.threshold)},
output);
return;
}

display_node_text(
options, Non supported condition ${condition.type}, output);
}

/**

  • Adds a value inside of a node.

  • @param {!options} options Dictionary of configurations.

  • @param {!value} value Value to display.

  • @param {!output} output Output display accumulator.
    */
    function display_value(options, value, output) {
    if (value.type === ‘PROBABILITY’) {
    const left_margin = 0;
    const right_margin = 50;
    const plot_width = options.node_x_size - options.node_padding * 2 -
    left_margin - right_margin;

    let cusum = Array.from(d3.cumsum(value.distribution));
    cusum.unshift(0);
    const distribution_plot = output.content.append(‘g’).attr(
    ‘transform’, translate(0,${output.vertical_offset + 0.5}));

    distribution_plot.selectAll(‘rect’)
    .data(value.distribution)
    .join(‘rect’)
    .attr(‘height’, 10)
    .attr(
    ‘x’,
    (d, i) =>
    (cusum[i] * plot_width + left_margin + options.node_padding))
    .attr(‘width’, (d, i) => d * plot_width)
    .style(‘fill’, (d, i) => d3.schemeSet1[i]);

    const num_examples =
    output.content.append(‘g’)
    .attr(‘transform’, translate(0,${output.vertical_offset}))
    .append(‘text’)
    .attr(‘x’, options.node_x_size - options.node_padding)
    .attr(‘alignment-baseline’, ‘hanging’)
    .attr(‘text-anchor’, ‘end’)
    .text((${value.num_examples}));

    const distribution_details = d3.create(‘ul’);
    distribution_details.selectAll(‘li’)
    .data(value.distribution)
    .join(‘li’)
    .append(‘span’)
    .text(
    (d, i) =>
    ‘class ’ + i + ‘: ’ + d3.format(’.3%’)(value.distribution[i]));

    add_tooltip(options, distribution_plot, () => distribution_details.html());
    add_tooltip(options, num_examples, () => ‘Number of examples’);

    output.vertical_offset += 10;
    return;
    }

if (value.type === ‘REGRESSION’) {
display_node_text(
options,
‘value: ’ + d3.format(‘r’)(value.value) + ( +
d3.format(’.6’)(value.num_examples) + ),
output);
return;
}

if (value.type === ‘UPLIFT’) {
display_node_text(
options,
‘effect: ’ + d3.format(‘r’)(value.treatment_effect) + ( +
d3.format(’.6’)(value.num_examples) + ),
output);
return;
}

display_node_text(options, Non supported value ${value.type}, output);
}

/**

  • Adds an explanation inside of a node.
  • @param {!options} options Dictionary of configurations.
  • @param {!explanation} explanation Explanation to display.
  • @param {!output} output Output display accumulator.
    */
    function display_explanation(options, explanation, output) {
    // Margin before the explanation.
    output.vertical_offset += 10;

display_node_text(
options, Non supported explanation ${explanation.type}, output);
}

/**

  • Draw the edges of the tree.
  • @param {!options} options Dictionary of configurations.
  • @param {!graph} graph D3 search handle containing the graph.
  • @param {!tree_struct} tree_struct Structure of the tree (node placement,
  • data, etc.).
    

*/
function display_edges(options, graph, tree_struct) {
// Draw an edge between a parent and a child node with a bezier.
function draw_single_edge(d) {
return ‘M’ + (d.source.y + options.node_x_size) + ‘,’ + d.source.x + ’ C’ +
(d.source.y + options.node_x_size + options.edge_rounding) + ‘,’ +
d.source.x + ’ ’ + (d.target.y - options.edge_rounding) + ‘,’ +
d.target.x + ’ ’ + d.target.y + ‘,’ + d.target.x;
}

graph.append(‘g’)
.attr(‘fill’, ‘none’)
.attr(‘stroke-width’, 1.2)
.selectAll(‘path’)
.data(tree_struct.links())
.join(‘path’)
.attr(‘d’, draw_single_edge)
.attr(
‘stroke’, d => (d.target === d.source.children[0]) ? ‘#0F0’ : ‘#F00’);
}

display_tree({“margin”: 10, “node_x_size”: 160, “node_y_size”: 28, “node_x_offset”: 180, “node_y_offset”: 33, “font_size”: 10, “edge_rounding”: 20, “node_padding”: 2, “show_plot_bounding_box”: false}, {“value”: {“type”: “PROBABILITY”, “distribution”: [0.4435146443514644, 0.34309623430962344, 0.21338912133891214], “num_examples”: 239.0}, “condition”: {“type”: “NUMERICAL_IS_HIGHER_THAN”, “attribute”: “flipper_length_mm”, “threshold”: 206.5}, “children”: [{“value”: {“type”: “PROBABILITY”, “distribution”: [0.0, 0.9534883720930233, 0.046511627906976744], “num_examples”: 86.0}, “condition”: {“type”: “NUMERICAL_IS_HIGHER_THAN”, “attribute”: “bill_depth_mm”, “threshold”: 17.200000762939453}, “children”: [{“value”: {“type”: “PROBABILITY”, “distribution”: [0.0, 0.2, 0.8], “num_examples”: 5.0}}, {“value”: {“type”: “PROBABILITY”, “distribution”: [0.0, 1.0, 0.0], “num_examples”: 81.0}}]}, {“value”: {“type”: “PROBABILITY”, “distribution”: [0.6928104575163399, 0.0, 0.30718954248366015], “num_examples”: 153.0}, “condition”: {“type”: “CATEGORICAL_IS_IN”, “attribute”: “island”, “mask”: [“Biscoe”, “Torgersen”]}, “children”: [{“value”: {“type”: “PROBABILITY”, “distribution”: [1.0, 0.0, 0.0], “num_examples”: 81.0}}, {“value”: {“type”: “PROBABILITY”, “distribution”: [0.3472222222222222, 0.0, 0.6527777777777778], “num_examples”: 72.0}, “condition”: {“type”: “NUMERICAL_IS_HIGHER_THAN”, “attribute”: “bill_length_mm”, “threshold”: 42.30000305175781}, “children”: [{“value”: {“type”: “PROBABILITY”, “distribution”: [0.0, 0.0, 1.0], “num_examples”: 47.0}}, {“value”: {“type”: “PROBABILITY”, “distribution”: [1.0, 0.0, 0.0], “num_examples”: 25.0}}]}]}]}, “#tree_plot_28efd97aa2df4edca61ad38bf7763da0”)

左侧的根节点包含第一个条件(bill_depth_mm >= 16.55),示例数量(240)和标签分布(红蓝绿色条形图)。

满足bill_depth_mm >= 16.55条件的示例分支到绿色路径。其他示例分支到红色路径。

节点越深,它们变得越“纯净”,即标签分布偏向于某个类别的子集。

**注意:**将鼠标悬停在图表上以获取详细信息。

8. 模型结构和特征重要性

模型的整体结构可以通过.summary()来展示。您将会看到以下内容:

  • 类型:用于训练模型的学习算法(在我们的案例中为随机森林)。
  • 任务:模型解决的问题(在我们的案例中为分类)。
  • 输入特征:模型的输入特征。
  • 变量重要性:每个特征对于模型的重要性的不同度量。
  • 袋外评估:模型的袋外评估。这是一种廉价且高效的交叉验证替代方法。
  • {树,节点}数量和其他指标:关于决策森林结构的统计信息。

备注:摘要的内容取决于学习算法(例如,袋外评估仅适用于随机森林)和超参数(例如,超参数中的平均准确率下降变量重要性可以禁用)。

# 设置单元格高度为300
%set_cell_height 300

# 打印模型1的概要信息
model_1.summary()



Model: "random_forest_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
=================================================================
Total params: 1 (1.00 Byte)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 1 (1.00 Byte)
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (7):
	bill_depth_mm
	bill_length_mm
	body_mass_g
	flipper_length_mm
	island
	sex
	year

No weights

Variable Importance: INV_MEAN_MIN_DEPTH:
    1. "flipper_length_mm"  0.459730 ################
    2.    "bill_length_mm"  0.428357 #############
    3.     "bill_depth_mm"  0.318034 #####
    4.            "island"  0.302253 ####
    5.       "body_mass_g"  0.270350 ##
    6.               "sex"  0.239199 
    7.              "year"  0.238541 

Variable Importance: NUM_AS_ROOT:
    1. "flipper_length_mm" 161.000000 ################
    2.    "bill_length_mm" 66.000000 ######
    3.     "bill_depth_mm" 55.000000 ####
    4.       "body_mass_g" 10.000000 
    5.            "island"  8.000000 

Variable Importance: NUM_NODES:
    1.    "bill_length_mm" 686.000000 ################
    2.     "bill_depth_mm" 411.000000 #########
    3. "flipper_length_mm" 357.000000 ########
    4.       "body_mass_g" 291.000000 ######
    5.            "island" 238.000000 #####
    6.               "sex" 21.000000 
    7.              "year" 14.000000 

Variable Importance: SUM_SCORE:
    1. "flipper_length_mm" 26375.887035 ################
    2.    "bill_length_mm" 23387.499002 ##############
    3.     "bill_depth_mm" 9981.270101 ######
    4.            "island" 8813.632840 #####
    5.       "body_mass_g" 3264.050597 #
    6.               "sex" 101.269852 
    7.              "year" 29.719130 



Winner takes all: true
Out-of-bag evaluation: accuracy:0.958159 logloss:0.0945677
Number of trees: 300
Total number of nodes: 4336

Number of nodes by tree:
Count: 300 Average: 14.4533 StdDev: 2.95654
Min: 7 Max: 25 Ignored: 0
----------------------------------------------
[  7,  8)  6   2.00%   2.00% #
[  8,  9)  0   0.00%   2.00%
[  9, 10)  9   3.00%   5.00% #
[ 10, 11)  0   0.00%   5.00%
[ 11, 12) 36  12.00%  17.00% ####
[ 12, 13)  0   0.00%  17.00%
[ 13, 14) 92  30.67%  47.67% ##########
[ 14, 15)  0   0.00%  47.67%
[ 15, 16) 73  24.33%  72.00% ########
[ 16, 17)  0   0.00%  72.00%
[ 17, 18) 48  16.00%  88.00% #####
[ 18, 19)  0   0.00%  88.00%
[ 19, 20) 26   8.67%  96.67% ###
[ 20, 21)  0   0.00%  96.67%
[ 21, 22)  8   2.67%  99.33% #
[ 22, 23)  0   0.00%  99.33%
[ 23, 24)  1   0.33%  99.67%
[ 24, 25)  0   0.00%  99.67%
[ 25, 25]  1   0.33% 100.00%

Depth by leafs:
Count: 2318 Average: 3.27653 StdDev: 1.02213
Min: 1 Max: 7 Ignored: 0
----------------------------------------------
[ 1, 2)  20   0.86%   0.86%
[ 2, 3) 563  24.29%  25.15% #######
[ 3, 4) 787  33.95%  59.10% ##########
[ 4, 5) 704  30.37%  89.47% #########
[ 5, 6) 200   8.63%  98.10% ###
[ 6, 7)  36   1.55%  99.65%
[ 7, 7]   8   0.35% 100.00%

Number of training obs by leaf:
Count: 2318 Average: 30.9318 StdDev: 32.1481
Min: 5 Max: 110 Ignored: 0
----------------------------------------------
[   5,  10) 1143  49.31%  49.31% ##########
[  10,  15)   88   3.80%  53.11% #
[  15,  20)   81   3.49%  56.60% #
[  20,  26)   78   3.36%  59.97% #
[  26,  31)   74   3.19%  63.16% #
[  31,  36)   81   3.49%  66.65% #
[  36,  42)  103   4.44%  71.10% #
[  42,  47)   46   1.98%  73.08%
[  47,  52)   34   1.47%  74.55%
[  52,  58)   20   0.86%  75.41%
[  58,  63)   30   1.29%  76.70%
[  63,  68)   39   1.68%  78.39%
[  68,  73)   58   2.50%  80.89% #
[  73,  79)   65   2.80%  83.69% #
[  79,  84)   98   4.23%  87.92% #
[  84,  89)   93   4.01%  91.93% #
[  89,  95)   98   4.23%  96.16% #
[  95, 100)   57   2.46%  98.62%
[ 100, 105)   25   1.08%  99.70%
[ 105, 110]    7   0.30% 100.00%

Attribute in nodes:
	686 : bill_length_mm [NUMERICAL]
	411 : bill_depth_mm [NUMERICAL]
	357 : flipper_length_mm [NUMERICAL]
	291 : body_mass_g [NUMERICAL]
	238 : island [CATEGORICAL]
	21 : sex [CATEGORICAL]
	14 : year [NUMERICAL]

Attribute in nodes with depth <= 0:
	161 : flipper_length_mm [NUMERICAL]
	66 : bill_length_mm [NUMERICAL]
	55 : bill_depth_mm [NUMERICAL]
	10 : body_mass_g [NUMERICAL]
	8 : island [CATEGORICAL]

Attribute in nodes with depth <= 1:
	258 : flipper_length_mm [NUMERICAL]
	252 : bill_length_mm [NUMERICAL]
	181 : bill_depth_mm [NUMERICAL]
	132 : island [CATEGORICAL]
	57 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 2:
	460 : bill_length_mm [NUMERICAL]
	318 : bill_depth_mm [NUMERICAL]
	317 : flipper_length_mm [NUMERICAL]
	207 : island [CATEGORICAL]
	172 : body_mass_g [NUMERICAL]
	3 : sex [CATEGORICAL]

Attribute in nodes with depth <= 3:
	631 : bill_length_mm [NUMERICAL]
	390 : bill_depth_mm [NUMERICAL]
	341 : flipper_length_mm [NUMERICAL]
	265 : body_mass_g [NUMERICAL]
	234 : island [CATEGORICAL]
	14 : sex [CATEGORICAL]
	9 : year [NUMERICAL]

Attribute in nodes with depth <= 5:
	683 : bill_length_mm [NUMERICAL]
	411 : bill_depth_mm [NUMERICAL]
	357 : flipper_length_mm [NUMERICAL]
	290 : body_mass_g [NUMERICAL]
	238 : island [CATEGORICAL]
	21 : sex [CATEGORICAL]
	14 : year [NUMERICAL]

Condition type in nodes:
	1759 : HigherCondition
	259 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
	292 : HigherCondition
	8 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
	748 : HigherCondition
	132 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
	1267 : HigherCondition
	210 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
	1636 : HigherCondition
	248 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
	1755 : HigherCondition
	259 : ContainsBitmapCondition
Node format: NOT_SET

Training OOB:
	trees: 1, Out-of-bag evaluation: accuracy:0.943182 logloss:2.04793
	trees: 11, Out-of-bag evaluation: accuracy:0.949367 logloss:0.383614
	trees: 21, Out-of-bag evaluation: accuracy:0.953975 logloss:0.386135
	trees: 35, Out-of-bag evaluation: accuracy:0.953975 logloss:0.249595
	trees: 50, Out-of-bag evaluation: accuracy:0.949791 logloss:0.249004
	trees: 62, Out-of-bag evaluation: accuracy:0.949791 logloss:0.247371
	trees: 73, Out-of-bag evaluation: accuracy:0.962343 logloss:0.246108
	trees: 83, Out-of-bag evaluation: accuracy:0.958159 logloss:0.240771
	trees: 96, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0994905
	trees: 106, Out-of-bag evaluation: accuracy:0.966527 logloss:0.100095
	trees: 117, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0959006
	trees: 127, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0962165
	trees: 138, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0927663
	trees: 148, Out-of-bag evaluation: accuracy:0.966527 logloss:0.0931921
	trees: 158, Out-of-bag evaluation: accuracy:0.966527 logloss:0.092117
	trees: 170, Out-of-bag evaluation: accuracy:0.966527 logloss:0.0926436
	trees: 180, Out-of-bag evaluation: accuracy:0.966527 logloss:0.0927239
	trees: 190, Out-of-bag evaluation: accuracy:0.966527 logloss:0.0942833
	trees: 200, Out-of-bag evaluation: accuracy:0.966527 logloss:0.0941766
	trees: 210, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0938748
	trees: 220, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0950461
	trees: 231, Out-of-bag evaluation: accuracy:0.953975 logloss:0.0951599
	trees: 241, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0948531
	trees: 251, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0942377
	trees: 262, Out-of-bag evaluation: accuracy:0.962343 logloss:0.0940229
	trees: 272, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0934476
	trees: 282, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0934649
	trees: 292, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0943068
	trees: 300, Out-of-bag evaluation: accuracy:0.958159 logloss:0.0945677

信息在summary中都可以通过模型检查器以编程方式获取:

# 获取模型的输入特征列表
features = model_1.make_inspector().features()
["bill_depth_mm" (1; #1),
 "bill_length_mm" (1; #2),
 "body_mass_g" (1; #3),
 "flipper_length_mm" (1; #4),
 "island" (4; #5),
 "sex" (4; #6),
 "year" (1; #7)]
# 打印变量重要性
model_1.make_inspector().variable_importances()
{'NUM_AS_ROOT': [("flipper_length_mm" (1; #4), 161.0),
  ("bill_length_mm" (1; #2), 66.0),
  ("bill_depth_mm" (1; #1), 55.0),
  ("body_mass_g" (1; #3), 10.0),
  ("island" (4; #5), 8.0)],
 'SUM_SCORE': [("flipper_length_mm" (1; #4), 26375.887034731917),
  ("bill_length_mm" (1; #2), 23387.499002089724),
  ("bill_depth_mm" (1; #1), 9981.270100556314),
  ("island" (4; #5), 8813.63283989951),
  ("body_mass_g" (1; #3), 3264.0505972094834),
  ("sex" (4; #6), 101.26985213905573),
  ("year" (1; #7), 29.719129994511604)],
 'NUM_NODES': [("bill_length_mm" (1; #2), 686.0),
  ("bill_depth_mm" (1; #1), 411.0),
  ("flipper_length_mm" (1; #4), 357.0),
  ("body_mass_g" (1; #3), 291.0),
  ("island" (4; #5), 238.0),
  ("sex" (4; #6), 21.0),
  ("year" (1; #7), 14.0)],
 'INV_MEAN_MIN_DEPTH': [("flipper_length_mm" (1; #4), 0.4597295587756743),
  ("bill_length_mm" (1; #2), 0.42835670851367663),
  ("bill_depth_mm" (1; #1), 0.31803398397339727),
  ("island" (4; #5), 0.30225257091871593),
  ("body_mass_g" (1; #3), 0.27035044480247944),
  ("sex" (4; #6), 0.23919881592559233),
  ("year" (1; #7), 0.23854067913913543)]}

概述和检查器的内容取决于学习算法(在本例中为tfdf.keras.RandomForestModel)及其超参数(例如,compute_oob_variable_importances=True将触发计算随机森林学习器的Out-of-bag变量重要性)。

9. 模型自我评估

在训练TFDF模型时,即使没有提供验证数据集给fit()方法,模型也可以进行自我评估。具体的逻辑取决于模型。例如,随机森林将使用袋外评估,而梯度提升树将使用内部的训练验证。

**注意:**虽然这个评估是在训练期间计算的,但它并不是在训练数据集上计算的,因此可能是低质量的评估。

可以通过检查器的evaluation()方法获取模型的自我评估结果。

# 创建一个名为model_1的模型,并调用make_inspector()方法创建一个检查器对象,调用检查器对象的evaluation()方法,对模型进行评估
model_1.make_inspector().evaluation()
Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09456771872859744, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)

10. 绘制训练日志

训练日志显示了模型的质量(例如在袋外或验证数据集上评估的准确率)与模型中树的数量之间的关系。这些日志有助于研究模型大小和模型质量之间的平衡。

日志可以通过多种方式获取:

  1. 如果fit()被包装在with sys_pipes():中,则在训练期间显示(参见上面的示例)。
  2. 在模型摘要的末尾,即model.summary()(参见上面的示例)。
  3. 通过编程方式,使用模型检查器,即model.make_inspector().training_logs()
  4. 使用TensorBoard

让我们尝试选项2和3:

# 创建一个名为model_1的模型,并调用make_inspector()方法创建一个检查器
# 调用training_logs()方法获取模型的训练日志
# 设置单元格高度为150






[TrainLog(num_trees=1, evaluation=Evaluation(num_examples=88, accuracy=0.9431818181818182, loss=2.04793474890969, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=11, evaluation=Evaluation(num_examples=237, accuracy=0.9493670886075949, loss=0.3836141189693902, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=21, evaluation=Evaluation(num_examples=239, accuracy=0.9539748953974896, loss=0.38613533478027606, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=35, evaluation=Evaluation(num_examples=239, accuracy=0.9539748953974896, loss=0.24959545451602178, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=50, evaluation=Evaluation(num_examples=239, accuracy=0.9497907949790795, loss=0.2490036289936329, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=62, evaluation=Evaluation(num_examples=239, accuracy=0.9497907949790795, loss=0.24737085921058594, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=73, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.24610795769606675, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=83, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.24077113418524235, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=96, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.0994904973703947, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=106, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.1000949550326524, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=117, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.09590058801033258, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=127, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09621651767593298, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=138, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09276632447123029, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=148, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.09319210400859432, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=158, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.09211699942041142, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=170, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.09264358151002658, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=180, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.09272387361925516, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=190, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.0942832787314656, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=200, evaluation=Evaluation(num_examples=239, accuracy=0.9665271966527197, loss=0.09417655552390604, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=210, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.09387483396353083, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=220, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.0950461220674248, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=231, evaluation=Evaluation(num_examples=239, accuracy=0.9539748953974896, loss=0.09515991921548314, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=241, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.09485313651701396, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=251, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.09423767419134473, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=262, evaluation=Evaluation(num_examples=239, accuracy=0.9623430962343096, loss=0.09402294695439697, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=272, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09344756691307453, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=282, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.0934649518804009, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=292, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09430678192307884, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)),
 TrainLog(num_trees=300, evaluation=Evaluation(num_examples=239, accuracy=0.9581589958158996, loss=0.09456771872859744, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None))]

让我们来绘制它:

# 导入matplotlib.pyplot模块,用于绘图
import matplotlib.pyplot as plt

# 获取模型的训练日志
logs = model_1.make_inspector().training_logs()

# 创建一个图形窗口,设置图形窗口的大小为12x4
plt.figure(figsize=(12, 4))

# 在图形窗口中创建一个子图,子图的位置为1行2列中的第1个位置
plt.subplot(1, 2, 1)
# 绘制折线图,x轴为每个日志中的树的数量,y轴为每个日志中的准确率
plt.plot([log.num_trees for log in logs], [log.evaluation.accuracy for log in logs])
# 设置x轴的标签为"Number of trees"
plt.xlabel("Number of trees")
# 设置y轴的标签为"Accuracy (out-of-bag)"
plt.ylabel("Accuracy (out-of-bag)")

# 在图形窗口中创建一个子图,子图的位置为1行2列中的第2个位置
plt.subplot(1, 2, 2)
# 绘制折线图,x轴为每个日志中的树的数量,y轴为每个日志中的损失值
plt.plot([log.num_trees for log in logs], [log.evaluation.loss for log in logs])
# 设置x轴的标签为"Number of trees"
plt.xlabel("Number of trees")
# 设置y轴的标签为"Logloss (out-of-bag)"
plt.ylabel("Logloss (out-of-bag)")

# 显示图形窗口中的图形
plt.show()

工具系列:TensorFlow Decision Forests_(1)构建、训练和评估模型_第1张图片

这个数据集很小。你可以看到模型几乎立即收敛。

让我们使用TensorBoard:

# 加载TensorBoard的notebook扩展
%load_ext tensorboard

# 加载Google内部版本的TensorBoard的notebook扩展
# %load_ext google3.learning.brain.tensorboard.notebook.extension
# 清除已有的结果(如果有的话)
!rm -fr "/tmp/tensorboard_logs"
# Export the meta-data to tensorboard.
# 将元数据导出到tensorboard。
model_1.make_inspector().export_to_tensorboard("/tmp/tensorboard_logs")
# 导入的包和模块
# 无需导入任何包和模块

# 设置tensorboard的日志目录为"/tmp/tensorboard_logs"
# %tensorboard是一个魔术命令,用于启动一个tensorboard实例
# --logdir参数指定了tensorboard的日志目录
%tensorboard --logdir "/tmp/tensorboard_logs"

11. 使用不同的学习算法重新训练模型

学习算法由模型类定义。例如,tfdf.keras.RandomForestModel()训练随机森林,而tfdf.keras.GradientBoostedTreesModel()训练梯度提升决策树。

可以通过调用tfdf.keras.get_all_models()或在学习器列表中列出学习算法。

# 获取所有可用的模型
tfdf.keras.get_all_models()
[tensorflow_decision_forests.keras.RandomForestModel,
 tensorflow_decision_forests.keras.GradientBoostedTreesModel,
 tensorflow_decision_forests.keras.CartModel,
 tensorflow_decision_forests.keras.DistributedGradientBoostedTreesModel]

学习算法的描述以及它们的超参数也可以在API参考和内置帮助中找到:

# 在任何地方都可以使用help函数来获取帮助信息。
help(tfdf.keras.RandomForestModel)

# 在ipython或notebook中,可以使用?来获取帮助信息,通常会在一个单独的面板中打开。
tfdf.keras.RandomForestModel?
Help on class RandomForestModel in module tensorflow_decision_forests.keras:

class RandomForestModel(tensorflow_decision_forests.keras.wrappers.RandomForestModel)
 |  RandomForestModel(*args, **kwargs)
 |  
 |  Method resolution order:
 |      RandomForestModel
 |      tensorflow_decision_forests.keras.wrappers.RandomForestModel
 |      tensorflow_decision_forests.keras.core.CoreModel
 |      tensorflow_decision_forests.keras.core_inference.InferenceCoreModel
 |      keras.src.engine.training.Model
 |      keras.src.engine.base_layer.Layer
 |      tensorflow.python.module.module.Module
 |      tensorflow.python.trackable.autotrackable.AutoTrackable
 |      tensorflow.python.trackable.base.Trackable
 |      keras.src.utils.version_utils.LayerVersionSelector
 |      keras.src.utils.version_utils.ModelVersionSelector
 |      builtins.object
 |  
 |  Methods inherited from tensorflow_decision_forests.keras.wrappers.RandomForestModel:
 |  
 |  __init__(self, task: Optional[ForwardRef('abstract_model_pb2.Task')] = 1, features: Optional[List[tensorflow_decision_forests.keras.core.FeatureUsage]] = None, exclude_non_specified_features: Optional[bool] = False, preprocessing: Optional[ForwardRef('tf.keras.models.Functional')] = None, postprocessing: Optional[ForwardRef('tf.keras.models.Functional')] = None, ranking_group: Optional[str] = None, uplift_treatment: Optional[str] = None, temp_directory: Optional[str] = None, verbose: int = 1, hyperparameter_template: Optional[str] = None, advanced_arguments: Optional[tensorflow_decision_forests.keras.core_inference.AdvancedArguments] = None, num_threads: Optional[int] = None, name: Optional[str] = None, max_vocab_count: Optional[int] = 2000, try_resume_training: Optional[bool] = True, check_dataset: Optional[bool] = True, tuner: Optional[tensorflow_decision_forests.component.tuner.tuner.Tuner] = None, discretize_numerical_features: bool = False, num_discretized_numerical_bins: int = 255, multitask: Optional[List[tensorflow_decision_forests.keras.core_inference.MultiTaskItem]] = None, adapt_bootstrap_size_ratio_for_maximum_training_duration: Optional[bool] = False, allow_na_conditions: Optional[bool] = False, bootstrap_size_ratio: Optional[float] = 1.0, bootstrap_training_dataset: Optional[bool] = True, categorical_algorithm: Optional[str] = 'CART', categorical_set_split_greedy_sampling: Optional[float] = 0.1, categorical_set_split_max_num_items: Optional[int] = -1, categorical_set_split_min_item_frequency: Optional[int] = 1, compute_oob_performances: Optional[bool] = True, compute_oob_variable_importances: Optional[bool] = False, growing_strategy: Optional[str] = 'LOCAL', honest: Optional[bool] = False, honest_fixed_separation: Optional[bool] = False, honest_ratio_leaf_examples: Optional[float] = 0.5, in_split_min_examples_check: Optional[bool] = True, keep_non_leaf_label_distribution: Optional[bool] = True, max_depth: Optional[int] = 16, max_num_nodes: Optional[int] = None, maximum_model_size_in_memory_in_bytes: Optional[float] = -1.0, maximum_training_duration_seconds: Optional[float] = -1.0, min_examples: Optional[int] = 5, missing_value_policy: Optional[str] = 'GLOBAL_IMPUTATION', num_candidate_attributes: Optional[int] = 0, num_candidate_attributes_ratio: Optional[float] = -1.0, num_oob_variable_importances_permutations: Optional[int] = 1, num_trees: Optional[int] = 300, pure_serving_model: Optional[bool] = False, random_seed: Optional[int] = 123456, sampling_with_replacement: Optional[bool] = True, sorting_strategy: Optional[str] = 'PRESORT', sparse_oblique_normalization: Optional[str] = None, sparse_oblique_num_projections_exponent: Optional[float] = None, sparse_oblique_projection_density_factor: Optional[float] = None, sparse_oblique_weights: Optional[str] = None, split_axis: Optional[str] = 'AXIS_ALIGNED', uplift_min_examples_in_treatment: Optional[int] = 5, uplift_split_score: Optional[str] = 'KULLBACK_LEIBLER', winner_take_all: Optional[bool] = True, explicit_args: Optional[Set[str]] = None)
 |  
 |  ----------------------------------------------------------------------
 |  Static methods inherited from tensorflow_decision_forests.keras.wrappers.RandomForestModel:
 |  
 |  capabilities() -> yggdrasil_decision_forests.learner.abstract_learner_pb2.LearnerCapabilities
 |      Lists the capabilities of the learning algorithm.
 |  
 |  predefined_hyperparameters() -> List[tensorflow_decision_forests.keras.core.HyperParameterTemplate]
 |      Returns a better than default set of hyper-parameters.
 |      
 |      They can be used directly with the `hyperparameter_template` argument of the
 |      model constructor.
 |      
 |      These hyper-parameters outperform the default hyper-parameters (either
 |      generally or in specific scenarios). Like default hyper-parameters, existing
 |      pre-defined hyper-parameters cannot change.
 |  
 |  ----------------------------------------------------------------------
 |  Methods inherited from tensorflow_decision_forests.keras.core.CoreModel:
 |  
 |  collect_data_step(self, data, is_training_example)
 |      Collect examples e.g. training or validation.
 |  
 |  fit(self, x=None, y=None, callbacks=None, verbose: Optional[Any] = None, validation_steps: Optional[int] = None, validation_data: Optional[Any] = None, sample_weight: Optional[Any] = None, steps_per_epoch: Optional[Any] = None, class_weight: Optional[Any] = None, **kwargs) -> keras.src.callbacks.History
 |      Trains the model.
 |      
 |      Local training
 |      ==============
 |      
 |      It is recommended to use a Pandas Dataframe dataset and to convert it to
 |      a TensorFlow dataset with `pd_dataframe_to_tf_dataset()`:
 |        ```python
 |        pd_dataset = pandas.Dataframe(...)
 |        tf_dataset = pd_dataframe_to_tf_dataset(dataset, label="my_label")
 |        model.fit(pd_dataset)
 |        ```
 |      
 |      The following dataset formats are supported:
 |      
 |        1. "x" is a `tf.data.Dataset` containing a tuple "(features, labels)".
 |           "features" can be a dictionary a tensor, a list of tensors or a
 |           dictionary of tensors (recommended). "labels" is a tensor.
 |      
 |        2. "x" is a tensor, list of tensors or dictionary of tensors containing
 |           the input features. "y" is a tensor.
 |      
 |        3. "x" is a numpy-array, list of numpy-arrays or dictionary of
 |           numpy-arrays containing the input features. "y" is a numpy-array.
 |      
 |      IMPORTANT: This model trains on the entire dataset at once. This has the
 |      following consequences:
 |      
 |        1. The dataset need to be read exactly once. If you use a TensorFlow
 |           dataset, make sure NOT to add a "repeat" operation.
 |        2. The algorithm does not benefit from shuffling the dataset. If you use a
 |           TensorFlow dataset, make sure NOT to add a "shuffle" operation.
 |        3. The dataset needs to be batched (i.e. with a "batch" operation).
 |           However, the number of elements per batch has not impact on the model.
 |           Generally, it is recommended to use batches as large as possible as its
 |           speeds-up reading the dataset in TensorFlow.
 |      
 |      Input features do not need to be normalized (e.g. dividing numerical values
 |      by the variance) or indexed (e.g. replacing categorical string values by
 |      an integer). Additionally, missing values can be consumed natively.
 |      
 |      Distributed training
 |      ====================
 |      
 |      Some of the learning algorithms will support distributed training with the
 |      ParameterServerStrategy.
 |      
 |      In this case, the dataset is read asynchronously in between the workers. The
 |      distribution of the training depends on the learning algorithm.
 |      
 |      Like for non-distributed training, the dataset should be read exactly once.
 |      The simplest solution is to divide the dataset into different files (i.e.
 |      shards) and have each of the worker read a non overlapping subset of shards.
 |      
 |      IMPORTANT: The training dataset should not be infinite i.e. the training
 |      dataset should not contain any repeat operation.
 |      
 |      Currently (to be changed), the validation dataset (if provided) is simply
 |      feed to the `model.evaluate()` method. Therefore, it should satisfy Keras'
 |      evaluate API. Notably, for distributed training, the validation dataset
 |      should be infinite (i.e. have a repeat operation).
 |      
 |      See https://www.tensorflow.org/decision_forests/distributed_training for
 |      more details and examples.
 |      
 |      Here is a single example of distributed training using PSS for both dataset
 |      reading and training distribution.
 |      
 |        ```python
 |        def dataset_fn(context, paths, training=True):
 |          ds_path = tf.data.Dataset.from_tensor_slices(paths)
 |      
 |      
 |          if context is not None:
 |            # Train on at least 2 workers.
 |            current_worker = tfdf.keras.get_worker_idx_and_num_workers(context)
 |            assert current_worker.num_workers > 2
 |      
 |            # Split the dataset's examples among the workers.
 |            ds_path = ds_path.shard(
 |                num_shards=current_worker.num_workers,
 |                index=current_worker.worker_idx)
 |      
 |          def read_csv_file(path):
 |            numerical = tf.constant([math.nan], dtype=tf.float32)
 |            categorical_string = tf.constant([""], dtype=tf.string)
 |            csv_columns = [
 |                numerical,  # age
 |                categorical_string,  # workclass
 |                numerical,  # fnlwgt
 |                ...
 |            ]
 |            column_names = [
 |              "age", "workclass", "fnlwgt", ...
 |            ]
 |            label_name = "label"
 |            return tf.data.experimental.CsvDataset(path, csv_columns, header=True)
 |      
 |          ds_columns = ds_path.interleave(read_csv_file)
 |      
 |          def map_features(*columns):
 |            assert len(column_names) == len(columns)
 |            features = {column_names[i]: col for i, col in enumerate(columns)}
 |            label = label_table.lookup(features.pop(label_name))
 |            return features, label
 |      
 |          ds_dataset = ds_columns.map(map_features)
 |          if not training:
 |            dataset = dataset.repeat(None)
 |          ds_dataset = ds_dataset.batch(batch_size)
 |          return ds_dataset
 |      
 |        strategy = tf.distribute.experimental.ParameterServerStrategy(...)
 |        sharded_train_paths = [list of dataset files]
 |        with strategy.scope():
 |          model = DistributedGradientBoostedTreesModel()
 |          train_dataset = strategy.distribute_datasets_from_function(
 |            lambda context: dataset_fn(context, sharded_train_paths))
 |      
 |          test_dataset = strategy.distribute_datasets_from_function(
 |            lambda context: dataset_fn(context, sharded_test_paths))
 |      
 |        model.fit(sharded_train_paths)
 |        evaluation = model.evaluate(test_dataset, steps=num_test_examples //
 |          batch_size)
 |        ```
 |      
 |      Args:
 |        x: Training dataset (See details above for the supported formats).
 |        y: Label of the training dataset. Only used if "x" does not contains the
 |          labels.
 |        callbacks: Callbacks triggered during the training. The training runs in a
 |          single epoch, itself run in a single step. Therefore, callback logic can
 |          be called equivalently before/after the fit function.
 |        verbose: Verbosity mode. 0 = silent, 1 = small details, 2 = full details.
 |        validation_steps: Number of steps in the evaluation dataset when
 |          evaluating the trained model with `model.evaluate()`. If not specified,
 |          evaluates the model on the entire dataset (generally recommended; not
 |          yet supported for distributed datasets).
 |        validation_data: Validation dataset. If specified, the learner might use
 |          this dataset to help training e.g. early stopping.
 |        sample_weight: Training weights. Note: training weights can also be
 |          provided as the third output in a `tf.data.Dataset` e.g. (features,
 |          label, weights).
 |        steps_per_epoch: [Parameter will be removed] Number of training batch to
 |          load before training the model. Currently, only supported for
 |          distributed training.
 |        class_weight: For binary classification only. Mapping class indices
 |          (integers) to a weight (float) value. Only available for non-Distributed
 |          training. For maximum compatibility, feed example weights through the
 |          tf.data.Dataset or using the `weight` argument of
 |          `pd_dataframe_to_tf_dataset`.
 |        **kwargs: Extra arguments passed to the core keras model's fit. Note that
 |          not all keras' model fit arguments are supported.
 |      
 |      Returns:
 |        A `History` object. Its `History.history` attribute is not yet
 |        implemented for decision forests algorithms, and will return empty.
 |        All other fields are filled as usual for `Keras.Mode.fit()`.
 |  
 |  fit_on_dataset_path(self, train_path: str, label_key: Optional[str] = None, weight_key: Optional[str] = None, valid_path: Optional[str] = None, dataset_format: Optional[str] = 'csv', max_num_scanned_rows_to_accumulate_statistics: Optional[int] = 100000, try_resume_training: Optional[bool] = True, input_model_signature_fn: Optional[Callable[[tensorflow_decision_forests.component.inspector.inspector.AbstractInspector], Any]] = , num_io_threads: int = 10)
 |      Trains the model on a dataset stored on disk.
 |      
 |      This solution is generally more efficient and easier than loading the
 |      dataset with a `tf.Dataset` both for local and distributed training.
 |      
 |      Usage example:
 |      
 |        # Local training
 |        ```python
 |        model = keras.GradientBoostedTreesModel()
 |        model.fit_on_dataset_path(
 |          train_path="/path/to/dataset.csv",
 |          label_key="label",
 |          dataset_format="csv")
 |        model.save("/model/path")
 |        ```
 |      
 |        # Distributed training
 |        ```python
 |        with tf.distribute.experimental.ParameterServerStrategy(...).scope():
 |          model = model = keras.DistributedGradientBoostedTreesModel()
 |        model.fit_on_dataset_path(
 |          train_path="/path/to/dataset@10",
 |          label_key="label",
 |          dataset_format="tfrecord+tfe")
 |        model.save("/model/path")
 |        ```
 |      
 |      Args:
 |        train_path: Path to the training dataset. Supports comma separated files,
 |          shard and glob notation.
 |        label_key: Name of the label column.
 |        weight_key: Name of the weighing column.
 |        valid_path: Path to the validation dataset. If not provided, or if the
 |          learning algorithm does not supports/needs a validation dataset,
 |          `valid_path` is ignored.
 |        dataset_format: Format of the dataset. Should be one of the registered
 |          dataset format (see [User
 |          Manual](https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/user_manual.md#dataset-path-and-format)
 |          for more details). The format "csv" is always available but it is
 |          generally only suited for small datasets.
 |        max_num_scanned_rows_to_accumulate_statistics: Maximum number of examples
 |          to scan to determine the statistics of the features (i.e. the dataspec,
 |          e.g. mean value, dictionaries). (Currently) the "first" examples of the
 |          dataset are scanned (e.g. the first examples of the dataset is a single
 |          file). Therefore, it is important that the sampled dataset is relatively
 |          uniformly sampled, notably the scanned examples should contains all the
 |          possible categorical values (otherwise the not seen value will be
 |          treated as out-of-vocabulary). If set to None, the entire dataset is
 |          scanned. This parameter has no effect if the dataset is stored in a
 |          format that already contains those values.
 |        try_resume_training: If true, tries to resume training from the model
 |          checkpoint stored in the `temp_directory` directory. If `temp_directory`
 |          does not contain any model checkpoint, start the training from the
 |          start. Works in the following three situations: (1) The training was
 |          interrupted by the user (e.g. ctrl+c). (2) the training job was
 |          interrupted (e.g. rescheduling), ond (3) the hyper-parameter of the
 |          model were changed such that an initially completed training is now
 |          incomplete (e.g. increasing the number of trees).
 |        input_model_signature_fn: A lambda that returns the
 |          (Dense,Sparse,Ragged)TensorSpec (or structure of TensorSpec e.g.
 |          dictionary, list) corresponding to input signature of the model. If not
 |          specified, the input model signature is created by
 |          `build_default_input_model_signature`. For example, specify
 |          `input_model_signature_fn` if an numerical input feature (which is
 |          consumed as DenseTensorSpec(float32) by default) will be feed
 |          differently (e.g. RaggedTensor(int64)).
 |        num_io_threads: Number of threads to use for IO operations e.g. reading a
 |          dataset from disk. Increasing this value can speed-up IO operations when
 |          IO operations are either latency or cpu bounded.
 |      
 |      Returns:
 |        A `History` object. Its `History.history` attribute is not yet
 |        implemented for decision forests algorithms, and will return empty.
 |        All other fields are filled as usual for `Keras.Mode.fit()`.
 |  
 |  load_weights(self, *args, **kwargs)
 |      No-op for TensorFlow Decision Forests models.
 |      
 |      `load_weights` is not supported by TensorFlow Decision Forests models.
 |      To save and restore a model, use the SavedModel API i.e.
 |      `model.save(...)` and `tf.keras.models.load_model(...)`. To resume the
 |      training of an existing model, create the model with
 |      `try_resume_training=True` (default value) and with a similar
 |      `temp_directory` argument. See documentation of `try_resume_training`
 |      for more details.
 |      
 |      Args:
 |        *args: Passed through to base `keras.Model` implemenation.
 |        **kwargs: Passed through to base `keras.Model` implemenation.
 |  
 |  save(self, filepath: str, overwrite: Optional[bool] = True, **kwargs)
 |      Saves the model as a TensorFlow SavedModel.
 |      
 |      The exported SavedModel contains a standalone Yggdrasil Decision Forests
 |      model in the "assets" sub-directory. The Yggdrasil model can be used
 |      directly using the Yggdrasil API. However, this model does not contain the
 |      "preprocessing" layer (if any).
 |      
 |      Args:
 |        filepath: Path to the output model.
 |        overwrite: If true, override an already existing model. If false, raise an
 |          error if a model already exist.
 |        **kwargs: Arguments passed to the core keras model's save.
 |  
 |  support_distributed_training(self)
 |  
 |  train_on_batch(self, *args, **kwargs)
 |      No supported for Tensorflow Decision Forests models.
 |      
 |      Decision forests are not trained in batches the same way neural networks
 |      are. To avoid confusion, train_on_batch is disabled.
 |      
 |      Args:
 |        *args: Ignored
 |        **kwargs: Ignored.
 |  
 |  train_step(self, data)
 |      Collects training examples.
 |  
 |  valid_step(self, data)
 |      Collects validation examples.
 |  
 |  ----------------------------------------------------------------------
 |  Readonly properties inherited from tensorflow_decision_forests.keras.core.CoreModel:
 |  
 |  exclude_non_specified_features
 |      If true, only use the features specified in "features".
 |  
 |  learner
 |      Name of the learning algorithm used to train the model.
 |  
 |  learner_params
 |      Gets the dictionary of hyper-parameters passed in the model constructor.
 |      
 |      Changing this dictionary will impact the training.
 |  
 |  num_threads
 |      Number of threads used to train the model.
 |  
 |  num_training_examples
 |      Number of training examples.
 |  
 |  num_validation_examples
 |      Number of validation examples.
 |  
 |  training_model_id
 |      Identifier of the model.
 |  
 |  ----------------------------------------------------------------------
 |  Methods inherited from tensorflow_decision_forests.keras.core_inference.InferenceCoreModel:
 |  
 |  call(self, inputs, training=False)
 |      Inference of the model.
 |      
 |      This method is used for prediction and evaluation of a trained model.
 |      
 |      Args:
 |        inputs: Input tensors.
 |        training: Is the model being trained. Always False.
 |      
 |      Returns:
 |        Model predictions.
 |  
 |  call_get_leaves(self, inputs)
 |      Computes the index of the active leaf in each tree.
 |      
 |      The active leaf is the leave that that receive the example during inference.
 |      
 |      The returned value "leaves[i,j]" is the index of the active leave for the
 |      i-th example and the j-th tree. Leaves are indexed by depth first
 |      exploration with the negative child visited before the positive one
 |      (similarly as "iterate_on_nodes()" iteration). Leaf indices are also
 |      available with LeafNode.leaf_idx.
 |      
 |      Args:
 |        inputs: Input tensors. Same signature as the model's "call(inputs)".
 |      
 |      Returns:
 |        Index of the active leaf for each tree in the model.
 |  
 |  compile(self, metrics=None, weighted_metrics=None, **kwargs)
 |      Configure the model for training.
 |      
 |      Unlike for most Keras model, calling "compile" is optional before calling
 |      "fit".
 |      
 |      Args:
 |        metrics: List of metrics to be evaluated by the model during training and
 |          testing.
 |        weighted_metrics: List of metrics to be evaluated and weighted by
 |          `sample_weight` or `class_weight` during training and testing.
 |        **kwargs: Other arguments passed to compile.
 |      
 |      Raises:
 |        ValueError: Invalid arguments.
 |  
 |  make_inspector(self, index: int = 0) -> tensorflow_decision_forests.component.inspector.inspector.AbstractInspector
 |      Creates an inspector to access the internal model structure.
 |      
 |      Usage example:
 |      
 |      ```python
 |      inspector = model.make_inspector()
 |      print(inspector.num_trees())
 |      print(inspector.variable_importances())
 |      ```
 |      
 |      Args:
 |        index: Index of the sub-model. Only used for multitask models.
 |      
 |      Returns:
 |        A model inspector.
 |  
 |  make_predict_function(self)
 |      Prediction of the model (!= evaluation).
 |  
 |  make_test_function(self)
 |      Predictions for evaluation.
 |  
 |  predict_get_leaves(self, x)
 |      Gets the index of the active leaf of each tree.
 |      
 |      The active leaf is the leave that that receive the example during inference.
 |      
 |      The returned value "leaves[i,j]" is the index of the active leave for the
 |      i-th example and the j-th tree. Leaves are indexed by depth first
 |      exploration with the negative child visited before the positive one
 |      (similarly as "iterate_on_nodes()" iteration). Leaf indices are also
 |      available with LeafNode.leaf_idx.
 |      
 |      Args:
 |        x: Input samples as a tf.data.Dataset.
 |      
 |      Returns:
 |        Index of the active leaf for each tree in the model.
 |  
 |  ranking_group(self) -> Optional[str]
 |  
 |  summary(self, line_length=None, positions=None, print_fn=None)
 |      Shows information about the model.
 |  
 |  uplift_treatment(self) -> Optional[str]
 |  
 |  yggdrasil_model_path_tensor(self, multitask_model_index: int = 0) -> Optional[tensorflow.python.framework.ops.Tensor]
 |      Gets the path to yggdrasil model, if available.
 |      
 |      The effective path can be obtained with:
 |      
 |      ```python
 |      yggdrasil_model_path_tensor().numpy().decode("utf-8")
 |      ```
 |      
 |      Args:
 |        multitask_model_index: Index of the sub-model. Only used for multitask
 |          models.
 |      
 |      Returns:
 |        Path to the Yggdrasil model.
 |  
 |  yggdrasil_model_prefix(self, index: int = 0) -> str
 |      Gets the prefix of the internal yggdrasil model.
 |  
 |  ----------------------------------------------------------------------
 |  Readonly properties inherited from tensorflow_decision_forests.keras.core_inference.InferenceCoreModel:
 |  
 |  multitask
 |      Tasks to solve.
 |  
 |  task
 |      Task to solve (e.g. CLASSIFICATION, REGRESSION, RANKING).
 |  
 |  ----------------------------------------------------------------------
 |  Methods inherited from keras.src.engine.training.Model:
 |  
 |  __call__(self, *args, **kwargs)
 |  
 |  __copy__(self)
 |  
 |  __deepcopy__(self, memo)
 |  
 |  __reduce__(self)
 |      Helper for pickle.
 |  
 |  __setattr__(self, name, value)
 |      Support self.foo = trackable syntax.
 |  
 |  build(self, input_shape)
 |      Builds the model based on input shapes received.
 |      
 |      This is to be used for subclassed models, which do not know at
 |      instantiation time what their inputs look like.
 |      
 |      This method only exists for users who want to call `model.build()` in a
 |      standalone way (as a substitute for calling the model on real data to
 |      build it). It will never be called by the framework (and thus it will
 |      never throw unexpected errors in an unrelated workflow).
 |      
 |      Args:
 |       input_shape: Single tuple, `TensorShape` instance, or list/dict of
 |         shapes, where shapes are tuples, integers, or `TensorShape`
 |         instances.
 |      
 |      Raises:
 |        ValueError:
 |          1. In case of invalid user-provided data (not of type tuple,
 |             list, `TensorShape`, or dict).
 |          2. If the model requires call arguments that are agnostic
 |             to the input shapes (positional or keyword arg in call
 |             signature).
 |          3. If not all layers were properly built.
 |          4. If float type inputs are not supported within the layers.
 |      
 |        In each of these cases, the user should build their model by calling
 |        it on real tensor data.
 |  
 |  compile_from_config(self, config)
 |      Compiles the model with the information given in config.
 |      
 |      This method uses the information in the config (optimizer, loss,
 |      metrics, etc.) to compile the model.
 |      
 |      Args:
 |          config: Dict containing information for compiling the model.
 |  
 |  compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None)
 |      Compute the total loss, validate it, and return it.
 |      
 |      Subclasses can optionally override this method to provide custom loss
 |      computation logic.
 |      
 |      Example:
 |      ```python
 |      class MyModel(tf.keras.Model):
 |      
 |        def __init__(self, *args, **kwargs):
 |          super(MyModel, self).__init__(*args, **kwargs)
 |          self.loss_tracker = tf.keras.metrics.Mean(name='loss')
 |      
 |        def compute_loss(self, x, y, y_pred, sample_weight):
 |          loss = tf.reduce_mean(tf.math.squared_difference(y_pred, y))
 |          loss += tf.add_n(self.losses)
 |          self.loss_tracker.update_state(loss)
 |          return loss
 |      
 |        def reset_metrics(self):
 |          self.loss_tracker.reset_states()
 |      
 |        @property
 |        def metrics(self):
 |          return [self.loss_tracker]
 |      
 |      tensors = tf.random.uniform((10, 10)), tf.random.uniform((10,))
 |      dataset = tf.data.Dataset.from_tensor_slices(tensors).repeat().batch(1)
 |      
 |      inputs = tf.keras.layers.Input(shape=(10,), name='my_input')
 |      outputs = tf.keras.layers.Dense(10)(inputs)
 |      model = MyModel(inputs, outputs)
 |      model.add_loss(tf.reduce_sum(outputs))
 |      
 |      optimizer = tf.keras.optimizers.SGD()
 |      model.compile(optimizer, loss='mse', steps_per_execution=10)
 |      model.fit(dataset, epochs=2, steps_per_epoch=10)
 |      print('My custom loss: ', model.loss_tracker.result().numpy())
 |      ```
 |      
 |      Args:
 |        x: Input data.
 |        y: Target data.
 |        y_pred: Predictions returned by the model (output of `model(x)`)
 |        sample_weight: Sample weights for weighting the loss function.
 |      
 |      Returns:
 |        The total loss as a `tf.Tensor`, or `None` if no loss results (which
 |        is the case when called by `Model.test_step`).
 |  
 |  compute_metrics(self, x, y, y_pred, sample_weight)
 |      Update metric states and collect all metrics to be returned.
 |      
 |      Subclasses can optionally override this method to provide custom metric
 |      updating and collection logic.
 |      
 |      Example:
 |      ```python
 |      class MyModel(tf.keras.Sequential):
 |      
 |        def compute_metrics(self, x, y, y_pred, sample_weight):
 |      
 |          # This super call updates `self.compiled_metrics` and returns
 |          # results for all metrics listed in `self.metrics`.
 |          metric_results = super(MyModel, self).compute_metrics(
 |              x, y, y_pred, sample_weight)
 |      
 |          # Note that `self.custom_metric` is not listed in `self.metrics`.
 |          self.custom_metric.update_state(x, y, y_pred, sample_weight)
 |          metric_results['custom_metric_name'] = self.custom_metric.result()
 |          return metric_results
 |      ```
 |      
 |      Args:
 |        x: Input data.
 |        y: Target data.
 |        y_pred: Predictions returned by the model (output of `model.call(x)`)
 |        sample_weight: Sample weights for weighting the loss function.
 |      
 |      Returns:
 |        A `dict` containing values that will be passed to
 |        `tf.keras.callbacks.CallbackList.on_train_batch_end()`. Typically, the
 |        values of the metrics listed in `self.metrics` are returned. Example:
 |        `{'loss': 0.2, 'accuracy': 0.7}`.
 |  
 |  evaluate(self, x=None, y=None, batch_size=None, verbose='auto', sample_weight=None, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False, return_dict=False, **kwargs)
 |      Returns the loss value & metrics values for the model in test mode.
 |      
 |      Computation is done in batches (see the `batch_size` arg.)
 |      
 |      Args:
 |          x: Input data. It could be:
 |            - A Numpy array (or array-like), or a list of arrays
 |              (in case the model has multiple inputs).
 |            - A TensorFlow tensor, or a list of tensors
 |              (in case the model has multiple inputs).
 |            - A dict mapping input names to the corresponding array/tensors,
 |              if the model has named inputs.
 |            - A `tf.data` dataset. Should return a tuple
 |              of either `(inputs, targets)` or
 |              `(inputs, targets, sample_weights)`.
 |            - A generator or `keras.utils.Sequence` returning `(inputs,
 |              targets)` or `(inputs, targets, sample_weights)`.
 |            A more detailed description of unpacking behavior for iterator
 |            types (Dataset, generator, Sequence) is given in the `Unpacking
 |            behavior for iterator-like inputs` section of `Model.fit`.
 |          y: Target data. Like the input data `x`, it could be either Numpy
 |            array(s) or TensorFlow tensor(s). It should be consistent with `x`
 |            (you cannot have Numpy inputs and tensor targets, or inversely).
 |            If `x` is a dataset, generator or `keras.utils.Sequence` instance,
 |            `y` should not be specified (since targets will be obtained from
 |            the iterator/dataset).
 |          batch_size: Integer or `None`. Number of samples per batch of
 |            computation. If unspecified, `batch_size` will default to 32. Do
 |            not specify the `batch_size` if your data is in the form of a
 |            dataset, generators, or `keras.utils.Sequence` instances (since
 |            they generate batches).
 |          verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
 |              0 = silent, 1 = progress bar, 2 = single line.
 |              `"auto"` becomes 1 for most cases, and to 2 when used with
 |              `ParameterServerStrategy`. Note that the progress bar is not
 |              particularly useful when logged to a file, so `verbose=2` is
 |              recommended when not running interactively (e.g. in a production
 |              environment). Defaults to 'auto'.
 |          sample_weight: Optional Numpy array of weights for the test samples,
 |            used for weighting the loss function. You can either pass a flat
 |            (1D) Numpy array with the same length as the input samples
 |              (1:1 mapping between weights and samples), or in the case of
 |                temporal data, you can pass a 2D array with shape `(samples,
 |                sequence_length)`, to apply a different weight to every
 |                timestep of every sample. This argument is not supported when
 |                `x` is a dataset, instead pass sample weights as the third
 |                element of `x`.
 |          steps: Integer or `None`. Total number of steps (batches of samples)
 |            before declaring the evaluation round finished. Ignored with the
 |            default value of `None`. If x is a `tf.data` dataset and `steps`
 |            is None, 'evaluate' will run until the dataset is exhausted. This
 |            argument is not supported with array inputs.
 |          callbacks: List of `keras.callbacks.Callback` instances. List of
 |            callbacks to apply during evaluation. See
 |            [callbacks](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks).
 |          max_queue_size: Integer. Used for generator or
 |            `keras.utils.Sequence` input only. Maximum size for the generator
 |            queue. If unspecified, `max_queue_size` will default to 10.
 |          workers: Integer. Used for generator or `keras.utils.Sequence` input
 |            only. Maximum number of processes to spin up when using
 |            process-based threading. If unspecified, `workers` will default to
 |            1.
 |          use_multiprocessing: Boolean. Used for generator or
 |            `keras.utils.Sequence` input only. If `True`, use process-based
 |            threading. If unspecified, `use_multiprocessing` will default to
 |            `False`. Note that because this implementation relies on
 |            multiprocessing, you should not pass non-picklable arguments to
 |            the generator as they can't be passed easily to children
 |            processes.
 |          return_dict: If `True`, loss and metric results are returned as a
 |            dict, with each key being the name of the metric. If `False`, they
 |            are returned as a list.
 |          **kwargs: Unused at this time.
 |      
 |      See the discussion of `Unpacking behavior for iterator-like inputs` for
 |      `Model.fit`.
 |      
 |      Returns:
 |          Scalar test loss (if the model has a single output and no metrics)
 |          or list of scalars (if the model has multiple outputs
 |          and/or metrics). The attribute `model.metrics_names` will give you
 |          the display labels for the scalar outputs.
 |      
 |      Raises:
 |          RuntimeError: If `model.evaluate` is wrapped in a `tf.function`.
 |  
 |  evaluate_generator(self, generator, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)
 |      Evaluates the model on a data generator.
 |      
 |      DEPRECATED:
 |        `Model.evaluate` now supports generators, so there is no longer any
 |        need to use this endpoint.
 |  
 |  export(self, filepath)
 |      Create a SavedModel artifact for inference (e.g. via TF-Serving).
 |      
 |      This method lets you export a model to a lightweight SavedModel artifact
 |      that contains the model's forward pass only (its `call()` method)
 |      and can be served via e.g. TF-Serving. The forward pass is registered
 |      under the name `serve()` (see example below).
 |      
 |      The original code of the model (including any custom layers you may
 |      have used) is *no longer* necessary to reload the artifact -- it is
 |      entirely standalone.
 |      
 |      Args:
 |          filepath: `str` or `pathlib.Path` object. Path where to save
 |              the artifact.
 |      
 |      Example:
 |      
 |      ```python
 |      # Create the artifact
 |      model.export("path/to/location")
 |      
 |      # Later, in a different process / environment...
 |      reloaded_artifact = tf.saved_model.load("path/to/location")
 |      predictions = reloaded_artifact.serve(input_data)
 |      ```
 |      
 |      If you would like to customize your serving endpoints, you can
 |      use the lower-level `keras.export.ExportArchive` class. The `export()`
 |      method relies on `ExportArchive` internally.
 |  
 |  fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, validation_freq=1, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
 |      Fits the model on data yielded batch-by-batch by a Python generator.
 |      
 |      DEPRECATED:
 |        `Model.fit` now supports generators, so there is no longer any need to
 |        use this endpoint.
 |  
 |  get_compile_config(self)
 |      Returns a serialized config with information for compiling the model.
 |      
 |      This method returns a config dictionary containing all the information
 |      (optimizer, loss, metrics, etc.) with which the model was compiled.
 |      
 |      Returns:
 |          A dict containing information for compiling the model.
 |  
 |  get_config(self)
 |      Returns the config of the `Model`.
 |      
 |      Config is a Python dictionary (serializable) containing the
 |      configuration of an object, which in this case is a `Model`. This allows
 |      the `Model` to be be reinstantiated later (without its trained weights)
 |      from this configuration.
 |      
 |      Note that `get_config()` does not guarantee to return a fresh copy of
 |      dict every time it is called. The callers should make a copy of the
 |      returned dict if they want to modify it.
 |      
 |      Developers of subclassed `Model` are advised to override this method,
 |      and continue to update the dict from `super(MyModel, self).get_config()`
 |      to provide the proper configuration of this `Model`. The default config
 |      will return config dict for init parameters if they are basic types.
 |      Raises `NotImplementedError` when in cases where a custom
 |      `get_config()` implementation is required for the subclassed model.
 |      
 |      Returns:
 |          Python dictionary containing the configuration of this `Model`.
 |  
 |  get_layer(self, name=None, index=None)
 |      Retrieves a layer based on either its name (unique) or index.
 |      
 |      If `name` and `index` are both provided, `index` will take precedence.
 |      Indices are based on order of horizontal graph traversal (bottom-up).
 |      
 |      Args:
 |          name: String, name of layer.
 |          index: Integer, index of layer.
 |      
 |      Returns:
 |          A layer instance.
 |  
 |  get_metrics_result(self)
 |      Returns the model's metrics values as a dict.
 |      
 |      If any of the metric result is a dict (containing multiple metrics),
 |      each of them gets added to the top level returned dict of this method.
 |      
 |      Returns:
 |        A `dict` containing values of the metrics listed in `self.metrics`.
 |        Example:
 |        `{'loss': 0.2, 'accuracy': 0.7}`.
 |  
 |  get_weight_paths(self)
 |      Retrieve all the variables and their paths for the model.
 |      
 |      The variable path (string) is a stable key to identify a `tf.Variable`
 |      instance owned by the model. It can be used to specify variable-specific
 |      configurations (e.g. DTensor, quantization) from a global view.
 |      
 |      This method returns a dict with weight object paths as keys
 |      and the corresponding `tf.Variable` instances as values.
 |      
 |      Note that if the model is a subclassed model and the weights haven't
 |      been initialized, an empty dict will be returned.
 |      
 |      Returns:
 |          A dict where keys are variable paths and values are `tf.Variable`
 |           instances.
 |      
 |      Example:
 |      
 |      ```python
 |      class SubclassModel(tf.keras.Model):
 |      
 |        def __init__(self, name=None):
 |          super().__init__(name=name)
 |          self.d1 = tf.keras.layers.Dense(10)
 |          self.d2 = tf.keras.layers.Dense(20)
 |      
 |        def call(self, inputs):
 |          x = self.d1(inputs)
 |          return self.d2(x)
 |      
 |      model = SubclassModel()
 |      model(tf.zeros((10, 10)))
 |      weight_paths = model.get_weight_paths()
 |      # weight_paths:
 |      # {
 |      #    'd1.kernel': model.d1.kernel,
 |      #    'd1.bias': model.d1.bias,
 |      #    'd2.kernel': model.d2.kernel,
 |      #    'd2.bias': model.d2.bias,
 |      # }
 |      
 |      # Functional model
 |      inputs = tf.keras.Input((10,), batch_size=10)
 |      x = tf.keras.layers.Dense(20, name='d1')(inputs)
 |      output = tf.keras.layers.Dense(30, name='d2')(x)
 |      model = tf.keras.Model(inputs, output)
 |      d1 = model.layers[1]
 |      d2 = model.layers[2]
 |      weight_paths = model.get_weight_paths()
 |      # weight_paths:
 |      # {
 |      #    'd1.kernel': d1.kernel,
 |      #    'd1.bias': d1.bias,
 |      #    'd2.kernel': d2.kernel,
 |      #    'd2.bias': d2.bias,
 |      # }
 |      ```
 |  
 |  get_weights(self)
 |      Retrieves the weights of the model.
 |      
 |      Returns:
 |          A flat list of Numpy arrays.
 |  
 |  make_train_function(self, force=False)
 |      Creates a function that executes one step of training.
 |      
 |      This method can be overridden to support custom training logic.
 |      This method is called by `Model.fit` and `Model.train_on_batch`.
 |      
 |      Typically, this method directly controls `tf.function` and
 |      `tf.distribute.Strategy` settings, and delegates the actual training
 |      logic to `Model.train_step`.
 |      
 |      This function is cached the first time `Model.fit` or
 |      `Model.train_on_batch` is called. The cache is cleared whenever
 |      `Model.compile` is called. You can skip the cache and generate again the
 |      function with `force=True`.
 |      
 |      Args:
 |        force: Whether to regenerate the train function and skip the cached
 |          function if available.
 |      
 |      Returns:
 |        Function. The function created by this method should accept a
 |        `tf.data.Iterator`, and return a `dict` containing values that will
 |        be passed to `tf.keras.Callbacks.on_train_batch_end`, such as
 |        `{'loss': 0.2, 'accuracy': 0.7}`.
 |  
 |  predict(self, x, batch_size=None, verbose='auto', steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False)
 |      Generates output predictions for the input samples.
 |      
 |      Computation is done in batches. This method is designed for batch
 |      processing of large numbers of inputs. It is not intended for use inside
 |      of loops that iterate over your data and process small numbers of inputs
 |      at a time.
 |      
 |      For small numbers of inputs that fit in one batch,
 |      directly use `__call__()` for faster execution, e.g.,
 |      `model(x)`, or `model(x, training=False)` if you have layers such as
 |      `tf.keras.layers.BatchNormalization` that behave differently during
 |      inference. You may pair the individual model call with a `tf.function`
 |      for additional performance inside your inner loop.
 |      If you need access to numpy array values instead of tensors after your
 |      model call, you can use `tensor.numpy()` to get the numpy array value of
 |      an eager tensor.
 |      
 |      Also, note the fact that test loss is not affected by
 |      regularization layers like noise and dropout.
 |      
 |      Note: See [this FAQ entry](
 |      https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call)
 |      for more details about the difference between `Model` methods
 |      `predict()` and `__call__()`.
 |      
 |      Args:
 |          x: Input samples. It could be:
 |            - A Numpy array (or array-like), or a list of arrays
 |              (in case the model has multiple inputs).
 |            - A TensorFlow tensor, or a list of tensors
 |              (in case the model has multiple inputs).
 |            - A `tf.data` dataset.
 |            - A generator or `keras.utils.Sequence` instance.
 |            A more detailed description of unpacking behavior for iterator
 |            types (Dataset, generator, Sequence) is given in the `Unpacking
 |            behavior for iterator-like inputs` section of `Model.fit`.
 |          batch_size: Integer or `None`.
 |              Number of samples per batch.
 |              If unspecified, `batch_size` will default to 32.
 |              Do not specify the `batch_size` if your data is in the
 |              form of dataset, generators, or `keras.utils.Sequence` instances
 |              (since they generate batches).
 |          verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
 |              0 = silent, 1 = progress bar, 2 = single line.
 |              `"auto"` becomes 1 for most cases, and to 2 when used with
 |              `ParameterServerStrategy`. Note that the progress bar is not
 |              particularly useful when logged to a file, so `verbose=2` is
 |              recommended when not running interactively (e.g. in a production
 |              environment). Defaults to 'auto'.
 |          steps: Total number of steps (batches of samples)
 |              before declaring the prediction round finished.
 |              Ignored with the default value of `None`. If x is a `tf.data`
 |              dataset and `steps` is None, `predict()` will
 |              run until the input dataset is exhausted.
 |          callbacks: List of `keras.callbacks.Callback` instances.
 |              List of callbacks to apply during prediction.
 |              See [callbacks](
 |              https://www.tensorflow.org/api_docs/python/tf/keras/callbacks).
 |          max_queue_size: Integer. Used for generator or
 |              `keras.utils.Sequence` input only. Maximum size for the
 |              generator queue. If unspecified, `max_queue_size` will default
 |              to 10.
 |          workers: Integer. Used for generator or `keras.utils.Sequence` input
 |              only. Maximum number of processes to spin up when using
 |              process-based threading. If unspecified, `workers` will default
 |              to 1.
 |          use_multiprocessing: Boolean. Used for generator or
 |              `keras.utils.Sequence` input only. If `True`, use process-based
 |              threading. If unspecified, `use_multiprocessing` will default to
 |              `False`. Note that because this implementation relies on
 |              multiprocessing, you should not pass non-picklable arguments to
 |              the generator as they can't be passed easily to children
 |              processes.
 |      
 |      See the discussion of `Unpacking behavior for iterator-like inputs` for
 |      `Model.fit`. Note that Model.predict uses the same interpretation rules
 |      as `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for
 |      all three methods.
 |      
 |      Returns:
 |          Numpy array(s) of predictions.
 |      
 |      Raises:
 |          RuntimeError: If `model.predict` is wrapped in a `tf.function`.
 |          ValueError: In case of mismatch between the provided
 |              input data and the model's expectations,
 |              or in case a stateful model receives a number of samples
 |              that is not a multiple of the batch size.
 |  
 |  predict_generator(self, generator, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)
 |      Generates predictions for the input samples from a data generator.
 |      
 |      DEPRECATED:
 |        `Model.predict` now supports generators, so there is no longer any
 |        need to use this endpoint.
 |  
 |  predict_on_batch(self, x)
 |      Returns predictions for a single batch of samples.
 |      
 |      Args:
 |          x: Input data. It could be:
 |            - A Numpy array (or array-like), or a list of arrays (in case the
 |                model has multiple inputs).
 |            - A TensorFlow tensor, or a list of tensors (in case the model has
 |                multiple inputs).
 |      
 |      Returns:
 |          Numpy array(s) of predictions.
 |      
 |      Raises:
 |          RuntimeError: If `model.predict_on_batch` is wrapped in a
 |            `tf.function`.
 |  
 |  predict_step(self, data)
 |      The logic for one inference step.
 |      
 |      This method can be overridden to support custom inference logic.
 |      This method is called by `Model.make_predict_function`.
 |      
 |      This method should contain the mathematical logic for one step of
 |      inference.  This typically includes the forward pass.
 |      
 |      Configuration details for *how* this logic is run (e.g. `tf.function`
 |      and `tf.distribute.Strategy` settings), should be left to
 |      `Model.make_predict_function`, which can also be overridden.
 |      
 |      Args:
 |        data: A nested structure of `Tensor`s.
 |      
 |      Returns:
 |        The result of one inference step, typically the output of calling the
 |        `Model` on data.
 |  
 |  reset_metrics(self)
 |      Resets the state of all the metrics in the model.
 |      
 |      Examples:
 |      
 |      >>> inputs = tf.keras.layers.Input(shape=(3,))
 |      >>> outputs = tf.keras.layers.Dense(2)(inputs)
 |      >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
 |      >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
 |      
 |      >>> x = np.random.random((2, 3))
 |      >>> y = np.random.randint(0, 2, (2, 2))
 |      >>> _ = model.fit(x, y, verbose=0)
 |      >>> assert all(float(m.result()) for m in model.metrics)
 |      
 |      >>> model.reset_metrics()
 |      >>> assert all(float(m.result()) == 0 for m in model.metrics)
 |  
 |  reset_states(self)
 |  
 |  save_spec(self, dynamic_batch=True)
 |      Returns the `tf.TensorSpec` of call args as a tuple `(args, kwargs)`.
 |      
 |      This value is automatically defined after calling the model for the
 |      first time. Afterwards, you can use it when exporting the model for
 |      serving:
 |      
 |      ```python
 |      model = tf.keras.Model(...)
 |      
 |      @tf.function
 |      def serve(*args, **kwargs):
 |        outputs = model(*args, **kwargs)
 |        # Apply postprocessing steps, or add additional outputs.
 |        ...
 |        return outputs
 |      
 |      # arg_specs is `[tf.TensorSpec(...), ...]`. kwarg_specs, in this
 |      # example, is an empty dict since functional models do not use keyword
 |      # arguments.
 |      arg_specs, kwarg_specs = model.save_spec()
 |      
 |      model.save(path, signatures={
 |        'serving_default': serve.get_concrete_function(*arg_specs,
 |                                                       **kwarg_specs)
 |      })
 |      ```
 |      
 |      Args:
 |        dynamic_batch: Whether to set the batch sizes of all the returned
 |          `tf.TensorSpec` to `None`. (Note that when defining functional or
 |          Sequential models with `tf.keras.Input([...], batch_size=X)`, the
 |          batch size will always be preserved). Defaults to `True`.
 |      Returns:
 |        If the model inputs are defined, returns a tuple `(args, kwargs)`. All
 |        elements in `args` and `kwargs` are `tf.TensorSpec`.
 |        If the model inputs are not defined, returns `None`.
 |        The model inputs are automatically set when calling the model,
 |        `model.fit`, `model.evaluate` or `model.predict`.
 |  
 |  save_weights(self, filepath, overwrite=True, save_format=None, options=None)
 |      Saves all layer weights.
 |      
 |      Either saves in HDF5 or in TensorFlow format based on the `save_format`
 |      argument.
 |      
 |      When saving in HDF5 format, the weight file has:
 |        - `layer_names` (attribute), a list of strings
 |            (ordered names of model layers).
 |        - For every layer, a `group` named `layer.name`
 |            - For every such layer group, a group attribute `weight_names`,
 |                a list of strings
 |                (ordered names of weights tensor of the layer).
 |            - For every weight in the layer, a dataset
 |                storing the weight value, named after the weight tensor.
 |      
 |      When saving in TensorFlow format, all objects referenced by the network
 |      are saved in the same format as `tf.train.Checkpoint`, including any
 |      `Layer` instances or `Optimizer` instances assigned to object
 |      attributes. For networks constructed from inputs and outputs using
 |      `tf.keras.Model(inputs, outputs)`, `Layer` instances used by the network
 |      are tracked/saved automatically. For user-defined classes which inherit
 |      from `tf.keras.Model`, `Layer` instances must be assigned to object
 |      attributes, typically in the constructor. See the documentation of
 |      `tf.train.Checkpoint` and `tf.keras.Model` for details.
 |      
 |      While the formats are the same, do not mix `save_weights` and
 |      `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should
 |      be loaded using `Model.load_weights`. Checkpoints saved using
 |      `tf.train.Checkpoint.save` should be restored using the corresponding
 |      `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over
 |      `save_weights` for training checkpoints.
 |      
 |      The TensorFlow format matches objects and variables by starting at a
 |      root object, `self` for `save_weights`, and greedily matching attribute
 |      names. For `Model.save` this is the `Model`, and for `Checkpoint.save`
 |      this is the `Checkpoint` even if the `Checkpoint` has a model attached.
 |      This means saving a `tf.keras.Model` using `save_weights` and loading
 |      into a `tf.train.Checkpoint` with a `Model` attached (or vice versa)
 |      will not match the `Model`'s variables. See the
 |      [guide to training checkpoints](
 |      https://www.tensorflow.org/guide/checkpoint) for details on
 |      the TensorFlow format.
 |      
 |      Args:
 |          filepath: String or PathLike, path to the file to save the weights
 |              to. When saving in TensorFlow format, this is the prefix used
 |              for checkpoint files (multiple files are generated). Note that
 |              the '.h5' suffix causes weights to be saved in HDF5 format.
 |          overwrite: Whether to silently overwrite any existing file at the
 |              target location, or provide the user with a manual prompt.
 |          save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
 |              '.keras' will default to HDF5 if `save_format` is `None`.
 |              Otherwise, `None` becomes 'tf'. Defaults to `None`.
 |          options: Optional `tf.train.CheckpointOptions` object that specifies
 |              options for saving weights.
 |      
 |      Raises:
 |          ImportError: If `h5py` is not available when attempting to save in
 |              HDF5 format.
 |  
 |  test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True, return_dict=False)
 |      Test the model on a single batch of samples.
 |      
 |      Args:
 |          x: Input data. It could be:
 |            - A Numpy array (or array-like), or a list of arrays (in case the
 |                model has multiple inputs).
 |            - A TensorFlow tensor, or a list of tensors (in case the model has
 |                multiple inputs).
 |            - A dict mapping input names to the corresponding array/tensors,
 |                if the model has named inputs.
 |          y: Target data. Like the input data `x`, it could be either Numpy
 |            array(s) or TensorFlow tensor(s). It should be consistent with `x`
 |            (you cannot have Numpy inputs and tensor targets, or inversely).
 |          sample_weight: Optional array of the same length as x, containing
 |            weights to apply to the model's loss for each sample. In the case
 |            of temporal data, you can pass a 2D array with shape (samples,
 |            sequence_length), to apply a different weight to every timestep of
 |            every sample.
 |          reset_metrics: If `True`, the metrics returned will be only for this
 |            batch. If `False`, the metrics will be statefully accumulated
 |            across batches.
 |          return_dict: If `True`, loss and metric results are returned as a
 |            dict, with each key being the name of the metric. If `False`, they
 |            are returned as a list.
 |      
 |      Returns:
 |          Scalar test loss (if the model has a single output and no metrics)
 |          or list of scalars (if the model has multiple outputs
 |          and/or metrics). The attribute `model.metrics_names` will give you
 |          the display labels for the scalar outputs.
 |      
 |      Raises:
 |          RuntimeError: If `model.test_on_batch` is wrapped in a
 |            `tf.function`.
 |  
 |  test_step(self, data)
 |      The logic for one evaluation step.
 |      
 |      This method can be overridden to support custom evaluation logic.
 |      This method is called by `Model.make_test_function`.
 |      
 |      This function should contain the mathematical logic for one step of
 |      evaluation.
 |      This typically includes the forward pass, loss calculation, and metrics
 |      updates.
 |      
 |      Configuration details for *how* this logic is run (e.g. `tf.function`
 |      and `tf.distribute.Strategy` settings), should be left to
 |      `Model.make_test_function`, which can also be overridden.
 |      
 |      Args:
 |        data: A nested structure of `Tensor`s.
 |      
 |      Returns:
 |        A `dict` containing values that will be passed to
 |        `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the
 |        values of the `Model`'s metrics are returned.
 |  
 |  to_json(self, **kwargs)
 |      Returns a JSON string containing the network configuration.
 |      
 |      To load a network from a JSON save file, use
 |      `keras.models.model_from_json(json_string, custom_objects={})`.
 |      
 |      Args:
 |          **kwargs: Additional keyword arguments to be passed to
 |              *`json.dumps()`.
 |      
 |      Returns:
 |          A JSON string.
 |  
 |  to_yaml(self, **kwargs)
 |      Returns a yaml string containing the network configuration.
 |      
 |      Note: Since TF 2.6, this method is no longer supported and will raise a
 |      RuntimeError.
 |      
 |      To load a network from a yaml save file, use
 |      `keras.models.model_from_yaml(yaml_string, custom_objects={})`.
 |      
 |      `custom_objects` should be a dictionary mapping
 |      the names of custom losses / layers / etc to the corresponding
 |      functions / classes.
 |      
 |      Args:
 |          **kwargs: Additional keyword arguments
 |              to be passed to `yaml.dump()`.
 |      
 |      Returns:
 |          A YAML string.
 |      
 |      Raises:
 |          RuntimeError: announces that the method poses a security risk
 |  
 |  ----------------------------------------------------------------------
 |  Class methods inherited from keras.src.engine.training.Model:
 |  
 |  from_config(config, custom_objects=None) from builtins.type
 |      Creates a layer from its config.
 |      
 |      This method is the reverse of `get_config`,
 |      capable of instantiating the same layer from the config
 |      dictionary. It does not handle layer connectivity
 |      (handled by Network), nor weights (handled by `set_weights`).
 |      
 |      Args:
 |          config: A Python dictionary, typically the
 |              output of get_config.
 |      
 |      Returns:
 |          A layer instance.
 |  
 |  ----------------------------------------------------------------------
 |  Static methods inherited from keras.src.engine.training.Model:
 |  
 |  __new__(cls, *args, **kwargs)
 |      Create and return a new object.  See help(type) for accurate signature.
 |  
 |  ----------------------------------------------------------------------
 |  Readonly properties inherited from keras.src.engine.training.Model:
 |  
 |  distribute_strategy
 |      The `tf.distribute.Strategy` this model was created under.
 |  
 |  metrics
 |      Return metrics added using `compile()` or `add_metric()`.
 |      
 |      Note: Metrics passed to `compile()` are available only after a
 |      `keras.Model` has been trained/evaluated on actual data.
 |      
 |      Examples:
 |      
 |      >>> inputs = tf.keras.layers.Input(shape=(3,))
 |      >>> outputs = tf.keras.layers.Dense(2)(inputs)
 |      >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
 |      >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
 |      >>> [m.name for m in model.metrics]
 |      []
 |      
 |      >>> x = np.random.random((2, 3))
 |      >>> y = np.random.randint(0, 2, (2, 2))
 |      >>> model.fit(x, y)
 |      >>> [m.name for m in model.metrics]
 |      ['loss', 'mae']
 |      
 |      >>> inputs = tf.keras.layers.Input(shape=(3,))
 |      >>> d = tf.keras.layers.Dense(2, name='out')
 |      >>> output_1 = d(inputs)
 |      >>> output_2 = d(inputs)
 |      >>> model = tf.keras.models.Model(
 |      ...    inputs=inputs, outputs=[output_1, output_2])
 |      >>> model.add_metric(
 |      ...    tf.reduce_sum(output_2), name='mean', aggregation='mean')
 |      >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
 |      >>> model.fit(x, (y, y))
 |      >>> [m.name for m in model.metrics]
 |      ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
 |      'out_1_acc', 'mean']
 |  
 |  metrics_names
 |      Returns the model's display labels for all outputs.
 |      
 |      Note: `metrics_names` are available only after a `keras.Model` has been
 |      trained/evaluated on actual data.
 |      
 |      Examples:
 |      
 |      >>> inputs = tf.keras.layers.Input(shape=(3,))
 |      >>> outputs = tf.keras.layers.Dense(2)(inputs)
 |      >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
 |      >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae"])
 |      >>> model.metrics_names
 |      []
 |      
 |      >>> x = np.random.random((2, 3))
 |      >>> y = np.random.randint(0, 2, (2, 2))
 |      >>> model.fit(x, y)
 |      >>> model.metrics_names
 |      ['loss', 'mae']
 |      
 |      >>> inputs = tf.keras.layers.Input(shape=(3,))
 |      >>> d = tf.keras.layers.Dense(2, name='out')
 |      >>> output_1 = d(inputs)
 |      >>> output_2 = d(inputs)
 |      >>> model = tf.keras.models.Model(
 |      ...    inputs=inputs, outputs=[output_1, output_2])
 |      >>> model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"])
 |      >>> model.fit(x, (y, y))
 |      >>> model.metrics_names
 |      ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',
 |      'out_1_acc']
 |  
 |  non_trainable_weights
 |      List of all non-trainable weights tracked by this layer.
 |      
 |      Non-trainable weights are *not* updated during training. They are
 |      expected to be updated manually in `call()`.
 |      
 |      Returns:
 |        A list of non-trainable variables.
 |  
 |  state_updates
 |      Deprecated, do NOT use!
 |      
 |      Returns the `updates` from all layers that are stateful.
 |      
 |      This is useful for separating training updates and
 |      state updates, e.g. when we need to update a layer's internal state
 |      during prediction.
 |      
 |      Returns:
 |          A list of update ops.
 |  
 |  trainable_weights
 |      List of all trainable weights tracked by this layer.
 |      
 |      Trainable weights are updated via gradient descent during training.
 |      
 |      Returns:
 |        A list of trainable variables.
 |  
 |  weights
 |      Returns the list of all layer variables/weights.
 |      
 |      Note: This will not track the weights of nested `tf.Modules` that are
 |      not themselves Keras layers.
 |      
 |      Returns:
 |        A list of variables.
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors inherited from keras.src.engine.training.Model:
 |  
 |  distribute_reduction_method
 |      The method employed to reduce per-replica values during training.
 |      
 |      Unless specified, the value "auto" will be assumed, indicating that
 |      the reduction strategy should be chosen based on the current
 |      running environment.
 |      See `reduce_per_replica` function for more details.
 |  
 |  jit_compile
 |      Specify whether to compile the model with XLA.
 |      
 |      [XLA](https://www.tensorflow.org/xla) is an optimizing compiler
 |      for machine learning. `jit_compile` is not enabled by default.
 |      Note that `jit_compile=True` may not necessarily work for all models.
 |      
 |      For more information on supported operations please refer to the
 |      [XLA documentation](https://www.tensorflow.org/xla). Also refer to
 |      [known XLA issues](https://www.tensorflow.org/xla/known_issues)
 |      for more details.
 |  
 |  layers
 |  
 |  run_eagerly
 |      Settable attribute indicating whether the model should run eagerly.
 |      
 |      Running eagerly means that your model will be run step by step,
 |      like Python code. Your model might run slower, but it should become
 |      easier for you to debug it by stepping into individual layer calls.
 |      
 |      By default, we will attempt to compile your model to a static graph to
 |      deliver the best execution performance.
 |      
 |      Returns:
 |        Boolean, whether the model should run eagerly.
 |  
 |  ----------------------------------------------------------------------
 |  Methods inherited from keras.src.engine.base_layer.Layer:
 |  
 |  __delattr__(self, name)
 |      Implement delattr(self, name).
 |  
 |  __getstate__(self)
 |  
 |  __setstate__(self, state)
 |  
 |  add_loss(self, losses, **kwargs)
 |      Add loss tensor(s), potentially dependent on layer inputs.
 |      
 |      Some losses (for instance, activity regularization losses) may be
 |      dependent on the inputs passed when calling a layer. Hence, when reusing
 |      the same layer on different inputs `a` and `b`, some entries in
 |      `layer.losses` may be dependent on `a` and some on `b`. This method
 |      automatically keeps track of dependencies.
 |      
 |      This method can be used inside a subclassed layer or model's `call`
 |      function, in which case `losses` should be a Tensor or list of Tensors.
 |      
 |      Example:
 |      
 |      ```python
 |      class MyLayer(tf.keras.layers.Layer):
 |        def call(self, inputs):
 |          self.add_loss(tf.abs(tf.reduce_mean(inputs)))
 |          return inputs
 |      ```
 |      
 |      The same code works in distributed training: the input to `add_loss()`
 |      is treated like a regularization loss and averaged across replicas
 |      by the training loop (both built-in `Model.fit()` and compliant custom
 |      training loops).
 |      
 |      The `add_loss` method can also be called directly on a Functional Model
 |      during construction. In this case, any loss Tensors passed to this Model
 |      must be symbolic and be able to be traced back to the model's `Input`s.
 |      These losses become part of the model's topology and are tracked in
 |      `get_config`.
 |      
 |      Example:
 |      
 |      ```python
 |      inputs = tf.keras.Input(shape=(10,))
 |      x = tf.keras.layers.Dense(10)(inputs)
 |      outputs = tf.keras.layers.Dense(1)(x)
 |      model = tf.keras.Model(inputs, outputs)
 |      # Activity regularization.
 |      model.add_loss(tf.abs(tf.reduce_mean(x)))
 |      ```
 |      
 |      If this is not the case for your loss (if, for example, your loss
 |      references a `Variable` of one of the model's layers), you can wrap your
 |      loss in a zero-argument lambda. These losses are not tracked as part of
 |      the model's topology since they can't be serialized.
 |      
 |      Example:
 |      
 |      ```python
 |      inputs = tf.keras.Input(shape=(10,))
 |      d = tf.keras.layers.Dense(10)
 |      x = d(inputs)
 |      outputs = tf.keras.layers.Dense(1)(x)
 |      model = tf.keras.Model(inputs, outputs)
 |      # Weight regularization.
 |      model.add_loss(lambda: tf.reduce_mean(d.kernel))
 |      ```
 |      
 |      Args:
 |        losses: Loss tensor, or list/tuple of tensors. Rather than tensors,
 |          losses may also be zero-argument callables which create a loss
 |          tensor.
 |        **kwargs: Used for backwards compatibility only.
 |  
 |  add_metric(self, value, name=None, **kwargs)
 |      Adds metric tensor to the layer.
 |      
 |      This method can be used inside the `call()` method of a subclassed layer
 |      or model.
 |      
 |      ```python
 |      class MyMetricLayer(tf.keras.layers.Layer):
 |        def __init__(self):
 |          super(MyMetricLayer, self).__init__(name='my_metric_layer')
 |          self.mean = tf.keras.metrics.Mean(name='metric_1')
 |      
 |        def call(self, inputs):
 |          self.add_metric(self.mean(inputs))
 |          self.add_metric(tf.reduce_sum(inputs), name='metric_2')
 |          return inputs
 |      ```
 |      
 |      This method can also be called directly on a Functional Model during
 |      construction. In this case, any tensor passed to this Model must
 |      be symbolic and be able to be traced back to the model's `Input`s. These
 |      metrics become part of the model's topology and are tracked when you
 |      save the model via `save()`.
 |      
 |      ```python
 |      inputs = tf.keras.Input(shape=(10,))
 |      x = tf.keras.layers.Dense(10)(inputs)
 |      outputs = tf.keras.layers.Dense(1)(x)
 |      model = tf.keras.Model(inputs, outputs)
 |      model.add_metric(math_ops.reduce_sum(x), name='metric_1')
 |      ```
 |      
 |      Note: Calling `add_metric()` with the result of a metric object on a
 |      Functional Model, as shown in the example below, is not supported. This
 |      is because we cannot trace the metric result tensor back to the model's
 |      inputs.
 |      
 |      ```python
 |      inputs = tf.keras.Input(shape=(10,))
 |      x = tf.keras.layers.Dense(10)(inputs)
 |      outputs = tf.keras.layers.Dense(1)(x)
 |      model = tf.keras.Model(inputs, outputs)
 |      model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1')
 |      ```
 |      
 |      Args:
 |        value: Metric tensor.
 |        name: String metric name.
 |        **kwargs: Additional keyword arguments for backward compatibility.
 |          Accepted values:
 |          `aggregation` - When the `value` tensor provided is not the result
 |          of calling a `keras.Metric` instance, it will be aggregated by
 |          default using a `keras.Metric.Mean`.
 |  
 |  add_update(self, updates)
 |      Add update op(s), potentially dependent on layer inputs.
 |      
 |      Weight updates (for instance, the updates of the moving mean and
 |      variance in a BatchNormalization layer) may be dependent on the inputs
 |      passed when calling a layer. Hence, when reusing the same layer on
 |      different inputs `a` and `b`, some entries in `layer.updates` may be
 |      dependent on `a` and some on `b`. This method automatically keeps track
 |      of dependencies.
 |      
 |      This call is ignored when eager execution is enabled (in that case,
 |      variable updates are run on the fly and thus do not need to be tracked
 |      for later execution).
 |      
 |      Args:
 |        updates: Update op, or list/tuple of update ops, or zero-arg callable
 |          that returns an update op. A zero-arg callable should be passed in
 |          order to disable running the updates by setting `trainable=False`
 |          on this Layer, when executing in Eager mode.
 |  
 |  add_variable(self, *args, **kwargs)
 |      Deprecated, do NOT use! Alias for `add_weight`.
 |  
 |  add_weight(self, name=None, shape=None, dtype=None, initializer=None, regularizer=None, trainable=None, constraint=None, use_resource=None, synchronization=, aggregation=, **kwargs)
 |      Adds a new variable to the layer.
 |      
 |      Args:
 |        name: Variable name.
 |        shape: Variable shape. Defaults to scalar if unspecified.
 |        dtype: The type of the variable. Defaults to `self.dtype`.
 |        initializer: Initializer instance (callable).
 |        regularizer: Regularizer instance (callable).
 |        trainable: Boolean, whether the variable should be part of the layer's
 |          "trainable_variables" (e.g. variables, biases)
 |          or "non_trainable_variables" (e.g. BatchNorm mean and variance).
 |          Note that `trainable` cannot be `True` if `synchronization`
 |          is set to `ON_READ`.
 |        constraint: Constraint instance (callable).
 |        use_resource: Whether to use a `ResourceVariable` or not.
 |          See [this guide](
 |          https://www.tensorflow.org/guide/migrate/tf1_vs_tf2#resourcevariables_instead_of_referencevariables)
 |           for more information.
 |        synchronization: Indicates when a distributed a variable will be
 |          aggregated. Accepted values are constants defined in the class
 |          `tf.VariableSynchronization`. By default the synchronization is set
 |          to `AUTO` and the current `DistributionStrategy` chooses when to
 |          synchronize. If `synchronization` is set to `ON_READ`, `trainable`
 |          must not be set to `True`.
 |        aggregation: Indicates how a distributed variable will be aggregated.
 |          Accepted values are constants defined in the class
 |          `tf.VariableAggregation`.
 |        **kwargs: Additional keyword arguments. Accepted values are `getter`,
 |          `collections`, `experimental_autocast` and `caching_device`.
 |      
 |      Returns:
 |        The variable created.
 |      
 |      Raises:
 |        ValueError: When giving unsupported dtype and no initializer or when
 |          trainable has been set to True with synchronization set as
 |          `ON_READ`.
 |  
 |  build_from_config(self, config)
 |      Builds the layer's states with the supplied config dict.
 |      
 |      By default, this method calls the `build(config["input_shape"])` method,
 |      which creates weights based on the layer's input shape in the supplied
 |      config. If your config contains other information needed to load the
 |      layer's state, you should override this method.
 |      
 |      Args:
 |          config: Dict containing the input shape associated with this layer.
 |  
 |  compute_mask(self, inputs, mask=None)
 |      Computes an output mask tensor.
 |      
 |      Args:
 |          inputs: Tensor or list of tensors.
 |          mask: Tensor or list of tensors.
 |      
 |      Returns:
 |          None or a tensor (or list of tensors,
 |              one per output tensor of the layer).
 |  
 |  compute_output_shape(self, input_shape)
 |      Computes the output shape of the layer.
 |      
 |      This method will cause the layer's state to be built, if that has not
 |      happened before. This requires that the layer will later be used with
 |      inputs that match the input shape provided here.
 |      
 |      Args:
 |          input_shape: Shape tuple (tuple of integers) or `tf.TensorShape`,
 |              or structure of shape tuples / `tf.TensorShape` instances
 |              (one per output tensor of the layer).
 |              Shape tuples can include None for free dimensions,
 |              instead of an integer.
 |      
 |      Returns:
 |          A `tf.TensorShape` instance
 |          or structure of `tf.TensorShape` instances.
 |  
 |  compute_output_signature(self, input_signature)
 |      Compute the output tensor signature of the layer based on the inputs.
 |      
 |      Unlike a TensorShape object, a TensorSpec object contains both shape
 |      and dtype information for a tensor. This method allows layers to provide
 |      output dtype information if it is different from the input dtype.
 |      For any layer that doesn't implement this function,
 |      the framework will fall back to use `compute_output_shape`, and will
 |      assume that the output dtype matches the input dtype.
 |      
 |      Args:
 |        input_signature: Single TensorSpec or nested structure of TensorSpec
 |          objects, describing a candidate input for the layer.
 |      
 |      Returns:
 |        Single TensorSpec or nested structure of TensorSpec objects,
 |          describing how the layer would transform the provided input.
 |      
 |      Raises:
 |        TypeError: If input_signature contains a non-TensorSpec object.
 |  
 |  count_params(self)
 |      Count the total number of scalars composing the weights.
 |      
 |      Returns:
 |          An integer count.
 |      
 |      Raises:
 |          ValueError: if the layer isn't yet built
 |            (in which case its weights aren't yet defined).
 |  
 |  finalize_state(self)
 |      Finalizes the layers state after updating layer weights.
 |      
 |      This function can be subclassed in a layer and will be called after
 |      updating a layer weights. It can be overridden to finalize any
 |      additional layer state after a weight update.
 |      
 |      This function will be called after weights of a layer have been restored
 |      from a loaded model.
 |  
 |  get_build_config(self)
 |      Returns a dictionary with the layer's input shape.
 |      
 |      This method returns a config dict that can be used by
 |      `build_from_config(config)` to create all states (e.g. Variables and
 |      Lookup tables) needed by the layer.
 |      
 |      By default, the config only contains the input shape that the layer
 |      was built with. If you're writing a custom layer that creates state in
 |      an unusual way, you should override this method to make sure this state
 |      is already created when Keras attempts to load its value upon model
 |      loading.
 |      
 |      Returns:
 |          A dict containing the input shape associated with the layer.
 |  
 |  get_input_at(self, node_index)
 |      Retrieves the input tensor(s) of a layer at a given node.
 |      
 |      Args:
 |          node_index: Integer, index of the node
 |              from which to retrieve the attribute.
 |              E.g. `node_index=0` will correspond to the
 |              first input node of the layer.
 |      
 |      Returns:
 |          A tensor (or list of tensors if the layer has multiple inputs).
 |      
 |      Raises:
 |        RuntimeError: If called in Eager mode.
 |  
 |  get_input_mask_at(self, node_index)
 |      Retrieves the input mask tensor(s) of a layer at a given node.
 |      
 |      Args:
 |          node_index: Integer, index of the node
 |              from which to retrieve the attribute.
 |              E.g. `node_index=0` will correspond to the
 |              first time the layer was called.
 |      
 |      Returns:
 |          A mask tensor
 |          (or list of tensors if the layer has multiple inputs).
 |  
 |  get_input_shape_at(self, node_index)
 |      Retrieves the input shape(s) of a layer at a given node.
 |      
 |      Args:
 |          node_index: Integer, index of the node
 |              from which to retrieve the attribute.
 |              E.g. `node_index=0` will correspond to the
 |              first time the layer was called.
 |      
 |      Returns:
 |          A shape tuple
 |          (or list of shape tuples if the layer has multiple inputs).
 |      
 |      Raises:
 |        RuntimeError: If called in Eager mode.
 |  
 |  get_output_at(self, node_index)
 |      Retrieves the output tensor(s) of a layer at a given node.
 |      
 |      Args:
 |          node_index: Integer, index of the node
 |              from which to retrieve the attribute.
 |              E.g. `node_index=0` will correspond to the
 |              first output node of the layer.
 |      
 |      Returns:
 |          A tensor (or list of tensors if the layer has multiple outputs).
 |      
 |      Raises:
 |        RuntimeError: If called in Eager mode.
 |  
 |  get_output_mask_at(self, node_index)
 |      Retrieves the output mask tensor(s) of a layer at a given node.
 |      
 |      Args:
 |          node_index: Integer, index of the node
 |              from which to retrieve the attribute.
 |              E.g. `node_index=0` will correspond to the
 |              first time the layer was called.
 |      
 |      Returns:
 |          A mask tensor
 |          (or list of tensors if the layer has multiple outputs).
 |  
 |  get_output_shape_at(self, node_index)
 |      Retrieves the output shape(s) of a layer at a given node.
 |      
 |      Args:
 |          node_index: Integer, index of the node
 |              from which to retrieve the attribute.
 |              E.g. `node_index=0` will correspond to the
 |              first time the layer was called.
 |      
 |      Returns:
 |          A shape tuple
 |          (or list of shape tuples if the layer has multiple outputs).
 |      
 |      Raises:
 |        RuntimeError: If called in Eager mode.
 |  
 |  load_own_variables(self, store)
 |      Loads the state of the layer.
 |      
 |      You can override this method to take full control of how the state of
 |      the layer is loaded upon calling `keras.models.load_model()`.
 |      
 |      Args:
 |          store: Dict from which the state of the model will be loaded.
 |  
 |  save_own_variables(self, store)
 |      Saves the state of the layer.
 |      
 |      You can override this method to take full control of how the state of
 |      the layer is saved upon calling `model.save()`.
 |      
 |      Args:
 |          store: Dict where the state of the model will be saved.
 |  
 |  set_weights(self, weights)
 |      Sets the weights of the layer, from NumPy arrays.
 |      
 |      The weights of a layer represent the state of the layer. This function
 |      sets the weight values from numpy arrays. The weight values should be
 |      passed in the order they are created by the layer. Note that the layer's
 |      weights must be instantiated before calling this function, by calling
 |      the layer.
 |      
 |      For example, a `Dense` layer returns a list of two values: the kernel
 |      matrix and the bias vector. These can be used to set the weights of
 |      another `Dense` layer:
 |      
 |      >>> layer_a = tf.keras.layers.Dense(1,
 |      ...   kernel_initializer=tf.constant_initializer(1.))
 |      >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))
 |      >>> layer_a.get_weights()
 |      [array([[1.],
 |             [1.],
 |             [1.]], dtype=float32), array([0.], dtype=float32)]
 |      >>> layer_b = tf.keras.layers.Dense(1,
 |      ...   kernel_initializer=tf.constant_initializer(2.))
 |      >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))
 |      >>> layer_b.get_weights()
 |      [array([[2.],
 |             [2.],
 |             [2.]], dtype=float32), array([0.], dtype=float32)]
 |      >>> layer_b.set_weights(layer_a.get_weights())
 |      >>> layer_b.get_weights()
 |      [array([[1.],
 |             [1.],
 |             [1.]], dtype=float32), array([0.], dtype=float32)]
 |      
 |      Args:
 |        weights: a list of NumPy arrays. The number
 |          of arrays and their shape must match
 |          number of the dimensions of the weights
 |          of the layer (i.e. it should match the
 |          output of `get_weights`).
 |      
 |      Raises:
 |        ValueError: If the provided weights list does not match the
 |          layer's specifications.
 |  
 |  ----------------------------------------------------------------------
 |  Readonly properties inherited from keras.src.engine.base_layer.Layer:
 |  
 |  compute_dtype
 |      The dtype of the layer's computations.
 |      
 |      This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless
 |      mixed precision is used, this is the same as `Layer.dtype`, the dtype of
 |      the weights.
 |      
 |      Layers automatically cast their inputs to the compute dtype, which
 |      causes computations and the output to be in the compute dtype as well.
 |      This is done by the base Layer class in `Layer.__call__`, so you do not
 |      have to insert these casts if implementing your own layer.
 |      
 |      Layers often perform certain internal computations in higher precision
 |      when `compute_dtype` is float16 or bfloat16 for numeric stability. The
 |      output will still typically be float16 or bfloat16 in such cases.
 |      
 |      Returns:
 |        The layer's compute dtype.
 |  
 |  dtype
 |      The dtype of the layer weights.
 |      
 |      This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless
 |      mixed precision is used, this is the same as `Layer.compute_dtype`, the
 |      dtype of the layer's computations.
 |  
 |  dtype_policy
 |      The dtype policy associated with this layer.
 |      
 |      This is an instance of a `tf.keras.mixed_precision.Policy`.
 |  
 |  dynamic
 |      Whether the layer is dynamic (eager-only); set in the constructor.
 |  
 |  inbound_nodes
 |      Return Functional API nodes upstream of this layer.
 |  
 |  input
 |      Retrieves the input tensor(s) of a layer.
 |      
 |      Only applicable if the layer has exactly one input,
 |      i.e. if it is connected to one incoming layer.
 |      
 |      Returns:
 |          Input tensor or list of input tensors.
 |      
 |      Raises:
 |        RuntimeError: If called in Eager mode.
 |        AttributeError: If no inbound nodes are found.
 |  
 |  input_mask
 |      Retrieves the input mask tensor(s) of a layer.
 |      
 |      Only applicable if the layer has exactly one inbound node,
 |      i.e. if it is connected to one incoming layer.
 |      
 |      Returns:
 |          Input mask tensor (potentially None) or list of input
 |          mask tensors.
 |      
 |      Raises:
 |          AttributeError: if the layer is connected to
 |          more than one incoming layers.
 |  
 |  input_shape
 |      Retrieves the input shape(s) of a layer.
 |      
 |      Only applicable if the layer has exactly one input,
 |      i.e. if it is connected to one incoming layer, or if all inputs
 |      have the same shape.
 |      
 |      Returns:
 |          Input shape, as an integer shape tuple
 |          (or list of shape tuples, one tuple per input tensor).
 |      
 |      Raises:
 |          AttributeError: if the layer has no defined input_shape.
 |          RuntimeError: if called in Eager mode.
 |  
 |  losses
 |      List of losses added using the `add_loss()` API.
 |      
 |      Variable regularization tensors are created when this property is
 |      accessed, so it is eager safe: accessing `losses` under a
 |      `tf.GradientTape` will propagate gradients back to the corresponding
 |      variables.
 |      
 |      Examples:
 |      
 |      >>> class MyLayer(tf.keras.layers.Layer):
 |      ...   def call(self, inputs):
 |      ...     self.add_loss(tf.abs(tf.reduce_mean(inputs)))
 |      ...     return inputs
 |      >>> l = MyLayer()
 |      >>> l(np.ones((10, 1)))
 |      >>> l.losses
 |      [1.0]
 |      
 |      >>> inputs = tf.keras.Input(shape=(10,))
 |      >>> x = tf.keras.layers.Dense(10)(inputs)
 |      >>> outputs = tf.keras.layers.Dense(1)(x)
 |      >>> model = tf.keras.Model(inputs, outputs)
 |      >>> # Activity regularization.
 |      >>> len(model.losses)
 |      0
 |      >>> model.add_loss(tf.abs(tf.reduce_mean(x)))
 |      >>> len(model.losses)
 |      1
 |      
 |      >>> inputs = tf.keras.Input(shape=(10,))
 |      >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones')
 |      >>> x = d(inputs)
 |      >>> outputs = tf.keras.layers.Dense(1)(x)
 |      >>> model = tf.keras.Model(inputs, outputs)
 |      >>> # Weight regularization.
 |      >>> model.add_loss(lambda: tf.reduce_mean(d.kernel))
 |      >>> model.losses
 |      []
 |      
 |      Returns:
 |        A list of tensors.
 |  
 |  name
 |      Name of the layer (string), set in the constructor.
 |  
 |  non_trainable_variables
 |      Sequence of non-trainable variables owned by this module and its submodules.
 |      
 |      Note: this method uses reflection to find variables on the current instance
 |      and submodules. For performance reasons you may wish to cache the result
 |      of calling this method if you don't expect the return value to change.
 |      
 |      Returns:
 |        A sequence of variables for the current module (sorted by attribute
 |        name) followed by variables from all submodules recursively (breadth
 |        first).
 |  
 |  outbound_nodes
 |      Return Functional API nodes downstream of this layer.
 |  
 |  output
 |      Retrieves the output tensor(s) of a layer.
 |      
 |      Only applicable if the layer has exactly one output,
 |      i.e. if it is connected to one incoming layer.
 |      
 |      Returns:
 |        Output tensor or list of output tensors.
 |      
 |      Raises:
 |        AttributeError: if the layer is connected to more than one incoming
 |          layers.
 |        RuntimeError: if called in Eager mode.
 |  
 |  output_mask
 |      Retrieves the output mask tensor(s) of a layer.
 |      
 |      Only applicable if the layer has exactly one inbound node,
 |      i.e. if it is connected to one incoming layer.
 |      
 |      Returns:
 |          Output mask tensor (potentially None) or list of output
 |          mask tensors.
 |      
 |      Raises:
 |          AttributeError: if the layer is connected to
 |          more than one incoming layers.
 |  
 |  output_shape
 |      Retrieves the output shape(s) of a layer.
 |      
 |      Only applicable if the layer has one output,
 |      or if all outputs have the same shape.
 |      
 |      Returns:
 |          Output shape, as an integer shape tuple
 |          (or list of shape tuples, one tuple per output tensor).
 |      
 |      Raises:
 |          AttributeError: if the layer has no defined output shape.
 |          RuntimeError: if called in Eager mode.
 |  
 |  trainable_variables
 |      Sequence of trainable variables owned by this module and its submodules.
 |      
 |      Note: this method uses reflection to find variables on the current instance
 |      and submodules. For performance reasons you may wish to cache the result
 |      of calling this method if you don't expect the return value to change.
 |      
 |      Returns:
 |        A sequence of variables for the current module (sorted by attribute
 |        name) followed by variables from all submodules recursively (breadth
 |        first).
 |  
 |  updates
 |  
 |  variable_dtype
 |      Alias of `Layer.dtype`, the dtype of the weights.
 |  
 |  variables
 |      Returns the list of all layer variables/weights.
 |      
 |      Alias of `self.weights`.
 |      
 |      Note: This will not track the weights of nested `tf.Modules` that are
 |      not themselves Keras layers.
 |      
 |      Returns:
 |        A list of variables.
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors inherited from keras.src.engine.base_layer.Layer:
 |  
 |  activity_regularizer
 |      Optional regularizer function for the output of this layer.
 |  
 |  input_spec
 |      `InputSpec` instance(s) describing the input format for this layer.
 |      
 |      When you create a layer subclass, you can set `self.input_spec` to
 |      enable the layer to run input compatibility checks when it is called.
 |      Consider a `Conv2D` layer: it can only be called on a single input
 |      tensor of rank 4. As such, you can set, in `__init__()`:
 |      
 |      ```python
 |      self.input_spec = tf.keras.layers.InputSpec(ndim=4)
 |      ```
 |      
 |      Now, if you try to call the layer on an input that isn't rank 4
 |      (for instance, an input of shape `(2,)`, it will raise a
 |      nicely-formatted error:
 |      
 |      ```
 |      ValueError: Input 0 of layer conv2d is incompatible with the layer:
 |      expected ndim=4, found ndim=1. Full shape received: [2]
 |      ```
 |      
 |      Input checks that can be specified via `input_spec` include:
 |      - Structure (e.g. a single input, a list of 2 inputs, etc)
 |      - Shape
 |      - Rank (ndim)
 |      - Dtype
 |      
 |      For more information, see `tf.keras.layers.InputSpec`.
 |      
 |      Returns:
 |        A `tf.keras.layers.InputSpec` instance, or nested structure thereof.
 |  
 |  stateful
 |  
 |  supports_masking
 |      Whether this layer supports computing a mask using `compute_mask`.
 |  
 |  trainable
 |  
 |  ----------------------------------------------------------------------
 |  Class methods inherited from tensorflow.python.module.module.Module:
 |  
 |  with_name_scope(method) from builtins.type
 |      Decorator to automatically enter the module name scope.
 |      
 |      >>> class MyModule(tf.Module):
 |      ...   @tf.Module.with_name_scope
 |      ...   def __call__(self, x):
 |      ...     if not hasattr(self, 'w'):
 |      ...       self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
 |      ...     return tf.matmul(x, self.w)
 |      
 |      Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose
 |      names included the module name:
 |      
 |      >>> mod = MyModule()
 |      >>> mod(tf.ones([1, 2]))
 |      
 |      >>> mod.w
 |      
 |      
 |      Args:
 |        method: The method to wrap.
 |      
 |      Returns:
 |        The original method wrapped such that it enters the module's name scope.
 |  
 |  ----------------------------------------------------------------------
 |  Readonly properties inherited from tensorflow.python.module.module.Module:
 |  
 |  name_scope
 |      Returns a `tf.name_scope` instance for this class.
 |  
 |  submodules
 |      Sequence of all sub-modules.
 |      
 |      Submodules are modules which are properties of this module, or found as
 |      properties of modules which are properties of this module (and so on).
 |      
 |      >>> a = tf.Module()
 |      >>> b = tf.Module()
 |      >>> c = tf.Module()
 |      >>> a.b = b
 |      >>> b.c = c
 |      >>> list(a.submodules) == [b, c]
 |      True
 |      >>> list(b.submodules) == [c]
 |      True
 |      >>> list(c.submodules) == []
 |      True
 |      
 |      Returns:
 |        A sequence of all submodules.
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors inherited from tensorflow.python.trackable.base.Trackable:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)

12. 使用特征子集

前面的例子没有指定特征,所以所有的列都被用作输入特征(除了标签)。下面的例子展示了如何指定输入特征。


# 创建特征对象 feature_1 和 feature_2
feature_1 = tfdf.keras.FeatureUsage(name="bill_length_mm")
feature_2 = tfdf.keras.FeatureUsage(name="island")

# 将特征对象添加到特征列表 all_features 中
all_features = [feature_1, feature_2]

# 注意:该模型仅使用了两个特征进行训练,因此它的性能不如使用所有特征训练的模型好。

# 创建梯度提升树模型对象 model_2
model_2 = tfdf.keras.GradientBoostedTreesModel(
    features=all_features, exclude_non_specified_features=True)

# 编译模型,指定评估指标为准确率
model_2.compile(metrics=["accuracy"])

# 使用训练数据集 train_ds 进行训练,并使用验证数据集 test_ds 进行验证
model_2.fit(train_ds, validation_data=test_ds)

# 打印模型在测试数据集上的评估结果,以字典形式返回
print(model_2.evaluate(test_ds, return_dict=True))
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


Use /tmpfs/tmp/tmpoow4zpd8 as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.144474. Found 239 examples.
Reading validation dataset...


[WARNING 23-08-16 11:05:24.8447 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:24.8447 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:24.8447 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".


Num validation examples: tf.Tensor(105, shape=(), dtype=int32)
Validation dataset read in 0:00:00.205776. Found 105 examples.
Training model...
Model trained in 0:00:00.610888
Compiling model...
Model compiled.


[INFO 23-08-16 11:05:25.7975 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpoow4zpd8/model/ with prefix 730fc7c477154271
[INFO 23-08-16 11:05:25.8146 UTC decision_forest.cc:660] Model loaded with 168 root(s), 5352 node(s), and 2 input feature(s).
[INFO 23-08-16 11:05:25.8146 UTC kernel.cc:1075] Use fast generic engine



1/1 [==============================] - ETA: 0s - loss: 0.0000e+00 - accuracy: 0.9810
1/1 [==============================] - 0s 87ms/step - loss: 0.0000e+00 - accuracy: 0.9810
{'loss': 0.0, 'accuracy': 0.9809523820877075}

**注意:**正如预期的那样,准确率低于以前。

TF-DF为每个特征附加了一个语义。这个语义控制了模型如何使用该特征。目前支持以下语义:

  • 数值型:通常用于具有完全排序的数量或计数。例如,一个人的年龄或一个袋子中的物品数量。可以是浮点数或整数。缺失值用浮点数(Nan)或空稀疏张量表示。
  • 分类型:通常用于有限可能值集合中的类型/类别,没有排序。例如,集合{RED, BLUE, GREEN}中的颜色RED。可以是字符串或整数。缺失值表示为空字符串"",值为-2或空稀疏张量。
  • 分类集合型:一组分类值。非常适合表示分词文本。可以是字符串或整数,存储在稀疏张量或不规则张量(推荐)中。每个项的顺序/索引无关紧要。

如果未指定语义,则从表示类型中推断语义,并在训练日志中显示:

  • int、float(密集或稀疏)→ 数值型语义。
  • str(密集或稀疏)→ 分类型语义。
  • int、str(不规则)→ 分类集合型语义。

在某些情况下,推断的语义是错误的。例如:将枚举存储为整数的情况下,语义上是分类型的,但会被检测为数值型。在这种情况下,应在输入中指定语义参数。成人数据集的education_num字段就是一个经典例子。

该数据集不包含这样的特征。然而,为了演示,我们将使模型将year视为分类特征:

# 设置单元格高度为300
%set_cell_height 300

# 创建一个特征使用对象,表示一个分类特征"year"
feature_1 = tfdf.keras.FeatureUsage(name="year", semantic=tfdf.keras.FeatureSemantic.CATEGORICAL)

# 创建一个特征使用对象,表示一个数值特征"bill_length_mm"
feature_2 = tfdf.keras.FeatureUsage(name="bill_length_mm")

# 创建一个特征使用对象,表示一个分类特征"sex"
feature_3 = tfdf.keras.FeatureUsage(name="sex")

# 将所有特征使用对象放入列表中
all_features = [feature_1, feature_2, feature_3]

# 创建一个梯度提升树模型对象,指定使用的特征为all_features,排除未指定的特征
model_3 = tfdf.keras.GradientBoostedTreesModel(features=all_features, exclude_non_specified_features=True)

# 编译模型,指定评估指标为准确率
model_3.compile(metrics=["accuracy"])

# 使用训练数据集进行模型训练,并使用验证数据集进行模型验证
model_3.fit(train_ds, validation_data=test_ds)



Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


Use /tmpfs/tmp/tmpg9srb1ip as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.143701. Found 239 examples.
Reading validation dataset...


[WARNING 23-08-16 11:05:26.3095 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:26.3095 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:26.3095 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".


Num validation examples: tf.Tensor(105, shape=(), dtype=int32)
Validation dataset read in 0:00:00.152938. Found 105 examples.
Training model...
Model trained in 0:00:00.267350
Compiling model...
Model compiled.


[INFO 23-08-16 11:05:26.8771 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpg9srb1ip/model/ with prefix caa2217206e3449f
[INFO 23-08-16 11:05:26.8819 UTC decision_forest.cc:660] Model loaded with 42 root(s), 1322 node(s), and 3 input feature(s).
[INFO 23-08-16 11:05:26.8819 UTC kernel.cc:1075] Use fast generic engine






请注意year在CATEGORICAL特征列表中(与第一次运行不同)。

13. 超参数

超参数是训练算法的参数,会影响最终模型的质量。它们在模型类的构造函数中指定。可以使用问号 colab命令(例如?tfdf.keras.GradientBoostedTreesModel)查看超参数列表。

或者,您可以在TensorFlow决策森林Github或Yggdrasil决策森林文档中找到它们。

每个算法的默认超参数大致与初始发表的论文相匹配。为了确保一致性,默认情况下始终禁用新功能及其匹配的超参数。这就是为什么调整超参数是一个好主意的原因。

# 创建一个梯度提升树模型,使用BEST_FIRST_GLOBAL作为生长策略,最大深度为8,共有500棵树
model_6 = tfdf.keras.GradientBoostedTreesModel(
    num_trees=500, growing_strategy="BEST_FIRST_GLOBAL", max_depth=8)

# 使用训练数据集来训练模型
model_6.fit(train_ds)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


Use /tmpfs/tmp/tmp3ys0wqar as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.173451. Found 239 examples.
Training model...


[WARNING 23-08-16 11:05:27.1239 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:27.1240 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:27.1240 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
[INFO 23-08-16 11:05:33.8276 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmp3ys0wqar/model/ with prefix 10ca7d5d94bb4882


Model trained in 0:00:06.802956
Compiling model...
Model compiled.


[INFO 23-08-16 11:05:34.0959 UTC decision_forest.cc:660] Model loaded with 1500 root(s), 86196 node(s), and 7 input feature(s).
[INFO 23-08-16 11:05:34.0960 UTC abstract_model.cc:1311] Engine "GradientBoostedTreesGeneric" built
[INFO 23-08-16 11:05:34.0960 UTC kernel.cc:1075] Use fast generic engine






# 创建一个更复杂但可能更准确的模型

# 使用GradientBoostedTreesModel创建一个模型,参数如下:
# num_trees=500:使用500棵树来构建模型
# growing_strategy="BEST_FIRST_GLOBAL":使用最佳优先全局生长策略,即每次选择最佳的特征进行分裂
# max_depth=8:每棵树的最大深度为8
# split_axis="SPARSE_OBLIQUE":使用稀疏斜分裂轴,即使用斜线进行特征分裂
# categorical_algorithm="RANDOM":对于分类特征,使用随机算法进行处理
model_7 = tfdf.keras.GradientBoostedTreesModel(
    num_trees=500,
    growing_strategy="BEST_FIRST_GLOBAL",
    max_depth=8,
    split_axis="SPARSE_OBLIQUE",
    categorical_algorithm="RANDOM",
    )

# 使用训练数据集train_ds对模型进行训练
model_7.fit(train_ds)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


Use /tmpfs/tmp/tmpkpibv70a as temporary training directory
Reading training dataset...


[WARNING 23-08-16 11:05:34.2860 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:34.2860 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:34.2860 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".


WARNING:tensorflow:5 out of the last 5 calls to  triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


WARNING:tensorflow:5 out of the last 5 calls to  triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


Training dataset read in 0:00:00.171961. Found 239 examples.
Training model...


[INFO 23-08-16 11:05:42.9485 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpkpibv70a/model/ with prefix c697f69f5a7e4d74


Model trained in 0:00:08.763151
Compiling model...
WARNING:tensorflow:5 out of the last 5 calls to .predict_function_trained at 0x7f823050ec10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


[INFO 23-08-16 11:05:43.2128 UTC decision_forest.cc:660] Model loaded with 1500 root(s), 85322 node(s), and 7 input feature(s).
[INFO 23-08-16 11:05:43.2128 UTC kernel.cc:1075] Use fast generic engine
WARNING:tensorflow:5 out of the last 5 calls to .predict_function_trained at 0x7f823050ec10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


Model compiled.






随着新的训练方法的发布和实施,超参数的组合可以出现比默认参数好或几乎总是更好的情况。为了避免改变默认的超参数值,这些好的组合被索引并作为超参数模板提供。

例如,benchmark_rank1模板是我们内部基准测试中最佳的组合。这些模板被版本化,以确保训练配置的稳定性,例如benchmark_rank1@v1

# 导入所需的库
import tensorflow_decision_forests as tfdf

# 创建一个使用预定义超参数模板的梯度提升树模型
# 使用"benchmark_rank1"超参数模板,该模板是一个良好的模板选择
model_8 = tfdf.keras.GradientBoostedTreesModel(hyperparameter_template="benchmark_rank1")

# 使用训练数据集来训练模型
model_8.fit(train_ds)
Resolve hyper-parameter template "benchmark_rank1" to "benchmark_rank1@v1" -> {'growing_strategy': 'BEST_FIRST_GLOBAL', 'categorical_algorithm': 'RANDOM', 'split_axis': 'SPARSE_OBLIQUE', 'sparse_oblique_normalization': 'MIN_MAX', 'sparse_oblique_num_projections_exponent': 1.0}.
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


Use /tmpfs/tmp/tmpzjvgcmpm as temporary training directory
Reading training dataset...
WARNING:tensorflow:6 out of the last 6 calls to  triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


[WARNING 23-08-16 11:05:43.4453 UTC gradient_boosted_trees.cc:1818] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:43.4453 UTC gradient_boosted_trees.cc:1829] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 23-08-16 11:05:43.4453 UTC gradient_boosted_trees.cc:1843] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
WARNING:tensorflow:6 out of the last 6 calls to  triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


Training dataset read in 0:00:00.169369. Found 239 examples.
Training model...
Model trained in 0:00:03.481820
Compiling model...


[INFO 23-08-16 11:05:46.9935 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpzjvgcmpm/model/ with prefix 5763241a118b4af3
[INFO 23-08-16 11:05:47.0978 UTC decision_forest.cc:660] Model loaded with 900 root(s), 35042 node(s), and 7 input feature(s).
[INFO 23-08-16 11:05:47.0978 UTC abstract_model.cc:1311] Engine "GradientBoostedTreesGeneric" built
[INFO 23-08-16 11:05:47.0978 UTC kernel.cc:1075] Use fast generic engine


WARNING:tensorflow:6 out of the last 6 calls to .predict_function_trained at 0x7f82304539d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


WARNING:tensorflow:6 out of the last 6 calls to .predict_function_trained at 0x7f82304539d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


Model compiled.






可用的模板可通过predefined_hyperparameters进行访问。请注意,即使名称相似,不同的学习算法也有不同的模板。

# 导入tfdf库
import tensorflow_decision_forests as tfdf

# 打印梯度提升树模型的预定义超参数模板
print(tfdf.keras.GradientBoostedTreesModel.predefined_hyperparameters())
[HyperParameterTemplate(name='better_default', version=1, parameters={'growing_strategy': 'BEST_FIRST_GLOBAL'}, description='A configuration that is generally better than the default parameters without being more expensive.'), HyperParameterTemplate(name='benchmark_rank1', version=1, parameters={'growing_strategy': 'BEST_FIRST_GLOBAL', 'categorical_algorithm': 'RANDOM', 'split_axis': 'SPARSE_OBLIQUE', 'sparse_oblique_normalization': 'MIN_MAX', 'sparse_oblique_num_projections_exponent': 1.0}, description='Top ranking hyper-parameters on our benchmark slightly modified to run in reasonable time.')]

14. 特征预处理

有时需要对特征进行预处理,以便处理具有复杂结构的信号,规范化模型或应用迁移学习。可以通过以下三种方式进行预处理:

  1. 在Pandas数据框上进行预处理。这种解决方案易于实现,通常适用于实验。但是,通过model.save()无法导出预处理逻辑。

  2. Keras预处理:虽然比前一种解决方案更复杂,但Keras预处理已打包在模型中。

  3. TensorFlow特征列:此API是TF Estimator库(!= Keras)的一部分,并计划停用。在使用现有预处理代码时,此解决方案非常有趣。

注意:使用TensorFlow Hub预训练的嵌入通常是使用TF-DF处理文本和图像的好方法。例如,hub.KerasLayer("https://tfhub.dev/google/nnlm-en-dim128/2")。有关更多详细信息,请参阅中级教程。

在下一个示例中,将body_mass_g特征预处理为body_mass_kg = body_mass_g / 1000bill_length_mm不经过预处理。请注意,这种单调转换通常不会对决策森林模型产生影响。

# 定义输入层,shape为(1,),名称为"body_mass_g"
body_mass_g = tf.keras.layers.Input(shape=(1,), name="body_mass_g")

# 将"body_mass_g"除以1000,得到"body_mass_kg"
body_mass_kg = body_mass_g / 1000.0

# 定义输入层,shape为(1,),名称为"bill_length_mm"
bill_length_mm = tf.keras.layers.Input(shape=(1,), name="bill_length_mm")

# 将输入层封装成字典形式,键为输入层的名称,值为输入层本身
raw_inputs = {"body_mass_g": body_mass_g, "bill_length_mm": bill_length_mm}

# 将处理后的输入层封装成字典形式,键为处理后的输入层的名称,值为输入层本身
processed_inputs = {"body_mass_kg": body_mass_kg, "bill_length_mm": bill_length_mm}

# 创建一个包含预处理逻辑的模型,输入为raw_inputs,输出为processed_inputs
preprocessor = tf.keras.Model(inputs=raw_inputs, outputs=processed_inputs)

# 创建一个包含预处理逻辑和决策森林的模型,预处理逻辑为preprocessor
model_4 = tfdf.keras.RandomForestModel(preprocessing=preprocessor)

# 使用训练数据集训练模型
model_4.fit(train_ds)

# 打印模型的摘要信息
model_4.summary()



Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


Use /tmpfs/tmp/tmpmthx2t9p as temporary training directory
Reading training dataset...


/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/functional.py:639: UserWarning: Input dict contained keys ['island', 'bill_depth_mm', 'flipper_length_mm', 'sex', 'year'] which did not match any model input. They will be ignored by the model.
  inputs = self._flatten_to_reference_inputs(inputs)


Training dataset read in 0:00:00.226996. Found 239 examples.
Training model...
Model trained in 0:00:00.072877
Compiling model...
Model compiled.
WARNING:tensorflow:5 out of the last 12 calls to  triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


[INFO 23-08-16 11:05:47.5902 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpmthx2t9p/model/ with prefix 124e9d64a10e49f0
[INFO 23-08-16 11:05:47.6086 UTC decision_forest.cc:660] Model loaded with 300 root(s), 6310 node(s), and 2 input feature(s).
[INFO 23-08-16 11:05:47.6086 UTC kernel.cc:1075] Use fast generic engine
WARNING:tensorflow:5 out of the last 12 calls to  triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


Model: "random_forest_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 model (Functional)          {'body_mass_kg': (None,   0         
                              1),                                
                              'bill_length_mm': (Non             
                             e, 1)}                              
                                                                 
=================================================================
Total params: 1 (1.00 Byte)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 1 (1.00 Byte)
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (2):
	bill_length_mm
	body_mass_kg

No weights

Variable Importance: INV_MEAN_MIN_DEPTH:
    1. "bill_length_mm"  0.866916 ################
    2.   "body_mass_kg"  0.488533 

Variable Importance: NUM_AS_ROOT:
    1. "bill_length_mm" 263.000000 ################
    2.   "body_mass_kg" 37.000000 

Variable Importance: NUM_NODES:
    1. "bill_length_mm" 1537.000000 ################
    2.   "body_mass_kg" 1468.000000 

Variable Importance: SUM_SCORE:
    1. "bill_length_mm" 41227.008434 ################
    2.   "body_mass_kg" 27680.406197 



Winner takes all: true
Out-of-bag evaluation: accuracy:0.920502 logloss:0.636824
Number of trees: 300
Total number of nodes: 6310

Number of nodes by tree:
Count: 300 Average: 21.0333 StdDev: 3.13882
Min: 11 Max: 29 Ignored: 0
----------------------------------------------
[ 11, 12)  1   0.33%   0.33%
[ 12, 13)  0   0.00%   0.33%
[ 13, 14)  1   0.33%   0.67%
[ 14, 15)  0   0.00%   0.67%
[ 15, 16) 15   5.00%   5.67% ##
[ 16, 17)  0   0.00%   5.67%
[ 17, 18) 35  11.67%  17.33% ####
[ 18, 19)  0   0.00%  17.33%
[ 19, 20) 56  18.67%  36.00% #######
[ 20, 21)  0   0.00%  36.00%
[ 21, 22) 79  26.33%  62.33% ##########
[ 22, 23)  0   0.00%  62.33%
[ 23, 24) 58  19.33%  81.67% #######
[ 24, 25)  0   0.00%  81.67%
[ 25, 26) 40  13.33%  95.00% #####
[ 26, 27)  0   0.00%  95.00%
[ 27, 28) 13   4.33%  99.33% ##
[ 28, 29)  0   0.00%  99.33%
[ 29, 29]  2   0.67% 100.00%

Depth by leafs:
Count: 3305 Average: 4.01755 StdDev: 1.39146
Min: 1 Max: 8 Ignored: 0
----------------------------------------------
[ 1, 2)  20   0.61%   0.61%
[ 2, 3) 368  11.13%  11.74% ####
[ 3, 4) 918  27.78%  39.52% #########
[ 4, 5) 973  29.44%  68.96% ##########
[ 5, 6) 517  15.64%  84.60% #####
[ 6, 7) 318   9.62%  94.22% ###
[ 7, 8) 145   4.39%  98.61% #
[ 8, 8]  46   1.39% 100.00%

Number of training obs by leaf:
Count: 3305 Average: 21.6944 StdDev: 26.3178
Min: 5 Max: 107 Ignored: 0
----------------------------------------------
[   5,  10) 2102  63.60%  63.60% ##########
[  10,  15)  237   7.17%  70.77% #
[  15,  20)   54   1.63%  72.41%
[  20,  25)   17   0.51%  72.92%
[  25,  30)   53   1.60%  74.52%
[  30,  35)   74   2.24%  76.76%
[  35,  41)   99   3.00%  79.76%
[  41,  46)   58   1.75%  81.51%
[  46,  51)   23   0.70%  82.21%
[  51,  56)   18   0.54%  82.75%
[  56,  61)   58   1.75%  84.51%
[  61,  66)   70   2.12%  86.63%
[  66,  71)  102   3.09%  89.71%
[  71,  77)  109   3.30%  93.01% #
[  77,  82)   76   2.30%  95.31%
[  82,  87)   70   2.12%  97.43%
[  87,  92)   40   1.21%  98.64%
[  92,  97)   23   0.70%  99.33%
[  97, 102)   16   0.48%  99.82%
[ 102, 107]    6   0.18% 100.00%

Attribute in nodes:
	1537 : bill_length_mm [NUMERICAL]
	1468 : body_mass_kg [NUMERICAL]

Attribute in nodes with depth <= 0:
	263 : bill_length_mm [NUMERICAL]
	37 : body_mass_kg [NUMERICAL]

Attribute in nodes with depth <= 1:
	446 : bill_length_mm [NUMERICAL]
	434 : body_mass_kg [NUMERICAL]

Attribute in nodes with depth <= 2:
	917 : body_mass_kg [NUMERICAL]
	755 : bill_length_mm [NUMERICAL]

Attribute in nodes with depth <= 3:
	1195 : body_mass_kg [NUMERICAL]
	1143 : bill_length_mm [NUMERICAL]

Attribute in nodes with depth <= 5:
	1477 : bill_length_mm [NUMERICAL]
	1421 : body_mass_kg [NUMERICAL]

Condition type in nodes:
	3005 : HigherCondition
Condition type in nodes with depth <= 0:
	300 : HigherCondition
Condition type in nodes with depth <= 1:
	880 : HigherCondition
Condition type in nodes with depth <= 2:
	1672 : HigherCondition
Condition type in nodes with depth <= 3:
	2338 : HigherCondition
Condition type in nodes with depth <= 5:
	2898 : HigherCondition
Node format: NOT_SET

Training OOB:
	trees: 1, Out-of-bag evaluation: accuracy:0.875 logloss:4.50546
	trees: 13, Out-of-bag evaluation: accuracy:0.890295 logloss:2.35926
	trees: 23, Out-of-bag evaluation: accuracy:0.891213 logloss:1.76382
	trees: 35, Out-of-bag evaluation: accuracy:0.903766 logloss:1.61533
	trees: 46, Out-of-bag evaluation: accuracy:0.912134 logloss:1.61544
	trees: 59, Out-of-bag evaluation: accuracy:0.912134 logloss:1.33186
	trees: 69, Out-of-bag evaluation: accuracy:0.916318 logloss:1.19735
	trees: 80, Out-of-bag evaluation: accuracy:0.920502 logloss:1.20323
	trees: 90, Out-of-bag evaluation: accuracy:0.916318 logloss:1.06613
	trees: 102, Out-of-bag evaluation: accuracy:0.916318 logloss:0.920117
	trees: 112, Out-of-bag evaluation: accuracy:0.916318 logloss:0.919398
	trees: 122, Out-of-bag evaluation: accuracy:0.916318 logloss:0.918544
	trees: 132, Out-of-bag evaluation: accuracy:0.916318 logloss:0.917733
	trees: 143, Out-of-bag evaluation: accuracy:0.920502 logloss:0.916464
	trees: 154, Out-of-bag evaluation: accuracy:0.920502 logloss:0.916065
	trees: 167, Out-of-bag evaluation: accuracy:0.920502 logloss:0.915384
	trees: 178, Out-of-bag evaluation: accuracy:0.920502 logloss:0.781669
	trees: 188, Out-of-bag evaluation: accuracy:0.924686 logloss:0.782319
	trees: 200, Out-of-bag evaluation: accuracy:0.920502 logloss:0.774758
	trees: 210, Out-of-bag evaluation: accuracy:0.924686 logloss:0.774025
	trees: 221, Out-of-bag evaluation: accuracy:0.924686 logloss:0.770531
	trees: 231, Out-of-bag evaluation: accuracy:0.924686 logloss:0.77066
	trees: 241, Out-of-bag evaluation: accuracy:0.924686 logloss:0.767545
	trees: 251, Out-of-bag evaluation: accuracy:0.924686 logloss:0.767962
	trees: 261, Out-of-bag evaluation: accuracy:0.924686 logloss:0.767063
	trees: 271, Out-of-bag evaluation: accuracy:0.924686 logloss:0.767585
	trees: 281, Out-of-bag evaluation: accuracy:0.924686 logloss:0.766893
	trees: 292, Out-of-bag evaluation: accuracy:0.924686 logloss:0.634927
	trees: 300, Out-of-bag evaluation: accuracy:0.920502 logloss:0.636824

以下示例使用TensorFlow特征列重新实现了相同的逻辑。

# 定义一个函数 g_to_kg,用于将克转换为千克
def g_to_kg(x):
    return x / 1000

# 定义特征列,包括 "body_mass_g" 和 "bill_length_mm"
feature_columns = [
    tf.feature_column.numeric_column("body_mass_g", normalizer_fn=g_to_kg),  # 对 "body_mass_g" 应用 g_to_kg 函数进行归一化处理
    tf.feature_column.numeric_column("bill_length_mm"),
]

# 创建一个预处理层,将特征列应用到输入数据上
preprocessing = tf.keras.layers.DenseFeatures(feature_columns)

# 创建一个随机森林模型,将预处理层作为输入
model_5 = tfdf.keras.RandomForestModel(preprocessing=preprocessing)

# 使用训练数据集对模型进行训练
model_5.fit(train_ds)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10582/2850711544.py:5: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.


WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10582/2850711544.py:5: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.


Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


Use /tmpfs/tmp/tmpiqyvbd4a as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.163336. Found 239 examples.
Training model...
Model trained in 0:00:00.050388
Compiling model...
Model compiled.
WARNING:tensorflow:6 out of the last 13 calls to  triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


[INFO 23-08-16 11:05:47.9585 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpiqyvbd4a/model/ with prefix b3488dac1bd3468f
[INFO 23-08-16 11:05:47.9775 UTC decision_forest.cc:660] Model loaded with 300 root(s), 6310 node(s), and 2 input feature(s).
[INFO 23-08-16 11:05:47.9776 UTC kernel.cc:1075] Use fast generic engine
WARNING:tensorflow:6 out of the last 13 calls to  triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.






15. 训练回归模型

前面的例子训练了一个分类模型(TF-DF不区分二元分类和多元分类)。在下一个例子中,我们将在鲍鱼数据集上训练一个回归模型。该数据集的目标是预测鲍鱼的贝壳环数。

注意: CSV文件是通过附加UCI的标题和数据文件组装而成的。没有应用任何预处理。

工具系列:TensorFlow Decision Forests_(1)构建、训练和评估模型_第2张图片
# 下载数据集
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/abalone_raw.csv -O /tmp/abalone.csv

# 读取CSV文件并将其存储在名为dataset_df的DataFrame中
dataset_df = pd.read_csv("/tmp/abalone.csv")

# 打印DataFrame的前3行数据
print(dataset_df.head(3))
  Type  LongestShell  Diameter  Height  WholeWeight  ShuckedWeight  \
0    M         0.455     0.365   0.095       0.5140         0.2245   
1    M         0.350     0.265   0.090       0.2255         0.0995   
2    F         0.530     0.420   0.135       0.6770         0.2565   

   VisceraWeight  ShellWeight  Rings  
0         0.1010         0.15     15  
1         0.0485         0.07      7  
2         0.1415         0.21      9  
# 将数据集分为训练集和测试集。
train_ds_pd, test_ds_pd = split_dataset(dataset_df)

# 输出训练集和测试集的样本数量。
print("{} 个样本用于训练,{} 个样本用于测试。".format(len(train_ds_pd), len(test_ds_pd)))

# 定义标签列的名称。
label = "Rings"

# 将 Pandas 数据框转换为 TensorFlow 数据集,用于回归任务。
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label, task=tfdf.keras.Task.REGRESSION)
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=label, task=tfdf.keras.Task.REGRESSION)
2885 examples in training, 1292 examples for testing.
# 设置单元格高度为300

# 配置模型。
# 创建一个随机森林模型,用于回归任务。
model_7 = tfdf.keras.RandomForestModel(task=tfdf.keras.Task.REGRESSION)

# 训练模型。
# 使用训练数据集train_ds对模型进行训练。
model_7.fit(train_ds)



Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.


Use /tmpfs/tmp/tmpr6p677s9 as temporary training directory
Reading training dataset...
Training dataset read in 0:00:00.210193. Found 2885 examples.
Training model...


[INFO 23-08-16 11:05:49.2476 UTC kernel.cc:1243] Loading model from path /tmpfs/tmp/tmpr6p677s9/model/ with prefix 8e87bf0ca0c24b13


Model trained in 0:00:01.408931
Compiling model...


[INFO 23-08-16 11:05:50.0462 UTC decision_forest.cc:660] Model loaded with 300 root(s), 260570 node(s), and 8 input feature(s).
[INFO 23-08-16 11:05:50.0463 UTC kernel.cc:1075] Use fast generic engine


Model compiled.






# 评估模型在测试数据集上的性能。
model_7.compile(metrics=["mse"])  # 编译模型,指定评估指标为均方误差(Mean Squared Error)
evaluation = model_7.evaluate(test_ds, return_dict=True)  # 在测试数据集上评估模型,并返回评估结果字典

print(evaluation)  # 打印评估结果字典
print()
print(f"MSE: {evaluation['mse']}")  # 打印均方误差(Mean Squared Error)
print(f"RMSE: {math.sqrt(evaluation['mse'])}")  # 打印均方根误差(Root Mean Squared Error)
WARNING:tensorflow:5 out of the last 5 calls to .test_function at 0x7f82300f5310> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.


WARNING:tensorflow:5 out of the last 5 calls to .test_function at 0x7f82300f5310> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.



1/2 [==============>...............] - ETA: 0s - loss: 0.0000e+00 - mse: 4.7546
2/2 [==============================] - 0s 15ms/step - loss: 0.0000e+00 - mse: 4.6777
{'loss': 0.0, 'mse': 4.677670955657959}

MSE: 4.677670955657959
RMSE: 2.1627923977252093

你可能感兴趣的:(数据挖掘,tensorflow,机器学习)