From 71bbad98151ee2ea3c3913cef7acd1a0af0d8fac Mon Sep 17 00:00:00 2001 From: ygl Date: Sat, 11 Feb 2017 20:23:28 +0800 Subject: [PATCH] Connection supports execute query plan and using application defined function for index on expressions --- src/_sqlite3.py | 58 ++++++++++++++++++++++++++++++++++++++++++ src/_sqlite3_build.py | 4 +++ tests/test_function.py | 53 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+) create mode 100644 tests/test_function.py diff --git a/src/_sqlite3.py b/src/_sqlite3.py index 1816275..5bdd6c5 100644 --- a/src/_sqlite3.py +++ b/src/_sqlite3.py @@ -478,6 +478,34 @@ class Connection(object): if ret != _lib.SQLITE_OK: raise self.OperationalError("Error creating function") + @_check_thread_wrap + @_check_closed_wrap + def create_det_function(self, name, num_args, callback): + """ + create a deterministic function + :param name: + :param num_args: + :param callback: + :return: + """ + try: + closure = self.__func_cache[callback] + except KeyError: + @_ffi.callback("void(sqlite3_context*, int, sqlite3_value**)") + def closure(context, nargs, c_params): + _function_callback(callback, context, nargs, c_params) + + self.__func_cache[callback] = closure + + if isinstance(name, unicode): + name = name.encode('utf-8') + ret = _lib.sqlite3_create_function(self._db, name, num_args, + (_lib.SQLITE_UTF8 | _lib.SQLITE_DETERMINISTIC), + _ffi.NULL, + closure, _ffi.NULL, _ffi.NULL) + if ret != _lib.SQLITE_OK: + raise self.OperationalError("Error creating function") + @_check_thread_wrap @_check_closed_wrap def create_aggregate(self, name, num_args, cls): @@ -666,6 +694,36 @@ class Connection(object): else: _lib.sqlite3_backup_finish(bk_obj) + def execute_query_plan(self, sql, params): + if isinstance(sql, unicode): + sql = sql.encode('utf-8') + stmt_obj = Statement(self, sql) + stmt_obj._set_params(params) + explain_stmt = _ffi.new('sqlite3_stmt **') + c_prefix = _ffi.new("char[]", "EXPLAIN QUERY PLAN %s") + c_sql = _lib.sqlite3_sql(stmt_obj._statement) + c_explain_sql = _lib.sqlite3_mprintf(c_prefix, c_sql) + if c_explain_sql == _ffi.NULL: + raise self._get_exception(_lib.SQLITE_NOMEM) + ret = _lib.sqlite3_prepare_v2(self._db, c_explain_sql, -1, + explain_stmt, _ffi.NULL) + _lib.sqlite3_free(c_explain_sql) + if ret != _lib.SQLITE_OK: + raise self._get_exception(ret) + result = [] + while _lib.sqlite3_step(explain_stmt[0]) == _lib.SQLITE_ROW: + row = [_lib.sqlite3_column_int(explain_stmt[0], 0), + _lib.sqlite3_column_int(explain_stmt[0], 1), + _lib.sqlite3_column_int(explain_stmt[0], 2)] + detail_c = _lib.sqlite3_column_text(explain_stmt[0], 3) + detail_len = _lib.sqlite3_column_bytes(explain_stmt[0], 3) + buf_obj = _ffi.buffer(detail_c, detail_len) + detail_str = self.text_factory(buf_obj) + row.append(detail_str) + result.append(row) + _lib.sqlite3_finalize(explain_stmt[0]) + return result + class Cursor(object): __initialized = False diff --git a/src/_sqlite3_build.py b/src/_sqlite3_build.py index 7f0f165..06d04a3 100644 --- a/src/_sqlite3_build.py +++ b/src/_sqlite3_build.py @@ -102,6 +102,7 @@ static void *const SQLITE_TRANSIENT; #define SQLITE_CREATE_VTABLE ... #define SQLITE_DROP_VTABLE ... #define SQLITE_FUNCTION ... +#define SQLITE_DETERMINISTIC ... const char *sqlite3_libversion(void); @@ -232,6 +233,9 @@ int sqlite3_backup_finish(sqlite3_backup*); int sqlite3_backup_remaining(sqlite3_backup*); int sqlite3_backup_pagecount(sqlite3_backup*); int sqlite3_sleep(int); +char* sqlite3_mprintf(const char*, ...); +void sqlite3_free(void*); +const char *sqlite3_sql(sqlite3_stmt *pStmt); """) diff --git a/tests/test_function.py b/tests/test_function.py new file mode 100644 index 0000000..17db0c2 --- /dev/null +++ b/tests/test_function.py @@ -0,0 +1,53 @@ +import pytest +import base64 +import src._sqlite3 as _sqlite3 + + +class ConnWrap: + def __init__(self, conn_obj): + self.conn_obj = conn_obj + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None and exc_val is None and exc_tb is None: + self.conn_obj.commit() + else: + self.conn_obj.rollback() + self.conn_obj.close() + + +def SetB64EncodeFunc(conn_obj): + def __helper(bytes_obj): + if isinstance(bytes_obj, unicode): + temp = bytes_obj.encode('UTF-8') + else: + temp = bytes_obj if bytes_obj is not None else '' + return base64.b64encode(temp) + conn_obj.create_det_function('b64encode', 1, __helper) + + +@pytest.mark.skipif(False, reason='') +class TestSqlite(object): + @pytest.mark.skipif(False, reason='') + def test_det_function(self): + with ConnWrap(_sqlite3.connect("./bk1.db")) as conn: + conn.conn_obj.execute("drop table if exists test_det_func") + + conn.conn_obj.execute("create table if not exists test_det_func(id text, val integer)") + SetB64EncodeFunc(conn.conn_obj) + conn.conn_obj.execute("create index test_det_func_idx1 on test_det_func(b64encode(id))") + + with ConnWrap(_sqlite3.connect("./bk1.db")) as conn: + SetB64EncodeFunc(conn.conn_obj) + for i in xrange(1000): + conn.conn_obj.execute("insert into test_det_func(id,val) values (?,?)", (unicode(i), i+1)) + + with ConnWrap(_sqlite3.connect("./bk1.db")) as conn: + SetB64EncodeFunc(conn.conn_obj) + res = conn.conn_obj.execute_query_plan("select * from test_det_func where b64encode(id)=?", ('1',)) + rows = [row for row in res if row[3].find(u'USING INDEX') > 0 and row[3].find(u'') > 0] + assert len(rows) == 1 + +