常见的神经网络权重文件格式及其详细说明

常见的神经网络权重文件格式及其详细说明的表格:

扩展名 所属框架/工具 如何生成 表示内容 使用方法 注意事项
.pt.pth PyTorch torch.save(model.state_dict(), "model.pt") PyTorch模型的状态字典(权重和参数)或整个模型 加载方式:model.load_state_dict(torch.load("model.pt")) 如果保存整个模型(含结构和权重),可能导致跨设备加载问题。pth通常用于旧版本。
.h5.hdf5 Keras/TensorFlow model.save("model.h5") 完整的Keras模型(含结构、权重和优化器状态) 加载方式:keras.models.load_model("model.h5") HDF5格式依赖h5py库;部分自定义层可能需要手动定义。
.pkl.pickle Python通用(如scikit-learn) pickle.dump(model, open("model.pkl", "wb")) 序列化的Python对象(如模型参数或全量模型) 加载方式:model = pickle.load(open("model.pkl", "rb")) 存在安全风险(反序列化恶意代码);建议仅在可信来源使用。
.ckpt TensorFlow 使用tf.train.CheckpointModelCheckpoint回调 TensorFlow检查点文件(模型参数、优化器状态等) 恢复训练:model.load_weights("model.ckpt") 检查点文件包含多个文件(如.index.data-xxx等),需一起保留。
.pb TensorFlow(SavedModel) tf.saved_model.save(model, "model_dir") TensorFlow的计算图结构和权重(Protocol Buffers格式) 加载方式:tf.keras.models.load_model("model_dir") 或 tf.saved_model.load("model_dir") 跨语言兼容性好,支持C++/Java等语言调用。
.onnx ONNX(跨框架标准) torch.onnx.export(model, input, "model.onnx") 跨框架的标准化模型(含结构和权重) 加载工具:ONNX Runtime(ort.InferenceSession("model.onnx"))或其他框架的转换工具 需验证框架支持的操作;可能需调整算子兼容性。
.weights Darknet/YOLO Darknet训练时自动生成(如YOLOv3的yolov3.weights Darknet模型的权重参数,需配合配置文件(.cfg)使用 加载方式:darknet.load_net("yolov3.cfg", "yolov3.weights") 无模型结构信息,必须与对应的.cfg文件匹配使用。
.bin.safetensors Hugging Face Transformers库 model.save_pretrained("dir")会生成pytorch_model.bin 模型权重文件,通常与配置文件(config.json)配合使用 加载方式:model.from_pretrained("dir") .safetensors是更安全的格式(替代.bin),避免恶意代码注入。
.tflite TensorFlow Lite 转换工具:tf.lite.TFLiteConverter.from_saved_model("model_dir").convert() 轻量化模型,适用于移动端/嵌入式设备 移动端推理:使用TensorFlow Lite的Interpreter API加载。 模型可能经过量化(精度降低但体积减小)。
.params MXNet net.save_parameters("model.params") MXNet模型的权重参数 加载方式:net.load_parameters("model.params") 需预先定义网络结构再加载参数。
.joblib scikit-learn joblib.dump(model, "model.joblib") 序列化后的scikit-learn模型(适用于大文件) 加载方式:model = joblib.load("model.joblib") pickle更高效,但主要服务于传统机器学习模型,少用于神经网络。

使用场景总结

  • 训练/推理用途.pt(PyTorch)、.h5(Keras)、.ckpt(TensorFlow)常用于训练过程中的保存与恢复。
  • 跨框架部署.onnx适合不同框架间的模型转换;.pb(TensorFlow)适合生产环境部署。
  • 嵌入式/移动端.tflite针对移动设备优化。
  • 安全性优先:优先选择.safetensors替代.pkl.bin
  • 协作与共享.weights+.cfg(YOLO)、saved_model(目录)包含完整信息,便于团队协作。

你可能感兴趣的:(基础知识,神经网络,人工智能,深度学习)