]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-45828: Use unraisable exceptions within sqlite3 callbacks (FH-29591)
authorErlend Egeberg Aasland <erlend.aasland@innova.no>
Mon, 29 Nov 2021 15:22:32 +0000 (16:22 +0100)
committerGitHub <noreply@github.com>
Mon, 29 Nov 2021 15:22:32 +0000 (15:22 +0000)
Doc/library/sqlite3.rst
Doc/whatsnew/3.11.rst
Lib/test/test_sqlite3/test_hooks.py
Lib/test/test_sqlite3/test_userfunctions.py
Misc/NEWS.d/next/Library/2021-11-17-11-40-21.bpo-45828.kQU35U.rst [new file with mode: 0644]
Modules/_sqlite/connection.c

index 9fffe4da901a85f42dda695c1319ddc7daaebbc7..fb38182370ea9a7d8dda70cbba82d5cd0a863359 100644 (file)
@@ -329,9 +329,27 @@ Module functions and constants
 
    By default you will not get any tracebacks in user-defined functions,
    aggregates, converters, authorizer callbacks etc. If you want to debug them,
-   you can call this function with *flag* set to ``True``. Afterwards, you will
-   get tracebacks from callbacks on ``sys.stderr``. Use :const:`False` to
-   disable the feature again.
+   you can call this function with *flag* set to :const:`True`. Afterwards, you
+   will get tracebacks from callbacks on :data:`sys.stderr`. Use :const:`False`
+   to disable the feature again.
+
+   Register an :func:`unraisable hook handler <sys.unraisablehook>` for an
+   improved debug experience::
+
+      >>> import sqlite3
+      >>> sqlite3.enable_callback_tracebacks(True)
+      >>> cx = sqlite3.connect(":memory:")
+      >>> cx.set_trace_callback(lambda stmt: 5/0)
+      >>> cx.execute("select 1")
+      Exception ignored in: <function <lambda> at 0x10b4e3ee0>
+      Traceback (most recent call last):
+        File "<stdin>", line 1, in <lambda>
+      ZeroDivisionError: division by zero
+      >>> import sys
+      >>> sys.unraisablehook = lambda unraisable: print(unraisable)
+      >>> cx.execute("select 1")
+      UnraisableHookArgs(exc_type=<class 'ZeroDivisionError'>, exc_value=ZeroDivisionError('division by zero'), exc_traceback=<traceback object at 0x10b559900>, err_msg=None, object=<function <lambda> at 0x10b4e3ee0>)
+      <sqlite3.Cursor object at 0x10b1fe840>
 
 
 .. _sqlite3-connection-objects:
index 9751f894f9a9a544f5c566400324bbb5fc902c17..3b65921a92619688293c55f39569c778b002e6c7 100644 (file)
@@ -248,7 +248,6 @@ sqlite3
   (Contributed by Aviv Palivoda, Daniel Shahaf, and Erlend E. Aasland in
   :issue:`16379` and :issue:`24139`.)
 
-
 * Add :meth:`~sqlite3.Connection.setlimit` and
   :meth:`~sqlite3.Connection.getlimit` to :class:`sqlite3.Connection` for
   setting and getting SQLite limits by connection basis.
@@ -258,6 +257,12 @@ sqlite3
   threading mode the underlying SQLite library has been compiled with.
   (Contributed by Erlend E. Aasland in :issue:`45613`.)
 
+* :mod:`sqlite3` C callbacks now use unraisable exceptions if callback
+  tracebacks are enabled. Users can now register an
+  :func:`unraisable hook handler <sys.unraisablehook>` to improve their debug
+  experience.
+  (Contributed by Erlend E. Aasland in :issue:`45828`.)
+
 
 threading
 ---------
index bf454b2aa887f3f0c9b760309933cd78c8e36c9d..9e5e53ad223f0a9127bd7f96a33237988e915663 100644 (file)
@@ -197,7 +197,7 @@ class ProgressTests(unittest.TestCase):
         con.execute("select 1 union select 2 union select 3").fetchall()
         self.assertEqual(action, 0, "progress handler was not cleared")
 
