]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-31746: Prevent segfaults when sqlite3.Connection is uninitialised (GH-27431)
authorErlend Egeberg Aasland <erlend.aasland@innova.no>
Thu, 29 Jul 2021 19:45:32 +0000 (21:45 +0200)
committerGitHub <noreply@github.com>
Thu, 29 Jul 2021 19:45:32 +0000 (20:45 +0100)
Lib/sqlite3/test/dbapi.py
Modules/_sqlite/connection.c

index 20cca33e23834bd0446e8b54eb19d87b9b63aa37..408f9945f2c970adf99d4ff2a957f8dc374bf16f 100644 (file)
@@ -243,6 +243,26 @@ class ConnectionTests(unittest.TestCase):
             self.assertEqual(cu.fetchone()[0], n)
 
 
+class UninitialisedConnectionTests(unittest.TestCase):
+    def setUp(self):
+        self.cx = sqlite.Connection.__new__(sqlite.Connection)
+
+    def test_uninit_operations(self):
+        funcs = (
+            lambda: self.cx.isolation_level,
+            lambda: self.cx.total_changes,
+            lambda: self.cx.in_transaction,
+            lambda: self.cx.iterdump(),
+            lambda: self.cx.cursor(),
+            lambda: self.cx.close(),
+        )
+        for func in funcs:
+            with self.subTest(func=func):
+                self.assertRaisesRegex(sqlite.ProgrammingError,
+                                       "Base Connection.__init__ not called",
+                                       func)
+
+
 class OpenTests(unittest.TestCase):
     _sql = "create table test(id integer)"
 
@@ -951,6 +971,7 @@ def suite():
         ModuleTests,
         SqliteOnConflictTests,
         ThreadTests,
+        UninitialisedConnectionTests,
     ]
     return unittest.TestSuite(
         [unittest.TestLoader().loadTestsFromTestCase(t) for t in tests]
index a95b75a0fe14a6acfbcd098b69c315f027dd3901..dd332e5e4d81a731c381b59ac24ffd99768d2332 100644 (file)
@@ -111,8 +111,6 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
 
     const char *database = PyBytes_AsString(database_obj);
 
-    self->initialized = 1;
-
     self->begin_statement = NULL;
 
     Py_CLEAR(self->statement_cache);
@@ -147,7 +145,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
         Py_INCREF(isolation_level);
     }
     Py_CLEAR(self->isolation_level);
-    if (pysqlite_connection_set_isolation_level(self, isolation_level, NULL) < 0) {
+    if (pysqlite_connection_set_isolation_level(self, isolation_level, NULL) != 0) {
         Py_DECREF(isolation_level);
         return -1;
     }
@@ -195,6 +193,8 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
         return -1;
     }
 
+    self->initialized = 1;
+
     return 0;
 }
 
@@ -371,6 +371,13 @@ pysqlite_connection_close_impl(pysqlite_Connection *self)
         return NULL;
     }
 
+    if (!self->initialized) {
+        pysqlite_state *state = pysqlite_get_state(NULL);
+        PyErr_SetString(state->ProgrammingError,
+                        "Base Connection.__init__ not called.");
+        return NULL;
+    }
+
     pysqlite_do_all_statements(self, ACTION_FINALIZE, 1);
     connection_close(self);
 
@@ -1258,6 +1265,9 @@ int pysqlite_check_thread(pysqlite_Connection* self)
 
 static PyObject* pysqlite_connection_get_isolation_level(pysqlite_Connection* self, void* unused)
 {
+    if (!pysqlite_check_connection(self)) {
+        return NULL;
+    }
     return Py_NewRef(self->isolation_level);
 }
 
@@ -1289,11 +1299,17 @@ pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* iso
         return -1;
     }
     if (isolation_level == Py_None) {
-        PyObject *res = pysqlite_connection_commit(self, NULL);
-        if (!res) {
-            return -1;
+        /* We might get called during connection init, so we cannot use
+         * pysqlite_connection_commit() here. */
+        if (self->db && !sqlite3_get_autocommit(self->db)) {
+            int rc;
+            Py_BEGIN_ALLOW_THREADS
+            rc = sqlite3_exec(self->db, "COMMIT", NULL, NULL, NULL);
+            Py_END_ALLOW_THREADS
+            if (rc != SQLITE_OK) {
+                return _pysqlite_seterror(self->db);
+            }
         }
-        Py_DECREF(res);
 
         self->begin_statement = NULL;
     } else {