]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-44430: Refactor `sqlite3` threading tests (GH-26748)
authorErlend Egeberg Aasland <erlend.aasland@innova.no>
Sun, 20 Jun 2021 19:26:36 +0000 (21:26 +0200)
committerGitHub <noreply@github.com>
Sun, 20 Jun 2021 19:26:36 +0000 (20:26 +0100)
1. Rewrite ThreadTests with a _run_test() helper method that does the heavy lifting
2. Add test.support.threading_helper.reap_threads to _run_test()
3. Use _run_test() in all threading tests
4. Add test case for sqlite3.Connection.set_trace_callback
5. Add test case for sqlite3.Connection.create_collation

Lib/sqlite3/test/dbapi.py

index e8bd8c59947cdd666e453c80d9285fcd6ea70a2a..c9bcac9bb832757754a3239cf68db17dcc654f11 100644 (file)
@@ -581,158 +581,54 @@ class ThreadTests(unittest.TestCase):
     def setUp(self):
         self.con = sqlite.connect(":memory:")
         self.cur = self.con.cursor()
-        self.cur.execute("create table test(id integer primary key, name text, bin binary, ratio number, ts timestamp)")
+        self.cur.execute("create table test(name text)")
 
     def tearDown(self):
         self.cur.close()
         self.con.close()
 
-    def test_con_cursor(self):
-        def run(con, errors):
-            try:
-                cur = con.cursor()
-                errors.append("did not raise ProgrammingError")
-                return
-            except sqlite.ProgrammingError:
-                return
-            except:
-                errors.append("raised wrong exception")
-
-        errors = []
-        t = threading.Thread(target=run, kwargs={"con": self.con, "errors": errors})
-        t.start()
-        t.join()
-        if len(errors) > 0:
-            self.fail("\n".join(errors))
-
-    def test_con_commit(self):
-        def run(con, errors):
-            try:
-                con.commit()
-                errors.append("did not raise ProgrammingError")
-                return
-            except sqlite.ProgrammingError:
-                return
-            except:
-                errors.append("raised wrong exception")
-
-        errors = []
-        t = threading.Thread(target=run, kwargs={"con": self.con, "errors": errors})
-        t.start()
-        t.join()
-        if len(errors) > 0:
-            self.fail("\n".join(errors))
-
-    def test_con_rollback(self):
-        def run(con, errors):
-            try:
-                con.rollback()
-                errors.append("did not raise ProgrammingError")
-                return
-            except sqlite.ProgrammingError:
-                return
-            except:
-                errors.append("raised wrong exception")
-
-        errors = []
-        t = threading.Thread(target=run, kwargs={"con": self.con, "errors": errors})
-        t.start()
-        t.join()
-        if len(errors) > 0:
-            self.fail("\n".join(errors))
-
-    def test_con_close(self):
-        def run(con, errors):
-            try:
-                con.close()
-                errors.append("did not raise ProgrammingError")
-                return
-            except sqlite.ProgrammingError:
-                return
-            except:
-                errors.append("raised wrong exception")
-
-        errors = []
-        t = threading.Thread(target=run, kwargs={"con": self.con, "errors": errors})
-        t.start()
-        t.join()
-        if len(errors) > 0:
-            self.fail("\n".join(errors))
-
-    def test_cur_implicit_begin(self):
-        def run(cur, errors):
-            try:
-                cur.execute("insert into test(name) values ('a')")
-                errors.append("did not raise ProgrammingError")
-                return
-            except sqlite.ProgrammingError:
-                return
-            except:
-                errors.append("raised wrong exception")
-
-        errors = []
-        t = threading.Thread(target=run, kwargs={"cur": self.cur, "errors": errors})
-        t.start()
-        t.join()
-        if len(errors) > 0:
-            self.fail("\n".join(errors))
-
-    def test_cur_close(self):
-        def run(cur, errors):
-            try:
-                cur.close()
-                errors.append("did not raise ProgrammingError")
-                return
-            except sqlite.ProgrammingError:
-                return
-            except:
-                errors.append("raised wrong exception")
-
-        errors = []
-        t = threading.Thread(target=run, kwargs={"cur": self.cur, "errors": errors})
-        t.start()
-        t.join()
-        if len(errors) > 0:
-            self.fail("\n".join(errors))
-
-    def test_cur_execute(self):
-        def run(cur, errors):
+    @threading_helper.reap_threads
+    def _run_test(self, fn, *args, **kwds):
+        def run(err):
             try:
-                cur.execute("select name from test")
-                errors.append("did not raise ProgrammingError")
-                return
+                fn(*args, **kwds)
+                err.append("did not raise ProgrammingError")
             except sqlite.ProgrammingError:
-                return
+                pass
             except:
-                errors.append("raised wrong exception")
+                err.append("raised wrong exception")
 
-        errors = []
-        self.cur.execute("insert into test(name) values ('a')")
-        t = threading.Thread(target=run, kwargs={"cur": self.cur, "errors": errors})
+        err = []
+        t = threading.Thread(target=run, kwargs={"err": err})
         t.start()
         t.join()
-        if len(errors) > 0:
-            self.fail("\n".join(errors))
-
-    def test_cur_iter_next(self):
-        def run(cur, errors):
-            try:
-                row = cur.fetchone()
-                errors.append("did not raise ProgrammingError")
-                return
-            except sqlite.ProgrammingError:
-                return
-            except:
-                errors.append("raised wrong exception")
+        if err:
+            self.fail("\n".join(err))
+
+    def test_check_connection_thread(self):
+        fns = [
+            lambda: self.con.cursor(),
+            lambda: self.con.commit(),
+            lambda: self.con.rollback(),
+            lambda: self.con.close(),
+            lambda: self.con.set_trace_callback(None),
+            lambda: self.con.create_collation("foo", None),
+        ]
+        for fn in fns:
+            with self.subTest(fn=fn):
+                self._run_test(fn)
+
+    def test_check_cursor_thread(self):
+        fns = [
+            lambda: self.cur.execute("insert into test(name) values('a')"),
+            lambda: self.cur.close(),
+            lambda: self.cur.execute("select name from test"),
+            lambda: self.cur.fetchone(),
+        ]
+        for fn in fns:
+            with self.subTest(fn=fn):
+                self._run_test(fn)
 
-        errors = []
-        self.cur.execute("insert into test(name) values ('a')")
-        self.cur.execute("select name from test")
-        t = threading.Thread(target=run, kwargs={"cur": self.cur, "errors": errors})
-        t.start()
-        t.join()
-        if len(errors) > 0:
-            self.fail("\n".join(errors))
 
     @threading_helper.reap_threads
     def test_dont_check_same_thread(self):