]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-105539: Explict resource management for connection objects in sqlite3 tests (...
authorErlend E. Aasland <erlend@python.org>
Thu, 17 Aug 2023 06:45:48 +0000 (08:45 +0200)
committerGitHub <noreply@github.com>
Thu, 17 Aug 2023 06:45:48 +0000 (08:45 +0200)
- Use memory_database() helper
- Move test utility functions to util.py
- Add convenience memory database mixin
- Add check() helper for closed connection tests

Lib/test/test_sqlite3/test_backup.py
Lib/test/test_sqlite3/test_dbapi.py
Lib/test/test_sqlite3/test_dump.py
Lib/test/test_sqlite3/test_factory.py
Lib/test/test_sqlite3/test_hooks.py
Lib/test/test_sqlite3/test_regression.py
Lib/test/test_sqlite3/test_transactions.py
Lib/test/test_sqlite3/test_userfunctions.py
Lib/test/test_sqlite3/util.py [new file with mode: 0644]

index 87ab29c54d65e268303109e15fb6c006f6df77fb..4584d976bce0c6fa48a51ffccb2f8b5603fcd0e2 100644 (file)
@@ -1,6 +1,8 @@
 import sqlite3 as sqlite
 import unittest
 
+from .util import memory_database
+
 
 class BackupTests(unittest.TestCase):
     def setUp(self):
@@ -32,32 +34,32 @@ class BackupTests(unittest.TestCase):
             self.cx.backup(self.cx)
 
     def test_bad_target_closed_connection(self):
-        bck = sqlite.connect(':memory:')
-        bck.close()
-        with self.assertRaises(sqlite.ProgrammingError):
-            self.cx.backup(bck)
+        with memory_database() as bck:
+            bck.close()
+            with self.assertRaises(sqlite.ProgrammingError):
+                self.cx.backup(bck)
 
     def test_bad_source_closed_connection(self):
-        bck = sqlite.connect(':memory:')
-        source = sqlite.connect(":memory:")
-        source.close()
-        with self.assertRaises(sqlite.ProgrammingError):
-            source.backup(bck)
+        with memory_database() as bck:
+            source = sqlite.connect(":memory:")
+            source.close()
+            with self.assertRaises(sqlite.ProgrammingError):
+                source.backup(bck)
 
     def test_bad_target_in_transaction(self):
-        bck = sqlite.connect(':memory:')
-        bck.execute('CREATE TABLE bar (key INTEGER)')
-        bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)])
-        with self.assertRaises(sqlite.OperationalError) as cm:
-            self.cx.backup(bck)
+        with memory_database() as bck:
+            bck.execute('CREATE TABLE bar (key INTEGER)')
+            bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)])
+            with self.assertRaises(sqlite.OperationalError) as cm:
+                self.cx.backup(bck)
 
     def test_keyword_only_args(self):
         with self.assertRaises(TypeError):
-            with sqlite.connect(':memory:') as bck:
+            with memory_database() as bck:
                 self.cx.backup(bck, 1)
 
     def test_simple(self):
-        with sqlite.connect(':memory:') as bck:
+        with memory_database() as bck:
             self.cx.backup(bck)
             self.verify_backup(bck)
 
@@ -67,7 +69,7 @@ class BackupTests(unittest.TestCase):
         def progress(status, remaining, total):
             journal.append(status)
 
-        with sqlite.connect(':memory:') as bck:
+        with memory_database() as bck:
             self.cx.backup(bck, pages=1, progress=progress)
             self.verify_backup(bck)
 
@@ -81,7 +83,7 @@ class BackupTests(unittest.TestCase):
         def progress(status, remaining, total):
             journal.append(remaining)
 
-        with sqlite.connect(':memory:') as bck:
+        with memory_database() as bck:
             self.cx.backup(bck, progress=progress)
             self.verify_backup(bck)
 
@@ -94,7 +96,7 @@ class BackupTests(unittest.TestCase):
         def progress(status, remaining, total):
             journal.append(remaining)
 
-        with sqlite.connect(':memory:') as bck:
+        with memory_database() as bck:
             self.cx.backup(bck, pages=-1, progress=progress)
             self.verify_backup(bck)
 
@@ -103,7 +105,7 @@ class BackupTests(unittest.TestCase):
 
     def test_non_callable_progress(self):
         with self.assertRaises(TypeError) as cm:
-            with sqlite.connect(':memory:') as bck:
+            with memory_database() as bck:
                 self.cx.backup(bck, pages=1, progress='bar')
         self.assertEqual(str(cm.exception), 'progress argument must be a callable')
 
@@ -116,7 +118,7 @@ class BackupTests(unittest.TestCase):
                 self.cx.commit()
             journal.append(remaining)
 
-        with sqlite.connect(':memory:') as bck:
+        with memory_database() as bck:
             self.cx.backup(bck, pages=1, progress=progress)
             self.verify_backup(bck)
 
@@ -140,12 +142,12 @@ class BackupTests(unittest.TestCase):
         self.assertEqual(str(err.exception), 'nearly out of space')
 
     def test_database_source_name(self):
-        with sqlite.connect(':memory:') as bck:
+        with memory_database() as bck:
             self.cx.backup(bck, name='main')
-        with sqlite.connect(':memory:') as bck:
+        with memory_database() as bck:
             self.cx.backup(bck, name='temp')
         with self.assertRaises(sqlite.OperationalError) as cm:
-            with sqlite.connect(':memory:') as bck:
+            with memory_database() as bck:
                 self.cx.backup(bck, name='non-existing')
         self.assertIn("unknown database", str(cm.exception))
 
@@ -153,7 +155,7 @@ class BackupTests(unittest.TestCase):
         self.cx.execute('CREATE TABLE attached_db.foo (key INTEGER)')
         self.cx.executemany('INSERT INTO attached_db.foo (key) VALUES (?)', [(3,), (4,)])
         self.cx.commit()
-        with sqlite.connect(':memory:') as bck:
+        with memory_database() as bck:
             self.cx.backup(bck, name='attached_db')
             self.verify_backup(bck)
 
index c9a9e1353938c6c6d4706f3d1a4ec1fe95a7b21f..df3c2ea8d1dbdabcffbd53e0218286d092ee3446 100644 (file)
@@ -33,26 +33,13 @@ from test.support import (
     SHORT_TIMEOUT, check_disallow_instantiation, requires_subprocess,
     is_emscripten, is_wasi
 )
+from test.support import gc_collect
 from test.support import threading_helper
 from _testcapi import INT_MAX, ULLONG_MAX
 from os import SEEK_SET, SEEK_CUR, SEEK_END
 from test.support.os_helper import TESTFN, TESTFN_UNDECODABLE, unlink, temp_dir, FakePath
 
