机器学习模型的保存和加载

当我们的数据集的数量非常庞大的时候,并不适合每次运行的时候都加载一遍,那样的话,所需要的时间就非常庞大。因此我们需要进行模型保存
    1. 模型保存API
        joblib.dump(estimator, filename)
            estimator: 就是我们训练完成的模型
            filename:就是我们要保存的文件名,通常,文件名的后缀用.pkl来保存
    2. 模型加载
        joblib.load(filename)
            filename: 传入文件路径的字符串即可

模型保存代码:

# 对乳腺癌进行分类和评估(通过ROC曲线和AUC指标)
import joblib
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, plot_roc_curve # 用来绘制ROC曲线
from sklearn.metrics import roc_auc_score # 用来计算AUC指标
from sklearn.metrics import classification_report
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression



# 1)数据集获取
data = load_breast_cancer()
# 2)数据集分离
x_train, x_test, y_train, y_test = train_test_split(data.data, data.target, random_state=22)
# 3)特诊工程标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)
# 4)逻辑回归流程
# 注意,这里可以采用网格搜索和交叉验证来进行出来,找到合适的estimator
estimator = LogisticRegression(solver='liblinear', penalty='l2', C=1.0)
estimator.fit(x_test, y_test)
# 对模型进行保存
joblib.dump(estimator, '逻辑回归.pkl')

模型加载代码:

import joblib
import joblib
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, plot_roc_curve # 用来绘制ROC曲线
from sklearn.metrics import roc_auc_score # 用来计算AUC指标
from sklearn.metrics import classification_report
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

# 模型加载
estimator = joblib.load('逻辑回归.pkl')

estimator.coef_, estimator.intercept_
# 准确率(注意,这里的x_test和y_test并不会保存下来,因此需要在保存模型的同时,保存测试集)
estimator.score(x_test, y_test)

# 5)精确率、召回率、F1-score
report = classification_report(y_test, estimator.predict(x_test), labels=[0, 1], target_names=['良性', '恶性'])
print(report)
# 6)ROC曲线和AUC指标
print(roc_curve(y_test, estimator.predict(x_test)))
print(roc_auc_score(y_test, estimator.predict(x_test)))
plot_roc_curve(estimator,x_test, y_test)
plt.plot([0, 1], [0, 1], 'r--', label='random classify')
plt.legend()
plt.show()
         precision    recall  f1-score   support

          良性       0.98      0.93      0.95        55
          恶性       0.96      0.99      0.97        88

    accuracy                           0.97       143
   macro avg       0.97      0.96      0.96       143
weighted avg       0.97      0.97      0.96       143

(array([0.        , 0.07272727, 1.        ]), array([0.        , 0.98863636, 1.        ]), array([2, 1, 0]))
0.9579545454545455

机器学习模型的保存和加载_第1张图片

学习地址:

 黑马程序员3天快速入门python机器学习_哔哩哔哩_bilibili

你可能感兴趣的:(机器学习,机器学习,人工智能,python,sklearn,深度学习)