一觉醒来!Keras 3.0史诗级更新,大一统深度学习三大后端框架【Tensorflow/PyTorch/Jax】

不知道大家入门上手机器学习项目是首先入坑的哪个深度学习框架,对于我来说,最先看到的听到的就是Tensorflow了,但是实际上手做项目开发的时候却发现了一个很重要的问题,不容易上手,基于原生的tf框架来直接开发模总是有不小的难度,后来发现了Keras,简直就是深度学习的福音,本质上来讲,Keras是对后端深度学习框架的高级封装。

框架介绍

当涉及深度学习和机器学习时,TensorFlow、PyTorch和JAX都是非常流行的框架。以下是它们的详细介绍以及各自的优点和缺点:

【TensorFlow】
TensorFlow 是由 Google Brain 团队开发的开源框架,可以实现深度学习和机器学习模型的构建和训练。
优点:

深度学习生态系统完善,支持常见的深度学习模型开发和部署。
支持分布式计算,适合大规模数据和模型训练。
提供了 TensorFlow Serving 等支持模型部署和服务化的工具。
在生产环境中表现稳定,有着广泛的产业应用。
缺点:

前端API相对较复杂,学习曲线较陡。
相比其他框架,可读性稍低,需要编写更多的代码。
在一些场景下的灵活性和易用性不如 PyTorch。
【PyTorch】
PyTorch 是 Facebook 开发的深度学习框架,也是一个开源项目。它和 TensorFlow 一样,可以用于实现深度学习算法。
优点:

前端API设计简洁直观,易于学习和使用。
动态计算图的特性使得调试更加方便,同时也更具灵活性。
提供了针对自然语言处理和计算机视觉等任务的高层抽象库,如 TorchText 和 TorchVision。
缺点:

缺乏在大规模分布式训练中的一些优化和工具支持。
对生产部署支持相对不足,相较 TensorFlow 在工业界的部署相对不如。
【JAX】
JAX 是 Google 开源的一个库,它提供了用于数值计算(尤其是自动微分和加速机器学习模型)的高性能Python接口。
优点:

采用 XLA (Accelerated Linear Algebra)进行加速,能够在 CPU 或 GPU 上自动并行化。
支持自动微分,可以用于构建自定义的优化器和模型。
完全可扩展,可以在大规模计算中实现高性能计算。
缺点:

相对 TensorFlow 和 PyTorch,社区和生态系统相对较小。
API 相对不够丰富,需要编写更多的自定义代码。
综上所述,TensorFlow、PyTorch 和 JAX 都有各自的优势和劣势,选择哪个框架要根据具体的需求和背景来决定。例如,如果需要处理大规模的生产环境应用,可能更倾向于选择 TensorFlow;而如果需要快速的原型开发和实验,PyTorch 是一个更好的选择。而 JAX 则适合于寻求高性能计算的研究和实验工作。

接下来我们来详细看官方本次发布的Keras3.0的详细介绍,官方介绍在这里,如下所示:

一觉醒来!Keras 3.0史诗级更新,大一统深度学习三大后端框架【Tensorflow/PyTorch/Jax】_第1张图片

经过五个月的广泛公测,我们很高兴地宣布Keras 3.0的正式发布。Keras 3是对Keras的全面重写,使您能够在JAX、TensorFlow或PyTorch之上运行Keras工作流,并释放全新的大规模模型培训和部署功能。你可以选择最适合你的框架,并根据你当前的目标从一个框架切换到另一个框架。您还可以使用Keras作为一种低级的跨框架语言来开发自定义组件,如层、模型或度量,这些组件可以在JAX、TensorFlow或PyTorch中的本地工作流中使用——只需一个代码库。

欢迎使用多框架机器学习

您已经熟悉了使用Keras的好处——它通过专注于出色的UX、API设计和可调试性来实现高速开发。这也是一个经过战斗测试的框架,已被超过250万开发者选择,为世界上一些最复杂、规模最大的ML系统提供了动力,如Waymo自动驾驶车队和YouTube推荐引擎。但是使用新的多后端Keras 3还有什么额外的好处呢?

始终为您的模型获得最佳性能