-
-# Helper for temporary memory databases
-def memory_database(*args, **kwargs):
-    cx = sqlite.connect(":memory:", *args, **kwargs)
-    return contextlib.closing(cx)
-
-
-# Temporarily limit a database connection parameter
-@contextlib.contextmanager
-def cx_limit(cx, category=sqlite.SQLITE_LIMIT_SQL_LENGTH, limit=128):
-    try:
-        _prev = cx.setlimit(category, limit)
-        yield limit
-    finally:
-        cx.setlimit(category, _prev)
+from .util import memory_database, cx_limit
 
 
 class ModuleTests(unittest.TestCase):
@@ -326,9 +313,9 @@ class ModuleTests(unittest.TestCase):
             self.assertEqual(exc.sqlite_errorname, "SQLITE_CONSTRAINT_CHECK")
 
     def test_disallow_instantiation(self):
-        cx = sqlite.connect(":memory:")
-        check_disallow_instantiation(self, type(cx("select 1")))
-        check_disallow_instantiation(self, sqlite.Blob)
+        with memory_database() as cx:
+            check_disallow_instantiation(self, type(cx("select 1")))
+            check_disallow_instantiation(self, sqlite.Blob)
 
     def test_complete_statement(self):
         self.assertFalse(sqlite.complete_statement("select t"))
@@ -342,6 +329,7 @@ class ConnectionTests(unittest.TestCase):
         cu = self.cx.cursor()
         cu.execute("create table test(id integer primary key, name text)")
         cu.execute("insert into test(name) values (?)", ("foo",))
+        cu.close()
 
     def tearDown(self):
         self.cx.close()
@@ -412,21 +400,22 @@ class ConnectionTests(unittest.TestCase):
 
     def test_in_transaction(self):
         # Can't use db from setUp because we want to test initial state.
-        cx = sqlite.connect(":memory:")
-        cu = cx.cursor()
-        self.assertEqual(cx.in_transaction, False)
-        cu.execute("create table transactiontest(id integer primary key, name text)")
-        self.assertEqual(cx.in_transaction, False)
-        cu.execute("insert into transactiontest(name) values (?)", ("foo",))
-        self.assertEqual(cx.in_transaction, True)
-        cu.execute("select name from transactiontest where name=?", ["foo"])
-        row = cu.fetchone()
-        self.assertEqual(cx.in_transaction, True)
-        cx.commit()
-        self.assertEqual(cx.in_transaction, False)
-        cu.execute("select name from transactiontest where name=?", ["foo"])
-        row = cu.fetchone()
-        self.assertEqual(cx.in_transaction, False)
+        with memory_database() as cx:
+            cu = cx.cursor()
+            self.assertEqual(cx.in_transaction, False)
+            cu.execute("create table transactiontest(id integer primary key, name text)")
+            self.assertEqual(cx.in_transaction, False)
+            cu.execute("insert into transactiontest(name) values (?)", ("foo",))
+            self.assertEqual(cx.in_transaction, True)
+            cu.execute("select name from transactiontest where name=?", ["foo"])
+            row = cu.fetchone()
+            self.assertEqual(cx.in_transaction, True)
+            cx.commit()
+            self.assertEqual(cx.in_transaction, False)
+            cu.execute("select name from transactiontest where name=?", ["foo"])
+            row = cu.fetchone()
+            self.assertEqual(cx.in_transaction, False)
+            cu.close()
 
     def test_in_transaction_ro(self):
         with self.assertRaises(AttributeError):
@@ -450,10 +439,9 @@ class ConnectionTests(unittest.TestCase):
                 self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc))
 
     def test_interrupt_on_closed_db(self):
-        cx = sqlite.connect(":memory:")
-        cx.close()
+        self.cx.close()
         with self.assertRaises(sqlite.ProgrammingError):
-            cx.interrupt()
+            self.cx.interrupt()
 
     def test_interrupt(self):
         self.assertIsNone(self.cx.interrupt())
@@ -521,29 +509,29 @@ class ConnectionTests(unittest.TestCase):
                     self.assertEqual(cx.isolation_level, level)
 
     def test_connection_reinit(self):
-        db = ":memory:"
-        cx = sqlite.connect(db)
-        cx.text_factory = bytes
-        cx.row_factory = sqlite.Row
-        cu = cx.cursor()
-        cu.execute("create table foo (bar)")
-        cu.executemany("insert into foo (bar) values (?)",
-                       ((str(v),) for v in range(4)))
-        cu.execute("select bar from foo")
-
-        rows = [r for r in cu.fetchmany(2)]
-        self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
-        self.assertEqual([r[0] for r in rows], [b"0", b"1"])
-
-        cx.__init__(db)
-        cx.execute("create table foo (bar)")
-        cx.executemany("insert into foo (bar) values (?)",
-                       ((v,) for v in ("a", "b", "c", "d")))
-
-        # This uses the old database, old row factory, but new text factory
-        rows = [r for r in cu.fetchall()]
-        self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
-        self.assertEqual([r[0] for r in rows], ["2", "3"])
+        with memory_database() as cx:
+            cx.text_factory = bytes
+            cx.row_factory = sqlite.Row
+            cu = cx.cursor()
+            cu.execute("create table foo (bar)")
+            cu.executemany("insert into foo (bar) values (?)",
+                           ((str(v),) for v in range(4)))
+            cu.execute("select bar from foo")
+
+            rows = [r for r in cu.fetchmany(2)]
+            self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
+            self.assertEqual([r[0] for r in rows], [b"0", b"1"])
+
+            cx.__init__(":memory:")
+            cx.execute("create table foo (bar)")
+            cx.executemany("insert into foo (bar) values (?)",
+                           ((v,) for v in ("a", "b", "c", "d")))
+
+            # This uses the old database, old row factory, but new text factory
+            rows = [r for r in cu.fetchall()]
+            self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
+            self.assertEqual([r[0] for r in rows], ["2", "3"])
+            cu.close()
 
     def test_connection_bad_reinit(self):
         cx = sqlite.connect(":memory:")
@@ -591,11 +579,11 @@ class ConnectionTests(unittest.TestCase):
             "parameters in Python 3.15."
         )
         with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
-            sqlite.connect(":memory:", 1.0)
+            cx = sqlite.connect(":memory:", 1.0)
+            cx.close()
         self.assertEqual(cm.filename, __file__)
 
 
-
 class UninitialisedConnectionTests(unittest.TestCase):
     def setUp(self):
         self.cx = sqlite.Connection.__new__(sqlite.Connection)
@@ -1571,12 +1559,12 @@ class ThreadTests(unittest.TestCase):
             except sqlite.Error:
                 err.append("multi-threading not allowed")
 
