ANTLR工具本身是用JAVA语言编写的,想要在Python环境中使用,需要进行简单的环境准备,参考:Python中使用Antlr4的环境准备
词法文件 – MODBLexerRules.g4
lexer grammar MODBLexerRules ;
//keyword
K_AND: A N D ;
K_AS: A S ;
K_ASC: A S C ;
K_BY: B Y ;
K_CREATE: C R E A T E ;
K_DELETE: D E L E T E ;
K_DROP: D R O P ;
K_DATABASE: D A T A B A S E ;
K_DESC: D E S C ;
K_EXISTS: E X I S T S ;
K_FROM: F R O M ;
K_IS: I S ;
K_IN: I N ;
K_INSERT: I N S E R T ;
K_INTO: I N T O ;
K_NOT: N O T ;
K_OR: O R ;
K_SELECT: S E L E C T ;
K_SET: S E T ;
K_TABLE: T A B L E ;
K_TO: T O ;
K_UPDATE: U P D A T E ;
K_VALUES: V A L U E S ;
K_WHERE: W H E R E ;
K_NULL: N U L L ;
//keyword数据类型
K_VARCHAR: V A R C H A R ;
K_VARCHAR2: V A R C H A R '2' ;
K_CHAR: C H A R ;
K_INTEGER: I N T E G E R ;
K_INT: I N T ;
K_FLOAT: F L O A T ;
K_BOOLEAN: B O O L E A N ;
//定义符号
LPAREN: '(' ;
RPAREN: ')' ;
COMMA: ',' ;
SEMI: ';' ;
DOT: '.' ;
PLUS: '+' ;
MINUS: '-' ;
STAR: '*' ;
DIV: '/' ;
GT: '>' ;
GE: '>=' ;
LT: '<' ;
LE: '<=' ;
EQUAL: '=' ;
NOT_EQUAL: '!=' ;
LG: '<>';
AT: '@' ;
LITERAL_NUMBER
: (PLUS | MINUS)? ([0-9]+ (DOT [0-9]*)? | DOT [0-9]+) (E (PLUS | MINUS)? [0-9]+)?
;
QUOTED_STRING
: '"' .*? '"' //双引号的任意字符串
| '\'' .*? '\'' //单引号得任意字符串
;
ID: [a-zA-Z_][a-zA-Z0-9_$]* ;
// 空白字符和注释
WS
: [ \t\r\n] + -> skip
;
COMMENT
: '/*' .*? '*/' -> channel(HIDDEN)
;
LINE_COMMENT
: '--' .*? (('\r'? '\n') | EOF) -> channel(HIDDEN)
;
//FRAGMENT
fragment A
: [aA]
;
fragment B
: [bB]
;
fragment C
: [cC]
;
fragment D
: [dD]
;
fragment E
: [eE]
;
fragment F
: [fF]
;
fragment G
: [gG]
;
fragment H
: [hH]
;
fragment I
: [iI]
;
fragment J
: [jJ]
;
fragment K
: [kK]
;
fragment L
: [lL]
;
fragment M
: [mM]
;
fragment N
: [nN]
;
fragment O
: [oO]
;
fragment P
: [pP]
;
fragment Q
: [qQ]
;
fragment R
: [rR]
;
fragment S
: [sS]
;
fragment T
: [tT]
;
fragment U
: [uU]
;
fragment V
: [vV]
;
fragment W
: [wW]
;
fragment X
: [xX]
;
fragment Y
: [yY]
;
fragment Z
: [zZ]
;
语法文件 – MODB.g4
grammar MODB ;
import MODBLexerRules ;
sqlStmt
: (ddlStmt | dmlStmt | dqlStmt | connDatabase) SEMI
;
connDatabase
: K_DATABASE database
;
ddlStmt
: createStmt
| dropStmt
;
dmlStmt
: deleteStmt
| insertStmt
| updateStmt
;
dqlStmt
: selectStmt
;
createStmt
: createTableStmt
| createDatabaseStmt
;
dropStmt
: dropTableStmt
;
createTableStmt
: K_CREATE K_TABLE table tableDefine
;
createDatabaseStmt
: K_CREATE K_DATABASE database
;
dropTableStmt
: K_DROP K_TABLE table
;
deleteStmt
: K_DELETE K_FROM table whereClause?
;
insertStmt
: K_INSERT K_INTO table (LPAREN column (COMMA column)* RPAREN)? valuesClause
;
updateStmt
: K_UPDATE table setStmt whereClause?
;
selectStmt
: K_SELECT selectOptions
;
selectOptions
: projectionClause fromClause whereClause?
;
projectionClause
: selectList (COMMA selectList)*
;
selectList
: column | STAR
;
fromClause
: K_FROM table
;
setStmt
: K_SET (multipleColumnFormat | singleColumnFormat)
;
singleColumnFormat
: column EQUAL valueItem (COMMA column EQUAL valueItem)*
;
multipleColumnFormat
: ( STAR | ( LPAREN column (COMMA column)* RPAREN ) ) EQUAL multipleColumnValues
;
multipleColumnValues
: LPAREN valueItem (COMMA valueItem)* RPAREN
;
whereClause
: K_WHERE condition
;
condition
: column relationOperator constantExpression # conditionNoRecursive
| condition op=(K_AND | K_OR) condition # conditionRecursive
;
relationOperator
: GT | GE | LT | LE | EQUAL | NOT_EQUAL
;
valuesClause
: K_VALUES LPAREN valueItem (COMMA valueItem)* RPAREN
;
valueItem
: K_NULL | constantExpression
;
constantExpression
: quotedString | literalNumber | literalBoolean
;
table
: ID
;
database
: ID
;
tableDefine
: LPAREN columnDefine (COMMA columnDefine)* RPAREN
;
columnDefine
: column dataType
;
column
: ID
;
dataType
: characterDataType
| numericDataType
| K_BOOLEAN
;
characterDataType
: K_CHAR (LPAREN size RPAREN)?
| K_VARCHAR (LPAREN maxx (COMMA reserve)?RPAREN)?
;
numericDataType
: K_INT | K_INTEGER
;
size
: LITERAL_NUMBER ;
maxx
: LITERAL_NUMBER ;
reserve
: LITERAL_NUMBER ;
quotedString
: QUOTED_STRING
;
literalNumber
: LITERAL_NUMBER
;
literalBoolean
: QUOTED_STRING
;
使用访问器模式生成词法分析器和语法分析器
antlr4vpy3 MODB.g4
编写访问器的具体逻辑 – create_visitor.py
# 用到了pandas,需要提前安装
pip install pandas
# 词法分析器和语法分析器所在的包名是modb
from modb.MODBVisitor import MODBVisitor
from modb.MODBParser import MODBParser
import os
import pandas as pd
import sys
import uuid
DBSPACE = "dbspace"
def db_writer(dbfile, mode):
"""创建数据库操作用户"""
db_writer = pd.ExcelWriter(dbfile, mode=mode, engine='openpyxl')
return db_writer
def get_col_names(db_writer, tab_name):
"""获取指定表的全部列名"""
df = pd.read_excel(db_writer, sheet_name=tab_name)
return df.columns.values
def create_database(dbfile):
"""创建数据库文件"""
df_systables = pd.DataFrame(columns=["tab_id", "tab_name"])
df_systables.to_excel(dbfile, sheet_name='systables', index=False)
def create_table(tab_name, col_names, db_writer):
"""向数据库中新增表"""
tab_name = tab_name.lower()
df_current_tab = pd.DataFrame(columns=col_names)
df_current_tab.to_excel(excel_writer=db_writer, sheet_name=tab_name, index=False)
db_writer.save()
def drop_table(tab_name, db_writer):
"""删除数据库中的指定表"""
db_writer.book.remove(db_writer.book[tab_name])
db_writer.save()
db_writer.sheets.pop(tab_name)
def insert_row(tab_name, data: dict, db_writer):
"""向指定表中插入一行数据"""
index_str = uuid.uuid4()
df = pd.read_excel(db_writer, sheet_name=tab_name)
df.loc[index_str] = data
db_writer.if_sheet_exists = 'replace'
df.to_excel(excel_writer=db_writer, sheet_name=tab_name, index=False)
db_writer.save()
def delete_row(tab_name, filter: list, db_writer):
"""从指定表中删除指定条件的数据"""
df = pd.read_excel(db_writer, sheet_name=tab_name)
result = result_filter(df, filter)
if result is None:
df.drop(df.index, inplace=True)
else:
df.drop(df[result].index, inplace=True)
db_writer.if_sheet_exists = 'replace'
df.to_excel(excel_writer=db_writer, sheet_name=tab_name, index=False)
db_writer.save()
def update_row(tab_name, filter: list, data: dict, db_writer):
"""更新指定表的指定条件的指定字段的值"""
df = pd.read_excel(db_writer, sheet_name=tab_name)
result = result_filter(df, filter)
columns = list(data.keys())
values = list(data.values())
if result is None:
df[columns] = values
else:
df.loc[result, columns] = values
db_writer.if_sheet_exists = 'replace'
df.to_excel(excel_writer=db_writer, sheet_name=tab_name, index=False)
db_writer.save()
def show_rows(tab_name, col_names, filter: list, db_writer):
"""显示满足指定条件的数据"""
df = pd.read_excel(db_writer, sheet_name=tab_name)
result = result_filter(df, filter)
if "*" not in col_names:
df = df[col_names]
if result is None:
print(df)
else:
print(df[result])
def result_filter(df, ft:list):
"""通过条件过滤df"""
if len(ft) == 0:
return None
if len(ft) == 1:
return relationOperator(df[ft[0][0]], ft[0][1], ft[0][2])
else:
first_ft = ft.pop(0)
operator = ft.pop(0)
if operator == '&':
# return (df[first_ft[0]] == first_ft[1]) & result_filter(df, ft)
return relationOperator(df[first_ft[0]], first_ft[1], first_ft[2]) & result_filter(df, ft)
else:
# return (df[first_ft[0]] == first_ft[1]) | result_filter(df, ft)
return relationOperator(df[first_ft[0]], first_ft[1], first_ft[2]) | result_filter(df, ft)
def relationOperator(left, right, relation: str):
if relation == ">":
return left > right
elif relation == ">=":
return left >= right
elif relation == "<":
return left < right
elif relation == "<=":
return left <= right
else:
return left == right
class CreateVisitor(MODBVisitor):
def __init__(self, dbwriter):
self.dbwriter = dbwriter
def visitSqlStmt(self, ctx: MODBParser.SqlStmtContext):
ddl_stmt = ctx.ddlStmt()
if ddl_stmt:
self.visit(ddl_stmt)
dml_stmt = ctx.dmlStmt()
if dml_stmt:
self.visit(dml_stmt)
dql_stmt = ctx.dqlStmt()
if dql_stmt:
self.visit(dql_stmt)
conn_database_stmt = ctx.connDatabase()
if conn_database_stmt:
self.visit(conn_database_stmt)
def visitConnDatabase(self, ctx:MODBParser.ConnDatabaseContext):
db_name = self.visit(ctx.database())
dbfile = os.path.join(DBSPACE, "{}.xlsx".format(db_name))
if not os.path.exists(dbfile):
sys.stdout.write("ERROR: 数据库 {} 不存在\n".format(db_name))
else:
if self.dbwriter is not None:
self.dbwriter.close()
self.dbwriter = db_writer(dbfile, mode='a')
sys.stdout.write("INFO: 数据库 {} 已连接\n".format(db_name))
def visitDdlStmt(self, ctx: MODBParser.DdlStmtContext):
create_stmt = ctx.createStmt()
if create_stmt:
self.visit(ctx.createStmt())
drop_stmt = ctx.dropStmt()
if drop_stmt:
self.visit(ctx.dropStmt())
def visitDmlStmt(self, ctx: MODBParser.DmlStmtContext):
if self.dbwriter is None:
sys.stdout.write("ERROR: 数据库未连接\n")
else:
delete_stmt = ctx.deleteStmt()
if delete_stmt:
self.visit(delete_stmt)
insert_stmt = ctx.insertStmt()
if insert_stmt:
self.visit(insert_stmt)
update_stmt = ctx.updateStmt()
if update_stmt:
self.visit(update_stmt)
def visitUpdateStmt(self, ctx:MODBParser.UpdateStmtContext):
print("执行update")
tab_name = self.visit(ctx.table())
if tab_name not in self.dbwriter.sheets:
sys.stdout.write("ERROR: {} 表不存在\n".format(tab_name))
else:
where_clause = ctx.whereClause()
if where_clause:
filterlist = self.visit(where_clause)
else:
filterlist = []
colnames, colvalues = self.visit(ctx.setStmt())
if len(colnames) == 0:
colnames = get_col_names(self.dbwriter, tab_name)
if len(colnames) != len(colvalues):
sys.stdout.write("ERROR: 列的数量和值的数量不匹配")
else:
row = {col_name: value for col_name, value in zip(colnames, colvalues)}
print(tab_name, filterlist, row)
update_row(tab_name=tab_name, filter=filterlist, data=row, db_writer=self.dbwriter)
def visitSetStmt(self, ctx:MODBParser.SetStmtContext):
"""返回列名称列表和值名称列表"""
colnames = []
colvalues = []
multi_format = ctx.multipleColumnFormat()
if multi_format:
colnames, colvalues = self.visit(multi_format)
single_format = ctx.singleColumnFormat()
if single_format:
colnames, colvalues = self.visit(single_format)
return colnames, colvalues
def visitMultipleColumnFormat(self, ctx:MODBParser.MultipleColumnFormatContext):
col_names = [self.visit(col_name) for col_name in ctx.column()]
col_values = self.visit(ctx.multipleColumnValues())
return col_names, col_values
def visitMultipleColumnValues(self, ctx:MODBParser.MultipleColumnValuesContext):
return [self.visit(value_item) for value_item in ctx.valueItem()]
def visitSingleColumnFormat(self, ctx:MODBParser.SingleColumnFormatContext):
col_names = [self.visit(col_name) for col_name in ctx.column()]
col_values = [self.visit(value_item) for value_item in ctx.valueItem()]
return col_names, col_values
def visitDeleteStmt(self, ctx:MODBParser.DeleteStmtContext):
tab_name = self.visit(ctx.table())
if tab_name not in self.dbwriter.sheets:
sys.stdout.write("ERROR: {} 表不存在\n".format(tab_name))
else:
where_clause = ctx.whereClause()
if where_clause:
filterlist = self.visit(where_clause) # 返回过滤条件列表
else:
filterlist = []
delete_row(tab_name=tab_name, filter=filterlist, db_writer=self.dbwriter)
def visitDqlStmt(self, ctx:MODBParser.DqlStmtContext):
self.visit(ctx.selectStmt())
def visitSelectStmt(self, ctx:MODBParser.SelectStmtContext):
self.visit(ctx.selectOptions())
def visitSelectOptions(self, ctx:MODBParser.SelectOptionsContext):
tab_name = self.visit(ctx.fromClause()) # 应返回表名称
if tab_name not in self.dbwriter.sheets:
sys.stdout.write("ERROR: {} 表不存在\n".format(tab_name))
else:
col_names = self.visit(ctx.projectionClause()) # 应返回列值列表
where_clause = ctx.whereClause()
if where_clause:
filterlist = self.visit(where_clause) # 返回过滤条件列表
else:
filterlist = []
show_rows(tab_name=tab_name, col_names=col_names, filter=filterlist, db_writer=self.dbwriter)
def visitProjectionClause(self, ctx:MODBParser.ProjectionClauseContext):
col_names = []
for sl in ctx.selectList():
col_names.append(self.visit(sl))
return col_names
def visitSelectList(self, ctx:MODBParser.SelectListContext):
return ctx.getText().lower()
def visitFromClause(self, ctx:MODBParser.FromClauseContext):
tab_name = self.visit(ctx.table())
return tab_name
def visitWhereClause(self, ctx:MODBParser.WhereClauseContext):
filterlist = self.visit(ctx.condition())
return filterlist
def visitConditionNoRecursive(self, ctx:MODBParser.ConditionNoRecursiveContext):
key = self.visit(ctx.column())
value = self.visit(ctx.constantExpression())
relation = self.visit(ctx.relationOperator())
return [(key, value, relation)]
def visitRelationOperator(self, ctx:MODBParser.RelationOperatorContext):
return ctx.getText()
def visitConditionRecursive(self, ctx:MODBParser.ConditionRecursiveContext):
left = self.visit(ctx.condition(0))
right = self.visit(ctx.condition(1))
if ctx.op.type == MODBParser.K_AND:
operator = "&"
else:
operator = "|"
left.append(operator)
left.extend(right)
return left
def visitConstantExpression(self, ctx:MODBParser.ConstantExpressionContext):
value = ctx.getText()
return eval(value)
def visitCreateStmt(self, ctx: MODBParser.CreateStmtContext):
create_table_stmt = ctx.createTableStmt()
if create_table_stmt:
if self.dbwriter is None:
sys.stdout.write("ERROR: 数据库未连接\n")
else:
self.visit(create_table_stmt)
create_database_stmt = ctx.createDatabaseStmt()
if create_database_stmt:
self.visit(create_database_stmt)
def visitDropStmt(self, ctx: MODBParser.DropStmtContext):
self.visit(ctx.dropTableStmt())
def visitInsertStmt(self, ctx: MODBParser.InsertStmtContext):
tb_name = self.visit(ctx.table())
if tb_name not in self.dbwriter.sheets:
sys.stdout.write("ERROR: {} 表不存在\n".format(tb_name))
else:
col_names = [self.visit(col_name) for col_name in ctx.column()]
if len(col_names) == 0:
col_names = get_col_names(self.dbwriter, tb_name)
values = self.visit(ctx.valuesClause())
if len(col_names) != len(values):
sys.stdout.write("ERROR: 列的数量和值的数量不匹配")
else:
row = {col_name: value for col_name, value in zip(col_names, values)}
insert_row(tab_name=tb_name, data=row, db_writer=self.dbwriter)
sys.stdout.write("INFO: 1 条数据插入成功\n")
def visitDropTableStmt(self, ctx: MODBParser.DropTableStmtContext):
tb_name = self.visit(ctx.table())
if tb_name not in self.dbwriter.sheets:
sys.stdout.write("ERROR: {} 表不存在\n".format(tb_name))
else:
drop_table(tab_name=tb_name, db_writer=self.dbwriter)
delete_row(tab_name='systables', filter=[('tab_name', tb_name, '=')], db_writer=self.dbwriter)
sys.stdout.write("INFO: {} 表已删除\n".format(tb_name))
def visitCreateTableStmt(self, ctx: MODBParser.CreateTableStmtContext):
tb_name = self.visit(ctx.table())
if tb_name in self.dbwriter.sheets:
sys.stdout.write("ERROR: {} 表已存在\n".format(tb_name))
else:
columns = self.visit(ctx.tableDefine()) # 返回列对象的列表
create_table(tab_name=tb_name, col_names=[col[0] for col in columns], db_writer=self.dbwriter)
insert_row(tab_name='systables', data={'tab_id': uuid.uuid4(), 'tab_name': tb_name}, db_writer=self.dbwriter)
sys.stdout.write("INFO: {} 表已创建\n".format(tb_name))
def visitCreateDatabaseStmt(self, ctx:MODBParser.CreateDatabaseStmtContext):
db_name = self.visit(ctx.database())
dbfile = os.path.join(DBSPACE, "{}.xlsx".format(db_name))
if os.path.exists(dbfile):
sys.stdout.write("ERROR: 数据库 {} 已存在\n".format(db_name))
else:
create_database(dbfile)
sys.stdout.write("INFO: 数据库 {} 已创建\n".format(db_name))
def visitDatabase(self, ctx:MODBParser.DatabaseContext):
"""返回数据库名称"""
return ctx.getText().lower()
def visitValuesClause(self, ctx: MODBParser.ValuesClauseContext):
"""返回值列表"""
return [self.visit(value) for value in ctx.valueItem()]
def visitValueItem(self, ctx:MODBParser.ValueItemContext):
"""返回值"""
return eval(ctx.getText())
def visitTable(self, ctx: MODBParser.TableContext):
"""返回表名"""
return ctx.getText().lower()
def visitTableDefine(self, ctx: MODBParser.TableDefineContext):
"""返回column对象列表"""
columns = []
for _column_define in ctx.columnDefine():
colname, datatype = self.visit(_column_define)
columns.append((colname, datatype))
return columns
def visitColumnDefine(self, ctx: MODBParser.ColumnDefineContext):
"""返回column对象"""
col_name = self.visit(ctx.column())
data_type_name = self.visit(ctx.dataType())
return col_name, data_type_name
def visitColumn(self, ctx: MODBParser.ColumnContext):
"""返回列名称"""
return ctx.getText().lower()
def visitDataType(self, ctx: MODBParser.DataTypeContext):
"""返回数据类型的名称"""
return ctx.getText().lower()
客户端程序 – moclient.py
from antlr4 import *
from modb.MODBParser import MODBParser
from modb.MODBLexer import MODBLexer
from modb.create_visitor import CreateVisitor
class MODBClient:
def __init__(self):
self.db_writer = None
def run(self):
while True:
sql = input("MODB>")
sql = sql.rstrip()
if sql == "":
continue
if sql == "exit":
break
self.run_sql(sql)
def run_sql(self,sql:str):
sql = sql.rstrip()
if sql.endswith(";"):
input_ = InputStream(sql)
lexer = MODBLexer(input_)
tokens = CommonTokenStream(lexer)
parser = MODBParser(tokens)
sqltree = parser.sqlStmt()
try:
visitor = CreateVisitor(dbwriter=self.db_writer)
visitor.visit(sqltree)
self.db_writer = visitor.dbwriter
except Exception as e:
print(e)
else:
sql = "\n".join([sql, input()])
self.run_sql(sql)
if __name__ == '__main__':
client = MODBClient()
client.run()
# 提前创建dbspace目录用于存放我们的数据文件
(venv) D:\PythonProjects\MODB>python moclient.py
MODB>create database dbdemo;
INFO: 数据库 dbdemo 已创建
MODB>database dbdemo;
INFO: 数据库 dbdemo 已连接
MODB>CREATE TABLE EMP
(
EMPNO INT PRIMARY KEY,
ENAME VARCHAR(10),
JOB VARCHAR(9),
MGR INT,
SAL INT,
COMM INT,
DEPTNO INT
);
INFO: emp 表已创建
MODB>insert into emp values(1, 'momo', 'tester', 1, 2, 3, 1);
INFO: 1 条数据插入成功
MODB>insert into emp(empon, ename, job) values(2, 'xiaoming', 'developer');
INFO: 1 条数据插入成功
MODB>select * from emp;
empno ename job mgr sal comm deptno
0 1 momo tester 1.0 2.0 3.0 1.0
1 2 xiaoming developer NaN NaN NaN NaN
MODB>select ename from emp where empno=1;
ename
0 momo
MODB>update emp set ename='无-涯子' where ename='momo';
MODB>select * from emp;
empno ename job mgr sal comm deptno
0 1 无-涯子 tester 1.0 2.0 3.0 1.0
1 2 xiaoming developer NaN NaN NaN NaN
MODB>delete from emp where empno=2;
MODB>select * from emp;
empno ename job mgr sal comm deptno
0 1 无-涯子 tester 1 2 3 1
MODB>create table dept(id int, name varchar(20));
INFO: dept 表已创建
MODB>select * from systables;
tab_id tab_name
0 bd71c96c-5d02-4388-89f4-0268f6c90ef4 emp
1 5a2aa571-3b25-4601-b0e1-aa823698e20b dept
MODB>drop table dept;
INFO: dept 表已删除
MODB>select * from systables;
tab_id tab_name
0 bd71c96c-5d02-4388-89f4-0268f6c90ef4 emp
MODB>exit
(venv) D:\PythonProjects\MODB>