在我们的基准测试中,我们发现JAX通常在GPU、TPU和CPU上提供最佳的训练和推理性能,但结果因模型而异,因为非XLA TensorFlow在GPU上偶尔会更快。能够动态选择为您的模型提供最佳性能的后端,而无需对代码进行任何更改,这意味着您可以保证以最高的效率进行培训和服务。

解锁您的模型的生态系统可选性

任何Keras 3模型都可以实例化为PyTorch模块,可以导出为TensorFlow SavedModel,也可以实例化为无状态JAX函数。这意味着您可以将Keras 3模型与PyTorch生态系统包、全套TensorFlow部署和生产工具(如TF Serving、TF.js和TFLite)以及JAX大规模TPU培训基础设施一起使用。使用Keras3API编写一个model.py,并访问ML世界所提供的一切。

利用JAX的大规模模型并行性和数据并行性。Keras3包含一个全新的发行版API,Keras.distribution名称空间,目前为JAX后端实现(即将在TensorFlow和PyTorch后端实现)。它使在任意模型规模和集群规模上进行模型并行、数据并行以及两者的组合变得容易。因为它使模型定义、训练逻辑和分片配置彼此分离,所以它使您的分发工作流易于开发和维护。请参阅我们的入门指南。

最大限度地扩大您的开源模型发布范围

想要发布一个预先训练好的模型吗?想要尽可能多的人能够使用它吗?如果你在纯TensorFlow或PyTorch中实现它,大约一半的社区都可以使用它。如果你在Keras 3中实现它,那么任何人都可以立即使用它,无论他们选择的框架是什么(即使他们自己不是Keras用户)。在不增加开发成本的情况下实现两倍的效果。

使用任何来源的数据管道

Keras 3 fit()/evaluatate()/predict()例程与tf.data.Dataset对象、PyTorch DataLoader对象、NumPy数组、Pandas数据帧兼容,无论您使用的是后端。您可以在PyTorch DataLoader上训练Keras 3+TensorFlow模型,也可以在tf.data.Dataset上训练Keras3+PyTorch模型。

完整的Keras API,可用于JAX、TensorFlow和PyTorch

Keras 3实现了完整的Keras API,并使其与TensorFlow、JAX和PyTorch一起使用-超过一百层、数十个度量、丢失函数、优化器和回调、Keras训练和评估循环以及Keras保存和序列化基础设施。所有你熟悉和喜爱的API都在这里。

任何只使用内置层的Keras模型都将立即与所有支持的后端一起工作。事实上,您现有的仅使用内置层的tf.keras模型可以立即在JAX和PyTorch中运行!没错,您的代码库刚刚获得了一组全新的功能。

一觉醒来!Keras 3.0史诗级更新,大一统深度学习三大后端框架【Tensorflow/PyTorch/Jax】_第2张图片

编写多框架层、模型、度量

Keras 3使您能够创建在任何框架中都能正常工作的组件(如任意自定义层或预训练模型)。特别是,Keras3允许您访问跨所有后端工作的Keras.ops命名空间。它包含:

NumPy API的完整实现。不是类似NumPy的东西——只是字面上的NumPy API,具有相同的函数和相同的参数。您将获得ops.matmul、ops.sum、ops.stack、ops.einsum等。

NumPy中没有的一组特定于神经网络的函数,如ops.softmax、ops.binary_crossentropy、ops.cov等。

只要您只使用keras.ops中的操作,您的自定义层、自定义损失、自定义度量和自定义优化器将使用JAX、PyTorch和TensorFlow使用相同的代码。这意味着您只能维护一个组件实现(例如,一个model.py和一个检查点文件),并且您可以在所有框架中使用它,使用完全相同的数字。

…与任何JAX、TensorFlow和PyTorch工作流无缝配合

Keras 3不仅仅适用于以Keras为中心的工作流,您可以在其中定义Keras模型、Keras优化器、Keras损失和度量,并调用fit()、evaluate()和predict()。它还意味着可以与底层后端本机工作流无缝协作:您可以采用Keras模型(或任何其他组件,如损失或度量),并开始在JAX训练循环、TensorFlow训练循环或PyTorch训练循环中使用它,或者作为JAX或PyTorc模型的一部分,零摩擦。Keras3在JAX和PyTorch中提供了与tf.Keras之前在TensorFlow中所做的完全相同程度的底层实现灵活性。

