]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-44839: Raise more specific errors in sqlite3 (GH-27613)
authorSerhiy Storchaka <storchaka@gmail.com>
Fri, 6 Aug 2021 18:28:47 +0000 (21:28 +0300)
committerGitHub <noreply@github.com>
Fri, 6 Aug 2021 18:28:47 +0000 (21:28 +0300)
MemoryError raised in user-defined functions will now preserve
its type. OverflowError will now be converted to DataError.
Previously both were converted to OperationalError.

Lib/sqlite3/test/userfunctions.py
Misc/NEWS.d/next/Library/2021-08-05-14-59-39.bpo-44839.MURNL9.rst [new file with mode: 0644]
Modules/_sqlite/connection.c

index b3da3c425b8035dfc8cac6bdd2eca0ad9e8f6c5b..9681dbdde2b0921ed60c015fbb1f61cbca7ec390 100644 (file)
 
 import contextlib
 import functools
+import gc
 import io
+import sys
 import unittest
 import unittest.mock
-import gc
 import sqlite3 as sqlite
 
+from test.support import bigmemtest
+
+
 def with_tracebacks(strings):
     """Convenience decorator for testing callback tracebacks."""
     strings.append('Traceback')
@@ -69,6 +73,10 @@ def func_returnlonglong():
     return 1<<31
 def func_raiseexception():
     5/0
+def func_memoryerror():
+    raise MemoryError
+def func_overflowerror():
+    raise OverflowError
 
 def func_isstring(v):
     return type(v) is str
@@ -187,6 +195,8 @@ class FunctionTests(unittest.TestCase):
         self.con.create_function("returnblob", 0, func_returnblob)
         self.con.create_function("returnlonglong", 0, func_returnlonglong)
         self.con.create_function("raiseexception", 0, func_raiseexception)
+        self.con.create_function("memoryerror", 0, func_memoryerror)
+        self.con.create_function("overflowerror", 0, func_overflowerror)
 
         self.con.create_function("isstring", 1, func_isstring)
         self.con.create_function("isint", 1, func_isint)
@@ -279,6 +289,20 @@ class FunctionTests(unittest.TestCase):
             cur.fetchone()
         self.assertEqual(str(cm.exception), 'user-defined function raised exception')
 
+    @with_tracebacks(['func_memoryerror', 'MemoryError'])
+    def test_func_memory_error(self):
+        cur = self.con.cursor()
+        with self.assertRaises(MemoryError):
+            cur.execute("select memoryerror()")
+            cur.fetchone()
+
+    @with_tracebacks(['func_overflowerror', 'OverflowError'])
+    def test_func_overflow_error(self):
+        cur = self.con.cursor()
+        with self.assertRaises(sqlite.DataError):
+            cur.execute("select overflowerror()")
+            cur.fetchone()
+
     def test_param_string(self):
         cur = self.con.cursor()
         for text in ["foo", str()]:
@@ -384,6 +408,25 @@ class FunctionTests(unittest.TestCase):
         del x,y
         gc.collect()
 
+    @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
+    @bigmemtest(size=2**31, memuse=3, dry_run=False)
+    def test_large_text(self, size):
+        cur = self.con.cursor()
+        for size in 2**31-1, 2**31:
+            self.con.create_function("largetext", 0, lambda size=size: "b" * size)
+            with self.assertRaises(sqlite.DataError):
+                cur.execute("select largetext()")
+
+    @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform')
+    @bigmemtest(size=2**31, memuse=2, dry_run=False)
+    def test_large_blob(self, size):
+        cur = self.con.cursor()
+        for size in 2**31-1, 2**31:
+            self.con.create_function("largeblob", 0, lambda size=size: b"b" * size)
+            with self.assertRaises(sqlite.DataError):
+                cur.execute("select largeblob()")
+
+
 class AggregateTests(unittest.TestCase):
     def setUp(self):
         self.con = sqlite.connect(":memory:")
