]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Issue #11688: Add sqlite3.Connection.set_trace_callback(). Patch by Torsten Landschoff.
authorAntoine Pitrou <solipsis@pitrou.net>
Sun, 3 Apr 2011 22:12:04 +0000 (00:12 +0200)
committerAntoine Pitrou <solipsis@pitrou.net>
Sun, 3 Apr 2011 22:12:04 +0000 (00:12 +0200)
Doc/library/sqlite3.rst
Lib/sqlite3/test/hooks.py
Misc/NEWS
Modules/_sqlite/connection.c

index 736767443036cb97d00b44bc4fdd911688527084..01c306c6d6c0863896220f11c3d41f718fe98fcb 100644 (file)
@@ -369,6 +369,22 @@ Connection Objects
    method with :const:`None` for *handler*.
 
 
+.. method:: Connection.set_trace_callback(trace_callback)
+
+   Registers *trace_callback* to be called for each SQL statement that is
+   actually executed by the SQLite backend.
+
+   The only argument passed to the callback is the statement (as string) that
+   is being executed. The return value of the callback is ignored. Note that
+   the backend does not only run statements passed to the :meth:`Cursor.execute`
+   methods.  Other sources include the transaction management of the Python
+   module and the execution of triggers defined in the current database.
+
+   Passing :const:`None` as *trace_callback* will disable the trace callback.
+
+   .. versionadded:: 3.3
+
+
 .. method:: Connection.enable_load_extension(enabled)
 
    This routine allows/disallows the SQLite engine to load SQLite extensions
index a6161fac85399699904c1c64ce4edc3c4c5e684e..94da785cc3e7e7dbafcc7b01c5240ba026c5893d 100644 (file)
@@ -175,10 +175,56 @@ class ProgressTests(unittest.TestCase):
         con.execute("select 1 union select 2 union select 3").fetchall()
         self.assertEqual(action, 0, "progress handler was not cleared")
 
+class TraceCallbackTests(unittest.TestCase):
+    def CheckTraceCallbackUsed(self):
+        """
+        Test that the trace callback is invoked once it is set.
+        """
+        con = sqlite.connect(":memory:")
+        traced_statements = []
+        def trace(statement):
+            traced_statements.append(statement)
+        con.set_trace_callback(trace)
+        con.execute("create table foo(a, b)")
+        self.assertTrue(traced_statements)
+        self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
+
+    def CheckClearTraceCallback(self):
+        """
+        Test that setting the trace callback to None clears the previously set callback.
+        """
+        con = sqlite.connect(":memory:")
+        traced_statements = []
+        def trace(statement):
+            traced_statements.append(statement)
+        con.set_trace_callback(trace)
+        con.set_trace_callback(None)
+        con.execute("create table foo(a, b)")
+        self.assertFalse(traced_statements, "trace callback was not cleared")
+
+    def CheckUnicodeContent(self):
+        """
+        Test that the statement can contain unicode literals.
+        """
+        unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
+        con = sqlite.connect(":memory:")
+        traced_statements = []
+        def trace(statement):
+            traced_statements.append(statement)
+        con.set_trace_callback(trace)
+        con.execute("create table foo(x)")
+        con.execute("insert into foo(x) values (?)", (unicode_value,))
+        con.commit()
+        self.assertTrue(any(unicode_value in stmt for stmt in traced_statements),
+                "Unicode data garbled in trace callback")
+
+
+
 def suite():
     collation_suite = unittest.makeSuite(CollationTests, "Check")
     progress_suite = unittest.makeSuite(ProgressTests, "Check")
-    return unittest.TestSuite((collation_suite, progress_suite))
+    trace_suite = unittest.makeSuite(TraceCallbackTests, "Check")
+    return unittest.TestSuite((collation_suite, progress_suite, trace_suite))
 
 def test():
     runner = unittest.TextTestRunner()
index 9afa06c418f325d5d91d07e881a7be582dac9359..0b3d22d571e9d47856ae0bf78154977f4525e489 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -87,6 +87,9 @@ Core and Builtins
 Library
 -------
 
+- Issue #11688: Add sqlite3.Connection.set_trace_callback().  Patch by
+  Torsten Landschoff.
+
 - Issue #11746: Fix SSLContext.load_cert_chain() to accept elliptic curve
   private keys.
 
index 4bc16c394b1187e96978e42b451142eca85eb7c7..76d635a8fd12ab2f6a7b1bc5aec24b83ab6d3fc8 100644 (file)
@@ -904,6 +904,38 @@ static int _progress_handler(void* user_arg)
     return rc;
 }
 
+static void _trace_callback(void* user_arg, const char* statement_string)
+{
+    PyObject *py_statement = NULL;
+    PyObject *ret = NULL;
+
+#ifdef WITH_THREAD
+    PyGILState_STATE gilstate;
+
+    gilstate = PyGILState_Ensure();
+#endif
+    py_statement = PyUnicode_DecodeUTF8(statement_string,
+            strlen(statement_string), "replace");
+    if (py_statement) {
+        ret = PyObject_CallFunctionObjArgs((PyObject*)user_arg, py_statement, NULL);
+        Py_DECREF(py_statement);
+    }
+
+    if (ret) {
+        Py_DECREF(ret);
+    } else {
+        if (_enable_callback_tracebacks) {
+            PyErr_Print();
+        } else {
+            PyErr_Clear();
+        }
+    }
+
+#ifdef WITH_THREAD
+    PyGILState_Release(gilstate);
+#endif
+}
+
 static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
 {
     PyObject* authorizer_cb;
@@ -963,6 +995,34 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s
     return Py_None;
 }
 
+static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
+{
+    PyObject* trace_callback;
+
+    static char *kwlist[] = { "trace_callback", NULL };
+
+    if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
+        return NULL;
+    }
+
+    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:set_trace_callback",
+                                      kwlist, &trace_callback)) {
+        return NULL;
+    }
+
+    if (trace_callback == Py_None) {
+        /* None clears the trace callback previously set */
+        sqlite3_trace(self->db, 0, (void*)0);
+    } else {
+        if (PyDict_SetItem(self->function_pinboard, trace_callback, Py_None) == -1)
+            return NULL;
+        sqlite3_trace(self->db, _trace_callback, trace_callback);
+    }
+
+    Py_INCREF(Py_None);
+    return Py_None;
+}
+
 #ifdef HAVE_LOAD_EXTENSION
 static PyObject* pysqlite_enable_load_extension(pysqlite_Connection* self, PyObject* args)
 {
@@ -1516,6 +1576,8 @@ static PyMethodDef connection_methods[] = {
     #endif
     {"set_progress_handler", (PyCFunction)pysqlite_connection_set_progress_handler, METH_VARARGS|METH_KEYWORDS,
         PyDoc_STR("Sets progress handler callback. Non-standard.")},
+    {"set_trace_callback", (PyCFunction)pysqlite_connection_set_trace_callback, METH_VARARGS|METH_KEYWORDS,
+        PyDoc_STR("Sets a trace callback called for each SQL statement (passed as unicode). Non-standard.")},
     {"execute", (PyCFunction)pysqlite_connection_execute, METH_VARARGS,
         PyDoc_STR("Executes a SQL statement. Non-standard.")},
     {"executemany", (PyCFunction)pysqlite_connection_executemany, METH_VARARGS,