您可以:

编写一个低级JAX训练循环,使用optax优化器JAX.grad、JAX.jit、JAX.pmap来训练Keras模型。

编写一个低级TensorFlow训练循环,使用tf.GradientTape和tf.distribute训练Keras模型。

编写一个低级PyTorch训练循环,使用torch.optim优化器、torch损失函数和torch.nn.parallel.DistributtedDataParallel包装器来训练Keras模型。

在PyTorch模块中使用Keras层(因为它们也是模块实例!)

在Keras模型中使用任何PyTorch模块,就好像它是Keras层一样。

等等

一种新的分布式API,用于大规模数据并行和模型并行

我们一直在研究的模型越来越大,所以我们想为多设备模型分片问题提供一个Kerasic解决方案。我们设计的API使模型定义、训练逻辑和分片配置完全分离,这意味着可以将模型编写为在单个设备上运行。然后,当需要对任意模型进行训练时,可以将任意分片配置添加到任意模型中。

数据并行性(在多个设备上相同地复制一个小模型)只需两行即可处理:

一觉醒来!Keras 3.0史诗级更新,大一统深度学习三大后端框架【Tensorflow/PyTorch/Jax】_第3张图片

模型并行性使您可以沿着多个命名维度为模型变量和中间输出张量指定分片布局。在典型情况下,您可以将可用设备组织为二维网格(称为设备网格),其中第一个维度用于数据并行,第二个维度用于模型并行。然后,您可以将模型配置为沿模型维度进行分片,并沿数据维度进行复制。

API允许您通过正则表达式配置每个变量和每个输出张量的布局。这样可以很容易地为整个变量类别快速指定相同的布局。

一觉醒来!Keras 3.0史诗级更新,大一统深度学习三大后端框架【Tensorflow/PyTorch/Jax】_第4张图片

新的发行版API旨在成为多后端,但目前仅适用于JAX后端。TensorFlow和PyTorch的支持即将到来。开始使用此指南!

预训练模型

有一系列预先训练好的模型,您今天可以在Keras 3中开始使用。

所有40个Keras应用程序模型(Keras.Applications命名空间)在所有后端都可用。KerasCV和KerasNLP中的大量预训练模型也适用于所有后端。这包括:

一觉醒来!Keras 3.0史诗级更新,大一统深度学习三大后端框架【Tensorflow/PyTorch/Jax】_第5张图片

支持所有后端的跨框架数据管道

多框架ML也意味着多框架数据的加载和预处理。Keras 3模型可以使用广泛的数据管道进行训练——无论您使用的是JAX、PyTorch还是TensorFlow后端。它只是起作用。

tf.data.Dataset管道:可扩展生产ML的参考。

torch.utils.data.DataLoader对象。

NumPy数组和Pandas数据帧。

Keras自己的Keras.utils.PyDataset对象。

复杂性的逐步披露

复杂性的渐进披露是Keras API核心的设计原则。Keras不会强迫你遵循一种“真正”的方式来构建和训练模型。相反,它支持各种不同的工作流,从非常高级到非常低级,对应于不同的用户配置文件。

这意味着您可以从简单的工作流开始,例如使用Sequential和Functional模型,并使用fit()对它们进行训练。当您需要更大的灵活性时,您可以轻松地自定义不同的组件,同时重用大多数以前的代码。随着你的需求变得更加具体,你不会突然从复杂性的悬崖上跌落,也不需要切换到不同的工具集。

我们把这个原则带到了我们所有的后台。例如,您可以自定义训练循环中发生的事情,同时仍然利用fit()的功能,而不必从头开始编写自己的训练循环——只需重写train_step方法。

以下是它在PyTorch和TensorFlow中的工作方式:

这是JAX版本的链接。

