]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
run handle error for commit/rollback fail and cancel transaction
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Jan 2021 22:23:52 +0000 (17:23 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Jan 2021 22:51:34 +0000 (17:51 -0500)
Fixed bug in asyncpg dialect where a failure during a "commit" or less
likely a "rollback" should cancel the entire transaction; it's no longer
possible to emit rollback. Previously the connection would continue to
await a rollback that could not succeed as asyncpg would reject it.

Fixes: #5824
Change-Id: I5a4916740c269b410f4d1a78ed25191de344b9d0

doc/build/changelog/unreleased_14/5824.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/testing/fixtures.py
test/dialect/postgresql/test_async_pg_py3k.py
test/orm/test_deprecations.py

diff --git a/doc/build/changelog/unreleased_14/5824.rst b/doc/build/changelog/unreleased_14/5824.rst
new file mode 100644 (file)
index 0000000..cbdcbc6
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, postgresql, asyncio
+    :tickets: 5824
+
+    Fixed bug in asyncpg dialect where a failure during a "commit" or less
+    likely a "rollback" should cancel the entire transaction; it's no longer
+    possible to emit rollback. Previously the connection would continue to
+    await a rollback that could not succeed as asyncpg would reject it.
index e542c77f43354424e2aa7c13c77c92fc9d46e88e..424ed0d5070d07f782859b6dd46a58a49aa9f6d7 100644 (file)
@@ -615,6 +615,10 @@ class AsyncAdapt_asyncpg_connection:
         return prepared_stmt, attributes
 
     def _handle_exception(self, error):
+        if self._connection.is_closed():
+            self._transaction = None
+            self._started = False
+
         if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
             exception_mapping = self.dbapi._asyncpg_error_translate
 
@@ -669,15 +673,23 @@ class AsyncAdapt_asyncpg_connection:
 
     def rollback(self):
         if self._started:
-            self.await_(self._transaction.rollback())
-            self._transaction = None
-            self._started = False
+            try:
+                self.await_(self._transaction.rollback())
+            except Exception as error:
+                self._handle_exception(error)
+            finally:
+                self._transaction = None
+                self._started = False
 
     def commit(self):
         if self._started:
-            self.await_(self._transaction.commit())
-            self._transaction = None
-            self._started = False
+            try:
+                self.await_(self._transaction.commit())
+            except Exception as error:
+                self._handle_exception(error)
+            finally:
+                self._transaction = None
+                self._started = False
 
     def close(self):
         self.rollback()
index dcdeee5c902340a7897f082f09cfa27f77ac3acd..08cab051e66009856954479ff0ceb8cc7affc3a5 100644 (file)
@@ -106,6 +106,14 @@ class TestBase(object):
 
         engines.testing_reaper._drop_testing_engines("fixture")
 
+    @config.fixture()
+    def async_testing_engine(self, testing_engine):
+        def go(**kw):
+            kw["asyncio"] = True
+            return testing_engine(**kw)
+
+        return go
+
     @config.fixture()
     def metadata(self, request):
         """Provide bound MetaData for a single test, dropping afterwards."""
index f6d48f3c65b0a4360d594544f3af3d65f0e1db40..62c8f5dde98161d4166ce346e5fd3937101b8e0f 100644 (file)
@@ -2,15 +2,16 @@ import random
 
 from sqlalchemy import Column
 from sqlalchemy import exc
+from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy.dialects.postgresql import ENUM
-from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.testing import async_test
-from sqlalchemy.testing import engines
+from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 
 
@@ -18,28 +19,9 @@ class AsyncPgTest(fixtures.TestBase):
     __requires__ = ("async_dialect",)
     __only_on__ = "postgresql+asyncpg"
 
-    @testing.fixture
-    def async_engine(self):
-        return create_async_engine(testing.db.url)
-
-    @testing.fixture()
-    def metadata(self):
-        # TODO: remove when Iae6ab95938a7e92b6d42086aec534af27b5577d3
-        # merges
-
-        from sqlalchemy.testing import util as testing_util
-        from sqlalchemy.sql import schema
-
-        metadata = schema.MetaData()
-
-        try:
-            yield metadata
-        finally:
-            testing_util.drop_all_tables_from_metadata(metadata, testing.db)
-
     @async_test
     async def test_detect_stale_ddl_cache_raise_recover(
-        self, metadata, async_engine
+        self, metadata, async_testing_engine
     ):
         async def async_setup(engine, strlen):
             metadata.clear()
@@ -68,9 +50,10 @@ class AsyncPgTest(fixtures.TestBase):
             Column("name", String),
         )
 
-        await async_setup(async_engine, 30)
+        first_engine = async_testing_engine()
+        second_engine = async_testing_engine()
 
