]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve asyncpg exception hierarchy and asyncio hierarchies overall
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 20 Sep 2025 20:22:28 +0000 (16:22 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Sep 2025 13:49:19 +0000 (09:49 -0400)
The "emulated" exception hierarchies for the asyncio
drivers such as asyncpg, aiomysql, aioodbc, etc. have been standardized
on a common base :class:`.EmulatedDBAPIException`, which is now what's
available from the :attr:`.StatementException.orig` attribute on a
SQLAlchemy :class:`.DBAPIException` object.   Within :class:`.EmulatedDBAPIException`
and the subclasses in its hiearchy, the original driver-level exception is
also now avaliable via the :attr:`.EmulatedDBAPIException.orig` attribute,
and is also available from :class:`.DBAPIException` directly using the
:attr:`.DBAPIException.driver_exception` attribute.

Added additional emulated error classes for the subclasses of
``asyncpg.exception.IntegrityError`` including ``RestrictViolationError``,
``NotNullViolationError``, ``ForeignKeyViolationError``,
``UniqueViolationError`` ``CheckViolationError``,
``ExclusionViolationError``.  These exceptions are not directly thrown by
SQLAlchemy's asyncio emulation, however are available from the
newly added :attr:`.DBAPIException.driver_exception` attribute when a
:class:`.IntegrityError` is caught.

Fixes: #8047
Change-Id: I6a34e85b055265c087b0615f7c573be8582b3486

13 files changed:
doc/build/changelog/unreleased_21/8047.rst [new file with mode: 0644]
doc/build/core/internals.rst
lib/sqlalchemy/connectors/aioodbc.py
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/oracle/oracledb.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/sqlite/aiosqlite.py
lib/sqlalchemy/exc.py
test/base/test_except.py
test/engine/test_execute.py

diff --git a/doc/build/changelog/unreleased_21/8047.rst b/doc/build/changelog/unreleased_21/8047.rst
new file mode 100644 (file)
index 0000000..2a7b4e9
--- /dev/null
@@ -0,0 +1,28 @@
+.. change::
+    :tags: feature, asyncio
+    :tickets: 8047
+
+    The "emulated" exception hierarchies for the asyncio
+    drivers such as asyncpg, aiomysql, aioodbc, etc. have been standardized
+    on a common base :class:`.EmulatedDBAPIException`, which is now what's
+    available from the :attr:`.StatementException.orig` attribute on a
+    SQLAlchemy :class:`.DBAPIError` object.   Within :class:`.EmulatedDBAPIException`
+    and the subclasses in its hiearchy, the original driver-level exception is
+    also now avaliable via the :attr:`.EmulatedDBAPIException.orig` attribute,
+    and is also available from :class:`.DBAPIError` directly using the
+    :attr:`.DBAPIError.driver_exception` attribute.
+
+
+
+.. change::
+    :tags: feature, postgresql
+    :tickets: 8047
+
+    Added additional emulated error classes for the subclasses of
+    ``asyncpg.exception.IntegrityError`` including ``RestrictViolationError``,
+    ``NotNullViolationError``, ``ForeignKeyViolationError``,
+    ``UniqueViolationError`` ``CheckViolationError``,
+    ``ExclusionViolationError``.  These exceptions are not directly thrown by
+    SQLAlchemy's asyncio emulation, however are available from the
+    newly added :attr:`.DBAPIError.driver_exception` attribute when a
+    :class:`.IntegrityError` is caught.
index 5146ef4af43907fb9cf07d25946c00c5dc08435e..eeb2800fdc6a03ed21ad9b51f8e1a4383a9f42c2 100644 (file)
@@ -39,7 +39,6 @@ Some key internal constructs are listed here.
 .. autoclass:: sqlalchemy.engine.default.DefaultExecutionContext
     :members:
 
-
 .. autoclass:: sqlalchemy.engine.ExecutionContext
     :members:
 
index 57a16d720182ad9a85152a98db1273252e75e940..39f45dc26531a3794f380c2ddbda11e8eaaad427 100644 (file)
@@ -14,6 +14,7 @@ from .asyncio import AsyncAdapt_dbapi_connection
 from .asyncio import AsyncAdapt_dbapi_cursor
 from .asyncio import AsyncAdapt_dbapi_ss_cursor
 from .pyodbc import PyODBCConnector
+from ..connectors.asyncio import AsyncAdapt_dbapi_module
 from ..util.concurrency import await_
 
 if TYPE_CHECKING:
@@ -92,8 +93,9 @@ class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection):
             super().close()
 
 