-        con = sqlite.connect(":memory:", check_same_thread=False)
-        err = []
-        t = threading.Thread(target=run, kwargs={"con": con, "err": err})
-        t.start()
-        t.join()
-        self.assertEqual(len(err), 0, "\n".join(err))
+        with memory_database(check_same_thread=False) as con:
+            err = []
+            t = threading.Thread(target=run, kwargs={"con": con, "err": err})
+            t.start()
+            t.join()
+            self.assertEqual(len(err), 0, "\n".join(err))
 
 
 class ConstructorTests(unittest.TestCase):
@@ -1602,9 +1590,16 @@ class ConstructorTests(unittest.TestCase):
         b = sqlite.Binary(b"\0'")
 
 class ExtensionTests(unittest.TestCase):
+    def setUp(self):
+        self.con = sqlite.connect(":memory:")
+        self.cur = self.con.cursor()
+
+    def tearDown(self):
+        self.cur.close()
+        self.con.close()
+
     def test_script_string_sql(self):
-        con = sqlite.connect(":memory:")
-        cur = con.cursor()
+        cur = self.cur
         cur.executescript("""
             -- bla bla
             /* a stupid comment */
@@ -1616,40 +1611,40 @@ class ExtensionTests(unittest.TestCase):
         self.assertEqual(res, 5)
 
     def test_script_syntax_error(self):
-        con = sqlite.connect(":memory:")
-        cur = con.cursor()
         with self.assertRaises(sqlite.OperationalError):
-            cur.executescript("create table test(x); asdf; create table test2(x)")
+            self.cur.executescript("""
+                CREATE TABLE test(x);
+                asdf;
+                CREATE TABLE test2(x)
+            """)
 
     def test_script_error_normal(self):
-        con = sqlite.connect(":memory:")
-        cur = con.cursor()
         with self.assertRaises(sqlite.OperationalError):
-            cur.executescript("create table test(sadfsadfdsa); select foo from hurz;")
+            self.cur.executescript("""
+                CREATE TABLE test(sadfsadfdsa);
+                SELECT foo FROM hurz;
+            """)
 
     def test_cursor_executescript_as_bytes(self):
-        con = sqlite.connect(":memory:")
-        cur = con.cursor()
         with self.assertRaises(TypeError):
-            cur.executescript(b"create table test(foo); insert into test(foo) values (5);")
+            self.cur.executescript(b"""
+                CREATE TABLE test(foo);
+                INSERT INTO test(foo) VALUES (5);
+            """)
 
     def test_cursor_executescript_with_null_characters(self):
-        con = sqlite.connect(":memory:")
-        cur = con.cursor()
         with self.assertRaises(ValueError):
-            cur.executescript("""
-                create table a(i);\0
-                insert into a(i) values (5);
-                """)
+            self.cur.executescript("""
+                CREATE TABLE a(i);\0
+                INSERT INTO a(i) VALUES (5);
+            """)
 
     def test_cursor_executescript_with_surrogates(self):
-        con = sqlite.connect(":memory:")
-        cur = con.cursor()
         with self.assertRaises(UnicodeEncodeError):
-            cur.executescript("""
-                create table a(s);
-                insert into a(s) values ('\ud8ff');
-                """)
+            self.cur.executescript("""
+                CREATE TABLE a(s);
+                INSERT INTO a(s) VALUES ('\ud8ff');
+            """)
 
     def test_cursor_executescript_too_large_script(self):
         msg = "query string is too large"
@@ -1659,19 +1654,18 @@ class ExtensionTests(unittest.TestCase):
                 cx.executescript("select 'too large'".ljust(lim+1))
 
     def test_cursor_executescript_tx_control(self):
-        con = sqlite.connect(":memory:")
+        con = self.con
         con.execute("begin")
         self.assertTrue(con.in_transaction)
         con.executescript("select 1")
         self.assertFalse(con.in_transaction)
 
     def test_connection_execute(self):
-        con = sqlite.connect(":memory:")
-        result = con.execute("select 5").fetchone()[0]
+        result = self.con.execute("select 5").fetchone()[0]
         self.assertEqual(result, 5, "Basic test of Connection.execute")
 
     def test_connection_executemany(self):
-        con = sqlite.connect(":memory:")
+        con = self.con
         con.execute("create table test(foo)")
         con.executemany("insert into test(foo) values (?)", [(3,), (4,)])
         result = con.execute("select foo from test order by foo").fetchall()
@@ -1679,47 +1673,44 @@ class ExtensionTests(unittest.TestCase):
         self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany")
 
     def test_connection_executescript(self):
-        con = sqlite.connect(":memory:")
-        con.executescript("create table test(foo); insert into test(foo) values (5);")
+        con = self.con
+        con.executescript("""
+            CREATE TABLE test(foo);
+            INSERT INTO test(foo) VALUES (5);
+        """)
         result = con.execute("select foo from test").fetchone()[0]
         self.assertEqual(result, 5, "Basic test of Connection.executescript")
 
+
 class ClosedConTests(unittest.TestCase):
+    def check(self, fn, *args, **kwds):
+        regex = "Cannot operate on a closed database."
+        with self.assertRaisesRegex(sqlite.ProgrammingError, regex):
+            fn(*args, **kwds)
+
+    def setUp(self):
+        self.con = sqlite.connect(":memory:")
+        self.cur = self.con.cursor()
+        self.con.close()
+
     def test_closed_con_cursor(self):
-        con = sqlite.connect(":memory:")
-        con.close()
-        with self.assertRaises(sqlite.ProgrammingError):
-            cur = con.cursor()
+        self.check(self.con.cursor)
 
     def test_closed_con_commit(self):
-        con = sqlite.connect(":memory:")
-        con.close()
-        with self.assertRaises(sqlite.ProgrammingError):
-            con.commit()
+        self.check(self.con.commit)
 
     def test_closed_con_rollback(self):
-        con = sqlite.connect(":memory:")
-        con.close()
-        with self.assertRaises(sqlite.ProgrammingError):
-            con.rollback()
+        self.check(self.con.rollback)
 
     def test_closed_cur_execute(self):
-        con = sqlite.connect(":memory:")
-        cur = con.cursor()
-        con.close()
-        with self.assertRaises(sqlite.ProgrammingError):
-            cur.execute("select 4")
+        self.check(self.cur.execute, "select 4")
 
     def test_closed_create_function(self):
-        con = sqlite.connect(":memory:")
-        con.close()
-        def f(x): return 17
-        with self.assertRaises(sqlite.ProgrammingError):
-            con.create_function("foo", 1, f)
+        def f(x):
+            return 17
+        self.check(self.con.create_function, "foo", 1, f)
 
     def test_closed_create_aggregate(self):
-        con = sqlite.connect(":memory:")
-        con.close()
         class Agg:
             def __init__(self):
                 pass
@@ -1727,29 +1718,21 @@ class ClosedConTests(unittest.TestCase):
                 pass
             def finalize(self):
                 return 17
-        with self.assertRaises(sqlite.ProgrammingError):
-            con.create_aggregate("foo", 1, Agg)
+        self.check(self.con.create_aggregate, "foo", 1, Agg)
 
     def test_closed_set_authorizer(self):
-        con = sqlite.connect(":memory:")
-        con.close()
         def authorizer(*args):
             return sqlite.DENY
-        with self.assertRaises(sqlite.ProgrammingError):
-            con.set_authorizer(authorizer)
+        self.check(self.con.set_authorizer, authorizer)
 
     def test_closed_set_progress_callback(self):
-        con = sqlite.connect(":memory:")
-        con.close()
-        def progress(): pass
-        with self.assertRaises(sqlite.ProgrammingError):
-            con.set_progress_handler(progress, 100)
+        def progress():
+            pass
+        self.check(self.con.set_progress_handler, progress, 100)
 
     def test_closed_call(self):
-        con = sqlite.connect(":memory:")
-        con.close()
-        with self.assertRaises(sqlite.ProgrammingError):
-            con()
+        self.check(self.con)
+
 
 class ClosedCurTests(unittest.TestCase):
     def test_closed(self):
index d0c24b9c60e61392902fcafb6dd1ab03f4a7d7be..5f6811fb5cc0a52ac19343aee17c67352224b0c9 100644 (file)
@@ -2,16 +2,12 @@
 
 import unittest
 import sqlite3 as sqlite
-from .test_dbapi import memory_database
 
+from .util import memory_database
+from .util import MemoryDatabaseMixin
 
-class DumpTests(unittest.TestCase):
-    def setUp(self):
-        self.cx = sqlite.connect(":memory:")
-        self.cu = self.cx.cursor()
 
-    def tearDown(self):
-        self.cx.close()
+class DumpTests(MemoryDatabaseMixin, unittest.TestCase):
 
     def test_table_dump(self):
         expected_sqls = [
index d63589483e1042a9204ca030ea2c5492624b8fc7..a7c4417862aff747de08d937fb555e522e97deb5 100644 (file)
@@ -24,6 +24,9 @@ import unittest
 import sqlite3 as sqlite
 from collections.abc import Sequence
 
+from .util import memory_database
+from .util import MemoryDatabaseMixin
+
 
 def dict_factory(cursor, row):
     d = {}
@@ -45,10 +48,12 @@ class ConnectionFactoryTests(unittest.TestCase):
             def __init__(self, *args, **kwargs):
                 sqlite.Connection.__init__(self, *args, **kwargs)
 
-        for factory in DefectFactory, OkFactory:
-            with self.subTest(factory=factory):
-                con = sqlite.connect(":memory:", factory=factory)
-                self.assertIsInstance(con, factory)
+        with memory_database(factory=OkFactory) as con:
+            self.assertIsInstance(con, OkFactory)
+        regex = "Base Connection.__init__ not called."
+        with self.assertRaisesRegex(sqlite.ProgrammingError, regex):
+            with memory_database(factory=DefectFactory) as con:
+                self.assertIsInstance(con, DefectFactory)
 
     def test_connection_factory_relayed_call(self):
         # gh-95132: keyword args must not be passed as positional args
@@ -57,9 +62,9 @@ class ConnectionFactoryTests(unittest.TestCase):
                 kwargs["isolation_level"] = None
                 super(Factory, self).__init__(*args, **kwargs)
 
-        con = sqlite.connect(":memory:", factory=Factory)
-        self.assertIsNone(con.isolation_level)
-        self.assertIsInstance(con, Factory)
+        with memory_database(factory=Factory) as con:
+            self.assertIsNone(con.isolation_level)
+            self.assertIsInstance(con, Factory)
 
     def test_connection_factory_as_positional_arg(self):
         class Factory(sqlite.Connection):
@@ -74,18 +79,13 @@ class ConnectionFactoryTests(unittest.TestCase):
             r"parameters in Python 3.15."
         )
         with self.assertWarnsRegex(DeprecationWarning, regex) as cm:
-            con = sqlite.connect(":memory:", 5.0, 0, None, True, Factory)
+            with memory_database(5.0, 0, None, True, Factory) as con:
+                self.assertIsNone(con.isolation_level)
+                self.assertIsInstance(con, Factory)
         self.assertEqual(cm.filename, __file__)
-        self.assertIsNone(con.isolation_level)
-        self.assertIsInstance(con, Factory)
 
 
-class CursorFactoryTests(unittest.TestCase):
-    def setUp(self):
-        self.con = sqlite.connect(":memory:")
-
-    def tearDown(self):
-        self.con.close()
+class CursorFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
 
     def test_is_instance(self):
         cur = self.con.cursor()
@@ -103,9 +103,8 @@ class CursorFactoryTests(unittest.TestCase):
         # invalid callable returning non-cursor
         self.assertRaises(TypeError, self.con.cursor, lambda con: None)
 
-class RowFactoryTestsBackwardsCompat(unittest.TestCase):
-    def setUp(self):
-        self.con = sqlite.connect(":memory:")
+
+class RowFactoryTestsBackwardsCompat(MemoryDatabaseMixin, unittest.TestCase):
 
     def test_is_produced_by_factory(self):
         cur = self.con.cursor(factory=MyCursor)
@@ -114,12 +113,8 @@ class RowFactoryTestsBackwardsCompat(unittest.TestCase):
         self.assertIsInstance(row, dict)
         cur.close()
 
-    def tearDown(self):
-        self.con.close()
 
-class RowFactoryTests(unittest.TestCase):
-    def setUp(self):
-        self.con = sqlite.connect(":memory:")
+class RowFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
 
     def test_custom_factory(self):
         self.con.row_factory = lambda cur, row: list(row)
@@ -265,12 +260,8 @@ class RowFactoryTests(unittest.TestCase):
         self.assertRaises(TypeError, self.con.cursor, FakeCursor)
         self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ())
 
-    def tearDown(self):
-        self.con.close()
 
-class TextFactoryTests(unittest.TestCase):
-    def setUp(self):
-        self.con = sqlite.connect(":memory:")
+class TextFactoryTests(MemoryDatabaseMixin, unittest.TestCase):
 
     def test_unicode(self):
         austria = "Österreich"
@@ -291,15 +282,17 @@ class TextFactoryTests(unittest.TestCase):
         self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
         self.assertTrue(row[0].endswith("reich"), "column must contain original data")
 
-    def tearDown(self):
-        self.con.close()
 
 class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
+
     def setUp(self):
         self.con = sqlite.connect(":memory:")
         self.con.execute("create table test (value text)")
         self.con.execute("insert into test (value) values (?)", ("a\x00b",))
 
+    def tearDown(self):
+        self.con.close()
+
     def test_string(self):
         # text_factory defaults to str
         row = self.con.execute("select value from test").fetchone()
@@ -325,9 +318,6 @@ class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
         self.assertIs(type(row[0]), bytes)
         self.assertEqual(row[0], b"a\x00b")
 
-    def tearDown(self):
-        self.con.close()
-
 
 if __name__ == "__main__":
     unittest.main()
index 89230c08cc91430c328df59b2ea749d1936fe49e..33f0af99532a104fcb51d8ed511c0bcbd95312fd 100644 (file)
@@ -26,34 +26,31 @@ import unittest
 
 from test.support.os_helper import TESTFN, unlink
 
-from test.test_sqlite3.test_dbapi import memory_database, cx_limit
-from test.test_sqlite3.test_userfunctions import with_tracebacks
+from .util import memory_database, cx_limit, with_tracebacks
+from .util import MemoryDatabaseMixin
 
 
-class CollationTests(unittest.TestCase):
+class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
+
     def test_create_collation_not_string(self):
-        con = sqlite.connect(":memory:")
         with self.assertRaises(TypeError):
-            con.create_collation(None, lambda x, y: (x > y) - (x < y))
+            self.con.create_collation(None, lambda x, y: (x > y) - (x < y))
 
     def test_create_collation_not_callable(self):
-        con = sqlite.connect(":memory:")
         with self.assertRaises(TypeError) as cm:
-            con.create_collation("X", 42)
+            self.con.create_collation("X", 42)
         self.assertEqual(str(cm.exception), 'parameter must be callable')
 
     def test_create_collation_not_ascii(self):
-        con = sqlite.connect(":memory:")
-        con.create_collation("collä", lambda x, y: (x > y) - (x < y))
+        self.con.create_collation("collä", lambda x, y: (x > y) - (x < y))
 
     def test_create_collation_bad_upper(self):
         class BadUpperStr(str):
             def upper(self):
                 return None
-        con = sqlite.connect(":memory:")
         mycoll = lambda x, y: -((x > y) - (x < y))
-        con.create_collation(BadUpperStr("mycoll"), mycoll)
-        result = con.execute("""
+        self.con.create_collation(BadUpperStr("mycoll"), mycoll)
+        result = self.con.execute("""
             select x from (
             select 'a' as x
             union
@@ -68,8 +65,7 @@ class CollationTests(unittest.TestCase):
             # reverse order
             return -((x > y) - (x < y))
 
-        con = sqlite.connect(":memory:")
-        con.create_collation("mycoll", mycoll)
+        self.con.create_collation("mycoll", mycoll)
         sql = """
             select x from (
             select 'a' as x
@@ -79,21 +75,20 @@ class CollationTests(unittest.TestCase):
             select 'c' as x
             ) order by x collate mycoll
             """
-        result = con.execute(sql).fetchall()
+        result = self.con.execute(sql).fetchall()
         self.assertEqual(result, [('c',), ('b',), ('a',)],
                          msg='the expected order was not returned')
 
-        con.create_collation("mycoll", None)
+        self.con.create_collation("mycoll", None)
         with self.assertRaises(sqlite.OperationalError) as cm:
-            result = con.execute(sql).fetchall()
+            result = self.con.execute(sql).fetchall()
         self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
 
     def test_collation_returns_large_integer(self):
         def mycoll(x, y):
             # reverse order
             return -((x > y) - (x < y)) * 2**32
-        con = sqlite.connect(":memory:")
-        con.create_collation("mycoll", mycoll)
+        self.con.create_collation("mycoll", mycoll)
         sql = """
             select x from (
             select 'a' as x
@@ -103,7 +98,7 @@ class CollationTests(unittest.TestCase):
             select 'c' as x
             ) order by x collate mycoll
             """
-        result = con.execute(sql).fetchall()
+        result = self.con.execute(sql).fetchall()
         self.assertEqual(result, [('c',), ('b',), ('a',)],
                          msg="the expected order was not returned")
 
@@ -112,7 +107,7 @@ class CollationTests(unittest.TestCase):
         Register two different collation functions under the same name.
         Verify that the last one is actually used.
         """
-        con = sqlite.connect(":memory:")
+        con = self.con
         con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
         con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
         result = con.execute("""
@@ -126,25 +121,26 @@ class CollationTests(unittest.TestCase):
         Register a collation, then deregister it. Make sure an error is raised if we try
         to use it.
         """
-        con = sqlite.connect(":memory:")
+        con = self.con
         con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
         con.create_collation("mycoll", None)
         with self.assertRaises(sqlite.OperationalError) as cm:
             con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
         self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
 
-class ProgressTests(unittest.TestCase):
+
+class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
+
     def test_progress_handler_used(self):
         """
         Test that the progress handler is invoked once it is set.
         """
-        con = sqlite.connect(":memory:")
         progress_calls = []
         def progress():
             progress_calls.append(None)
             return 0
-        con.set_progress_handler(progress, 1)
-        con.execute("""
+        self.con.set_progress_handler(progress, 1)
+        self.con.execute("""
             create table foo(a, b)
             """)
         self.assertTrue(progress_calls)
@@ -153,7 +149,7 @@ class ProgressTests(unittest.TestCase):
         """
         Test that the opcode argument is respected.
         """
-        con = sqlite.connect(":memory:")
+        con = self.con
         progress_calls = []
         def progress():
             progress_calls.append(None)
@@ -176,11 +172,10 @@ class ProgressTests(unittest.TestCase):
         """
         Test that returning a non-zero value stops the operation in progress.
         """
-        con = sqlite.connect(":memory:")
         def progress():
             return 1
-        con.set_progress_handler(progress, 1)
-        curs = con.cursor()
+        self.con.set_progress_handler(progress, 1)
+        curs = self.con.cursor()
         self.assertRaises(
             sqlite.OperationalError,
             curs.execute,
@@ -190,7 +185,7 @@ class ProgressTests(unittest.TestCase):
         """
         Test that setting the progress handler to None clears the previously set handler.
         """
-        con = sqlite.connect(":memory:")
+        con = self.con
         action = 0
         def progress():
             nonlocal action
@@ -203,31 +198,30 @@ class ProgressTests(unittest.TestCase):
 
     @with_tracebacks(ZeroDivisionError, name="bad_progress")
     def test_error_in_progress_handler(self):
-        con = sqlite.connect(":memory:")
         def bad_progress():
             1 / 0
-        con.set_progress_handler(bad_progress, 1)
+        self.con.set_progress_handler(bad_progress, 1)
         with self.assertRaises(sqlite.OperationalError):
-            con.execute("""
+            self.con.execute("""
                 create table foo(a, b)
                 """)
 
     @with_tracebacks(ZeroDivisionError, name="bad_progress")
     def test_error_in_progress_handler_result(self):
-        con = sqlite.connect(":memory:")
         class BadBool:
             def __bool__(self):
                 1 / 0
         def bad_progress():
             return BadBool()
-        con.set_progress_handler(bad_progress, 1)
+        self.con.set_progress_handler(bad_progress, 1)
         with self.assertRaises(sqlite.OperationalError):
-            con.execute("""
+            self.con.execute("""
                 create table foo(a, b)
                 """)
 
 
-class TraceCallbackTests(unittest.TestCase):
+class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
+
     @contextlib.contextmanager
     def check_stmt_trace(self, cx, expected):
         try:
@@ -242,12 +236,11 @@ class TraceCallbackTests(unittest.TestCase):
         """
         Test that the trace callback is invoked once it is set.
         """
-        con = sqlite.connect(":memory:")
         traced_statements = []
         def trace(statement):
             traced_statements.append(statement)
-        con.set_trace_callback(trace)
-        con.execute("create table foo(a, b)")
+        self.con.set_trace_callback(trace)
+        self.con.execute("create table foo(a, b)")
         self.assertTrue(traced_statements)
         self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
 
@@ -255,7 +248,7 @@ class TraceCallbackTests(unittest.TestCase):
         """
         Test that setting the trace callback to None clears the previously set callback.
         """
-        con = sqlite.connect(":memory:")
+        con = self.con
         traced_statements = []
         def trace(statement):
             traced_statements.append(statement)
@@ -269,7 +262,7 @@ class TraceCallbackTests(unittest.TestCase):
         Test that the statement can contain unicode literals.
         """
         unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
-        con = sqlite.connect(":memory:")
+        con = self.con
         traced_statements = []
         def trace(statement):
             traced_statements.append(statement)
index 7e8221e7227e6e0435ef800410bc31e777ad7c96..db4e13222da9da18f40dbdb664c1c16ad15a33eb 100644 (file)
@@ -28,15 +28,12 @@ import functools
 
 from test import support
 from unittest.mock import patch
-from test.test_sqlite3.test_dbapi import memory_database, cx_limit
 
+from .util import memory_database, cx_limit
+from .util import MemoryDatabaseMixin
 
-class RegressionTests(unittest.TestCase):
-    def setUp(self):
-        self.con = sqlite.connect(":memory:")
 
-    def tearDown(self):
-        self.con.close()
+class RegressionTests(MemoryDatabaseMixin, unittest.TestCase):
 
     def test_pragma_user_version(self):
         # This used to crash pysqlite because this pragma command returns NULL for the column name
@@ -45,28 +42,24 @@ class RegressionTests(unittest.TestCase):
 
     def test_pragma_schema_version(self):
         # This still crashed pysqlite <= 2.2.1
-        con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES)
-        try:
+        with memory_database(detect_types=sqlite.PARSE_COLNAMES) as con:
             cur = self.con.cursor()
             cur.execute("pragma schema_version")
-        finally:
-            cur.close()
-            con.close()
 
     def test_statement_reset(self):
         # pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are
         # reset before a rollback, but only those that are still in the
         # statement cache. The others are not accessible from the connection object.
-        con = sqlite.connect(":memory:", cached_statements=5)
-        cursors = [con.cursor() for x in range(5)]
-        cursors[0].execute("create table test(x)")
-        for i in range(10):
-            cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)])
+        with memory_database(cached_statements=5) as con:
+            cursors = [con.cursor() for x in range(5)]
+            cursors[0].execute("create table test(x)")
+            for i in range(10):
+                cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)])
 
