sklearn如何保存模型

问题

用sklearn训练的模型,如何将其参数保存,方便下次调用

模型

gbr = GBR(random_state=1412) # 实例化
gbr.fit(X, y.ravel()) # 训练模型

方法

常用方法 joblib 和 pickle 库

保存模型

  • joblib
# from sklearn.externals import joblib # 低版本Scikit-learn 0.21版本以下
import joblib # 新版本 Scikit-learn
joblib.dump(gbr, "train_model.m")
  • pickle
import pickle
with open('train_model.pkl', 'wb') as f:
    pickle.dump(gbr, f)

读取模型

  • joblib
import joblib
gbr = joblib.load("train_model.m")
  • pickle
import pickle
with open('train_model.pkl', 'rb') as f:
    gbr = pickle.load(f)

不同架构(Java、C++等)

  • 官方建议使用Open Neural Network Exchange 格式或Predictive Model Markup Language (PMML) 格式导出

ONNX 是模型的二进制序列化。它的开发是为了提高数据模型的可互操作表示的可用性。它旨在促进数据模型在不同机器学习框架之间的转换,并提高它们在不同计算架构上的可移植性。更多详细信息可从ONNX 教程中获得。为了将 scikit-learn 模型转换为 ONNX,我们开发了一个特定的工具sklearn-onnx。

PMML 是XML文档标准的一种实现,定义为表示数据模型以及用于生成它们的数据。PMML 是人类和机器可读的,是在不同平台上进行模型验证和长期存档的不错选择。另一方面,与一般的 XML 一样,当性能至关重要时,它的冗长对生产没有帮助。要将 scikit-learn 模型转换为 PMML,您可以使用在 Affero GPLv3 许可下分发的例如sklearn2pmml 。

  • 参考文章3给出了存储成json格式的方式
import json
import numpy as np
class MyLogReg(LogisticRegression):
    # Override the class constructor
    def __init__(self, C=1.0, solver='liblinear', max_iter=100, X_train=None, Y_train=None):
        LogisticRegression.__init__(self, C=C, solver=solver, max_iter=max_iter)
        self.X_train = X_train
        self.Y_train = Y_train
    # A method for saving object data to JSON file
    def save_json(self, filepath):
        dict_ = {}
        dict_['C'] = self.C
        dict_['max_iter'] = self.max_iter
        dict_['solver'] = self.solver
        dict_['X_train'] = self.X_train.tolist() if self.X_train is not None else 'None'
        dict_['Y_train'] = self.Y_train.tolist() if self.Y_train is not None else 'None'
        # Creat json and save to file
        json_txt = json.dumps(dict_, indent=4)
        with open(filepath, 'w') as file:
            file.write(json_txt)
    # A method for loading data from JSON file
    def load_json(self, filepath):
        with open(filepath, 'r') as file:
            dict_ = json.load(file)
        self.C = dict_['C']
        self.max_iter = dict_['max_iter']
        self.solver = dict_['solver']
        self.X_train = np.asarray(dict_['X_train']) if dict_['X_train'] != 'None' else None
        self.Y_train = np.asarray(dict_['Y_train']) if dict_['Y_train'] != 'None' else None

存储和查看方法

filepath = "mylogreg.json"
# Create a model and train it
mylogreg = MyLogReg(X_train=Xtrain, Y_train=Ytrain)
mylogreg.save_json(filepath)
# Create a new object and load its data from JSON file
json_mylogreg = MyLogReg()
json_mylogreg.load_json(filepath)
json_mylogreg

限制

Pickle 和 Joblib

  • 兼容性问题
    Pickle 和 Joblib 的最大缺点就是其兼容性问题,可能与不同模型不同版本的 scikit-learn 或 Python 版本有关。
  • 安全问题
    Pickle(以及扩展的 Joblib)在可维护性和安全性方面存在一些问题。

JSON

  • 安全性较低
  • 适用于实例变量较少的对象

使用 JSON 进行数据序列化实际上是将对象保存为字符串格式,所以我们可以用文本编辑器打开和修改 mylogreg.json 文件。尽管这种方法对开发人员来说很方便,但其他人员也可以随意查看和修改 JSON 文件的内容,因此安全性较低。而且,这种方法更适用于实例变量较少的对象,例如 sklearn 模型,因为任何新变量的添加都需要更改保存和载入的方法。


相关文章:

  1. Model persistence
  2. sklearn2pmml
  3. sklearn 模型的保存与加载

你可能感兴趣的:(僧旅,sklearn,人工智能,python)