From: Mike Bayer Date: Sat, 20 Sep 2025 20:22:28 +0000 (-0400) Subject: Improve asyncpg exception hierarchy and asyncio hierarchies overall X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=89a4174a8de0815605874a61a38639658aeb7eab;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve asyncpg exception hierarchy and asyncio hierarchies overall The "emulated" exception hierarchies for the asyncio drivers such as asyncpg, aiomysql, aioodbc, etc. have been standardized on a common base :class:`.EmulatedDBAPIException`, which is now what's available from the :attr:`.StatementException.orig` attribute on a SQLAlchemy :class:`.DBAPIException` object. Within :class:`.EmulatedDBAPIException` and the subclasses in its hiearchy, the original driver-level exception is also now avaliable via the :attr:`.EmulatedDBAPIException.orig` attribute, and is also available from :class:`.DBAPIException` directly using the :attr:`.DBAPIException.driver_exception` attribute. Added additional emulated error classes for the subclasses of ``asyncpg.exception.IntegrityError`` including ``RestrictViolationError``, ``NotNullViolationError``, ``ForeignKeyViolationError``, ``UniqueViolationError`` ``CheckViolationError``, ``ExclusionViolationError``. These exceptions are not directly thrown by SQLAlchemy's asyncio emulation, however are available from the newly added :attr:`.DBAPIException.driver_exception` attribute when a :class:`.IntegrityError` is caught. Fixes: #8047 Change-Id: I6a34e85b055265c087b0615f7c573be8582b3486 --- diff --git a/doc/build/changelog/unreleased_21/8047.rst b/doc/build/changelog/unreleased_21/8047.rst new file mode 100644 index 0000000000..2a7b4e9dc0 --- /dev/null +++ b/doc/build/changelog/unreleased_21/8047.rst @@ -0,0 +1,28 @@ +.. change:: + :tags: feature, asyncio + :tickets: 8047 + + The "emulated" exception hierarchies for the asyncio + drivers such as asyncpg, aiomysql, aioodbc, etc. have been standardized + on a common base :class:`.EmulatedDBAPIException`, which is now what's + available from the :attr:`.StatementException.orig` attribute on a + SQLAlchemy :class:`.DBAPIError` object. Within :class:`.EmulatedDBAPIException` + and the subclasses in its hiearchy, the original driver-level exception is + also now avaliable via the :attr:`.EmulatedDBAPIException.orig` attribute, + and is also available from :class:`.DBAPIError` directly using the + :attr:`.DBAPIError.driver_exception` attribute. + + + +.. change:: + :tags: feature, postgresql + :tickets: 8047 + + Added additional emulated error classes for the subclasses of + ``asyncpg.exception.IntegrityError`` including ``RestrictViolationError``, + ``NotNullViolationError``, ``ForeignKeyViolationError``, + ``UniqueViolationError`` ``CheckViolationError``, + ``ExclusionViolationError``. These exceptions are not directly thrown by + SQLAlchemy's asyncio emulation, however are available from the + newly added :attr:`.DBAPIError.driver_exception` attribute when a + :class:`.IntegrityError` is caught. diff --git a/doc/build/core/internals.rst b/doc/build/core/internals.rst index 5146ef4af4..eeb2800fdc 100644 --- a/doc/build/core/internals.rst +++ b/doc/build/core/internals.rst @@ -39,7 +39,6 @@ Some key internal constructs are listed here. .. autoclass:: sqlalchemy.engine.default.DefaultExecutionContext :members: - .. autoclass:: sqlalchemy.engine.ExecutionContext :members: diff --git a/lib/sqlalchemy/connectors/aioodbc.py b/lib/sqlalchemy/connectors/aioodbc.py index 57a16d7201..39f45dc265 100644 --- a/lib/sqlalchemy/connectors/aioodbc.py +++ b/lib/sqlalchemy/connectors/aioodbc.py @@ -14,6 +14,7 @@ from .asyncio import AsyncAdapt_dbapi_connection from .asyncio import AsyncAdapt_dbapi_cursor from .asyncio import AsyncAdapt_dbapi_ss_cursor from .pyodbc import PyODBCConnector +from ..connectors.asyncio import AsyncAdapt_dbapi_module from ..util.concurrency import await_ if TYPE_CHECKING: @@ -92,8 +93,9 @@ class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection): super().close() -class AsyncAdapt_aioodbc_dbapi: +class AsyncAdapt_aioodbc_dbapi(AsyncAdapt_dbapi_module): def __init__(self, aioodbc, pyodbc): + super().__init__(aioodbc, dbapi_module=pyodbc) self.aioodbc = aioodbc self.pyodbc = pyodbc self.paramstyle = pyodbc.paramstyle diff --git a/lib/sqlalchemy/connectors/asyncio.py b/lib/sqlalchemy/connectors/asyncio.py index c29aa3f69d..29ca0fc98f 100644 --- a/lib/sqlalchemy/connectors/asyncio.py +++ b/lib/sqlalchemy/connectors/asyncio.py @@ -12,6 +12,7 @@ from __future__ import annotations import asyncio import collections import sys +import types from typing import Any from typing import AsyncIterator from typing import Deque @@ -25,6 +26,7 @@ from typing import Type from typing import TYPE_CHECKING from ..engine import AdaptedConnection +from ..exc import EmulatedDBAPIException from ..util import EMPTY_DICT from ..util.concurrency import await_ from ..util.concurrency import in_greenlet @@ -123,6 +125,32 @@ class AsyncAdapt_dbapi_module: def __getattr__(self, key: str) -> Any: ... + def __init__( + self, + driver: types.ModuleType, + *, + dbapi_module: types.ModuleType | None = None, + ): + self.driver = driver + self.dbapi_module = dbapi_module + + @property + def exceptions_module(self) -> types.ModuleType: + """Return the module which we think will have the exception hierarchy. + + For an asyncio driver that wraps a plain DBAPI like aiomysql, + aioodbc, aiosqlite, etc. these exceptions will be from the + dbapi_module. For a "pure" driver like asyncpg these will come + from the driver module. + + .. versionadded:: 2.1 + + """ + if self.dbapi_module is not None: + return self.dbapi_module + else: + return self.driver + class AsyncAdapt_dbapi_cursor: server_side = False @@ -416,3 +444,12 @@ class AsyncAdapt_terminate: def _terminate_force_close(self) -> None: """Terminate the connection""" raise NotImplementedError + + +class AsyncAdapt_Error(EmulatedDBAPIException): + """Provide for the base of DBAPI ``Error`` base class for dialects + that need to emulate the DBAPI exception hierarchy. + + .. versionadded:: 2.1 + + """ diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index 9c043850e4..f630773318 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -112,6 +112,7 @@ class AsyncAdapt_aiomysql_connection( class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module): def __init__(self, aiomysql: ModuleType, pymysql: ModuleType): + super().__init__(aiomysql, dbapi_module=pymysql) self.aiomysql = aiomysql self.pymysql = pymysql self.paramstyle = "format" diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index 22a60a099a..952ea171e7 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -122,6 +122,7 @@ class AsyncAdapt_asyncmy_connection( class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module): def __init__(self, asyncmy: ModuleType): + super().__init__(asyncmy) self.asyncmy = asyncmy self.paramstyle = "format" self._init_dbapi_attributes() diff --git a/lib/sqlalchemy/dialects/oracle/oracledb.py b/lib/sqlalchemy/dialects/oracle/oracledb.py index a35fa9255c..7c4a56ff37 100644 --- a/lib/sqlalchemy/dialects/oracle/oracledb.py +++ b/lib/sqlalchemy/dialects/oracle/oracledb.py @@ -587,6 +587,7 @@ from . import cx_oracle as _cx_oracle from ... import exc from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_module from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor from ...engine import default from ...util import await_ @@ -838,8 +839,9 @@ class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection): return await_(self._connection.tpc_rollback(*args, **kwargs)) -class OracledbAdaptDBAPI: +class OracledbAdaptDBAPI(AsyncAdapt_dbapi_module): def __init__(self, oracledb) -> None: + super().__init__(oracledb) self.oracledb = oracledb for k, v in self.oracledb.__dict__.items(): diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 1571257884..51bc8b11bd 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -217,7 +217,9 @@ from ... import exc from ... import util from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_module from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...connectors.asyncio import AsyncAdapt_Error from ...connectors.asyncio import AsyncAdapt_terminate from ...engine import processors from ...sql import sqltypes @@ -841,11 +843,9 @@ class AsyncAdapt_asyncpg_connection( for super_ in type(error).__mro__: if super_ in exception_mapping: + message = error.args[0] translated_error = exception_mapping[super_]( - "%s: %s" % (type(error), error) - ) - translated_error.pgcode = translated_error.sqlstate = ( - getattr(error, "sqlstate", None) + message, error ) raise translated_error from error else: @@ -952,8 +952,9 @@ class AsyncAdapt_asyncpg_connection( return None -class AsyncAdapt_asyncpg_dbapi: +class AsyncAdapt_asyncpg_dbapi(AsyncAdapt_dbapi_module): def __init__(self, asyncpg): + super().__init__(asyncpg) self.asyncpg = asyncpg self.paramstyle = "numeric_dollar" @@ -973,10 +974,20 @@ class AsyncAdapt_asyncpg_dbapi: prepared_statement_name_func=prepared_statement_name_func, ) - class Error(Exception): - pass + class Error(AsyncAdapt_Error): + + pgcode: str | None + + sqlstate: str | None + + detail: str | None - class Warning(Exception): # noqa + def __init__(self, message, error=None): + super().__init__(message, error) + self.detail = getattr(error, "detail", None) + self.pgcode = self.sqlstate = getattr(error, "sqlstate", None) + + class Warning(AsyncAdapt_Error): # noqa pass class InterfaceError(Error): @@ -997,6 +1008,24 @@ class AsyncAdapt_asyncpg_dbapi: class IntegrityError(DatabaseError): pass + class RestrictViolationError(IntegrityError): + pass + + class NotNullViolationError(IntegrityError): + pass + + class ForeignKeyViolationError(IntegrityError): + pass + + class UniqueViolationError(IntegrityError): + pass + + class CheckViolationError(IntegrityError): + pass + + class ExclusionViolationError(IntegrityError): + pass + class DataError(DatabaseError): pass @@ -1007,7 +1036,7 @@ class AsyncAdapt_asyncpg_dbapi: pass class InvalidCachedStatementError(NotSupportedError): - def __init__(self, message): + def __init__(self, message, error=None): super().__init__( message + " (SQLAlchemy asyncpg dialect will now invalidate " "all prepared caches in response to this exception)", @@ -1030,6 +1059,12 @@ class AsyncAdapt_asyncpg_dbapi: asyncpg.exceptions.InterfaceError: self.InterfaceError, asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501 asyncpg.exceptions.InternalServerError: self.InternalServerError, + asyncpg.exceptions.RestrictViolationError: self.RestrictViolationError, # noqa: E501 + asyncpg.exceptions.NotNullViolationError: self.NotNullViolationError, # noqa: E501 + asyncpg.exceptions.ForeignKeyViolationError: self.ForeignKeyViolationError, # noqa: E501 + asyncpg.exceptions.UniqueViolationError: self.UniqueViolationError, + asyncpg.exceptions.CheckViolationError: self.CheckViolationError, + asyncpg.exceptions.ExclusionViolationError: self.ExclusionViolationError, # noqa: E501 } def Binary(self, value): diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 966195752d..9a87770206 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -115,6 +115,7 @@ from .types import CITEXT from ... import util from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdapt_dbapi_module from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor from ...sql import sqltypes from ...util.concurrency import await_ @@ -681,8 +682,9 @@ class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection): return AsyncAdapt_psycopg_cursor(self) -class PsycopgAdaptDBAPI: +class PsycopgAdaptDBAPI(AsyncAdapt_dbapi_module): def __init__(self, psycopg, ExecStatus) -> None: + super().__init__(psycopg) self.psycopg = psycopg self.ExecStatus = ExecStatus diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index cf8726c1f3..ad0cd89f60 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -189,6 +189,7 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection): class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module): def __init__(self, aiosqlite: ModuleType, sqlite: ModuleType): + super().__init__(aiosqlite, dbapi_module=sqlite) self.aiosqlite = aiosqlite self.sqlite = sqlite self.paramstyle = "qmark" diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 6740d0b9af..e2bf6d5fe8 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -115,6 +115,44 @@ class SQLAlchemyError(HasDescriptionCode, Exception): return self._sql_message() +class EmulatedDBAPIException(Exception): + """Serves as the base of the DBAPI ``Error`` class for dialects where + a DBAPI exception hierrchy needs to be emulated. + + The current example is the asyncpg dialect. + + .. versionadded:: 2.1 + + """ + + orig: Exception | None + + def __init__(self, message: str, orig: Exception | None = None): + # we accept None for Exception since all DBAPI.Error objects + # need to support construction with a message alone + super().__init__(message) + self.orig = orig + + @property + def driver_exception(self) -> Exception: + """The original driver exception that was raised. + + This exception object will always originate from outside of + SQLAlchemy. + + """ + + if self.orig is None: + raise ValueError( + "No original exception is present. Was this " + "EmulatedDBAPIException constructed without a driver error?" + ) + return self.orig + + def __reduce__(self) -> Any: + return self.__class__, (self.args[0], self.orig) + + class ArgumentError(SQLAlchemyError): """Raised when an invalid or conflicting function argument is supplied. @@ -463,6 +501,12 @@ class StatementError(SQLAlchemyError): orig: Optional[BaseException] = None """The original exception that was thrown. + .. seealso:: + + :attr:`.DBAPIError.driver_exception` - a more specific attribute that + is guaranteed to return the exception object raised by the third + party driver in use, even when using asyncio. + """ ismulti: Optional[bool] = None @@ -555,6 +599,8 @@ class DBAPIError(StatementError): code = "dbapi" + orig: Optional[Exception] + @overload @classmethod def instance( @@ -712,6 +758,42 @@ class DBAPIError(StatementError): ) self.connection_invalidated = connection_invalidated + @property + def driver_exception(self) -> Exception: + """The exception object originating from the driver (DBAPI) outside + of SQLAlchemy. + + In the case of some asyncio dialects, special steps are taken to + resolve the exception to what the third party driver has raised, even + for SQLAlchemy dialects that include an "emulated" DBAPI exception + hierarchy. + + For non-asyncio dialects, this attribute will be the same attribute + as the :attr:`.StatementError.orig` attribute. + + For an asyncio dialect provided by SQLAlchemy, depending on if the + dialect provides an "emulated" exception hierarchy or if the underlying + DBAPI raises DBAPI-style exceptions, it will refer to either the + :attr:`.EmulatedDBAPIException.driver_exception` attribute on the + :class:`.EmulatedDBAPIException` that's thrown (such as when using + asyncpg), or to the actual exception object thrown by the + third party driver. + + .. versionadded:: 2.1 + + """ + + if self.orig is None: + raise ValueError( + "No original exception is present. Was this " + "DBAPIError constructed without a driver error?" + ) + + if isinstance(self.orig, EmulatedDBAPIException): + return self.orig.driver_exception + else: + return self.orig + class InterfaceError(DBAPIError): """Wraps a DB-API InterfaceError.""" diff --git a/test/base/test_except.py b/test/base/test_except.py index 9353d28ab9..2a45cdcc95 100644 --- a/test/base/test_except.py +++ b/test/base/test_except.py @@ -419,12 +419,24 @@ def details(cls): return inst +class EqException(Exception): + def __init__(self, msg): + self.msg = msg + + def __eq__(self, other): + return isinstance(other, EqException) and other.msg == self.msg + + ALL_EXC = [ ( [sa_exceptions.SQLAlchemyError], [lambda cls: cls(1, 2, code="42")], ), ([sa_exceptions.ObjectNotExecutableError], [lambda cls: cls("xx")]), + ( + [sa_exceptions.EmulatedDBAPIException], + [lambda cls: cls("xx", EqException("original"))], + ), ( [ sa_exceptions.ArgumentError, diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index ab1491fd69..01bb1fe94b 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -30,6 +30,7 @@ from sqlalchemy import text from sqlalchemy import TypeDecorator from sqlalchemy import util from sqlalchemy import VARCHAR +from sqlalchemy.connectors.asyncio import AsyncAdapt_dbapi_module from sqlalchemy.engine import BindTyping from sqlalchemy.engine import default from sqlalchemy.engine.base import Connection @@ -51,6 +52,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true +from sqlalchemy.testing import ne_ from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.provision import normalize_sequence @@ -385,7 +387,6 @@ class ExecuteTest(fixtures.TablesTest): def test_exception_wrapping_dbapi(self): with testing.db.connect() as conn: - # engine does not have exec_driver_sql assert_raises_message( tsa.exc.DBAPIError, r"not_a_valid_statement", @@ -393,6 +394,34 @@ class ExecuteTest(fixtures.TablesTest): "not_a_valid_statement", ) + def test_exception_wrapping_orig_accessors(self): + de = None + + with testing.db.connect() as conn: + try: + conn.exec_driver_sql("not_a_valid_statement") + except tsa.exc.DBAPIError as de_caught: + de = de_caught + + assert isinstance(de.orig, conn.dialect.dbapi.Error) + + # get the driver module name, the one which we know will provide + # for exceptions + top_level_dbapi_module = conn.dialect.dbapi + if isinstance(top_level_dbapi_module, AsyncAdapt_dbapi_module): + driver_module = top_level_dbapi_module.exceptions_module + else: + driver_module = top_level_dbapi_module + top_level_dbapi_module = driver_module.__name__.split(".")[0] + + # check that it's not us + ne_(top_level_dbapi_module, "sqlalchemy") + + # then make sure driver_exception is from that module + assert type(de.driver_exception).__module__.startswith( + top_level_dbapi_module + ) + @testing.requires.sqlite def test_exception_wrapping_non_dbapi_error(self): e = create_engine("sqlite://")