]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-42213: Check connection in sqlite3.Connection.__enter__ (GH-26512)
authorErlend Egeberg Aasland <erlend.aasland@innova.no>
Thu, 3 Jun 2021 15:53:47 +0000 (17:53 +0200)
committerGitHub <noreply@github.com>
Thu, 3 Jun 2021 15:53:47 +0000 (17:53 +0200)
Try to harden connection close:

- add tests that exercise stuff against a closed database
- add wrapper for sqlite3_close_v2()
- check connection on __enter__
- explicitly free pending statements before close()
- sqlite3_close_v2() always returns SQLITE_OK

Lib/sqlite3/test/dbapi.py
Modules/_sqlite/connection.c

index 39c9bf5b61143db5f221d1e0b4fe0509d5d86d75..ab3313533940a35fb1366aedaff5585a2f4fd65f 100644 (file)
@@ -135,6 +135,26 @@ class ConnectionTests(unittest.TestCase):
     def test_close(self):
         self.cx.close()
 
+    def test_use_after_close(self):
+        sql = "select 1"
+        cu = self.cx.cursor()
+        res = cu.execute(sql)
+        self.cx.close()
+        self.assertRaises(sqlite.ProgrammingError, res.fetchall)
+        self.assertRaises(sqlite.ProgrammingError, cu.execute, sql)
+        self.assertRaises(sqlite.ProgrammingError, cu.executemany, sql, [])
+        self.assertRaises(sqlite.ProgrammingError, cu.executescript, sql)
+        self.assertRaises(sqlite.ProgrammingError, self.cx.execute, sql)
+        self.assertRaises(sqlite.ProgrammingError,
+                          self.cx.executemany, sql, [])
+        self.assertRaises(sqlite.ProgrammingError, self.cx.executescript, sql)
+        self.assertRaises(sqlite.ProgrammingError,
+                          self.cx.create_function, "t", 1, lambda x: x)
+        self.assertRaises(sqlite.ProgrammingError, self.cx.cursor)
+        with self.assertRaises(sqlite.ProgrammingError):
+            with self.cx:
+                pass
+
     def test_exceptions(self):
         # Optional DB-API extension.
         self.assertEqual(self.cx.Warning, sqlite.Warning)
index e15629c8aba245aef208a157bf7922262dd83936..62c4dc3bbb3943c82ce29ddc346b8948d9b98ac7 100644 (file)
@@ -258,6 +258,16 @@ connection_clear(pysqlite_Connection *self)
     return 0;
 }
 
+static void
+connection_close(pysqlite_Connection *self)
+{
+    if (self->db) {
+        int rc = sqlite3_close_v2(self->db);
+        assert(rc == SQLITE_OK);
+        self->db = NULL;
+    }
+}
+
 static void
 connection_dealloc(pysqlite_Connection *self)
 {
@@ -266,9 +276,7 @@ connection_dealloc(pysqlite_Connection *self)
     tp->tp_clear((PyObject *)self);
 
     /* Clean up if user has not called .close() explicitly. */
-    if (self->db) {
-        sqlite3_close_v2(self->db);
-    }
+    connection_close(self);
 
     tp->tp_free(self);
     Py_DECREF(tp);
@@ -353,24 +361,12 @@ static PyObject *
 pysqlite_connection_close_impl(pysqlite_Connection *self)
 /*[clinic end generated code: output=a546a0da212c9b97 input=3d58064bbffaa3d3]*/
 {
-    int rc;
-
     if (!pysqlite_check_thread(self)) {
         return NULL;
     }
 
     pysqlite_do_all_statements(self, ACTION_FINALIZE, 1);
-
-    if (self->db) {
-        rc = sqlite3_close_v2(self->db);
-
-        if (rc != SQLITE_OK) {
-            _pysqlite_seterror(self->db);
-            return NULL;
-        } else {
-            self->db = NULL;
-        }
-    }
+    connection_close(self);
 
     Py_RETURN_NONE;
 }
@@ -1820,6 +1816,9 @@ static PyObject *
 pysqlite_connection_enter_impl(pysqlite_Connection *self)
 /*[clinic end generated code: output=457b09726d3e9dcd input=127d7a4f17e86d8f]*/
 {
+    if (!pysqlite_check_connection(self)) {
+        return NULL;
+    }
     return Py_NewRef((PyObject *)self);
 }