]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
test sqlite w/ savepoint workaround in session fixture test
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Mar 2022 16:17:47 +0000 (11:17 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Mar 2022 16:18:25 +0000 (11:18 -0500)
Fixes: #7795
Change-Id: Ib790581555656c088f86c00080c70d19ca295a03
(cherry picked from commit fbacb1991585202a5bf22acb0d36b5c979bcfad8)

lib/sqlalchemy/testing/engines.py
test/orm/test_transaction.py
test/requirements.py

index a92d476ac54e978790d2e49d3fb2322217cab097..b8be6b9bd551462b3b28316b027d81bad00348d3 100644 (file)
@@ -276,10 +276,12 @@ def testing_engine(
     future=None,
     asyncio=False,
     transfer_staticpool=False,
+    _sqlite_savepoint=False,
 ):
     """Produce an engine configured by --options with optional overrides."""
 
     if asyncio:
+        assert not _sqlite_savepoint
         from sqlalchemy.ext.asyncio import (
             create_async_engine as create_engine,
         )
@@ -294,9 +296,11 @@ def testing_engine(
     if not options:
         use_reaper = True
         scope = "function"
+        sqlite_savepoint = False
     else:
         use_reaper = options.pop("use_reaper", True)
         scope = options.pop("scope", "function")
+        sqlite_savepoint = options.pop("sqlite_savepoint", False)
 
     url = url or config.db.url
 
@@ -312,6 +316,16 @@ def testing_engine(
 
     engine = create_engine(url, **options)
 
+    if sqlite_savepoint and engine.name == "sqlite":
+        # apply SQLite savepoint workaround
+        @event.listens_for(engine, "connect")
+        def do_connect(dbapi_connection, connection_record):
+            dbapi_connection.isolation_level = None
+
+        @event.listens_for(engine, "begin")
+        def do_begin(conn):
+            conn.exec_driver_sql("BEGIN")
+
     if transfer_staticpool:
         from sqlalchemy.pool import StaticPool
 
index 603ec079a767750c4c59273a795f033b7d107dcd..e077220e19b918158367b90d254f70a81c437853 100644 (file)
@@ -2526,10 +2526,10 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
 class JoinIntoAnExternalTransactionFixture(object):
     """Test the "join into an external transaction" examples"""
 
-    __leave_connections_for_teardown__ = True
-
     def setup_test(self):
-        self.engine = testing.db
+        self.engine = engines.testing_engine(
+            options={"use_reaper": False, "sqlite_savepoint": True}
+        )
         self.connection = self.engine.connect()
 
         self.metadata = MetaData()
@@ -2590,7 +2590,7 @@ class NewStyleJoinIntoAnExternalTransactionTest(
         # bind an individual Session to the connection
         self.session = Session(bind=self.connection, future=True)
 
-        if testing.requires.savepoints.enabled:
+        if testing.requires.compat_savepoints.enabled:
             self.nested = self.connection.begin_nested()
 
             @event.listens_for(self.session, "after_transaction_end")
@@ -2607,7 +2607,7 @@ class NewStyleJoinIntoAnExternalTransactionTest(
         if self.trans.is_active:
             self.trans.rollback()
 
-    @testing.requires.savepoints
+    @testing.requires.compat_savepoints
     def test_something_with_context_managers(self):
         A = self.A
 
@@ -2673,7 +2673,7 @@ class LegacyJoinIntoAnExternalTransactionTest(
         # bind an individual Session to the connection
         self.session = Session(bind=self.connection)
 
-        if testing.requires.savepoints.enabled:
+        if testing.requires.compat_savepoints.enabled:
             # start the session in a SAVEPOINT...
             self.session.begin_nested()
 
index 1780e3b21a2d524288f37006ae3ea103582b9f0f..4c9ac40c54df7a8f687974f8221711b7a7b09232 100644 (file)
@@ -558,6 +558,16 @@ class DefaultRequirements(SuiteRequirements):
             "savepoints not supported",
         )
 
+    @property
+    def compat_savepoints(self):
+        """Target database must support savepoints, or a compat
+        recipe e.g. for sqlite will be used"""
+
+        return skip_if(
+            ["sybase", ("mysql", "<", (5, 0, 3))],
+            "savepoints not supported",
+        )
+
     @property
     def savepoints_w_release(self):
         return self.savepoints + skip_if(