class CustomModel(keras.Model):
    def compute_loss_and_updates(
        self,
        trainable_variables,
        non_trainable_variables,
        x,
        y,
        training=False,
    ):
        y_pred, non_trainable_variables = self.stateless_call(
            trainable_variables,
            non_trainable_variables,
            x,
            training=training,
        )
        loss = self.compute_loss(x, y, y_pred)
        return loss, (y_pred, non_trainable_variables)

    def train_step(self, state, data):
        (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            metrics_variables,
        ) = state
        x, y = data

        # Get the gradient function.
        grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)

        # Compute the gradients.
        (loss, (y_pred, non_trainable_variables)), grads = grad_fn(
            trainable_variables,
            non_trainable_variables,
            x,
            y,
            training=True,
        )

        # Update trainable variables and optimizer variables.
        (
            trainable_variables,
            optimizer_variables,
        ) = self.optimizer.stateless_apply(
            optimizer_variables, grads, trainable_variables
        )

        # Update metrics.
        new_metrics_vars = []
        for metric in self.metrics:
            this_metric_vars = metrics_variables[
                len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
            ]
            if metric.name == "loss":
                this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
            else:
                this_metric_vars = metric.stateless_update_state(
                    this_metric_vars, y, y_pred
                )
            logs = metric.stateless_result(this_metric_vars)
            new_metrics_vars += this_metric_vars

        # Return metric logs and updated state variables.
        state = (
            trainable_variables,
            non_trainable_variables,
            optimizer_variables,
            new_metrics_vars,
        )
        return logs, state

新的无状态API,用于层、模型、度量和优化器

你喜欢函数式编程吗?你会得到款待的。

Keras中的所有有状态对象(即拥有在训练或评估过程中更新的数值变量的对象)现在都有一个无状态的API,从而可以在JAX函数中使用它们(要求完全无状态):

所有层和模型都有一个stateless_call()方法,该方法镜像__call__()。

所有优化器都有一个stateless_apply()方法,该方法镜像apply(()。

所有度量都有一个镜像update_state()的stateless_update_state)方法和一个镜像result()的stateless_sult()方法。

这些方法没有任何副作用:它们将目标对象的状态变量的当前值作为输入,并将更新值作为输出的一部分返回,例如:

outputs, updated_non_trainable_variables = layer.stateless_call(
    trainable_variables,
    non_trainable_variables,
    inputs,
)

您永远不必自己实现这些方法——只要您实现了有状态版本(例如call()或update_state()),它们就会自动可用。

从Keras 2移动到Keras 3

Keras 3与Keras 2高度向后兼容:它实现了Keras 2的完整公共API表面,这里列出了有限的例外情况。大多数用户不需要进行任何代码更改就可以在Keras3上开始运行他们的Keras脚本。

较大的代码库可能需要一些代码更改,因为它们更有可能遇到上面列出的异常之一,并且更有可能使用私有API或不推荐使用的API(tf.compat.v1.keras命名空间、实验命名空间、keras.src私有命名空间)。为了帮助您迁移到Keras 3,我们发布了一份完整的迁移指南,其中包含您可能遇到的所有问题的快速修复程序。

您还可以选择忽略Keras 3中的更改,只需将Keras 2与TensorFlow一起使用——这对于那些没有积极开发但需要使用更新的依赖关系继续运行的项目来说是一个很好的选择。您有两种可能性:

1、如果您将keras作为一个独立的包访问,只需切换到使用Python包tf_keras即可,您可以通过pip-install-tf_keras安装该包。代码和API完全没有变化——它是Keras 2.15,具有不同的包名称。我们将继续修复tf_keras中的错误,并定期发布新版本。但是,由于该软件包现在处于维护模式,因此不会添加任何新功能或性能改进。

2、如果您通过tf.keras访问keras,那么在TensorFlow 2.16之前不会立即发生更改。TensorFlow 2.16+默认情况下将使用Keras 3。在TensorFlow 2.16+中,要继续使用Keras 2,可以先安装tf_Keras,然后导出环境变量tf_USE_LEGACY_Keras=1。这将指导TensorFlow 2.16+将tf.keras解析为本地安装的tf_keras包。请注意,这可能影响的不仅仅是您自己的代码:它将影响Python进程中导入tf.keras的任何包。为了确保您的更改只影响您自己的代码,您应该使用tf_keras包。

常见问题解答

Q: Keras 3是否与旧版Keras 2兼容?

使用tf.keras开发的代码通常可以像使用keras 3(使用TensorFlow后端)一样运行。您应该注意的不兼容性数量有限,所有这些都在本迁移指南中介绍。

