保存和加载scikit-learn模型

保存和加载scikit-learn模型

有时版本不一致,我们保存的时候需要附加上版本

21.1 Saving and Loading a scikit-learn Model¶
Problem
You have trained a scikit-learn model and want to save it and load it elsewhere.

Solution
Save the model as a pickle file:

# 保存和加载sk模型
# load libraries
from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
from sklearn.externals import joblib
​
# load data
iris = datasets.load_iris()
features = iris.data
target = iris.target
​
# create decision tree classifier object
classifier = RandomForestClassifier()# train model
model = classifier.fit(features, target)# save model as pickle file  存储为pickle格式
joblib.dump(model, "model.pkl")
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\externals\joblib\__init__.py:15: FutureWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.
  warnings.warn(msg, category=FutureWarning)
['model.pkl']
Once the model is saved we can use scikit-learn in our destination application (e.g., web application) to load the model:

# load model from file  从文件中加载模型
classifier = joblib.load("model.pkl")
And use it to make predictions

# create new observation  样本
new_observation = [[ 5.2, 3.2, 1.1, 0.1]]# predict obserrvation's class  预测
classifier.predict(new_observation)
array([0])
Discussion
The first step in using a model in production is to save that model as a file that can be loaded by another application or workflow. We can accomplish this by saving the model as a pickle file, a Python-specific data format. Specifically, to save the model we use joblib, which is a library extending pickle for cases when we have large NumPy arrays--a common occurance for trained models in scikit-learn.

When saving scikit-learn models, be aware that saved models might not be compatible between versions of scikit-learn; therefore, it can be helpful to include the version of scikit-learn used in the model in the filename:

# import library  有时版本不一致,我们保存的时候需要附加上版本
import sklearn
​
# get scikit-learn version
scikit_version = joblib.__version__
​
# save model as pickle file  保存模型
joblib.dump(model, "model_(version).pkl".format(version=scikit_version))

你可能感兴趣的:(算法)