mindsdb 源码解析

应用对接的方式包含:

  • mysql

  • mongodb

  • http

三种方式都可以实现对于预测器的透出;当然我们常见的都是http或者mysql的方式;

下面介绍mysql的方式实现mindsdb的模型预测过程;

mysqlProxy.py中的handle()开始:

核心代码:当输入的SQL为不同类型行

try:
                if p.type.value == COMMANDS.COM_QUERY:
                    sql = self.decode_utf(p.sql.value)
                    sql = SqlStatementParser.clear_sql(sql)
                    log.debug(f'COM_QUERY: {sql}')
                            ### query的核心代码
                    self.query_answer(sql)
                elif p.type.value == COMMANDS.COM_STMT_PREPARE:
                    # https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html
                    sql = self.decode_utf(p.sql.value)
                    statement = SqlStatementParser(sql)
                    log.debug(f'COM_STMT_PREPARE: {statement.sql}')
                    self.answer_stmt_prepare(statement)
                elif p.type.value == COMMANDS.COM_STMT_EXECUTE:
                    self.answer_stmt_execute(p.stmt_id.value, p.parameters)
                elif p.type.value == COMMANDS.COM_STMT_FETCH:
                    self.answer_stmt_fetch(p.stmt_id.value, p.limit.value)
                elif p.type.value == COMMANDS.COM_STMT_CLOSE:
                    self.answer_stmt_close(p.stmt_id.value)
                elif p.type.value == COMMANDS.COM_QUIT:
                    log.debug('Session closed, on client disconnect')
                    self.session = None
                    break
                elif p.type.value == COMMANDS.COM_INIT_DB:
                    new_database = p.database.value.decode()
                    self.change_default_db(new_database)
                    self.packet(OkPacket).send()
                elif p.type.value == COMMANDS.COM_FIELD_LIST:
                    # this command is deprecated, but console client still use it.
                    self.packet(OkPacket).send()
                else:
                    log.warning('Command has no specific handler, return OK msg')
                    log.debug(str(p))
                    # p.pprintPacket() TODO: Make a version of print packet
                    # that sends it to debug isntead
                    self.packet(OkPacket).send()

SQL 语句处理如下:

  def query_answer(self, sql):
    try:
        try:
          ## 如果查询的mindsdb库,执行如下
            statement = parse_sql(sql, dialect='mindsdb')
        except Exception:
            statement = parse_sql(sql, dialect='mysql')
    except Exception:
        # not all statemts are parse by parse_sql
        log.warning(f'SQL statement are not parsed by mindsdb_sql: {sql}')
        pass
   
        ~~~

parse_sql的处理逻辑如下

~~~python
from mindsdb_sql.exceptions import ParsingException

def get_lexer_parser(dialect):
    if dialect == 'sqlite':
        from mindsdb_sql.parser.lexer import SQLLexer
        from mindsdb_sql.parser.parser import SQLParser
        lexer, parser = SQLLexer(), SQLParser()
    elif dialect == 'mysql':
        from mindsdb_sql.parser.dialects.mysql.lexer import MySQLLexer
        from mindsdb_sql.parser.dialects.mysql.parser import MySQLParser
        lexer, parser = MySQLLexer(), MySQLParser()
    elif dialect == 'mindsdb':
        from mindsdb_sql.parser.dialects.mindsdb.lexer import MindsDBLexer
        from mindsdb_sql.parser.dialects.mindsdb.parser import MindsDBParser
        lexer, parser = MindsDBLexer(), MindsDBParser()
    else:
        raise ParsingException(f'Unknown dialect {dialect}. Available options are: sqlite, mysql.')
    return lexer, parser


def parse_sql(sql, dialect='sqlite'):
    lexer, parser = get_lexer_parser(dialect)
    tokens = lexer.tokenize(sql)
    ast = parser.parse(tokens)
    return ast

可以看出分成两个部分:词法解析器和语法解析器;

其目标就是将SQL的字符串转换成对象;

CREATE PREDICTOR pred
                FROM integration_name 
                WITH (select * FROM table_name)
                AS ds_name
                PREDICT f1 as f1_alias, f2

转换成

CreatePredictor(
            name=Identifier('pred'),
            integration_name=Identifier('integration_name'),
            query_str="select * FROM table_name",
            datasource_name=Identifier('ds_name'),
            targets=[Identifier('f1', alias=Identifier('f1_alias')),
                             Identifier('f2')],
        )

接下来接着看mysqlProxy的源码,目标是predictor,所以我们就看type(statement) == CreatePredictor:的源码,其核心是answer_create_predictor的方法


elif type(statement) == CreatePredictor:
           self.answer_create_predictor(statement)

发现answer_create_predictor()方法里面有方法

