]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
test sqlite w/ savepoint workaround in session fixture test
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Mar 2022 16:17:47 +0000 (11:17 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Mar 2022 16:17:47 +0000 (11:17 -0500)
Fixes: #7795
Change-Id: Ib790581555656c088f86c00080c70d19ca295a03

lib/sqlalchemy/testing/engines.py
test/orm/test_transaction.py
test/requirements.py

index 79adb8c3cd22b3c514af253ca6805b58988b0c5c..4496b8dede4a8a9db976343c586b292927602d8e 100644 (file)
@@ -306,8 +306,10 @@ def testing_engine(
     options=None,
     asyncio=False,
     transfer_staticpool=False,
+    _sqlite_savepoint=False,
 ):
     if asyncio:
+        assert not _sqlite_savepoint
         from sqlalchemy.ext.asyncio import (
             create_async_engine as create_engine,
         )
@@ -318,9 +320,11 @@ def testing_engine(
     if not options:
         use_reaper = True
         scope = "function"
+        sqlite_savepoint = False
     else:
         use_reaper = options.pop("use_reaper", True)
         scope = options.pop("scope", "function")
+        sqlite_savepoint = options.pop("sqlite_savepoint", False)
 
     url = url or config.db.url
 
@@ -336,6 +340,16 @@ def testing_engine(
 
     engine = create_engine(url, **options)
 
+    if sqlite_savepoint and engine.name == "sqlite":
+        # apply SQLite savepoint workaround
+        @event.listens_for(engine, "connect")
+        def do_connect(dbapi_connection, connection_record):
+            dbapi_connection.isolation_level = None
+
+        @event.listens_for(engine, "begin")
+        def do_begin(conn):
+            conn.exec_driver_sql("BEGIN")
+
     if transfer_staticpool:
         from sqlalchemy.pool import StaticPool
 
index 96a00ff54c37706be6dd77f879f999ddae482cfe..bc84d444758ad7eb48774b639cdb53fa90930617 100644 (file)
@@ -2346,10 +2346,10 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
 class JoinIntoAnExternalTransactionFixture:
     """Test the "join into an external transaction" examples"""
 
-    __leave_connections_for_teardown__ = True
-
     def setup_test(self):
-        self.engine = testing.db
+        self.engine = engines.testing_engine(
+            options={"use_reaper": False, "sqlite_savepoint": True}
+        )
         self.connection = self.engine.connect()
 
         self.metadata = MetaData()
@@ -2410,7 +2410,7 @@ class NewStyleJoinIntoAnExternalTransactionTest(
         # bind an individual Session to the connection
         self.session = Session(bind=self.connection)
 
-        if testing.requires.savepoints.enabled:
+        if testing.requires.compat_savepoints.enabled:
             self.nested = self.connection.begin_nested()
 
             @event.listens_for(self.session, "after_transaction_end")
@@ -2427,7 +2427,7 @@ class NewStyleJoinIntoAnExternalTransactionTest(
         if self.trans.is_active:
             self.trans.rollback()
 
-    @testing.requires.savepoints
+    @testing.requires.compat_savepoints
     def test_something_with_context_managers(self):
         A = self.A
 
@@ -2478,7 +2478,7 @@ class LegacyJoinIntoAnExternalTransactionTest(
         # bind an individual Session to the connection
         self.session = Session(bind=self.connection)
 
-        if testing.requires.savepoints.enabled:
+        if testing.requires.compat_savepoints.enabled:
             # start the session in a SAVEPOINT...
             self.session.begin_nested()
 
index 46d0ba466f0d79f364edfadde00c8f1e4e25df75..df6b5d62f5362ceedbb9759706ea7a9f06f20b85 100644 (file)
@@ -483,6 +483,16 @@ class DefaultRequirements(SuiteRequirements):
             "savepoints not supported",
         )
 
+    @property
+    def compat_savepoints(self):
+        """Target database must support savepoints, or a compat
+        recipe e.g. for sqlite will be used"""
+
+        return skip_if(
+            ["sybase", ("mysql", "<", (5, 0, 3))],
+            "savepoints not supported",
+        )
+
     @property
     def savepoints_w_release(self):
         return self.savepoints + skip_if(