-    @with_tracebacks(['bad_progress', 'ZeroDivisionError'])
+    @with_tracebacks(ZeroDivisionError, name="bad_progress")
     def test_error_in_progress_handler(self):
         con = sqlite.connect(":memory:")
         def bad_progress():
@@ -208,7 +208,7 @@ class ProgressTests(unittest.TestCase):
                 create table foo(a, b)
                 """)
 
-    @with_tracebacks(['__bool__', 'ZeroDivisionError'])
+    @with_tracebacks(ZeroDivisionError, name="bad_progress")
     def test_error_in_progress_handler_result(self):
         con = sqlite.connect(":memory:")
         class BadBool:
index 62a11a5431b7b8329788066015c9023628574f73..996437b1a4bee89e87d4558ff2348b6e8b826aa1 100644 (file)
@@ -25,46 +25,52 @@ import contextlib
 import functools
 import gc
 import io
+import re
 import sys
 import unittest
 import unittest.mock
 import sqlite3 as sqlite
 
-from test.support import bigmemtest
+from test.support import bigmemtest, catch_unraisable_exception
 from .test_dbapi import cx_limit
 
 
-def with_tracebacks(strings, traceback=True):
+def with_tracebacks(exc, regex="", name=""):
     """Convenience decorator for testing callback tracebacks."""
-    if traceback:
-        strings.append('Traceback')
-
     def decorator(func):
+        _regex = re.compile(regex) if regex else None
         @functools.wraps(func)
         def wrapper(self, *args, **kwargs):
-            # First, run the test with traceback enabled.
-            with check_tracebacks(self, strings):
-                func(self, *args, **kwargs)
+            with catch_unraisable_exception() as cm:
+                # First, run the test with traceback enabled.
+                with check_tracebacks(self, cm, exc, _regex, name):
+                    func(self, *args, **kwargs)
 
             # Then run the test with traceback disabled.
             func(self, *args, **kwargs)
         return wrapper
     return decorator
 
+
 @contextlib.contextmanager
-def check_tracebacks(self, strings):
+def check_tracebacks(self, cm, exc, regex, obj_name):
     """Convenience context manager for testing callback tracebacks."""
     sqlite.enable_callback_tracebacks(True)
     try:
         buf = io.StringIO()
         with contextlib.redirect_stderr(buf):
             yield
-        tb = buf.getvalue()
-        for s in strings:
-            self.assertIn(s, tb)
+
+        self.assertEqual(cm.unraisable.exc_type, exc)
+        if regex:
+            msg = str(cm.unraisable.exc_value)
+            self.assertIsNotNone(regex.search(msg))
+        if obj_name:
+            self.assertEqual(cm.unraisable.object.__name__, obj_name)
     finally:
         sqlite.enable_callback_tracebacks(False)
 
+
 def func_returntext():
     return "foo"
 def func_returntextwithnull():
@@ -299,7 +305,7 @@ class FunctionTests(unittest.TestCase):
         val = cur.fetchone()[0]
         self.assertEqual(val, 1<<31)
 
-    @with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError'])
+    @with_tracebacks(ZeroDivisionError, name="func_raiseexception")
     def test_func_exception(self):
         cur = self.con.cursor()
         with self.assertRaises(sqlite.OperationalError) as cm:
@@ -307,14 +313,14 @@ class FunctionTests(unittest.TestCase):
             cur.fetchone()
         self.assertEqual(str(cm.exception), 'user-defined function raised exception')
 
-    @with_tracebacks(['func_memoryerror', 'MemoryError'])
+    @with_tracebacks(MemoryError, name="func_memoryerror")
     def test_func_memory_error(self):
         cur = self.con.cursor()
         with self.assertRaises(MemoryError):
             cur.execute("select memoryerror()")
             cur.fetchone()
 
-    @with_tracebacks(['func_overflowerror', 'OverflowError'])
+    @with_tracebacks(OverflowError, name="func_overflowerror")
     def test_func_overflow_error(self):
         cur = self.con.cursor()
         with self.assertRaises(sqlite.DataError):
@@ -426,22 +432,21 @@ class FunctionTests(unittest.TestCase):
         del x,y
         gc.collect()
 
+    @with_tracebacks(OverflowError)
     def test_func_return_too_large_int(self):
         cur = self.con.cursor()
         for value in 2**63, -2**63-1, 2**64:
             self.con.create_function("largeint", 0, lambda value=value: value)
-            with check_tracebacks(self, ['OverflowError']):
-                with self.assertRaises(sqlite.DataError):
-                    cur.execute("select largeint()")
+            with self.assertRaises(sqlite.DataError):
+                cur.execute("select largeint()")
 
+    @with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr")
     def test_func_return_text_with_surrogates(self):
         cur = self.con.cursor()
         self.con.create_function("pychr", 1, chr)
         for value in 0xd8ff, 0xdcff:
-            with check_tracebacks(self,
-                    ['UnicodeEncodeError', 'surrogates not allowed']):
-                with self.assertRaises(sqlite.OperationalError):
-                    cur.execute("select pychr(?)", (value,))
+            with self.assertRaises(sqlite.OperationalError):
+                cur.execute("select pychr(?)", (value,))
 
     @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
     @bigmemtest(size=2**31, memuse=3, dry_run=False)
@@ -510,7 +515,7 @@ class AggregateTests(unittest.TestCase):
             val = cur.fetchone()[0]
         self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
 
-    @with_tracebacks(['__init__', '5/0', 'ZeroDivisionError'])
+    @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
     def test_aggr_exception_in_init(self):
         cur = self.con.cursor()
         with self.assertRaises(sqlite.OperationalError) as cm:
@@ -518,7 +523,7 @@ class AggregateTests(unittest.TestCase):
             val = cur.fetchone()[0]
         self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
 
-    @with_tracebacks(['step', '5/0', 'ZeroDivisionError'])
+    @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep")
     def test_aggr_exception_in_step(self):
         cur = self.con.cursor()
         with self.assertRaises(sqlite.OperationalError) as cm:
@@ -526,7 +531,7 @@ class AggregateTests(unittest.TestCase):
             val = cur.fetchone()[0]
         self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
 
-    @with_tracebacks(['finalize', '5/0', 'ZeroDivisionError'])
+    @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize")
     def test_aggr_exception_in_finalize(self):
         cur = self.con.cursor()
         with self.assertRaises(sqlite.OperationalError) as cm:
@@ -643,11 +648,11 @@ class AuthorizerRaiseExceptionTests(AuthorizerTests):
             raise ValueError
         return sqlite.SQLITE_OK
 
-    @with_tracebacks(['authorizer_cb', 'ValueError'])
+    @with_tracebacks(ValueError, name="authorizer_cb")
     def test_table_access(self):
         super().test_table_access()
 
-    @with_tracebacks(['authorizer_cb', 'ValueError'])
+    @with_tracebacks(ValueError, name="authorizer_cb")
     def test_column_access(self):
         super().test_table_access()
 
diff --git a/Misc/NEWS.d/next/Library/2021-11-17-11-40-21.bpo-45828.kQU35U.rst b/Misc/NEWS.d/next/Library/2021-11-17-11-40-21.bpo-45828.kQU35U.rst
new file mode 100644 (file)
index 0000000..07ec273
--- /dev/null
@@ -0,0 +1,2 @@
+:mod:`sqlite` C callbacks now use unraisable exceptions if callback
+tracebacks are enabled. Patch by Erlend E. Aasland.
index 0bc9d1d0eeda5ff7128323a11d0448750be295b7..4f0baa649e1d0535b74d2b9cccb2fd2019b9d063 100644 (file)
@@ -691,7 +691,7 @@ print_or_clear_traceback(callback_context *ctx)
     assert(ctx != NULL);
     assert(ctx->state != NULL);
     if (ctx->state->enable_callback_tracebacks) {
-        PyErr_Print();
+        PyErr_WriteUnraisable(ctx->callable);
     }
     else {
         PyErr_Clear();