-        second_engine = engines.testing_engine(asyncio=True)
+        await async_setup(first_engine, 30)
 
         async with second_engine.connect() as conn:
             result = await conn.execute(
@@ -82,7 +65,7 @@ class AsyncPgTest(fixtures.TestBase):
             rows = result.fetchall()
             assert len(rows) >= 29
 
-        await async_setup(async_engine, 20)
+        await async_setup(first_engine, 20)
 
         async with second_engine.connect() as conn:
             with testing.expect_raises_message(
@@ -112,7 +95,7 @@ class AsyncPgTest(fixtures.TestBase):
 
     @async_test
     async def test_detect_stale_type_cache_raise_recover(
-        self, metadata, async_engine
+        self, metadata, async_testing_engine
     ):
         async def async_setup(engine, enums):
             metadata = MetaData()
@@ -141,13 +124,13 @@ class AsyncPgTest(fixtures.TestBase):
             ),
         )
 
-        await async_setup(async_engine, ("beans", "means", "keens"))
-
-        second_engine = engines.testing_engine(
-            asyncio=True,
-            options={"connect_args": {"prepared_statement_cache_size": 0}},
+        first_engine = async_testing_engine()
+        second_engine = async_testing_engine(
+            options={"connect_args": {"prepared_statement_cache_size": 0}}
         )
 
+        await async_setup(first_engine, ("beans", "means", "keens"))
+
         async with second_engine.connect() as conn:
             await conn.execute(
                 t1.insert(),
@@ -157,7 +140,7 @@ class AsyncPgTest(fixtures.TestBase):
                 ],
             )
 
-        await async_setup(async_engine, ("faux", "beau", "flow"))
+        await async_setup(first_engine, ("faux", "beau", "flow"))
 
         async with second_engine.connect() as conn:
             with testing.expect_raises_message(
@@ -180,3 +163,91 @@ class AsyncPgTest(fixtures.TestBase):
                     for i in range(10)
                 ],
             )
+
+    @async_test
+    async def test_failed_commit_recover(self, metadata, async_testing_engine):
+
+        Table("t1", metadata, Column("id", Integer, primary_key=True))
+
+        t2 = Table(
+            "t2",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column(
+                "t1_id",
+                Integer,
+                ForeignKey("t1.id", deferrable=True, initially="deferred"),
+            ),
+        )
+
+        engine = async_testing_engine()
+
+        async with engine.connect() as conn:
+            await conn.run_sync(metadata.create_all)
+
+            await conn.execute(t2.insert().values(id=1, t1_id=2))
+
+            with testing.expect_raises_message(
+                exc.IntegrityError, 'insert or update on table "t2"'
+            ):
+                await conn.commit()
+
+            await conn.rollback()
+
+            eq_((await conn.execute(select(1))).scalar(), 1)
+
+    @async_test
+    async def test_rollback_twice_no_problem(
+        self, metadata, async_testing_engine
+    ):
+
+        engine = async_testing_engine()
+
+        async with engine.connect() as conn:
+
+            trans = await conn.begin()
+
+            await trans.rollback()
+
+            await conn.rollback()
+
+    @async_test
+    async def test_closed_during_execute(self, metadata, async_testing_engine):
+
+        engine = async_testing_engine()
+
+        async with engine.connect() as conn:
+            await conn.begin()
+
+            with testing.expect_raises_message(
+                exc.DBAPIError, "connection was closed"
+            ):
+                await conn.exec_driver_sql(
+                    "select pg_terminate_backend(pg_backend_pid())"
+                )
+
+    @async_test
+    async def test_failed_rollback_recover(
+        self, metadata, async_testing_engine
+    ):
+
+        engine = async_testing_engine()
+
+        async with engine.connect() as conn:
+            await conn.begin()
+
+            (await conn.execute(select(1))).scalar()
+
+            raw_connection = await conn.get_raw_connection()
+            # close the asyncpg transaction directly
+            await raw_connection._transaction.rollback()
+
+            with testing.expect_raises_message(
+                exc.InterfaceError, "already rolled back"
+            ):
+                await conn.rollback()
+
+            # recovers no problem
+
+            await conn.begin()
+            await conn.rollback()
index ce012d381f73c7de7d8615bb5835bcfd58ecb396..90bb291f8c0ecd115c036d983389f4693f14c254 100644 (file)
@@ -2199,11 +2199,6 @@ class SessionTest(fixtures.RemovesEvents, _LocalFixture):
 
 class AutocommitClosesOnFailTest(fixtures.MappedTest):
     __requires__ = ("deferrable_fks",)
-    __only_on__ = ("postgresql+psycopg2",)  # needs #5824 for asyncpg
-
-    # this test has a lot of problems, am investigating asyncpg
-    # issues separately.  just get this legacy use case to pass for now.
-    __only_on__ = ("postgresql+psycopg2",)
 
     @classmethod
     def define_tables(cls, metadata):
@@ -2247,9 +2242,11 @@ class AutocommitClosesOnFailTest(fixtures.MappedTest):
 
         # with a deferred constraint, this fails at COMMIT time instead
         # of at INSERT time.
-        session.add(T2(t1_id=123))
+        session.add(T2(id=1, t1_id=123))
 
-        assert_raises(sa.exc.IntegrityError, session.flush)
+        assert_raises(
+            (sa.exc.IntegrityError, sa.exc.DatabaseError), session.flush
+        )
 
         assert session._legacy_transaction() is None