--- /dev/null
+.. change::
+ :tags: bug, asyncio
+ :tickets: 11956
+
+ Refactored all asyncio dialects so that exceptions which occur on failed
+ connection attempts are appropriately wrapped with SQLAlchemy exception
+ objects, allowing for consistent error handling.
def connect(self, *arg, **kw):
creator_fn = kw.pop("async_creator_fn", self.aioodbc.connect)
- return AsyncAdapt_aioodbc_connection(
- self,
- await_(creator_fn(*arg, **kw)),
+ return await_(
+ AsyncAdapt_aioodbc_connection.create(
+ self,
+ creator_fn(*arg, **kw),
+ )
)
import types
from typing import Any
from typing import AsyncIterator
+from typing import Awaitable
from typing import Deque
from typing import Iterator
from typing import NoReturn
_connection: AsyncIODBAPIConnection
+ @classmethod
+ async def create(
+ cls,
+ dbapi: Any,
+ connection_awaitable: Awaitable[AsyncIODBAPIConnection],
+ **kw: Any,
+ ) -> Self:
+ try:
+ connection = await connection_awaitable
+ except Exception as error:
+ cls._handle_exception_no_connection(dbapi, error)
+ else:
+ return cls(dbapi, connection, **kw)
+
def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection):
self.dbapi = dbapi
self._connection = connection
cursor.execute(operation, parameters)
return cursor
- def _handle_exception(self, error: Exception) -> NoReturn:
+ @classmethod
+ def _handle_exception_no_connection(
+ cls, dbapi: Any, error: Exception
+ ) -> NoReturn:
exc_info = sys.exc_info()
raise error.with_traceback(exc_info[2])
+ def _handle_exception(self, error: Exception) -> NoReturn:
+ self._handle_exception_no_connection(self.dbapi, error)
+
def rollback(self) -> None:
try:
await_(self._connection.rollback())
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection:
creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
- return AsyncAdapt_aiomysql_connection(
- self,
- await_(creator_fn(*arg, **kw)),
+ return await_(
+ AsyncAdapt_aiomysql_connection.create(
+ self,
+ creator_fn(*arg, **kw),
+ )
)
def _init_cursors_subclasses(
_cursor_cls = AsyncAdapt_asyncmy_cursor
_ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor
- def _handle_exception(self, error: Exception) -> NoReturn:
+ @classmethod
+ def _handle_exception_no_connection(
+ cls, dbapi: Any, error: Exception
+ ) -> NoReturn:
if isinstance(error, AttributeError):
- raise self.dbapi.InternalError(
+ raise dbapi.InternalError(
"network operation failed due to asyncmy attribute error"
)
def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_asyncmy_connection:
creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
- return AsyncAdapt_asyncmy_connection(
- self,
- await_(creator_fn(*arg, **kw)),
+ return await_(
+ AsyncAdapt_asyncmy_connection.create(
+ self,
+ creator_fn(*arg, **kw),
+ )
)
def connect(self, *arg, **kw):
creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async)
- return AsyncAdapt_oracledb_connection(
- self, await_(creator_fn(*arg, **kw))
+ return await_(
+ AsyncAdapt_oracledb_connection.create(self, creator_fn(*arg, **kw))
)
return prepared_stmt, attributes
- def _handle_exception(self, error: Exception) -> NoReturn:
- if self._connection.is_closed():
- self._transaction = None
-
+ @classmethod
+ def _handle_exception_no_connection(
+ cls, dbapi: Any, error: Exception
+ ) -> NoReturn:
if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
- exception_mapping = self.dbapi._asyncpg_error_translate
+ exception_mapping = dbapi._asyncpg_error_translate
for super_ in type(error).__mro__:
if super_ in exception_mapping:
message, error
)
raise translated_error from error
- else:
- super()._handle_exception(error)
- else:
- super()._handle_exception(error)
+ super()._handle_exception_no_connection(dbapi, error)
+
+ def _handle_exception(self, error: Exception) -> NoReturn:
+ if self._connection.is_closed():
+ self._transaction = None
+
+ super()._handle_exception(error)
@property
def autocommit(self):
"prepared_statement_name_func", None
)
- return AsyncAdapt_asyncpg_connection(
- self,
- await_(creator_fn(*arg, **kw)),
- prepared_statement_cache_size=prepared_statement_cache_size,
- prepared_statement_name_func=prepared_statement_name_func,
+ return await_(
+ AsyncAdapt_asyncpg_connection.create(
+ self,
+ creator_fn(*arg, **kw),
+ prepared_statement_cache_size=prepared_statement_cache_size,
+ prepared_statement_name_func=prepared_statement_name_func,
+ )
)
class Error(AsyncAdapt_Error):
creator_fn = kw.pop(
"async_creator_fn", self.psycopg.AsyncConnection.connect
)
- return AsyncAdapt_psycopg_connection(
- self, await_(creator_fn(*arg, **kw))
+ return await_(
+ AsyncAdapt_psycopg_connection.create(self, creator_fn(*arg, **kw))
)
except Exception as error:
self._handle_exception(error)
- def _handle_exception(self, error: Exception) -> NoReturn:
+ @classmethod
+ def _handle_exception_no_connection(
+ cls, dbapi: Any, error: Exception
+ ) -> NoReturn:
if isinstance(error, ValueError) and error.args[0].lower() in (
"no active connection",
"connection closed",
):
- raise self.dbapi.sqlite.OperationalError(error.args[0]) from error
+ raise dbapi.sqlite.OperationalError(error.args[0]) from error
else:
- super()._handle_exception(error)
+ super()._handle_exception_no_connection(dbapi, error)
class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
# because the cursor should be closed
await driver_cursor.execute(select_one_sql)
+ @async_test
+ async def test_async_creator_handle_error(self, async_testing_engine):
+ """test for #11956"""
+
+ existing_creator = testing.db.pool._creator
+
+ def create_and_break():
+ sync_conn = existing_creator()
+ cursor = sync_conn.cursor()
+
+ # figure out a way to get a native driver exception. This really
+ # only applies to asyncpg where we rewrite the exception
+ # hierarchy with our own emulated exception; other backends raise
+ # standard DBAPI exceptions (with some buggy cases here and there
+ # which they miss) even though they are async
+ try:
+ cursor.execute("this will raise an error")
+ except Exception as possibly_emulated_error:
+ if isinstance(
+ possibly_emulated_error, exc.EmulatedDBAPIException
+ ):
+ raise possibly_emulated_error.driver_exception
+ else:
+ raise possibly_emulated_error
+
+ async def async_creator():
+ return await greenlet_spawn(create_and_break)
+
+ engine = async_testing_engine(options={"async_creator": async_creator})
+
+ with expect_raises(exc.DBAPIError):
+ await engine.connect()
+
class AsyncCreatePoolTest(fixtures.TestBase):
@config.fixture