model_interface.learn(predictor_name, ds, predict, ds_data['id'], kwargs=kwargs, delete_ds_on_fail=True)

ModelController,最终模型对象经过一下代码,进行自我学习,选择合适的

 @mark_process(name='learn')
    def learn(self, name: str, from_data: dict, to_predict: str, datasource_id: int, kwargs: dict,
              company_id: int, delete_ds_on_fail: Optional[bool] = False) -> None:
        predictor_record = db.session.query(db.Predictor).filter_by(company_id=company_id, name=name).first()
        if predictor_record is not None:
            raise Exception('Predictor name must be unique.')

        df, problem_definition, join_learn_process, json_ai_override = self._unpack_old_args(from_data, kwargs, to_predict)
                ### 问题定义,
        problem_definition = ProblemDefinition.from_dict(problem_definition)
        predictor_record = db.Predictor(
            company_id=company_id,
            name=name,
            datasource_id=datasource_id,
            mindsdb_version=mindsdb_version,
            lightwood_version=lightwood_version,
            to_predict=problem_definition.target,
            learn_args=problem_definition.to_dict(),
            data={'name': name}
        )

        db.session.add(predictor_record)
        db.session.commit()
        predictor_id = predictor_record.id
                ### 训练模型的核心逻辑
        p = LearnProcess(df, problem_definition, predictor_id, delete_ds_on_fail, json_ai_override)
        p.start()
        if join_learn_process:
            p.join()
            if not IS_PY36:
                p.close()
        db.session.refresh(predictor_record)

        data = {}
        if predictor_record.update_status == 'available':
            data['status'] = 'complete'
        elif predictor_record.json_ai is None and predictor_record.code is None:
            data['status'] = 'generating'
        elif predictor_record.data is None:
            data['status'] = 'editable'
        elif 'training_log' in predictor_record.data:
            data['status'] = 'training'
        elif 'error' not in predictor_record.data:
            data['status'] = 'complete'
        else:
            data['status'] = 'error'

后面可以快速的定位到LearnProcess包括两个方法run_generate和run_fit

这个是产生预测器code的核心代码

@mark_process(name='learn')
def run_generate(df: DataFrame, problem_definition: ProblemDefinition, predictor_id: int, json_ai_override: dict = None) -> int:
    json_ai = lightwood.json_ai_from_problem(df, problem_definition)
    if json_ai_override is None:
        json_ai_override = {}

    json_ai_override = brack_to_mod(json_ai_override)
    json_ai = json_ai.to_dict()
    rep_recur(json_ai, json_ai_override)

    json_ai = JsonAI.from_dict(json_ai)
    ## 自动产生python code
    code = lightwood.code_from_json_ai(json_ai)

    predictor_record = Predictor.query.with_for_update().get(predictor_id)
    predictor_record.json_ai = json_ai.to_dict()
    predictor_record.code = code
    db.session.commit()

@mark_process(name='learn')
def run_fit(predictor_id: int, df: pd.DataFrame) -> None:
    try:
        predictor_record = Predictor.query.with_for_update().get(predictor_id)
        assert predictor_record is not None

        fs_store = FsStore()
        config = Config()

        predictor_record.data = {'training_log': 'training'}
        session.commit()
        predictor: lightwood.PredictorInterface = lightwood.predictor_from_code(predictor_record.code)
        predictor.learn(df)

        session.refresh(predictor_record)

        fs_name = f'predictor_{predictor_record.company_id}_{predictor_record.id}'
        pickle_path = os.path.join(config['paths']['predictors'], fs_name)
        ### 将算法模型保存在pickle_path 文件当中
        predictor.save(pickle_path)

        fs_store.put(fs_name, fs_name, config['paths']['predictors'])

        predictor_record.data = predictor.model_analysis.to_dict()
        predictor_record.dtype_dict = predictor.dtype_dict
        session.commit()

        dbw = DatabaseWrapper(predictor_record.company_id)
        mi = WithKWArgsWrapper(ModelInterface(), company_id=predictor_record.company_id)
    except Exception as e:
        session.refresh(predictor_record)
        predictor_record.data = {'error': f'{traceback.format_exc()}\nMain error: {e}'}
        session.commit()
        raise e

    try:
        dbw.register_predictors([mi.get_model_data(predictor_record.name)])
    except Exception as e:
        log.warn(e)

读取以上代码,大概梳理清楚了mindsdb如何通过mysql的SQL语句进行创建预测器,然后预测器报存在mindsdb的文件系统当中,其中涉及到了mindsdb的核心组件mindsdb_sql和minds_lightwood,其中mindsdb_sql是将SQL语句封装成结构体,便于后面处理;minds_lightwood 是mindsdb的机器学习的自主选择的组件

你可能感兴趣的:(mindsdb 源码解析)