diff --git a/Misc/NEWS.d/next/Library/2021-08-05-14-59-39.bpo-44839.MURNL9.rst b/Misc/NEWS.d/next/Library/2021-08-05-14-59-39.bpo-44839.MURNL9.rst
new file mode 100644 (file)
index 0000000..62ad62c
--- /dev/null
@@ -0,0 +1,4 @@
+:class:`MemoryError` raised in user-defined functions will now produce a
+``MemoryError`` in :mod:`sqlite3`. :class:`OverflowError` will now be converted
+to :class:`~sqlite3.DataError`. Previously
+:class:`~sqlite3.OperationalError` was produced in these cases.
index aae6c66d63fabaa69272911cedac14ea80eb3755..0dab3e85160e820cb8ff069e2c46fa049f8ae435 100644 (file)
@@ -619,6 +619,29 @@ error:
     return NULL;
 }
 
+// Checks the Python exception and sets the appropriate SQLite error code.
+static void
+set_sqlite_error(sqlite3_context *context, const char *msg)
+{
+    assert(PyErr_Occurred());
+    if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
+        sqlite3_result_error_nomem(context);
+    }
+    else if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
+        sqlite3_result_error_toobig(context);
+    }
+    else {
+        sqlite3_result_error(context, msg, -1);
+    }
+    pysqlite_state *state = pysqlite_get_state(NULL);
+    if (state->enable_callback_tracebacks) {
+        PyErr_Print();
+    }
+    else {
+        PyErr_Clear();
+    }
+}
+
 static void
 _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
 {
@@ -645,14 +668,7 @@ _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv
         Py_DECREF(py_retval);
     }
     if (!ok) {
-        pysqlite_state *state = pysqlite_get_state(NULL);
-        if (state->enable_callback_tracebacks) {
-            PyErr_Print();
-        }
-        else {
-            PyErr_Clear();
-        }
-        sqlite3_result_error(context, "user-defined function raised exception", -1);
+        set_sqlite_error(context, "user-defined function raised exception");
     }
 
     PyGILState_Release(threadstate);
@@ -676,18 +692,9 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
 
     if (*aggregate_instance == NULL) {
         *aggregate_instance = _PyObject_CallNoArg(aggregate_class);
-
-        if (PyErr_Occurred()) {
-            *aggregate_instance = 0;
-
-            pysqlite_state *state = pysqlite_get_state(NULL);
-            if (state->enable_callback_tracebacks) {
-                PyErr_Print();
-            }
-            else {
-                PyErr_Clear();
-            }
-            sqlite3_result_error(context, "user-defined aggregate's '__init__' method raised error", -1);
+        if (!*aggregate_instance) {
+            set_sqlite_error(context,
+                    "user-defined aggregate's '__init__' method raised error");
             goto error;
         }
     }
@@ -706,14 +713,8 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
     Py_DECREF(args);
 
     if (!function_result) {
-        pysqlite_state *state = pysqlite_get_state(NULL);
-        if (state->enable_callback_tracebacks) {
-            PyErr_Print();
-        }
-        else {
-            PyErr_Clear();
-        }
-        sqlite3_result_error(context, "user-defined aggregate's 'step' method raised error", -1);
+        set_sqlite_error(context,
+                "user-defined aggregate's 'step' method raised error");
     }
 
 error:
@@ -761,14 +762,8 @@ _pysqlite_final_callback(sqlite3_context *context)
         Py_DECREF(function_result);
     }
     if (!ok) {
-        pysqlite_state *state = pysqlite_get_state(NULL);
-        if (state->enable_callback_tracebacks) {
-            PyErr_Print();
-        }
-        else {
-            PyErr_Clear();
-        }
-        sqlite3_result_error(context, "user-defined aggregate's 'finalize' method raised error", -1);
+        set_sqlite_error(context,
+                "user-defined aggregate's 'finalize' method raised error");
     }
 
     /* Restore the exception (if any) of the last call to step(),