-        for i in range(5):
-            cursors[i].execute(" " * i + "select x from test")
+            for i in range(5):
+                cursors[i].execute(" " * i + "select x from test")
 
-        con.rollback()
+            con.rollback()
 
     def test_column_name_with_spaces(self):
         cur = self.con.cursor()
@@ -81,17 +74,15 @@ class RegressionTests(unittest.TestCase):
         # cache when closing the database. statements that were still
         # referenced in cursors weren't closed and could provoke "
         # "OperationalError: Unable to close due to unfinalised statements".
-        con = sqlite.connect(":memory:")
         cursors = []
         # default statement cache size is 100
         for i in range(105):
-            cur = con.cursor()
+            cur = self.con.cursor()
             cursors.append(cur)
             cur.execute("select 1 x union select " + str(i))
-        con.close()
 
     def test_on_conflict_rollback(self):
-        con = sqlite.connect(":memory:")
+        con = self.con
         con.execute("create table foo(x, unique(x) on conflict rollback)")
         con.execute("insert into foo(x) values (1)")
         try:
@@ -126,16 +117,16 @@ class RegressionTests(unittest.TestCase):
         a statement. This test exhibits the problem.
         """
         SELECT = "select * from foo"
-        con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES)
-        cur = con.cursor()
-        cur.execute("create table foo(bar timestamp)")
-        with self.assertWarnsRegex(DeprecationWarning, "adapter"):
-            cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),))
-        cur.execute(SELECT)
-        cur.execute("drop table foo")
-        cur.execute("create table foo(bar integer)")
-        cur.execute("insert into foo(bar) values (5)")
-        cur.execute(SELECT)
+        with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
+            cur = con.cursor()
+            cur.execute("create table foo(bar timestamp)")
+            with self.assertWarnsRegex(DeprecationWarning, "adapter"):
+                cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),))
+            cur.execute(SELECT)
+            cur.execute("drop table foo")
+            cur.execute("create table foo(bar integer)")
+            cur.execute("insert into foo(bar) values (5)")
+            cur.execute(SELECT)
 
     def test_bind_mutating_list(self):
         # Issue41662: Crash when mutate a list of parameters during iteration.
@@ -144,11 +135,11 @@ class RegressionTests(unittest.TestCase):
                 parameters.clear()
                 return "..."
         parameters = [X(), 0]
-        con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES)
-        con.execute("create table foo(bar X, baz integer)")
-        # Should not crash
-        with self.assertRaises(IndexError):
-            con.execute("insert into foo(bar, baz) values (?, ?)", parameters)
+        with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
+            con.execute("create table foo(bar X, baz integer)")
+            # Should not crash
+            with self.assertRaises(IndexError):
+                con.execute("insert into foo(bar, baz) values (?, ?)", parameters)
 
     def test_error_msg_decode_error(self):
         # When porting the module to Python 3.0, the error message about
@@ -173,7 +164,7 @@ class RegressionTests(unittest.TestCase):
             def __del__(self):
                 con.isolation_level = ""
 
-        con = sqlite.connect(":memory:")
+        con = self.con
         con.isolation_level = None
         for level in "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE":
             with self.subTest(level=level):
@@ -204,8 +195,7 @@ class RegressionTests(unittest.TestCase):
             def __init__(self, con):
                 pass
 
-        con = sqlite.connect(":memory:")
-        cur = Cursor(con)
+        cur = Cursor(self.con)
         with self.assertRaises(sqlite.ProgrammingError):
             cur.execute("select 4+5").fetchall()
         with self.assertRaisesRegex(sqlite.ProgrammingError,
@@ -238,7 +228,9 @@ class RegressionTests(unittest.TestCase):
         2.5.3 introduced a regression so that these could no longer
         be created.
         """
