]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-129928: Raise more accurate exception for incorrect sqlite3 UDF creation (#129941)
authorErlend E. Aasland <erlend@python.org>
Tue, 11 Feb 2025 07:26:01 +0000 (08:26 +0100)
committerGitHub <noreply@github.com>
Tue, 11 Feb 2025 07:26:01 +0000 (08:26 +0100)
Consistently raise ProgrammingError if the user tries to create an UDF
with an invalid number of parameters.

Lib/test/test_sqlite3/test_userfunctions.py
Misc/NEWS.d/next/Library/2025-02-10-08-44-11.gh-issue-129928.QuiZEz.rst [new file with mode: 0644]
Modules/_sqlite/connection.c

index 5bb2eff55ebc8f153e8211d50138a13915f618ff..3abc43a3b1afde3166b1df96d8b8e4fb472b9b02 100644 (file)
@@ -171,7 +171,7 @@ class FunctionTests(unittest.TestCase):
         self.con.close()
 
     def test_func_error_on_create(self):
-        with self.assertRaises(sqlite.OperationalError):
+        with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"):
             self.con.create_function("bla", -100, lambda x: 2*x)
 
     def test_func_too_many_args(self):
@@ -507,9 +507,8 @@ class WindowFunctionTests(unittest.TestCase):
         self.assertEqual(self.cur.fetchall(), self.expected)
 
     def test_win_error_on_create(self):
-        self.assertRaises(sqlite.ProgrammingError,
-                          self.con.create_window_function,
-                          "shouldfail", -100, WindowSumInt)
+        with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"):
+            self.con.create_window_function("shouldfail", -100, WindowSumInt)
 
     @with_tracebacks(BadWindow)
     def test_win_exception_in_method(self):
@@ -638,7 +637,7 @@ class AggregateTests(unittest.TestCase):
         self.con.close()
 
     def test_aggr_error_on_create(self):
-        with self.assertRaises(sqlite.OperationalError):
+        with self.assertRaisesRegex(sqlite.ProgrammingError, "not -100"):
             self.con.create_function("bla", -100, AggrSum)
 
     @with_tracebacks(AttributeError, msg_regex="AggrNoStep")
diff --git a/Misc/NEWS.d/next/Library/2025-02-10-08-44-11.gh-issue-129928.QuiZEz.rst b/Misc/NEWS.d/next/Library/2025-02-10-08-44-11.gh-issue-129928.QuiZEz.rst
new file mode 100644 (file)
index 0000000..9c24eb3
--- /dev/null
@@ -0,0 +1,2 @@
+Raise :exc:`sqlite3.ProgrammingError` if a user-defined SQL function with
+invalid number of parameters is created. Patch by Erlend Aasland.
index 16afd7eada113f45848f42e92c98db0f75938b80..a4191dd0a1cb2c4d0547b606eaa85a4a30567e0f 100644 (file)
@@ -1139,6 +1139,20 @@ destructor_callback(void *ctx)
     }
 }
 
+static int
+check_num_params(pysqlite_Connection *self, const int n, const char *name)
+{
+    int limit = sqlite3_limit(self->db, SQLITE_LIMIT_FUNCTION_ARG, -1);
+    assert(limit >= 0);
+    if (n < -1 || n > limit) {
+        PyErr_Format(self->ProgrammingError,
+                     "'%s' must be between -1 and %d, not %d",
+                     name, limit, n);
+        return -1;
+    }
+    return 0;
+}
+
 /*[clinic input]
 _sqlite3.Connection.create_function as pysqlite_connection_create_function
 
@@ -1167,6 +1181,9 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
     if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
         return NULL;
     }
+    if (check_num_params(self, narg, "narg") < 0) {
+        return NULL;
+    }
 
     if (deterministic) {
         flags |= SQLITE_DETERMINISTIC;
@@ -1307,10 +1324,12 @@ create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls,
                         "SQLite 3.25.0 or higher");
         return NULL;
     }
-
     if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
         return NULL;
     }
+    if (check_num_params(self, num_params, "num_params") < 0) {
+        return NULL;
+    }
 
     int flags = SQLITE_UTF8;
     int rc;
@@ -1367,6 +1386,9 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
     if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
         return NULL;
     }
+    if (check_num_params(self, n_arg, "n_arg") < 0) {
+        return NULL;
+    }
 
     callback_context *ctx = create_callback_context(cls, aggregate_class);
     if (ctx == NULL) {