将Spark的ML/MLlib机器学习库中算法生成的模型存入数据库

在使用Spark的ML/MLlib过程中,想要将算法生成的模型进行保存,方便下次调用,模型的save方法可以将模型以文件的形式保存到磁盘中,但是如果代码运行在其他环境中想要调用模型的话,需要将模型文件copy到其他环境中并配置好模型文件的路径,这样就很麻烦。所以将模型保存至数据库中,其他环境调用起来就很方便。

解决思路

写入:将模型转换为二进制流存入数据库
读取:将数据库中读取的数据进行反序列化,强制转换为对应的模型类型

具体实现

表:

CREATE TABLE `tb_sms_model` (
  `id` varchar(255) NOT NULL,
  `model` blob COMMENT '模型',
  `create_time` datetime DEFAULT NULL COMMENT '创建时间',
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;

模型对象(.java):

public class SensitiveSMSModel implements Serializable {

    private String id;// 主键id

    private byte[] model; // 模型

    private String createTime;// 创建时间

    public String getId() {
        return id;
    }

    public void setId(String id) {
        this.id = id;
    }

    public byte[] getModel() {
        return model;
    }

    public void setModel(byte[] model) {
        this.model = model;
    }

    public String getCreateTime() {
        return createTime;
    }

    public void setCreateTime(String createTime) {
        this.createTime = createTime;
    }
}

生成的模型(.scala):

        ......// 此处省略一万行生成训练集和测试集的代码
        // 训练模型,trainDataRdd:训练集
        val model = new NaiveBayes().fit(trainDataRdd) // model是训练出来的朴素贝叶斯模型
        / /利用模型做预测,testtrainDataRdd:测试集
        val predictions = model.transform(testtrainDataRdd)
        println("predictln out:")
        predictions.show

写入:
调用下面这个方法将模型转换为二进制流(.scala):

	/**
    * 通过二进制流形式保存模型
    * @param model 可能是不同算法模型的model,例如:NaiveBayesModel、SVMModel...
    * @Param modelType 模型类型,我将它作为数据库的id
    * @return 返回的对象直接存入数据库即可
    */
  def saveModel(model:Object,modelType:String):SensitiveSMSModel = {
	// 保存模型
        // 1.原来的方式是将模型以文件的形式保存在磁盘中
        //model.write.overwrite().save("D:\\sourceCode\\Spark\\model\\model1")
        // 2.将model转换为二进制流存入数据库中
        val os = new ByteArrayOutputStream() // 定义一个字节数组输出流
        val out = new ObjectOutputStream(os) // 对象输出流
        out.writeObject(model)
        val modelByte = os.toByteArray // byte[]
        var sensitiveSMSModel = new SensitiveSMSModel() // 申明一个模型对象
        sensitiveSMSModel.setId(modelType)
        sensitiveSMSModel.setModel(modelByte)
        os.close()
        out.close()
        sensitiveSMSModel // 返回对象
  }

读取:
调用下面的方法反序列化模型(.scala):

/**
    * 反序列化模型
    * @param model 是从数据库中读取的对象,主要对model.getModel进行反序列化,model.getModel是byte[]类型
    * @return 这里的T为泛型,例如调用这个方法时传入的模型类型为NaiveBayesModel,则该方法的最后(in.readObject().asInstanceOf[T])返回的将会是一个朴素贝叶斯的模型
    */
  def loadModel[T](model:SensitiveSMSModel):T = {
    val is = new ByteArrayInputStream(model.getModel) // 字节数组输入流
    val in = new ObjectInputStream(is) // 执行反序列化
    in.readObject().asInstanceOf[T] //类型强转
  }

你可能感兴趣的:(算法)