-        con = sqlite.connect(":memory:", isolation_level=None)
+        with memory_database(isolation_level=None) as con:
+            self.assertIsNone(con.isolation_level)
+            self.assertFalse(con.in_transaction)
 
     def test_pragma_autocommit(self):
         """
@@ -273,9 +265,7 @@ class RegressionTests(unittest.TestCase):
         Recursively using a cursor, such as when reusing it from a generator led to segfaults.
         Now we catch recursive cursor usage and raise a ProgrammingError.
         """
-        con = sqlite.connect(":memory:")
-
-        cur = con.cursor()
+        cur = self.con.cursor()
         cur.execute("create table a (bar)")
         cur.execute("create table b (baz)")
 
@@ -295,29 +285,30 @@ class RegressionTests(unittest.TestCase):
         since the microsecond string "456" actually represents "456000".
         """
 
-        con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES)
-        cur = con.cursor()
-        cur.execute("CREATE TABLE t (x TIMESTAMP)")
+        with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con:
+            cur = con.cursor()
+            cur.execute("CREATE TABLE t (x TIMESTAMP)")
 
-        # Microseconds should be 456000
-        cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')")
+            # Microseconds should be 456000
+            cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')")
 
-        # Microseconds should be truncated to 123456
-        cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')")
+            # Microseconds should be truncated to 123456
+            cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')")
 
-        cur.execute("SELECT * FROM t")
-        with self.assertWarnsRegex(DeprecationWarning, "converter"):
-            values = [x[0] for x in cur.fetchall()]
+            cur.execute("SELECT * FROM t")
+            with self.assertWarnsRegex(DeprecationWarning, "converter"):
+                values = [x[0] for x in cur.fetchall()]
 
-        self.assertEqual(values, [
-            datetime.datetime(2012, 4, 4, 15, 6, 0, 456000),
-            datetime.datetime(2012, 4, 4, 15, 6, 0, 123456),
-        ])
+            self.assertEqual(values, [
+                datetime.datetime(2012, 4, 4, 15, 6, 0, 456000),
+                datetime.datetime(2012, 4, 4, 15, 6, 0, 123456),
+            ])
 
     def test_invalid_isolation_level_type(self):
         # isolation level is a string, not an integer
-        self.assertRaises(TypeError,
-                          sqlite.connect, ":memory:", isolation_level=123)
+        regex = "isolation_level must be str or None"
+        with self.assertRaisesRegex(TypeError, regex):
+            memory_database(isolation_level=123).__enter__()
 
 
     def test_null_character(self):
@@ -333,7 +324,7 @@ class RegressionTests(unittest.TestCase):
                                        cur.execute, query)
 
     def test_surrogates(self):
-        con = sqlite.connect(":memory:")
+        con = self.con
         self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
         self.assertRaises(UnicodeEncodeError, con, "select '\udcff'")
         cur = con.cursor()
@@ -359,7 +350,7 @@ class RegressionTests(unittest.TestCase):
         to return rows multiple times when fetched from cursors
         after commit. See issues 10513 and 23129 for details.
         """
