Keras 3 是一个多后端深度学习框架,支持 JAX、TensorFlow、PyTorch 和 OpenVINO(仅推理)。轻松构建和训练适用于计算机视觉、自然语言处理、音频处理、时间序列预测、推荐系统等领域的模型。
加入近三百万开发者的行列,从新兴初创公司到全球企业,共同释放 Keras 3 的强大能力。
Keras 3 在 PyPI 上以 keras
提供。请注意,Keras 2 仍以 tf-keras
包的形式提供。
keras
:pip install keras --upgrade
使用 keras
时,还需安装所选后端:tensorflow
、jax
或 torch
。注意,某些 Keras 3 功能(如特定预处理层和 tf.data
流水线)需要安装 tensorflow
。
Keras 3 兼容 Linux 和 MacOS 系统。Windows 用户建议使用 WSL2 运行 Keras。安装本地开发版本的步骤如下:
pip install -r requirements.txt
python pip_build.py --install
keras_export
公共 API 的 PR 时,运行 API 生成脚本:./shell/api_gen.sh
requirements.txt
文件将安装仅支持 CPU 的 TensorFlow、JAX 和 PyTorch。对于 GPU 支持,我们还为 TensorFlow、JAX 和 PyTorch 提供了单独的 requirements-{backend}-cuda.txt
文件。这些文件通过 pip
安装所有 CUDA 依赖项,并需要预先安装 NVIDIA 驱动程序。建议为每个后端创建干净的 Python 环境以避免 CUDA 版本冲突。例如,以下是使用 conda
创建 Jax GPU 环境的方法:
conda create -y -n keras-jax python=3.10
conda activate keras-jax
pip install -r requirements-jax-cuda.txt
python pip_build.py --install
您可以通过导出环境变量 KERAS_BACKEND
或编辑本地配置文件 ~/.keras/keras.json
来配置后端。可用的后端选项有:"tensorflow"
、"jax"
、"torch"
、"openvino"
。例如:
export KERAS_BACKEND="jax"
在 Colab 中,可以执行:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
注意:必须在导入 keras
之前配置后端,且导入后无法更改后端。
注意:OpenVINO 后端仅用于推理,即仅设计用于通过 model.predict()
方法运行模型预测。
Keras 3 旨在作为 tf.keras
的替代品(使用 TensorFlow 后端时)。只需将现有的 tf.keras
代码迁移,确保 model.save()
调用使用最新的 .keras
格式即可。
如果您的 tf.keras
模型不包含自定义组件,可以立即在 JAX 或 PyTorch 上运行。
如果包含自定义组件(如自定义层或自定义 train_step()
),通常只需几分钟即可将其转换为与后端无关的实现。
此外,Keras 模型可以使用任何格式的数据集,无论您使用哪种后端:您可以使用现有的 tf.data.Dataset
流水线或 PyTorch DataLoader
训练模型。
Module
或 JAX 原生模型函数的一部分。更多信息请参阅 Keras 3 发布公告。