]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Move initialize do_rollback() outside of the dialect
authorMatthew Wilkes <git@matthewwilkes.name>
Thu, 9 May 2019 22:04:35 +0000 (18:04 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 May 2019 01:51:41 +0000 (21:51 -0400)
Moved the "rollback" which occurs during dialect initialization so that it
occurs after additional dialect-specific initialize steps, in particular
those of the psycopg2 dialect which would inadvertently leave transactional
state on the first new connection, which could interfere with some
psycopg2-specific APIs which require that no transaction is started.  Pull
request courtesy Matthew Wilkes.

Fixes: #4663
Closes: #4664
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/4664
Pull-request-sha: e544fe671d443ed06b210ba1cd1d7ba9c5653831

Change-Id: If40a15a1679b4eec0b8b8222f678697728009c30
(cherry picked from commit f601791a914d3181252493800871c458ad6c46d1)

doc/build/changelog/unreleased_13/4663.rst [new file with mode: 0644]
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/strategies.py
test/dialect/postgresql/test_dialect.py
test/engine/test_execute.py

diff --git a/doc/build/changelog/unreleased_13/4663.rst b/doc/build/changelog/unreleased_13/4663.rst
new file mode 100644 (file)
index 0000000..07b943e
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+   :tags: bug, engine, postgresql
+   :tickets: 4663
+
+   Moved the "rollback" which occurs during dialect initialization so that it
+   occurs after additional dialect-specific initialize steps, in particular
+   those of the psycopg2 dialect which would inadvertently leave transactional
+   state on the first new connection, which could interfere with some
+   psycopg2-specific APIs which require that no transaction is started.  Pull
+   request courtesy Matthew Wilkes.
+
index 51e2c4603b04746b1a1039a592d09653536cc143..f6c30cbf47a3a915291c1ff39e7d09b61a0da519 100644 (file)
@@ -312,8 +312,6 @@ class DefaultDialect(interfaces.Dialect):
         ):
             self._description_decoder = self.description_encoding = None
 
-        self.do_rollback(connection.connection)
-
     def on_connect(self):
         """return a callable which sets up a newly created DBAPI connection.
 
index e367ef890423f2ad554d09f3bbae4b86c0ed7298..d3a22e5ac8bef8ac139761befbd37286396a2752 100644 (file)
@@ -197,6 +197,7 @@ class DefaultEngineStrategy(EngineStrategy):
                 )
                 c._execution_options = util.immutabledict()
                 dialect.initialize(c)
+                dialect.do_rollback(c.connection)
 
             event.listen(pool, "first_connect", first_connect, once=True)
 
index c68af2abb89d441749877f10f983e43dfaaed5c5..25cba6269c387022e2c5dd3a9b953b914e4162a6 100644 (file)
@@ -39,6 +39,7 @@ from sqlalchemy.testing.assertions import assert_raises_message
 from sqlalchemy.testing.assertions import AssertsCompiledSQL
 from sqlalchemy.testing.assertions import AssertsExecutionResults
 from sqlalchemy.testing.assertions import eq_
+from sqlalchemy.testing.assertions import ne_
 from sqlalchemy.testing.mock import Mock
 from ...engine import test_execute
 
@@ -500,6 +501,14 @@ class MiscBackendTest(
                 "c %s NOT NULL" % expected,
             )
 
+    @testing.requires.psycopg2_compatibility
+    def test_initial_transaction_state(self):
+        from psycopg2.extensions import STATUS_IN_TRANSACTION
+
+        engine = engines.testing_engine()
+        with engine.connect() as conn:
+            ne_(conn.connection.status, STATUS_IN_TRANSACTION)
+
 
 class AutocommitTextTest(test_execute.AutocommitTextTest):
     __only_on__ = "postgresql"
index e18cdfad4cf8758490414e55ba1f1b86e8122b7a..480da712214eaf2bf1afb424d279235e96a90f38 100644 (file)
@@ -660,6 +660,15 @@ class ExecuteTest(fixtures.TestBase):
         eq_(conn._execution_options, {"autocommit": True})
         conn.close()
 
+    def test_initialize_rollback(self):
+        """test a rollback happens during first connect"""
+        eng = create_engine(testing.db.url)
+        with patch.object(eng.dialect, "do_rollback") as do_rollback:
+            assert do_rollback.call_count == 0
+            connection = eng.connect()
+            assert do_rollback.call_count == 1
+        connection.close()
+
     @testing.requires.ad_hoc_engines
     def test_dialect_init_uses_options(self):
         eng = create_engine(testing.db.url)