-        con = sqlite.connect(":memory:")
+        con = self.con
         con.executescript("""
         create table t(c);
         create table t2(c);
@@ -391,10 +382,9 @@ class RegressionTests(unittest.TestCase):
         """
         def callback(*args):
             pass
-        con = sqlite.connect(":memory:")
-        cur = sqlite.Cursor(con)
+        cur = sqlite.Cursor(self.con)
         ref = weakref.ref(cur, callback)
-        cur.__init__(con)
+        cur.__init__(self.con)
         del cur
         # The interpreter shouldn't crash when ref is collected.
         del ref
@@ -425,6 +415,7 @@ class RegressionTests(unittest.TestCase):
 
     def test_table_lock_cursor_replace_stmt(self):
         with memory_database() as con:
+            con = self.con
             cur = con.cursor()
             cur.execute("create table t(t)")
             cur.executemany("insert into t values(?)",
index 5d211dd47b0b6bd98e76dd38cdc9760d799841c8..b7b231d2225852461de2c510e7f9b57278ad2dcb 100644 (file)
@@ -28,7 +28,8 @@ from test.support import LOOPBACK_TIMEOUT
 from test.support.os_helper import TESTFN, unlink
 from test.support.script_helper import assert_python_ok
 
-from test.test_sqlite3.test_dbapi import memory_database
+from .util import memory_database
+from .util import MemoryDatabaseMixin
 
 
 TIMEOUT = LOOPBACK_TIMEOUT / 10
@@ -132,14 +133,14 @@ class TransactionTests(unittest.TestCase):
 
     def test_rollback_cursor_consistency(self):
         """Check that cursors behave correctly after rollback."""
-        con = sqlite.connect(":memory:")
-        cur = con.cursor()
-        cur.execute("create table test(x)")
-        cur.execute("insert into test(x) values (5)")
-        cur.execute("select 1 union select 2 union select 3")
+        with memory_database() as con:
+            cur = con.cursor()
+            cur.execute("create table test(x)")
+            cur.execute("insert into test(x) values (5)")
+            cur.execute("select 1 union select 2 union select 3")
 
-        con.rollback()
-        self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)])
+            con.rollback()
+            self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)])
 
     def test_multiple_cursors_and_iternext(self):
         # gh-94028: statements are cleared and reset in cursor iternext.