-class AsyncAdapt_aioodbc_dbapi:
+class AsyncAdapt_aioodbc_dbapi(AsyncAdapt_dbapi_module):
     def __init__(self, aioodbc, pyodbc):
+        super().__init__(aioodbc, dbapi_module=pyodbc)
         self.aioodbc = aioodbc
         self.pyodbc = pyodbc
         self.paramstyle = pyodbc.paramstyle
index c29aa3f69dd2bc511f63f59698f4867d4acc8147..29ca0fc98fe44a9e4de38c7998f3a68facecaa57 100644 (file)
@@ -12,6 +12,7 @@ from __future__ import annotations
 import asyncio
 import collections
 import sys
+import types
 from typing import Any
 from typing import AsyncIterator
 from typing import Deque
@@ -25,6 +26,7 @@ from typing import Type
 from typing import TYPE_CHECKING
 
 from ..engine import AdaptedConnection
+from ..exc import EmulatedDBAPIException
 from ..util import EMPTY_DICT
 from ..util.concurrency import await_
 from ..util.concurrency import in_greenlet
@@ -123,6 +125,32 @@ class AsyncAdapt_dbapi_module:
 
         def __getattr__(self, key: str) -> Any: ...
 
+    def __init__(
+        self,
+        driver: types.ModuleType,
+        *,
+        dbapi_module: types.ModuleType | None = None,
+    ):
+        self.driver = driver
+        self.dbapi_module = dbapi_module
+
+    @property
+    def exceptions_module(self) -> types.ModuleType:
+        """Return the module which we think will have the exception hierarchy.
+
+        For an asyncio driver that wraps a plain DBAPI like aiomysql,
+        aioodbc, aiosqlite, etc. these exceptions will be from the
+        dbapi_module.  For a "pure" driver like asyncpg these will come
+        from the driver module.
+
+        .. versionadded:: 2.1
+
+        """
+        if self.dbapi_module is not None:
+            return self.dbapi_module
+        else:
+            return self.driver
+
 
 class AsyncAdapt_dbapi_cursor:
     server_side = False
@@ -416,3 +444,12 @@ class AsyncAdapt_terminate:
     def _terminate_force_close(self) -> None:
         """Terminate the connection"""
         raise NotImplementedError
