]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dont erase transaction if rollback/commit failed outside of asyncpg
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Sep 2024 14:37:29 +0000 (10:37 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Sep 2024 17:14:36 +0000 (13:14 -0400)
Fixed critical issue in the asyncpg driver where a rollback or commit that
fails specifically for the ``MissingGreenlet`` condition or any other error
that is not raised by asyncpg itself would discard the asyncpg transaction
in any case, even though the transaction were still idle, leaving to a
server side condition with an idle transaction that then goes back into the
connection pool.   The flags for "transaction closed" are now not reset for
errors that are raised outside of asyncpg itself.  When asyncpg itself
raises an error for ``.commit()`` or ``.rollback()``, asyncpg does then
discard of this transaction.

Fixes: #11819
Change-Id: I12f0532788b03ea63fb47a7af21e07c37effb070

doc/build/changelog/unreleased_14/11819.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/asyncpg.py
test/dialect/postgresql/test_async_pg_py3k.py

diff --git a/doc/build/changelog/unreleased_14/11819.rst b/doc/build/changelog/unreleased_14/11819.rst
new file mode 100644 (file)
index 0000000..6211eb4
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 11819
+    :versions: 2.0.33, 1.4.54
+
+    Fixed critical issue in the asyncpg driver where a rollback or commit that
+    fails specifically for the ``MissingGreenlet`` condition or any other error
+    that is not raised by asyncpg itself would discard the asyncpg transaction
+    in any case, even though the transaction were still idle, leaving to a
+    server side condition with an idle transaction that then goes back into the
+    connection pool.   The flags for "transaction closed" are now not reset for
+    errors that are raised outside of asyncpg itself.  When asyncpg itself
+    raises an error for ``.commit()`` or ``.rollback()``, asyncpg does then
+    discard of this transaction.
index cb6b75154f31b0fb3189e677324b1e0b0fd0f3f3..90471556fc06cca084f771df593adc135088c9a4 100644 (file)
@@ -865,27 +865,47 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
         else:
             self._started = True
 
+    async def _rollback_and_discard(self):
+        try:
+            await self._transaction.rollback()
+        finally:
+            # if asyncpg .rollback() was actually called, then whether or
+            # not it raised or succeeded, the transation is done, discard it
+            self._transaction = None
+            self._started = False
+
+    async def _commit_and_discard(self):
+        try:
+            await self._transaction.commit()
+        finally:
+            # if asyncpg .commit() was actually called, then whether or
+            # not it raised or succeeded, the transation is done, discard it
+            self._transaction = None
+            self._started = False
+
     def rollback(self):
         if self._started:
             assert self._transaction is not None
             try:
-                await_(self._transaction.rollback())
-            except Exception as error:
-                self._handle_exception(error)
-            finally:
+                await_(self._rollback_and_discard())
                 self._transaction = None
                 self._started = False
+            except Exception as error:
+                # don't dereference asyncpg transaction if we didn't
+                # actually try to call rollback() on it
+                self._handle_exception(error)
 
     def commit(self):
         if self._started:
             assert self._transaction is not None
             try:
-                await_(self._transaction.commit())
-            except Exception as error:
-                self._handle_exception(error)
-            finally:
+                await_(self._commit_and_discard())
                 self._transaction = None
                 self._started = False
+            except Exception as error:
+                # don't dereference asyncpg transaction if we didn't
+                # actually try to call commit() on it
+                self._handle_exception(error)
 
     def close(self):
         self.rollback()
index c09acf5b47242c4fd21ce1429aa948ca43b745fe..feff60c5789142d973e5cdce55b59b0b7fa5f6b7 100644 (file)
@@ -13,6 +13,7 @@ from sqlalchemy import testing
 from sqlalchemy.dialects.postgresql import ENUM
 from sqlalchemy.testing import async_test
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 
@@ -165,6 +166,54 @@ class AsyncPgTest(fixtures.TestBase):
                 ],
             )
 
+    @testing.variation("trans", ["commit", "rollback"])
+    @async_test
+    async def test_dont_reset_open_transaction(
+        self, trans, async_testing_engine
+    ):
+        """test for #11819"""
+
+        engine = async_testing_engine()
+
+        control_conn = await engine.connect()
+        await control_conn.execution_options(isolation_level="AUTOCOMMIT")
+
+        conn = await engine.connect()
+        txid_current = (
+            await conn.exec_driver_sql("select txid_current()")
+        ).scalar()
+
+        with expect_raises(exc.MissingGreenlet):
+            if trans.commit:
+                conn.sync_connection.connection.dbapi_connection.commit()
+            elif trans.rollback:
+                conn.sync_connection.connection.dbapi_connection.rollback()
+            else:
+                trans.fail()
+
+        trans_exists = (
+            await control_conn.exec_driver_sql(
+                f"SELECT count(*) FROM pg_stat_activity "
+                f"where backend_xid={txid_current}"
+            )
+        ).scalar()
+        eq_(trans_exists, 1)
+
+        if trans.commit:
+            await conn.commit()
+        elif trans.rollback:
+            await conn.rollback()
+        else:
+            trans.fail()
+
+        trans_exists = (
+            await control_conn.exec_driver_sql(
+                f"SELECT count(*) FROM pg_stat_activity "
+                f"where backend_xid={txid_current}"
+            )
+        ).scalar()
+        eq_(trans_exists, 0)
+
     @async_test
     async def test_failed_commit_recover(self, metadata, async_testing_engine):
         Table("t1", metadata, Column("id", Integer, primary_key=True))