@@ -218,10 +219,7 @@ class RollbackTests(unittest.TestCase):
 
 
 
-class SpecialCommandTests(unittest.TestCase):
-    def setUp(self):
-        self.con = sqlite.connect(":memory:")
-        self.cur = self.con.cursor()
+class SpecialCommandTests(MemoryDatabaseMixin, unittest.TestCase):
 
     def test_drop_table(self):
         self.cur.execute("create table test(i)")
@@ -233,14 +231,8 @@ class SpecialCommandTests(unittest.TestCase):
         self.cur.execute("insert into test(i) values (5)")
         self.cur.execute("pragma count_changes=1")
 
-    def tearDown(self):
-        self.cur.close()
-        self.con.close()
-
 
-class TransactionalDDL(unittest.TestCase):
-    def setUp(self):
-        self.con = sqlite.connect(":memory:")
+class TransactionalDDL(MemoryDatabaseMixin, unittest.TestCase):
 
     def test_ddl_does_not_autostart_transaction(self):
         # For backwards compatibility reasons, DDL statements should not
@@ -268,9 +260,6 @@ class TransactionalDDL(unittest.TestCase):
         with self.assertRaises(sqlite.OperationalError):
             self.con.execute("select * from test")
 
-    def tearDown(self):
-        self.con.close()
-
 
 class IsolationLevelFromInit(unittest.TestCase):
     CREATE = "create table t(t)"
