From a8102ba496c4c11eae6b904a962cf352902f0de7 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 7 Mar 2022 11:17:47 -0500 Subject: [PATCH] test sqlite w/ savepoint workaround in session fixture test Fixes: #7795 Change-Id: Ib790581555656c088f86c00080c70d19ca295a03 (cherry picked from commit fbacb1991585202a5bf22acb0d36b5c979bcfad8) --- lib/sqlalchemy/testing/engines.py | 14 ++++++++++++++ test/orm/test_transaction.py | 12 ++++++------ test/requirements.py | 10 ++++++++++ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index a92d476ac5..b8be6b9bd5 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -276,10 +276,12 @@ def testing_engine( future=None, asyncio=False, transfer_staticpool=False, + _sqlite_savepoint=False, ): """Produce an engine configured by --options with optional overrides.""" if asyncio: + assert not _sqlite_savepoint from sqlalchemy.ext.asyncio import ( create_async_engine as create_engine, ) @@ -294,9 +296,11 @@ def testing_engine( if not options: use_reaper = True scope = "function" + sqlite_savepoint = False else: use_reaper = options.pop("use_reaper", True) scope = options.pop("scope", "function") + sqlite_savepoint = options.pop("sqlite_savepoint", False) url = url or config.db.url @@ -312,6 +316,16 @@ def testing_engine( engine = create_engine(url, **options) + if sqlite_savepoint and engine.name == "sqlite": + # apply SQLite savepoint workaround + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.exec_driver_sql("BEGIN") + if transfer_staticpool: from sqlalchemy.pool import StaticPool diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 603ec079a7..e077220e19 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -2526,10 +2526,10 @@ class NaturalPKRollbackTest(fixtures.MappedTest): class JoinIntoAnExternalTransactionFixture(object): """Test the "join into an external transaction" examples""" - __leave_connections_for_teardown__ = True - def setup_test(self): - self.engine = testing.db + self.engine = engines.testing_engine( + options={"use_reaper": False, "sqlite_savepoint": True} + ) self.connection = self.engine.connect() self.metadata = MetaData() @@ -2590,7 +2590,7 @@ class NewStyleJoinIntoAnExternalTransactionTest( # bind an individual Session to the connection self.session = Session(bind=self.connection, future=True) - if testing.requires.savepoints.enabled: + if testing.requires.compat_savepoints.enabled: self.nested = self.connection.begin_nested() @event.listens_for(self.session, "after_transaction_end") @@ -2607,7 +2607,7 @@ class NewStyleJoinIntoAnExternalTransactionTest( if self.trans.is_active: self.trans.rollback() - @testing.requires.savepoints + @testing.requires.compat_savepoints def test_something_with_context_managers(self): A = self.A @@ -2673,7 +2673,7 @@ class LegacyJoinIntoAnExternalTransactionTest( # bind an individual Session to the connection self.session = Session(bind=self.connection) - if testing.requires.savepoints.enabled: + if testing.requires.compat_savepoints.enabled: # start the session in a SAVEPOINT... self.session.begin_nested() diff --git a/test/requirements.py b/test/requirements.py index 1780e3b21a..4c9ac40c54 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -558,6 +558,16 @@ class DefaultRequirements(SuiteRequirements): "savepoints not supported", ) + @property + def compat_savepoints(self): + """Target database must support savepoints, or a compat + recipe e.g. for sqlite will be used""" + + return skip_if( + ["sybase", ("mysql", "<", (5, 0, 3))], + "savepoints not supported", + ) + @property def savepoints_w_release(self): return self.savepoints + skip_if( -- 2.47.2