]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use standard path for asyncio create w/ exception handler
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 21 Sep 2025 17:54:13 +0000 (13:54 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Sep 2025 14:45:22 +0000 (10:45 -0400)
Refactored all asyncio dialects so that exceptions which occur on failed
connection attempts are appropriately wrapped with SQLAlchemy exception
objects, allowing for consistent error handling.

Fixes: #11956
Change-Id: Ic3fdbf334f059f92b03896b6429efa50968ca8a8

doc/build/changelog/unreleased_21/11956.rst [new file with mode: 0644]
lib/sqlalchemy/connectors/aioodbc.py
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/oracle/oracledb.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/sqlite/aiosqlite.py
test/ext/asyncio/test_engine.py

diff --git a/doc/build/changelog/unreleased_21/11956.rst b/doc/build/changelog/unreleased_21/11956.rst
new file mode 100644 (file)
index 0000000..7cae83d
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 11956
+
+    Refactored all asyncio dialects so that exceptions which occur on failed
+    connection attempts are appropriately wrapped with SQLAlchemy exception
+    objects, allowing for consistent error handling.
index 39f45dc26531a3794f380c2ddbda11e8eaaad427..1a44c7ebe60e14e210e38bb3ef2720391b1dc006 100644 (file)
@@ -130,9 +130,11 @@ class AsyncAdapt_aioodbc_dbapi(AsyncAdapt_dbapi_module):
     def connect(self, *arg, **kw):
         creator_fn = kw.pop("async_creator_fn", self.aioodbc.connect)
 
-        return AsyncAdapt_aioodbc_connection(
-            self,
-            await_(creator_fn(*arg, **kw)),
+        return await_(
+            AsyncAdapt_aioodbc_connection.create(
+                self,
+                creator_fn(*arg, **kw),
+            )
         )
 
 
index 29ca0fc98fe44a9e4de38c7998f3a68facecaa57..0d565e300a4be09d8f28c0a341886731ab36e6e0 100644 (file)
@@ -15,6 +15,7 @@ import sys
 import types
 from typing import Any
 from typing import AsyncIterator
+from typing import Awaitable
 from typing import Deque
 from typing import Iterator
 from typing import NoReturn
@@ -364,6 +365,20 @@ class AsyncAdapt_dbapi_connection(AdaptedConnection):
 
     _connection: AsyncIODBAPIConnection
 
+    @classmethod
+    async def create(
+        cls,
+        dbapi: Any,
+        connection_awaitable: Awaitable[AsyncIODBAPIConnection],
+        **kw: Any,
+    ) -> Self:
+        try:
+            connection = await connection_awaitable
+        except Exception as error:
+            cls._handle_exception_no_connection(dbapi, error)
+        else:
+            return cls(dbapi, connection, **kw)
+
     def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
         self.dbapi = dbapi
         self._connection = connection
@@ -385,11 +400,17 @@ class AsyncAdapt_dbapi_connection(AdaptedConnection):
         cursor.execute(operation, parameters)
         return cursor
 
-    def _handle_exception(self, error: Exception) -> NoReturn:
+    @classmethod
+    def _handle_exception_no_connection(
+        cls, dbapi: Any, error: Exception
+    ) -> NoReturn:
         exc_info = sys.exc_info()
 
         raise error.with_traceback(exc_info[2])
 
+    def _handle_exception(self, error: Exception) -> NoReturn:
+        self._handle_exception_no_connection(self.dbapi, error)
+
     def rollback(self) -> None:
         try:
             await_(self._connection.rollback())
index f630773318d615037d068c9529f50e2f185510e3..f72f947dd33e482abb249e0c46d12489e0f381b3 100644 (file)
@@ -148,9 +148,11 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
     def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection:
         creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
 
-        return AsyncAdapt_aiomysql_connection(
-            self,
-            await_(creator_fn(*arg, **kw)),
+        return await_(
+            AsyncAdapt_aiomysql_connection.create(
+                self,
+                creator_fn(*arg, **kw),
+            )
         )
 
     def _init_cursors_subclasses(
index 952ea171e78c208a6899d15e82f08c9c9339d7ac..837f164bcc6c5f7149ae519b7aeb8b775b6cbe8c 100644 (file)
@@ -81,9 +81,12 @@ class AsyncAdapt_asyncmy_connection(
     _cursor_cls = AsyncAdapt_asyncmy_cursor
     _ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor
 
-    def _handle_exception(self, error: Exception) -> NoReturn:
+    @classmethod
+    def _handle_exception_no_connection(
+        cls, dbapi: Any, error: Exception
+    ) -> NoReturn:
         if isinstance(error, AttributeError):
-            raise self.dbapi.InternalError(
+            raise dbapi.InternalError(
                 "network operation failed due to asyncmy attribute error"
             )
 
@@ -153,9 +156,11 @@ class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
     def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_asyncmy_connection:
         creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
 
-        return AsyncAdapt_asyncmy_connection(
-            self,
-            await_(creator_fn(*arg, **kw)),
+        return await_(
+            AsyncAdapt_asyncmy_connection.create(
+                self,
+                creator_fn(*arg, **kw),
+            )
         )
 
 
index 7c4a56ff37bafbc6002d354106eceed9ce17cb77..1fbcabb6dd6e7026cb7be9f15ad36acc1f9d605c 100644 (file)
@@ -850,8 +850,8 @@ class OracledbAdaptDBAPI(AsyncAdapt_dbapi_module):
 
     def connect(self, *arg, **kw):
         creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async)
-        return AsyncAdapt_oracledb_connection(
-            self, await_(creator_fn(*arg, **kw))
+        return await_(
+            AsyncAdapt_oracledb_connection.create(self, creator_fn(*arg, **kw))
         )
 
 
index 51bc8b11bd3d9fcc86d69f571c6ab44f1a2eec32..65d6076ca4975d916da5ac5f68607b4e7d64cf1a 100644 (file)
@@ -834,12 +834,12 @@ class AsyncAdapt_asyncpg_connection(
 
         return prepared_stmt, attributes
 
-    def _handle_exception(self, error: Exception) -> NoReturn:
-        if self._connection.is_closed():
-            self._transaction = None
-
+    @classmethod
+    def _handle_exception_no_connection(
+        cls, dbapi: Any, error: Exception
+    ) -> NoReturn:
         if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
-            exception_mapping = self.dbapi._asyncpg_error_translate
+            exception_mapping = dbapi._asyncpg_error_translate
 
             for super_ in type(error).__mro__:
                 if super_ in exception_mapping:
@@ -848,10 +848,13 @@ class AsyncAdapt_asyncpg_connection(
                         message, error
                     )
                     raise translated_error from error
-            else:
-                super()._handle_exception(error)
-        else:
-            super()._handle_exception(error)
+        super()._handle_exception_no_connection(dbapi, error)
+
+    def _handle_exception(self, error: Exception) -> NoReturn:
+        if self._connection.is_closed():
+            self._transaction = None
+
+        super()._handle_exception(error)
 
     @property
     def autocommit(self):
@@ -967,11 +970,13 @@ class AsyncAdapt_asyncpg_dbapi(AsyncAdapt_dbapi_module):
             "prepared_statement_name_func", None
         )
 
-        return AsyncAdapt_asyncpg_connection(
-            self,
-            await_(creator_fn(*arg, **kw)),
-            prepared_statement_cache_size=prepared_statement_cache_size,
-            prepared_statement_name_func=prepared_statement_name_func,
+        return await_(
+            AsyncAdapt_asyncpg_connection.create(
+                self,
+                creator_fn(*arg, **kw),
+                prepared_statement_cache_size=prepared_statement_cache_size,
+                prepared_statement_name_func=prepared_statement_name_func,
+            )
         )
 
     class Error(AsyncAdapt_Error):
index 9a877702064550f561cce0498f597bd9a4e169e5..f525fe1831efffd2c767d6a18c21ec43968d20c1 100644 (file)
@@ -696,8 +696,8 @@ class PsycopgAdaptDBAPI(AsyncAdapt_dbapi_module):
         creator_fn = kw.pop(
             "async_creator_fn", self.psycopg.AsyncConnection.connect
         )
-        return AsyncAdapt_psycopg_connection(
-            self, await_(creator_fn(*arg, **kw))
+        return await_(
+            AsyncAdapt_psycopg_connection.create(self, creator_fn(*arg, **kw))
         )
 
 
index ad0cd89f60d4451f3ae5d120ea195cecf4105346..79b26d219f204c45f4f226c8e39951082fb6e3b5 100644 (file)
@@ -177,14 +177,17 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
         except Exception as error:
             self._handle_exception(error)
 
-    def _handle_exception(self, error: Exception) -> NoReturn:
+    @classmethod
+    def _handle_exception_no_connection(
+        cls, dbapi: Any, error: Exception
+    ) -> NoReturn:
         if isinstance(error, ValueError) and error.args[0].lower() in (
             "no active connection",
             "connection closed",
         ):
-            raise self.dbapi.sqlite.OperationalError(error.args[0]) from error
+            raise dbapi.sqlite.OperationalError(error.args[0]) from error
         else:
-            super()._handle_exception(error)
+            super()._handle_exception_no_connection(dbapi, error)
 
 
 class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
index a8d2e2ce3cbd819ae5b0584ce59c77496946c74f..49399f8e5ecec98241847cd37373d96be0c51845 100644 (file)
@@ -909,6 +909,39 @@ class AsyncEngineTest(EngineFixture):
                 # because the cursor should be closed
                 await driver_cursor.execute(select_one_sql)
 
+    @async_test
+    async def test_async_creator_handle_error(self, async_testing_engine):
+        """test for #11956"""
+
+        existing_creator = testing.db.pool._creator
+
+        def create_and_break():
+            sync_conn = existing_creator()
+            cursor = sync_conn.cursor()
+
+            # figure out a way to get a native driver exception.  This really
+            # only applies to asyncpg where we rewrite the exception
+            # hierarchy with our own emulated exception; other backends raise
+            # standard DBAPI exceptions (with some buggy cases here and there
+            # which they miss) even though they are async
+            try:
+                cursor.execute("this will raise an error")
+            except Exception as possibly_emulated_error:
+                if isinstance(
+                    possibly_emulated_error, exc.EmulatedDBAPIException
+                ):
+                    raise possibly_emulated_error.driver_exception
+                else:
+                    raise possibly_emulated_error
+
+        async def async_creator():
+            return await greenlet_spawn(create_and_break)
+
+        engine = async_testing_engine(options={"async_creator": async_creator})
+
+        with expect_raises(exc.DBAPIError):
+            await engine.connect()
+
 
 class AsyncCreatePoolTest(fixtures.TestBase):
     @config.fixture