luajit中,利用ffi可以嵌入C,目前luajit的最新版是2.0.4,在这之前的版本我还不清楚这个扩展库具体怎么样,不过在2.04中,真的很爽。
既然是嵌入C代码,那么要说让lua支持面向对象,不如说是让C语言模拟面向对象编程,然后让luajit的ffi嵌入。
要文字彻底来描述清楚这个问题,我的表达能力很有限,说不清楚,所以直接用代码来说吧。
//C++
class foo_type{
public:
void foo1()
{
printf("%d", a + b);
}
void foo2(int n)
{
printf("%d", a + b + n);
}
int a;
int b;
};
foo_type obj;
obj.foo1();
obj.foo2(100);
//在C语言要做到同样的事,最简单的做法如下。
typedef struct{
int a;
int b;
}foo_type;
void foo1(foo_type *obj)
{
printf("%d", obj->a + obj->y);
}
void foo2(foo_type *obj, int n)
{
printf("%d", obj->a + obj->y + n);
}
foo_type obj;
foo1(&obj);
foo2(&obj, 100);
/*****************************************
//C++从汇编语言的角度看
obj.foo1();
lea ecx, obj
call foo1
-----------------------
obj.foo2(100);
push 100
lea ecx, obj
call foo1
-----------------------
//C语言从汇编语言的角度看
foo1(&obj);
lea eax, obj
push eax
call foo1
-----------------------
foo2(&obj, 100);
push 100
lea eax, obj
push eax
call foo2
那么就可以看到,C和C++在实现这种功能时的区别之处主要在于thiscall使用了ecx寄存器向下传递对象指针
而在C语言,只能把结构指针push下去,这就有点像微软的COM包装类的stdcall的类成员了。
*****************************************/
//从汇编看到问题所在之后,要解决这个问题,也就在可以考虑写一个shellcode,这个shellcode实现将向下传递上层传入的参数,并且直接将结构指针以push参数的方式最后入栈,然后call即可。
typedef struct{
int a;
int b;
}foo_type;
void foo1(foo_type *obj)
{
printf("%d", obj->a + obj->y);
}
void foo2(foo_type *obj, int n)
{
printf("%d", obj->a + obj->y + n);
}
typedef struct{
void (*foo1)();
void (*foo2)(int n);
}foo_class;
foo_type ft;
foo_class obj;
//与foo1对应的shellcode模板
BYTE scT1[] = {
0x68, 0, 0, 0, 0, //push CONST
0xE8, 0, 0, 0, 0, //call CONST
0x83, 0xC4, 0x04, //add esp, 4
0xC3 //ret
};
BYTE *psc = (BYTE*)VirtualAlloc(NULL, sizeof(scT1), MEM_COMMIT, PAGE_EXECUTE_READWRITE);
memcpy(psc, scT1, sizeof(scT1));
*(ULONG*)(psc + 1) = (ULONG)&ft;
*(ULONG*)(psc + 6) = (ULONG)&foo1 - (ULONG)(psc + 5) - 5;
*(ULONG*)&(obj.foo1) = psc;
//到这里之后,就已经能够这样用了
obj.foo1();
//与foo2对应的shellcode模板
BYTE scT2[] = {
0xFF, 0x74, 0x24, 0x04, //push dword ptr [esp + 4]
0x68, 0, 0, 0, 0, //call CONST
0xE8, 0, 0, 0, 0, //call CONST
0x83, 0xC4, 0x08, //add esp, 8
0xC3 //ret
};
BYTE *psc = (BYTE*)VirtualAlloc(NULL, sizeof(scT2), MEM_COMMIT, PAGE_EXECUTE_READWRITE);
memcpy(psc, scT2, sizeof(scT2));
*(ULONG*)(psc + 5) = (ULONG)&ft;
*(ULONG*)(psc + 10) = (ULONG)&foo2 - (ULONG)(psc + 5) - 5;
*(ULONG*)&(obj.foo2) = psc;
obj.foo2(100);
//这样就模拟了面向对象形式的调用,之后就需要考虑的就是一些如何整合管理资源之类的问题,还需要将代码尽可能的利用一些宏来简化,这些东西就不在这里说了,这个思路深入下去,相信大家都会知道该怎么做了。
下面是一套我封装的luajit使用ODBC访问数据库的代码,没有注释,将就着看吧。
//luaodbc.h**************************************************
#pragma once
extern "C"{
typedef struct{
void (*close)();
bool (*set)(int i, int type, char *data, int size);
bool (*get)(int i, int type);
void *(*data)();
bool (*next)();
char *(*getstr)(int i);
int (*getint)(int i);
char *(*getstr_fn)(const char *name);
char *(*getint_fn)(const char *name);
bool (*setstr)(int i, const char *str);
bool (*setint)(int i, int num);
bool (*setstr_fn)(const char *name, const char *str);
bool (*setint_fn)(const char *name, int num);
}luastmt_class;
typedef struct{
bool (*open)(const char*dsn);
void (*close)();
luastmt_class *(*exec)(const char *sql);
void (*tran_begin)();
void (*tran_end)(int ct);
}luadbc_class;
__declspec(dllexport) luadbc_class *luadbc_init();
__declspec(dllexport) void luadbc_exit(luadbc_class *obj);
};
//luaodbc.cpp**************************************************
#include "stdafx.h"
#include "luaodbc.h"
#include
#include
#include
#include
#include
extern "C"{
#define LOS_CLOSED 0
#define LOS_INITILZED 1
#define LOS_CONNECTIONED 2
#define LOS_TRAN 4
#define LOS_STMTOPENED 8
typedef struct _ciType{
char name[68];
int type;
int type_len;
char *data;
int data_size;
int nds;
}ciType;
typedef struct _stmtType{
SQLHANDLE hstmt;
ciType *cols;
int col_count;
int cp_size;
unsigned int status;
void *parent;
luastmt_class thisobj;
unsigned char *class_memory;
int gci;
_stmtType *next;
}stmtType;
typedef struct{
int (*open)(const char*dsn);
void (*close)();
bool (*exec)(const char *sql);
void (*tran_begin)();
void (*tran_end)(int ct);
void *dbc_obj;
}luadbc_class_ex;
typedef struct _odbcType{
SQLHANDLE henv, hdbc;
_stmtType *stmt_list;
unsigned int status;
luadbc_class_ex thisobj;
unsigned char *class_memory;
}odbcType;
BYTE scT1[] = {
0x68, 0, 0, 0, 0,
0xE8, 0, 0, 0, 0,
0x83, 0xC4, 0x04,
0xC3
};
BYTE scT2[] = {
0xFF, 0x74, 0x24, 0x04,
0x68, 0, 0, 0, 0,
0xE8, 0, 0, 0, 0,
0x83, 0xC4, 0x08,
0xC3
};
BYTE scT3[] = {
0xFF, 0x74, 0x24, 0x08,
0xFF, 0x74, 0x24, 0x08,
0x68, 0, 0, 0, 0,
0xE8, 0, 0, 0, 0,
0x83, 0xC4, 0x0C,
0xC3
};
BYTE scT4[] = {
0xFF, 0x74, 0x24, 0x0C,
0xFF, 0x74, 0x24, 0x0C,
0xFF, 0x74, 0x24, 0x0C,
0x68, 0, 0, 0, 0,
0xE8, 0, 0, 0, 0,
0x83, 0xC4, 0x10,
0xC3
};
BYTE scT5[] = {
0xFF, 0x74, 0x24, 0x10,
0xFF, 0x74, 0x24, 0x10,
0xFF, 0x74, 0x24, 0x10,
0xFF, 0x74, 0x24, 0x10,
0x68, 0, 0, 0, 0,
0xE8, 0, 0, 0, 0,
0x83, 0xC4, 0x14,
0xC3
};
#define SCTP1 1
#define SCTP2 SCTP1 + 4
#define SCTP3 SCTP2 + 4
#define SCTP4 SCTP3 + 4
#define SCTP5 SCTP4 + 4
std::list<int> g_safe_luaodbc_list;
CRITICAL_SECTION g_safe_luaodbc_cs;
class luaodbc_startup{
public:
luaodbc_startup()
{
InitializeCriticalSection(&g_safe_luaodbc_cs);
static_link = 1;
}
int static_link;
};
static luaodbc_startup g_luaodbc_link;
static int ___g_static_link = g_luaodbc_link.static_link;
void luastmt_close(stmtType *v)
{
if(v->status & LOS_STMTOPENED)
{
SQLFreeHandle(SQL_HANDLE_STMT, v->hstmt);
v->hstmt = NULL;
v->status ^= LOS_STMTOPENED;
}
}
void luaodbc_coldata_realloc(stmtType *v, int i, int size)
{
delete v->cols[i].data;
v->cols[i].data = new char[size];
}
BOOL luastmt_set(stmtType *v, int i, int type, char *data, int size)
{
if(!data || i <= 0)
return FALSE;
if(!(v->status & LOS_STMTOPENED))
return FALSE;
SQLINTEGER ncb = size;
//printf("data:%s, i:%d, len:%d, type:%d\n", data,i, size, type);
SQLFreeStmt(v->hstmt, SQL_UNBIND);
SQLRETURN rc = SQLBindCol(v->hstmt, (SQLUSMALLINT)i, (SQLSMALLINT)type, (SQLPOINTER)data, (SQLINTEGER)size, &ncb);
if(rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO)
SQLSetPos(v->hstmt, 1, SQL_UPDATE, SQL_LOCK_NO_CHANGE);
//printf("update:%d, %d\n", rc, ncb);
SQLFreeStmt(v->hstmt, SQL_UNBIND);
return TRUE;
}
BOOL luastmt_get(stmtType *v, int i, int type)
{
if(i <= 0)
return FALSE;
if(!(v->status & LOS_STMTOPENED))
return FALSE;
SQLRETURN rc;
for(;;)
{
if(v->gci == i)
SQLSetPos(v->hstmt, 1, SQL_POSITION, SQL_LOCK_NO_CHANGE);
else
v->gci = i;
rc = SQLGetData(v->hstmt, i, type,
v->cols[i - 1].data, v->cols[i - 1].data_size, (SQLINTEGER*)&(v->cols[i - 1].nds));
if(rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO)
return TRUE;
if(v->cols[i - 1].data_size < v->cols[i - 1].nds)
luaodbc_coldata_realloc(v, i - 1, v->cols[i - 1].nds + 1);
else
return FALSE;
}
return false;
}
void *luastmt_dataptr(stmtType *v, int i)
{
if(i <= 0)
return false;
return v->cols[i - 1].data;
}
BOOL luastmt_next(stmtType *v)
{
if(!(v->status & LOS_STMTOPENED))
return false;
v->gci = 0;
SQLRETURN rc = SQLFetch(v->hstmt);
return (rc != SQL_NO_DATA_FOUND);
}
const char *luastmt_getstr(stmtType *v, int i)
{
if(!luastmt_get(v, i, SQL_C_CHAR))
return "";
return v->cols[i - 1].data;
}
int luastmt_getint(stmtType *v, int i)
{
if(!luastmt_get(v, i, SQL_C_SLONG))
return -1;
return *(int*)(v->cols[i - 1].data);
}
int _colindexfromename(stmtType *v, const char *name)
{
std::string str = name;
std::transform(str.begin(), str.end(), str.begin(), tolower);
for(short i = 0; i < v->col_count; i++)
{
if(str == (v->cols)[i].name)
return i + 1;
}
return -1;
}
const char *luastmt_getstr_fn(stmtType *v, const char *name)
{
if(!name)
return "";
return luastmt_getstr(v, _colindexfromename(v, name));
}
int luastmt_getint_fn(stmtType *v, const char *name)
{
if(!name)
return -1;
return luastmt_getint(v, _colindexfromename(v, name));
}
BOOL luastmt_setstr(stmtType *v, int i, const char *str)
{
if(i <= 0 || !str)
return FALSE;
return luastmt_set(v, i, SQL_C_CHAR, (char*)str, strlen(str) + 1);
}
BOOL luastmt_setint(stmtType *v, int i, int num)
{
if(!v || i <= 0)
return FALSE;
return luastmt_set(v, i, SQL_C_SLONG, (char*)&num, sizeof(int));
}
BOOL luastmt_setstr_fn(stmtType *v, const char *name, const char *str)
{
if(!v || !name)
return false;
return luastmt_setstr(v, _colindexfromename(v, name), str);
}
BOOL luastmt_setint_fn(stmtType *v, const char *name, int num)
{
if(!name)
return false;
return luastmt_setint(v, _colindexfromename(v, name), num);
}
void luastmt_class_init(stmtType *v)
{
BYTE *psc = v->class_memory;
#define make_luastmt_shellcode(_member_, _sc_t_, _pup_, _proc_) *(DWORD*)&(v->thisobj._member_) = (DWORD)psc; \
memcpy_s(psc, sizeof(_sc_t_), _sc_t_, sizeof(_sc_t_)); \
*(DWORD*)(psc + _pup_) = (DWORD)v; \
*(DWORD*)(psc + _pup_ + 5) = (DWORD)_proc_ - (DWORD)(psc + _pup_ + 5 - 1) - 5; \
psc += sizeof(_sc_t_)
make_luastmt_shellcode(close, scT1, SCTP1, luastmt_close);
make_luastmt_shellcode(set, scT5, SCTP5, luastmt_set);
make_luastmt_shellcode(get, scT3, SCTP3, luastmt_get);
make_luastmt_shellcode(data, scT1, SCTP1, luastmt_dataptr);
make_luastmt_shellcode(next, scT1, SCTP1, luastmt_next);
make_luastmt_shellcode(getstr, scT2, SCTP2, luastmt_getstr);
make_luastmt_shellcode(getint, scT2, SCTP2, luastmt_getint);
make_luastmt_shellcode(getstr_fn, scT2, SCTP2, luastmt_getstr_fn);
make_luastmt_shellcode(getint_fn, scT2, SCTP2, luastmt_getint_fn);
make_luastmt_shellcode(setstr, scT3, SCTP3, luastmt_setstr);
make_luastmt_shellcode(setint, scT3, SCTP3, luastmt_setint);
make_luastmt_shellcode(setstr_fn, scT3, SCTP3, luastmt_setstr_fn);
make_luastmt_shellcode(setint_fn, scT3, SCTP3, luastmt_setint_fn);
}
void luastmt_colinfo_release(stmtType *v)
{
for(int i = 0; i < v->cp_size; i++)
{
if(v->cols[i].data)
{
delete v->cols[i].data;
v->cols[i].data = NULL;
}
}
delete v->cols;
}
void luastmt_colinfo_init(stmtType *v)
{
short count;
SQLNumResultCols(v->hstmt, &count);
if(count > v->cp_size)
{
luastmt_colinfo_release(v);
v->cp_size = count;
v->cols = new ciType[count];
memset(v->cols, 0, sizeof(ciType) * count);
}
v->col_count = count;
char szColName[68];
SQLSMALLINT cbColName, sqlColType, ibScale, fNullable;
SQLUINTEGER cbColDef;
for(short i = 0; i < count; i++)
{
SQLDescribeColA(v->hstmt, i + 1, (SQLCHAR*)szColName, 68, &cbColName, &sqlColType, &cbColDef, &ibScale, &fNullable);
for(short j = 0; j < cbColName; j++)szColName[j] = tolower(szColName[j]);
memcpy_s(v->cols[i].name, 68, szColName, cbColName + 1);
if(v->cols[i].data_size < (int)cbColDef + 1)
{
if(v->cols[i].data)
delete v->cols[i].data;
v->cols[i].data_size = cbColDef + 1;
v->cols[i].data = new char[cbColDef + 1];
}
v->cols[i].type = sqlColType;
v->cols[i].type_len = cbColDef;
}
}
void luastmt_release(stmtType *v)
{
luastmt_colinfo_release(v);
VirtualFree(v->class_memory, 0, MEM_RELEASE);
delete v;
}
stmtType *luaodbc_newstmt(_odbcType *parent)
{
stmtType *v = new stmtType;
memset(v, 0, sizeof(stmtType));
v->cp_size = 64;
v->cols = new ciType[64];
memset(v->cols, 0, sizeof(ciType) * 64);
v->status = LOS_CLOSED;
v->next = NULL;
v->parent = parent;
v->class_memory = (unsigned char *)VirtualAlloc(NULL, 512, MEM_COMMIT, PAGE_EXECUTE_READWRITE);
luastmt_class_init(v);
return v;
}
void luaodbc_tran_begin(odbcType *v)
{
SQLSetConnectAttrA(v->hdbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_FALSE, NULL);
v->status |= LOS_TRAN;
}
void luaodbc_tran_end(odbcType *v, int ct)
{
SQLEndTran(SQL_HANDLE_DBC, v->hdbc, (short)ct);
SQLSetConnectAttrA(v->hdbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_TRUE, NULL);
v->status ^= LOS_TRAN;
}
void luaodbc_close(odbcType *v)
{
if(v->status & LOS_CONNECTIONED)
{
stmtType *s = v->stmt_list;
do{
luastmt_close(s);
s = s->next;
}while(s);
if(v->status & LOS_TRAN)
luaodbc_tran_end(v, SQL_ROLLBACK);
SQLDisconnect(v->hdbc);
SQLFreeHandle(SQL_HANDLE_DBC, v->hdbc);
v->status ^= LOS_CONNECTIONED;
v->hdbc = NULL;
}
}
BOOL luaodbc_open(odbcType *v, const char *dsn)
{
luaodbc_close(v);
SQLHANDLE hdbc;
SQLRETURN rc = SQLAllocHandle(SQL_HANDLE_DBC, v->henv, &hdbc);
if(rc != SQL_SUCCESS && rc != SQL_SUCCESS_WITH_INFO)
return FALSE;
rc = SQLConnectA(hdbc, (SQLCHAR*)dsn, SQL_NTS, NULL, 0, NULL, 0);
if(rc != SQL_SUCCESS && rc != SQL_SUCCESS_WITH_INFO)
{
SQLFreeHandle(SQL_HANDLE_DBC, hdbc);
return FALSE;
}
v->hdbc = hdbc;
v->status |= LOS_CONNECTIONED;
return TRUE;
}
luastmt_class *luaodbc_exec(odbcType *v, const char *sql)
{
stmtType *s = v->stmt_list;
stmtType *b = s;
do{
if(!(s->status & LOS_STMTOPENED))
break;
b = s;
s = s->next;
}while(s);
if(!s)
{
s = luaodbc_newstmt(v);
b->next = s;
}
SQLHANDLE hstmt;
SQLRETURN rc;
if(!(v->status & LOS_CONNECTIONED))
return NULL;
rc = SQLAllocHandle(SQL_HANDLE_STMT, v->hdbc, &hstmt);
if(rc != SQL_SUCCESS && rc != SQL_SUCCESS_WITH_INFO)
return FALSE;
SQLCancel(hstmt);
#if(_MSC_VER < 1300)
SQLSetStmtOption(hStmt, SQL_CONCURRENCY, SQL_CONCUR_VALUES);
SQLSetStmtOption(hStmt, SQL_CURSOR_TYPE, SQL_CURSOR_KEYSET_DRIVEN);
#else
SQLSetStmtAttr(hstmt, SQL_ATTR_CONCURRENCY, (SQLPOINTER)SQL_CONCUR_VALUES, 0);
SQLSetStmtAttr(hstmt, SQL_ATTR_CURSOR_TYPE, (SQLPOINTER)SQL_CURSOR_KEYSET_DRIVEN, 0);
#endif
rc = SQLExecDirectA(hstmt, (SQLCHAR*)sql, SQL_NTS);
if(rc != SQL_SUCCESS && rc != SQL_SUCCESS_WITH_INFO)
{
//printf("ERROR:[ODBC::Connection::Execute]->%s\n", sql);
SQLFreeHandle(SQL_HANDLE_STMT, hstmt);
return FALSE;
}
s->hstmt = hstmt;
luastmt_colinfo_init(s);
s->status |= LOS_STMTOPENED;
return &(s->thisobj);
}
void luaodbc_class_init(odbcType *v)
{
BYTE *psc = v->class_memory;
#define make_luaodbc_shellcode(_member_, _sc_t_, _pup_, _proc_) *(DWORD*)&(v->thisobj._member_) = (DWORD)psc; \
memcpy_s(psc, sizeof(_sc_t_), _sc_t_, sizeof(_sc_t_)); \
*(DWORD*)(psc + _pup_) = (DWORD)v; \
*(DWORD*)(psc + _pup_ + 5) = (DWORD)_proc_ - (DWORD)(psc + _pup_ + 5 - 1) - 5; \
psc += sizeof(_sc_t_)
make_luaodbc_shellcode(open, scT2, SCTP2, luaodbc_open);
make_luaodbc_shellcode(close, scT1, SCTP1, luaodbc_close);
make_luaodbc_shellcode(exec, scT2, SCTP2, luaodbc_exec);
make_luaodbc_shellcode(tran_begin, scT1, SCTP1, luaodbc_tran_begin);
make_luaodbc_shellcode(tran_end, scT2, SCTP2, luaodbc_tran_end);
}
__declspec(dllexport) luadbc_class *luadbc_init()
{
odbcType *v = new odbcType;
memset(v, 0, sizeof(odbcType));
SQLRETURN rc;
rc = SQLAllocHandle(SQL_HANDLE_ENV, NULL, &(v->henv));
if(rc != SQL_SUCCESS && rc != SQL_SUCCESS_WITH_INFO)
return NULL;
rc = SQLSetEnvAttr(v->henv, SQL_ATTR_ODBC_VERSION, (SQLPOINTER)SQL_OV_ODBC3, SQL_IS_INTEGER);
if(rc != SQL_SUCCESS && rc != SQL_SUCCESS_WITH_INFO)
{
SQLFreeHandle(SQL_HANDLE_ENV, v->henv);
v->henv = NULL;
return NULL;
}
v->stmt_list = luaodbc_newstmt(v);
v->status = LOS_INITILZED;
v->class_memory = (unsigned char *)VirtualAlloc(NULL, 512, MEM_COMMIT, PAGE_EXECUTE_READWRITE);
luaodbc_class_init(v);
v->thisobj.dbc_obj = v;
EnterCriticalSection(&g_safe_luaodbc_cs);
g_safe_luaodbc_list.push_back((int)&(v->thisobj));
LeaveCriticalSection(&g_safe_luaodbc_cs);
return (luadbc_class*)&(v->thisobj);
}
__declspec(dllexport) void luadbc_exit(luadbc_class *obj)
{
EnterCriticalSection(&g_safe_luaodbc_cs);
auto it = std::find(g_safe_luaodbc_list.begin(), g_safe_luaodbc_list.end(), (int)obj);
if(it == g_safe_luaodbc_list.end())
{
LeaveCriticalSection(&g_safe_luaodbc_cs);
return;
}
g_safe_luaodbc_list.erase(it);
LeaveCriticalSection(&g_safe_luaodbc_cs);
odbcType *v = (odbcType*)(((luadbc_class_ex*)obj)->dbc_obj);
VirtualFree(v->class_memory, 0, MEM_RELEASE);
luaodbc_close(v);
stmtType *s = v->stmt_list;
stmtType *b;
do{
b = s;
s = s->next;
luastmt_release(b);
}while(s);
SQLFreeHandle(SQL_HANDLE_ENV, v->henv);
delete v;
}
}
--luajit中这样调用
local ffi = require("ffi")
ffi.cdef[[
typedef struct{
void (*close)();
bool (*set)(int i, int type, char *data, int size);
bool (*get)(int i, int type);
void *(*data)();
bool (*next)();
char *(*getstr)(int i);
int (*getint)(int i);
char *(*getstr_fn)(const char *name);
char *(*getint_fn)(const char *name);
bool (*setstr)(int i, const char *str);
bool (*setint)(int i, int num);
bool (*setstr_fn)(const char *name, const char *str);
bool (*setint_fn)(const char *name, int num);
}luastmt_class;
typedef struct{
bool (*open)(const char*dsn);
void (*close)();
luastmt_class *(*exec)(const char *sql);
void (*tran_begin)();
void (*tran_end)(int ct);
}luadbc_class;
luadbc_class *luadbc_init();
void luadbc_exit(luadbc_class *obj);
]]
local conn = ffi.C.luadbc_init();
if conn.open("USERDSN") ~= false then
print("已连接!")
local rs = conn.exec("SELECT * FROM characters")
if rs ~= nil then
print("开始遍历数据!")
while rs.next() do
print(ffi.string(rs.getstr(3)))
end
rs.close()
end
conn.close()
end