+
+
+class AsyncAdapt_Error(EmulatedDBAPIException):
+    """Provide for the base of DBAPI ``Error`` base class for dialects
+    that need to emulate the DBAPI exception hierarchy.
+
+    .. versionadded:: 2.1
+
+    """
index 9c043850e452d65a7194f714521224aa4f2157e5..f630773318d615037d068c9529f50e2f185510e3 100644 (file)
@@ -112,6 +112,7 @@ class AsyncAdapt_aiomysql_connection(
 
 class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
     def __init__(self, aiomysql: ModuleType, pymysql: ModuleType):
+        super().__init__(aiomysql, dbapi_module=pymysql)
         self.aiomysql = aiomysql
         self.pymysql = pymysql
         self.paramstyle = "format"
index 22a60a099ab9e6097428b09a3983cd09d4f738c8..952ea171e78c208a6899d15e82f08c9c9339d7ac 100644 (file)
@@ -122,6 +122,7 @@ class AsyncAdapt_asyncmy_connection(
 
 class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
     def __init__(self, asyncmy: ModuleType):
+        super().__init__(asyncmy)
         self.asyncmy = asyncmy
         self.paramstyle = "format"
         self._init_dbapi_attributes()
index a35fa9255c4ef419f9c049742fc68e387b834455..7c4a56ff37bafbc6002d354106eceed9ce17cb77 100644 (file)
@@ -587,6 +587,7 @@ from . import cx_oracle as _cx_oracle
 from ... import exc
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
 from ...engine import default
 from ...util import await_
@@ -838,8 +839,9 @@ class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection):
         return await_(self._connection.tpc_rollback(*args, **kwargs))
 
 
-class OracledbAdaptDBAPI:
+class OracledbAdaptDBAPI(AsyncAdapt_dbapi_module):
     def __init__(self, oracledb) -> None:
+        super().__init__(oracledb)
         self.oracledb = oracledb
 
         for k, v in self.oracledb.__dict__.items():
index 15712578841c7f7d8e19166de27518cac89c7b91..51bc8b11bd3d9fcc86d69f571c6ab44f1a2eec32 100644 (file)
@@ -217,7 +217,9 @@ from ... import exc
 from ... import util
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...connectors.asyncio import AsyncAdapt_Error
 from ...connectors.asyncio import AsyncAdapt_terminate
 from ...engine import processors
 from ...sql import sqltypes
@@ -841,11 +843,9 @@ class AsyncAdapt_asyncpg_connection(
 
             for super_ in type(error).__mro__:
                 if super_ in exception_mapping:
+                    message = error.args[0]
                     translated_error = exception_mapping[super_](
-                        "%s: %s" % (type(error), error)
-                    )
-                    translated_error.pgcode = translated_error.sqlstate = (
-                        getattr(error, "sqlstate", None)
+                        message, error
                     )
                     raise translated_error from error
             else:
@@ -952,8 +952,9 @@ class AsyncAdapt_asyncpg_connection(
         return None
 
 
-class AsyncAdapt_asyncpg_dbapi:
+class AsyncAdapt_asyncpg_dbapi(AsyncAdapt_dbapi_module):
     def __init__(self, asyncpg):
+        super().__init__(asyncpg)
         self.asyncpg = asyncpg
         self.paramstyle = "numeric_dollar"
 
@@ -973,10 +974,20 @@ class AsyncAdapt_asyncpg_dbapi:
             prepared_statement_name_func=prepared_statement_name_func,
         )
 
-    class Error(Exception):
-        pass
+    class Error(AsyncAdapt_Error):
+
+        pgcode: str | None
+
+        sqlstate: str | None
+
+        detail: str | None
 
-    class Warning(Exception):  # noqa
+        def __init__(self, message, error=None):
+            super().__init__(message, error)
+            self.detail = getattr(error, "detail", None)
+            self.pgcode = self.sqlstate = getattr(error, "sqlstate", None)
+
+    class Warning(AsyncAdapt_Error):  # noqa
         pass
 
     class InterfaceError(Error):
@@ -997,6 +1008,24 @@ class AsyncAdapt_asyncpg_dbapi:
     class IntegrityError(DatabaseError):
         pass
 
+    class RestrictViolationError(IntegrityError):
+        pass
+
+    class NotNullViolationError(IntegrityError):
+        pass
+
+    class ForeignKeyViolationError(IntegrityError):
+        pass
+
+    class UniqueViolationError(IntegrityError):
+        pass
+
+    class CheckViolationError(IntegrityError):
+        pass
+
+    class ExclusionViolationError(IntegrityError):
+        pass
+
     class DataError(DatabaseError):
         pass
 
@@ -1007,7 +1036,7 @@ class AsyncAdapt_asyncpg_dbapi:
         pass
 
     class InvalidCachedStatementError(NotSupportedError):
-        def __init__(self, message):
+        def __init__(self, message, error=None):
             super().__init__(
                 message + " (SQLAlchemy asyncpg dialect will now invalidate "
                 "all prepared caches in response to this exception)",
@@ -1030,6 +1059,12 @@ class AsyncAdapt_asyncpg_dbapi:
             asyncpg.exceptions.InterfaceError: self.InterfaceError,
             asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError,  # noqa: E501
             asyncpg.exceptions.InternalServerError: self.InternalServerError,
+            asyncpg.exceptions.RestrictViolationError: self.RestrictViolationError,  # noqa: E501
+            asyncpg.exceptions.NotNullViolationError: self.NotNullViolationError,  # noqa: E501
+            asyncpg.exceptions.ForeignKeyViolationError: self.ForeignKeyViolationError,  # noqa: E501
+            asyncpg.exceptions.UniqueViolationError: self.UniqueViolationError,
+            asyncpg.exceptions.CheckViolationError: self.CheckViolationError,
+            asyncpg.exceptions.ExclusionViolationError: self.ExclusionViolationError,  # noqa: E501
         }
 
     def Binary(self, value):
index 966195752d0e9b213388d966d5e1e17904676b1e..9a877702064550f561cce0498f597bd9a4e169e5 100644 (file)
@@ -115,6 +115,7 @@ from .types import CITEXT
 from ... import util
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
+from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
 from ...sql import sqltypes
 from ...util.concurrency import await_
@@ -681,8 +682,9 @@ class AsyncAdapt_psycopg_connection(AsyncAdapt_dbapi_connection):
             return AsyncAdapt_psycopg_cursor(self)
 
 
-class PsycopgAdaptDBAPI:
+class PsycopgAdaptDBAPI(AsyncAdapt_dbapi_module):
     def __init__(self, psycopg, ExecStatus) -> None:
+        super().__init__(psycopg)
         self.psycopg = psycopg
         self.ExecStatus = ExecStatus
 
index cf8726c1f34b681e269992235151f1921b17e180..ad0cd89f60d4451f3ae5d120ea195cecf4105346 100644 (file)
@@ -189,6 +189,7 @@ class AsyncAdapt_aiosqlite_connection(AsyncAdapt_dbapi_connection):
 
 class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module):
     def __init__(self, aiosqlite: ModuleType, sqlite: ModuleType):
+        super().__init__(aiosqlite, dbapi_module=sqlite)
         self.aiosqlite = aiosqlite
         self.sqlite = sqlite
         self.paramstyle = "qmark"
index 6740d0b9af607dfbd38f0c486df5344094b81970..e2bf6d5fe8c9c396d2aa93dc9ae8b75bc4651d2f 100644 (file)
@@ -115,6 +115,44 @@ class SQLAlchemyError(HasDescriptionCode, Exception):
         return self._sql_message()
 
 
+class EmulatedDBAPIException(Exception):
+    """Serves as the base of the DBAPI ``Error`` class for dialects where
+    a DBAPI exception hierrchy needs to be emulated.
+
+    The current example is the asyncpg dialect.
+
+    .. versionadded:: 2.1
+
+    """
+
+    orig: Exception | None
+
+    def __init__(self, message: str, orig: Exception | None = None):
+        # we accept None for Exception since all DBAPI.Error objects
+        # need to support construction with a message alone
+        super().__init__(message)
+        self.orig = orig
+
+    @property
+    def driver_exception(self) -> Exception:
+        """The original driver exception that was raised.
+
+        This exception object will always originate from outside of
+        SQLAlchemy.
+
+        """
+
+        if self.orig is None:
+            raise ValueError(
+                "No original exception is present.  Was this "
+                "EmulatedDBAPIException constructed without a driver error?"
+            )
+        return self.orig
+
+    def __reduce__(self) -> Any:
+        return self.__class__, (self.args[0], self.orig)
+
+
 class ArgumentError(SQLAlchemyError):
     """Raised when an invalid or conflicting function argument is supplied.
 
@@ -463,6 +501,12 @@ class StatementError(SQLAlchemyError):
     orig: Optional[BaseException] = None
     """The original exception that was thrown.
 
+    .. seealso::
+
+        :attr:`.DBAPIError.driver_exception` - a more specific attribute that
+        is guaranteed to return the exception object raised by the third
+        party driver in use, even when using asyncio.
+
     """
 
     ismulti: Optional[bool] = None
@@ -555,6 +599,8 @@ class DBAPIError(StatementError):
 
     code = "dbapi"
 
+    orig: Optional[Exception]
+
     @overload
     @classmethod
     def instance(
@@ -712,6 +758,42 @@ class DBAPIError(StatementError):
         )
         self.connection_invalidated = connection_invalidated
 
+    @property
+    def driver_exception(self) -> Exception:
+        """The exception object originating from the driver (DBAPI) outside
+        of SQLAlchemy.
+
+        In the case of some asyncio dialects, special steps are taken to
+        resolve the exception to what the third party driver has raised, even
+        for SQLAlchemy dialects that include an "emulated" DBAPI exception
+        hierarchy.
+
+        For non-asyncio dialects, this attribute will be the same attribute
+        as the :attr:`.StatementError.orig` attribute.
+
+        For an asyncio dialect provided by SQLAlchemy, depending on if the
+        dialect provides an "emulated" exception hierarchy or if the underlying
+        DBAPI raises DBAPI-style exceptions, it will refer to either the
+        :attr:`.EmulatedDBAPIException.driver_exception` attribute on the
+        :class:`.EmulatedDBAPIException` that's thrown (such as when using
+        asyncpg), or to the actual exception object thrown by the
+        third party driver.
+
+        .. versionadded:: 2.1
+
+        """
+
+        if self.orig is None:
+            raise ValueError(
+                "No original exception is present.  Was this "
+                "DBAPIError constructed without a driver error?"
+            )
+
+        if isinstance(self.orig, EmulatedDBAPIException):
+            return self.orig.driver_exception
+        else:
+            return self.orig
+
 
 class InterfaceError(DBAPIError):
     """Wraps a DB-API InterfaceError."""
index 9353d28ab9fb43714f4bbddf3924818dcb14e5bc..2a45cdcc95a800b037de5f61a3e255bc3afba607 100644 (file)
@@ -419,12 +419,24 @@ def details(cls):
     return inst
 
 
+class EqException(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __eq__(self, other):
+        return isinstance(other, EqException) and other.msg == self.msg
+
+
 ALL_EXC = [
     (
         [sa_exceptions.SQLAlchemyError],
         [lambda cls: cls(1, 2, code="42")],
     ),
     ([sa_exceptions.ObjectNotExecutableError], [lambda cls: cls("xx")]),
+    (
+        [sa_exceptions.EmulatedDBAPIException],
+        [lambda cls: cls("xx", EqException("original"))],
+    ),
     (
         [
             sa_exceptions.ArgumentError,
index ab1491fd69b0fe45187ec2286298794fbd2c5f4a..01bb1fe94b89b18a917cdb38194d0224e9eba237 100644 (file)
@@ -30,6 +30,7 @@ from sqlalchemy import text
 from sqlalchemy import TypeDecorator
 from sqlalchemy import util
 from sqlalchemy import VARCHAR
+from sqlalchemy.connectors.asyncio import AsyncAdapt_dbapi_module
 from sqlalchemy.engine import BindTyping
 from sqlalchemy.engine import default
 from sqlalchemy.engine.base import Connection
@@ -51,6 +52,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
+from sqlalchemy.testing import ne_
 from sqlalchemy.testing.assertions import expect_deprecated
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.provision import normalize_sequence
@@ -385,7 +387,6 @@ class ExecuteTest(fixtures.TablesTest):
 
     def test_exception_wrapping_dbapi(self):
         with testing.db.connect() as conn:
-            # engine does not have exec_driver_sql
             assert_raises_message(
                 tsa.exc.DBAPIError,
                 r"not_a_valid_statement",
@@ -393,6 +394,34 @@ class ExecuteTest(fixtures.TablesTest):
                 "not_a_valid_statement",
             )
 
+    def test_exception_wrapping_orig_accessors(self):
+        de = None
+
+        with testing.db.connect() as conn:
+            try:
+                conn.exec_driver_sql("not_a_valid_statement")
+            except tsa.exc.DBAPIError as de_caught:
+                de = de_caught
+
+        assert isinstance(de.orig, conn.dialect.dbapi.Error)
+
+        # get the driver module name, the one which we know will provide
+        # for exceptions
+        top_level_dbapi_module = conn.dialect.dbapi
+        if isinstance(top_level_dbapi_module, AsyncAdapt_dbapi_module):
+            driver_module = top_level_dbapi_module.exceptions_module
+        else:
+            driver_module = top_level_dbapi_module
+        top_level_dbapi_module = driver_module.__name__.split(".")[0]
+
+        # check that it's not us
+        ne_(top_level_dbapi_module, "sqlalchemy")
+
+        # then make sure driver_exception is from that module
+        assert type(de.driver_exception).__module__.startswith(
+            top_level_dbapi_module
+        )
+
     @testing.requires.sqlite
     def test_exception_wrapping_non_dbapi_error(self):
         e = create_engine("sqlite://")