index 05c2fb3aa6f8f2b472166698d11c5d9435efbedc..5d12636dcd2b63e3e5a2b46d07983ab593bbce30 100644 (file)
 #    misrepresented as being the original software.
 # 3. This notice may not be removed or altered from any source distribution.
 
-import contextlib
-import functools
-import io
-import re
 import sys
 import unittest
 import sqlite3 as sqlite
 
 from unittest.mock import Mock, patch
-from test.support import bigmemtest, catch_unraisable_exception, gc_collect
-
-from test.test_sqlite3.test_dbapi import cx_limit
-
-
-def with_tracebacks(exc, regex="", name=""):
-    """Convenience decorator for testing callback tracebacks."""
-    def decorator(func):
-        _regex = re.compile(regex) if regex else None
-        @functools.wraps(func)
-        def wrapper(self, *args, **kwargs):
-            with catch_unraisable_exception() as cm:
-                # First, run the test with traceback enabled.
-                with check_tracebacks(self, cm, exc, _regex, name):
-                    func(self, *args, **kwargs)
-
-            # Then run the test with traceback disabled.
-            func(self, *args, **kwargs)
-        return wrapper
-    return decorator
-
-
-@contextlib.contextmanager
-def check_tracebacks(self, cm, exc, regex, obj_name):
-    """Convenience context manager for testing callback tracebacks."""
-    sqlite.enable_callback_tracebacks(True)
-    try:
-        buf = io.StringIO()
-        with contextlib.redirect_stderr(buf):
-            yield
-
-        self.assertEqual(cm.unraisable.exc_type, exc)
-        if regex:
-            msg = str(cm.unraisable.exc_value)
-            self.assertIsNotNone(regex.search(msg))
-        if obj_name:
-            self.assertEqual(cm.unraisable.object.__name__, obj_name)
-    finally:
-        sqlite.enable_callback_tracebacks(False)
+from test.support import bigmemtest, gc_collect
+
+from .util import cx_limit, memory_database
+from .util import with_tracebacks, check_tracebacks
 
 
 def func_returntext():
@@ -405,19 +366,19 @@ class FunctionTests(unittest.TestCase):
     def test_function_destructor_via_gc(self):
         # See bpo-44304: The destructor of the user function can
         # crash if is called without the GIL from the gc functions
-        dest = sqlite.connect(':memory:')
         def md5sum(t):
             return
 
-        dest.create_function("md5", 1, md5sum)
-        x = dest("create table lang (name, first_appeared)")
-        del md5sum, dest
+        with memory_database() as dest:
+            dest.create_function("md5", 1, md5sum)
+            x = dest("create table lang (name, first_appeared)")
+            del md5sum, dest
 
-        y = [x]
-        y.append(y)
+            y = [x]
+            y.append(y)
 
-        del x,y
-        gc_collect()
+            del x,y
+            gc_collect()
 
     @with_tracebacks(OverflowError)
     def test_func_return_too_large_int(self):
@@ -514,6 +475,10 @@ class WindowFunctionTests(unittest.TestCase):
         """
         self.con.create_window_function("sumint", 1, WindowSumInt)
 
+    def tearDown(self):
+        self.cur.close()
+        self.con.close()
+
     def test_win_sum_int(self):
         self.cur.execute(self.query % "sumint")
         self.assertEqual(self.cur.fetchall(), self.expected)
@@ -634,6 +599,7 @@ class AggregateTests(unittest.TestCase):
             """)
         cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
             ("foo", 5, 3.14, None, memoryview(b"blob"),))
+        cur.close()
 
         self.con.create_aggregate("nostep", 1, AggrNoStep)
         self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
@@ -646,9 +612,7 @@ class AggregateTests(unittest.TestCase):
         self.con.create_aggregate("aggtxt", 1, AggrText)
 
     def tearDown(self):
-        #self.cur.close()
-        #self.con.close()
-        pass
+        self.con.close()
 
     def test_aggr_error_on_create(self):
         with self.assertRaises(sqlite.OperationalError):
@@ -775,7 +739,7 @@ class AuthorizerTests(unittest.TestCase):
         self.con.set_authorizer(self.authorizer_cb)
 
     def tearDown(self):
-        pass
+        self.con.close()
 
     def test_table_access(self):
         with self.assertRaises(sqlite.DatabaseError) as cm:
diff --git a/Lib/test/test_sqlite3/util.py b/Lib/test/test_sqlite3/util.py
new file mode 100644 (file)
index 0000000..505406c
--- /dev/null
@@ -0,0 +1,78 @@
+import contextlib
+import functools
+import io
+import re
+import sqlite3
+import test.support
+import unittest
+
+
+# Helper for temporary memory databases
+def memory_database(*args, **kwargs):
+    cx = sqlite3.connect(":memory:", *args, **kwargs)
+    return contextlib.closing(cx)
+
+
+# Temporarily limit a database connection parameter
+@contextlib.contextmanager
+def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128):
+    try:
+        _prev = cx.setlimit(category, limit)
+        yield limit
+    finally:
+        cx.setlimit(category, _prev)
+
+
+def with_tracebacks(exc, regex="", name=""):
+    """Convenience decorator for testing callback tracebacks."""
+    def decorator(func):
+        _regex = re.compile(regex) if regex else None
+        @functools.wraps(func)
+        def wrapper(self, *args, **kwargs):
+            with test.support.catch_unraisable_exception() as cm:
+                # First, run the test with traceback enabled.
+                with check_tracebacks(self, cm, exc, _regex, name):
+                    func(self, *args, **kwargs)
+
+            # Then run the test with traceback disabled.
+            func(self, *args, **kwargs)
+        return wrapper
+    return decorator
+
+
+@contextlib.contextmanager
+def check_tracebacks(self, cm, exc, regex, obj_name):
+    """Convenience context manager for testing callback tracebacks."""
+    sqlite3.enable_callback_tracebacks(True)
+    try:
+        buf = io.StringIO()
+        with contextlib.redirect_stderr(buf):
+            yield
+
+        self.assertEqual(cm.unraisable.exc_type, exc)
+        if regex:
+            msg = str(cm.unraisable.exc_value)
+            self.assertIsNotNone(regex.search(msg))
+        if obj_name:
+            self.assertEqual(cm.unraisable.object.__name__, obj_name)
+    finally:
+        sqlite3.enable_callback_tracebacks(False)
+
+
+class MemoryDatabaseMixin:
+
+    def setUp(self):
+        self.con = sqlite3.connect(":memory:")
+        self.cur = self.con.cursor()
+
+    def tearDown(self):
+        self.cur.close()
+        self.con.close()
+
+    @property
+    def cx(self):
+        return self.con
+
+    @property
+    def cu(self):
+        return self.cur