Connection supports execute query plan and using application defined function for index on expressions

master
ygl 9 years ago
parent a7d54f6aa2
commit 71bbad9815
  1. 58
      src/_sqlite3.py
  2. 4
      src/_sqlite3_build.py
  3. 53
      tests/test_function.py

@ -478,6 +478,34 @@ class Connection(object):
if ret != _lib.SQLITE_OK:
raise self.OperationalError("Error creating function")
@_check_thread_wrap
@_check_closed_wrap
def create_det_function(self, name, num_args, callback):
"""
create a deterministic function
:param name:
:param num_args:
:param callback:
:return:
"""
try:
closure = self.__func_cache[callback]
except KeyError:
@_ffi.callback("void(sqlite3_context*, int, sqlite3_value**)")
def closure(context, nargs, c_params):
_function_callback(callback, context, nargs, c_params)
self.__func_cache[callback] = closure
if isinstance(name, unicode):
name = name.encode('utf-8')
ret = _lib.sqlite3_create_function(self._db, name, num_args,
(_lib.SQLITE_UTF8 | _lib.SQLITE_DETERMINISTIC),
_ffi.NULL,
closure, _ffi.NULL, _ffi.NULL)
if ret != _lib.SQLITE_OK:
raise self.OperationalError("Error creating function")
@_check_thread_wrap
@_check_closed_wrap
def create_aggregate(self, name, num_args, cls):
@ -666,6 +694,36 @@ class Connection(object):
else:
_lib.sqlite3_backup_finish(bk_obj)
def execute_query_plan(self, sql, params):
if isinstance(sql, unicode):
sql = sql.encode('utf-8')
stmt_obj = Statement(self, sql)
stmt_obj._set_params(params)
explain_stmt = _ffi.new('sqlite3_stmt **')
c_prefix = _ffi.new("char[]", "EXPLAIN QUERY PLAN %s")
c_sql = _lib.sqlite3_sql(stmt_obj._statement)
c_explain_sql = _lib.sqlite3_mprintf(c_prefix, c_sql)
if c_explain_sql == _ffi.NULL:
raise self._get_exception(_lib.SQLITE_NOMEM)
ret = _lib.sqlite3_prepare_v2(self._db, c_explain_sql, -1,
explain_stmt, _ffi.NULL)
_lib.sqlite3_free(c_explain_sql)
if ret != _lib.SQLITE_OK:
raise self._get_exception(ret)
result = []
while _lib.sqlite3_step(explain_stmt[0]) == _lib.SQLITE_ROW:
row = [_lib.sqlite3_column_int(explain_stmt[0], 0),
_lib.sqlite3_column_int(explain_stmt[0], 1),
_lib.sqlite3_column_int(explain_stmt[0], 2)]
detail_c = _lib.sqlite3_column_text(explain_stmt[0], 3)
detail_len = _lib.sqlite3_column_bytes(explain_stmt[0], 3)
buf_obj = _ffi.buffer(detail_c, detail_len)
detail_str = self.text_factory(buf_obj)
row.append(detail_str)
result.append(row)
_lib.sqlite3_finalize(explain_stmt[0])
return result
class Cursor(object):
__initialized = False

@ -102,6 +102,7 @@ static void *const SQLITE_TRANSIENT;
#define SQLITE_CREATE_VTABLE ...
#define SQLITE_DROP_VTABLE ...
#define SQLITE_FUNCTION ...
#define SQLITE_DETERMINISTIC ...
const char *sqlite3_libversion(void);
@ -232,6 +233,9 @@ int sqlite3_backup_finish(sqlite3_backup*);
int sqlite3_backup_remaining(sqlite3_backup*);
int sqlite3_backup_pagecount(sqlite3_backup*);
int sqlite3_sleep(int);
char* sqlite3_mprintf(const char*, ...);
void sqlite3_free(void*);
const char *sqlite3_sql(sqlite3_stmt *pStmt);
""")

@ -0,0 +1,53 @@
import pytest
import base64
import src._sqlite3 as _sqlite3
class ConnWrap:
def __init__(self, conn_obj):
self.conn_obj = conn_obj
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None and exc_val is None and exc_tb is None:
self.conn_obj.commit()
else:
self.conn_obj.rollback()
self.conn_obj.close()
def SetB64EncodeFunc(conn_obj):
def __helper(bytes_obj):
if isinstance(bytes_obj, unicode):
temp = bytes_obj.encode('UTF-8')
else:
temp = bytes_obj if bytes_obj is not None else ''
return base64.b64encode(temp)
conn_obj.create_det_function('b64encode', 1, __helper)
@pytest.mark.skipif(False, reason='')
class TestSqlite(object):
@pytest.mark.skipif(False, reason='')
def test_det_function(self):
with ConnWrap(_sqlite3.connect("./bk1.db")) as conn:
conn.conn_obj.execute("drop table if exists test_det_func")
conn.conn_obj.execute("create table if not exists test_det_func(id text, val integer)")
SetB64EncodeFunc(conn.conn_obj)
conn.conn_obj.execute("create index test_det_func_idx1 on test_det_func(b64encode(id))")
with ConnWrap(_sqlite3.connect("./bk1.db")) as conn:
SetB64EncodeFunc(conn.conn_obj)
for i in xrange(1000):
conn.conn_obj.execute("insert into test_det_func(id,val) values (?,?)", (unicode(i), i+1))
with ConnWrap(_sqlite3.connect("./bk1.db")) as conn:
SetB64EncodeFunc(conn.conn_obj)
res = conn.conn_obj.execute_query_plan("select * from test_det_func where b64encode(id)=?", ('1',))
rows = [row for row in res if row[3].find(u'USING INDEX') > 0 and row[3].find(u'<expr>') > 0]
assert len(rows) == 1
Loading…
Cancel
Save