应用对接的方式包含:
mysql
mongodb
http
三种方式都可以实现对于预测器的透出;当然我们常见的都是http或者mysql的方式;
下面介绍mysql的方式实现mindsdb的模型预测过程;
mysqlProxy.py中的handle()开始:
核心代码:当输入的SQL为不同类型行
try:
if p.type.value == COMMANDS.COM_QUERY:
sql = self.decode_utf(p.sql.value)
sql = SqlStatementParser.clear_sql(sql)
log.debug(f'COM_QUERY: {sql}')
### query的核心代码
self.query_answer(sql)
elif p.type.value == COMMANDS.COM_STMT_PREPARE:
# https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html
sql = self.decode_utf(p.sql.value)
statement = SqlStatementParser(sql)
log.debug(f'COM_STMT_PREPARE: {statement.sql}')
self.answer_stmt_prepare(statement)
elif p.type.value == COMMANDS.COM_STMT_EXECUTE:
self.answer_stmt_execute(p.stmt_id.value, p.parameters)
elif p.type.value == COMMANDS.COM_STMT_FETCH:
self.answer_stmt_fetch(p.stmt_id.value, p.limit.value)
elif p.type.value == COMMANDS.COM_STMT_CLOSE:
self.answer_stmt_close(p.stmt_id.value)
elif p.type.value == COMMANDS.COM_QUIT:
log.debug('Session closed, on client disconnect')
self.session = None
break
elif p.type.value == COMMANDS.COM_INIT_DB:
new_database = p.database.value.decode()
self.change_default_db(new_database)
self.packet(OkPacket).send()
elif p.type.value == COMMANDS.COM_FIELD_LIST:
# this command is deprecated, but console client still use it.
self.packet(OkPacket).send()
else:
log.warning('Command has no specific handler, return OK msg')
log.debug(str(p))
# p.pprintPacket() TODO: Make a version of print packet
# that sends it to debug isntead
self.packet(OkPacket).send()
SQL 语句处理如下:
def query_answer(self, sql):
try:
try:
## 如果查询的mindsdb库,执行如下
statement = parse_sql(sql, dialect='mindsdb')
except Exception:
statement = parse_sql(sql, dialect='mysql')
except Exception:
# not all statemts are parse by parse_sql
log.warning(f'SQL statement are not parsed by mindsdb_sql: {sql}')
pass
~~~
parse_sql的处理逻辑如下
~~~python
from mindsdb_sql.exceptions import ParsingException
def get_lexer_parser(dialect):
if dialect == 'sqlite':
from mindsdb_sql.parser.lexer import SQLLexer
from mindsdb_sql.parser.parser import SQLParser
lexer, parser = SQLLexer(), SQLParser()
elif dialect == 'mysql':
from mindsdb_sql.parser.dialects.mysql.lexer import MySQLLexer
from mindsdb_sql.parser.dialects.mysql.parser import MySQLParser
lexer, parser = MySQLLexer(), MySQLParser()
elif dialect == 'mindsdb':
from mindsdb_sql.parser.dialects.mindsdb.lexer import MindsDBLexer
from mindsdb_sql.parser.dialects.mindsdb.parser import MindsDBParser
lexer, parser = MindsDBLexer(), MindsDBParser()
else:
raise ParsingException(f'Unknown dialect {dialect}. Available options are: sqlite, mysql.')
return lexer, parser
def parse_sql(sql, dialect='sqlite'):
lexer, parser = get_lexer_parser(dialect)
tokens = lexer.tokenize(sql)
ast = parser.parse(tokens)
return ast
可以看出分成两个部分:词法解析器和语法解析器;
其目标就是将SQL的字符串转换成对象;
CREATE PREDICTOR pred
FROM integration_name
WITH (select * FROM table_name)
AS ds_name
PREDICT f1 as f1_alias, f2
转换成
CreatePredictor(
name=Identifier('pred'),
integration_name=Identifier('integration_name'),
query_str="select * FROM table_name",
datasource_name=Identifier('ds_name'),
targets=[Identifier('f1', alias=Identifier('f1_alias')),
Identifier('f2')],
)
接下来接着看mysqlProxy的源码,目标是predictor,所以我们就看type(statement) == CreatePredictor:的源码,其核心是answer_create_predictor的方法
elif type(statement) == CreatePredictor:
self.answer_create_predictor(statement)
发现answer_create_predictor()方法里面有方法
model_interface.learn(predictor_name, ds, predict, ds_data['id'], kwargs=kwargs, delete_ds_on_fail=True)
ModelController,最终模型对象经过一下代码,进行自我学习,选择合适的
@mark_process(name='learn')
def learn(self, name: str, from_data: dict, to_predict: str, datasource_id: int, kwargs: dict,
company_id: int, delete_ds_on_fail: Optional[bool] = False) -> None:
predictor_record = db.session.query(db.Predictor).filter_by(company_id=company_id, name=name).first()
if predictor_record is not None:
raise Exception('Predictor name must be unique.')
df, problem_definition, join_learn_process, json_ai_override = self._unpack_old_args(from_data, kwargs, to_predict)
### 问题定义,
problem_definition = ProblemDefinition.from_dict(problem_definition)
predictor_record = db.Predictor(
company_id=company_id,
name=name,
datasource_id=datasource_id,
mindsdb_version=mindsdb_version,
lightwood_version=lightwood_version,
to_predict=problem_definition.target,
learn_args=problem_definition.to_dict(),
data={'name': name}
)
db.session.add(predictor_record)
db.session.commit()
predictor_id = predictor_record.id
### 训练模型的核心逻辑
p = LearnProcess(df, problem_definition, predictor_id, delete_ds_on_fail, json_ai_override)
p.start()
if join_learn_process:
p.join()
if not IS_PY36:
p.close()
db.session.refresh(predictor_record)
data = {}
if predictor_record.update_status == 'available':
data['status'] = 'complete'
elif predictor_record.json_ai is None and predictor_record.code is None:
data['status'] = 'generating'
elif predictor_record.data is None:
data['status'] = 'editable'
elif 'training_log' in predictor_record.data:
data['status'] = 'training'
elif 'error' not in predictor_record.data:
data['status'] = 'complete'
else:
data['status'] = 'error'
后面可以快速的定位到LearnProcess包括两个方法run_generate和run_fit
这个是产生预测器code的核心代码
@mark_process(name='learn')
def run_generate(df: DataFrame, problem_definition: ProblemDefinition, predictor_id: int, json_ai_override: dict = None) -> int:
json_ai = lightwood.json_ai_from_problem(df, problem_definition)
if json_ai_override is None:
json_ai_override = {}
json_ai_override = brack_to_mod(json_ai_override)
json_ai = json_ai.to_dict()
rep_recur(json_ai, json_ai_override)
json_ai = JsonAI.from_dict(json_ai)
## 自动产生python code
code = lightwood.code_from_json_ai(json_ai)
predictor_record = Predictor.query.with_for_update().get(predictor_id)
predictor_record.json_ai = json_ai.to_dict()
predictor_record.code = code
db.session.commit()
@mark_process(name='learn')
def run_fit(predictor_id: int, df: pd.DataFrame) -> None:
try:
predictor_record = Predictor.query.with_for_update().get(predictor_id)
assert predictor_record is not None
fs_store = FsStore()
config = Config()
predictor_record.data = {'training_log': 'training'}
session.commit()
predictor: lightwood.PredictorInterface = lightwood.predictor_from_code(predictor_record.code)
predictor.learn(df)
session.refresh(predictor_record)
fs_name = f'predictor_{predictor_record.company_id}_{predictor_record.id}'
pickle_path = os.path.join(config['paths']['predictors'], fs_name)
### 将算法模型保存在pickle_path 文件当中
predictor.save(pickle_path)
fs_store.put(fs_name, fs_name, config['paths']['predictors'])
predictor_record.data = predictor.model_analysis.to_dict()
predictor_record.dtype_dict = predictor.dtype_dict
session.commit()
dbw = DatabaseWrapper(predictor_record.company_id)
mi = WithKWArgsWrapper(ModelInterface(), company_id=predictor_record.company_id)
except Exception as e:
session.refresh(predictor_record)
predictor_record.data = {'error': f'{traceback.format_exc()}\nMain error: {e}'}
session.commit()
raise e
try:
dbw.register_predictors([mi.get_model_data(predictor_record.name)])
except Exception as e:
log.warn(e)
读取以上代码,大概梳理清楚了mindsdb如何通过mysql的SQL语句进行创建预测器,然后预测器报存在mindsdb的文件系统当中,其中涉及到了mindsdb的核心组件mindsdb_sql和minds_lightwood,其中mindsdb_sql是将SQL语句封装成结构体,便于后面处理;minds_lightwood 是mindsdb的机器学习的自主选择的组件