当涉及到同时使用来自tf.keras和keras 3的API时,这是不可能的——它们是不同的包,运行在完全不同的引擎上。

Q: 在旧版Keras 2中开发的预训练模型是否适用于Keras 3?

一般来说,是的。任何tf.keras模型都应该使用带有TensorFlow后端的keras 3(确保以.keras v3格式保存)。此外,如果模型只使用内置的Keras层,那么它也可以在带有JAX和PyTorch后端的Keras 3中开箱即用。

如果模型包含使用TensorFlow API编写的自定义层,则通常很容易将代码转换为后端无关的。例如,我们只花了几个小时就将keras应用程序中的所有40个遗留tf.keras模型转换为后端不可知的模型。

Q: 我可以在一个后端保存一个Keras 3模型并在另一个后端重新加载它吗?

是的,你可以。在保存的.keras文件中没有后端专门化。您保存的Keras模型与框架无关,可以使用任何后端重新加载。

但是,请注意,重新加载包含具有不同后端的自定义组件的模型需要使用与后端无关的API(例如keras.ops)来实现自定义组件。

Q: 我可以在tf.data管道中使用Keras 3组件吗?

对于TensorFlow后端,Keras 3与tf.data完全兼容(例如,您可以将序列模型映射到tf.data管道中)。

使用不同的后端,Keras 3对tf.data的支持有限。您将无法将任意层或模型映射到tf.data管道中。但是,您可以将特定的Keras 3预处理层与tf.data一起使用,例如IntegerLookup或CategoryEncoding。

当涉及到使用tf.data管道(不使用Keras)来提供对.fit()、.evaluate()或.predict()的调用时,所有后端都是现成的。

Q: Keras 3型号在不同后端运行时表现相同吗?

是的,后端的数字是相同的。但是,请记住以下注意事项:

RNG行为在不同的后端之间是不同的(即使在种子设定之后-您的结果在每个后端都是确定的,但在后端之间是不同的)。所以随机权重初始化值和退出值在后端会有所不同。

由于浮点实现的性质,在float32中,每个函数执行的结果在1e-7精度以内是相同的。因此,当长时间训练一个模型时,微小的数值差异会积累起来,最终可能导致显著的数值差异。

由于PyTorch中缺少对使用非对称填充的平均池的支持,使用padding=“same”的平均池层可能会导致边框行/列上的数字不同。这在实践中并不经常发生——在40个Keras应用程序视觉模型中,只有一个受到影响。

Q: Keras 3是否支持分布式训练?

JAX、TensorFlow和PyTorch支持数据并行分发。JAX通过keras.distribution API支持模型并行分发。

使用TensorFlow:

Keras3与tf.distribute兼容——只需打开一个Distribution策略范围并在其中创建/训练您的模型。

使用PyTorch:

Keras 3与PyTorch的DistributedDataParallel实用程序兼容。这里有一个例子。

使用JAX:

您可以使用keras.distribution API在JAX中进行数据并行和模型并行分发。例如,要进行数据并行分发,只需要以下代码段:

distribution = keras.distribution.DataParallel(devices=keras.distribution.list_devices())
keras.distribution.set_distribution(distribution)

有关模型并行分布,请参见以下指南。

您还可以通过JAX.sharding等JAX API分发培训。

Q: 我的自定义Keras层是否可以用于本地PyTorch模块或亚麻模块?

如果它们只使用kerasapi(例如Keras.ops名称空间)编写,那么是的,您的Keras层将使用本机PyTorch和JAX代码直接工作。在PyTorch中,只需像其他PyTorch模块一样使用Keras层。在JAX中,确保使用无状态层API,即layer.statese\u call()。

Q: 您将来会添加更多后端吗?那么框架XYZ呢?

我们愿意添加新的后端,只要目标框架有一个大的用户群,或者有一些独特的技术优势。然而,添加和维护一个新的后端是一个很大的负担,所以我们将仔细考虑每个新的后端候选人在个案的基础上,我们不太可能添加许多新的后端。我们不会添加任何尚未完善的新框架。我们现在可能会考虑添加一个用Mojo编写的后端。如果这是一些你可能会发现有用的东西,请让Mojo团队知道。

你可能感兴趣的:(深度学习,keras,tensorflow)