]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve support for enum in mapped classes
authorFederico Caselli <cfederico87@gmail.com>
Sun, 27 Nov 2022 17:11:34 +0000 (18:11 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Nov 2022 22:49:27 +0000 (17:49 -0500)
Add a new system by which TypeEngine objects have some
say in how the declarative type registry interprets them.
The Enum datatype is the primary target for this but it is
hoped the system may be useful for other types as well.

Fixes: #8859
Change-Id: I15ac3daee770408b5795746f47c1bbd931b7d26d

doc/build/changelog/unreleased_20/8859.rst [new file with mode: 0644]
doc/build/orm/declarative_tables.rst
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/8859.rst b/doc/build/changelog/unreleased_20/8859.rst
new file mode 100644 (file)
index 0000000..85e4be4
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 8859
+
+    Added support custom user-defined types which extend the Python
+    ``enum.Enum`` base class to be resolved automatically
+    to SQLAlchemy :class:`.Enum` SQL types, when using the Annotated
+    Declarative Table feature.  The feature is made possible through new
+    lookup features added to the ORM type map feature, and includes support
+    for changing the arguments of the :class:`.Enum` that's generated by
+    default as well as setting up specific ``enum.Enum`` types within
+    the map with specific arguments.
+
+    .. seealso::
+
+        :ref:`orm_declarative_mapped_column_enums`
index 475813f819fee58e983cc1e2dea7fb2c2af9c13e..806a6897f2417ca21f35cb7dca9d44e25de5a9f3 100644 (file)
@@ -369,6 +369,106 @@ while still being able to use succinct annotation-only :func:`_orm.mapped_column
 configurations.  There are two more levels of Python-type configurability
 available beyond this, described in the next two sections.
 
+.. _orm_declarative_mapped_column_enums:
+
+Using Python ``Enum`` types in the type map
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. versionadded:: 2.0.0b4
+
+User-defined Python types which derive from the Python built-in ``enum.Enum``
+class are automatically linked to the SQLAlchemy :class:`.Enum` datatype
+when used in an ORM declarative mapping::
+
+    import enum
+
+    from sqlalchemy.orm import DeclarativeBase
+    from sqlalchemy.orm import Mapped
+    from sqlalchemy.orm import mapped_column
+
+
+    class Base(DeclarativeBase):
+        pass
+
+
+    class Status(enum.Enum):
+        PENDING = "pending"
+        RECEIVED = "received"
+        COMPLETED = "completed"
+
+
+    class SomeClass(Base):
+        __tablename__ = "some_table"
+
+        id: Mapped[int] = mapped_column(primary_key=True)
+        status: Mapped[Status]
+
+In the above example, the mapped attribute ``SomeClass.status`` will be
+linked to a :class:`.Column` with the datatype of ``Enum(Status)``.
+We can see this for example in the CREATE TABLE output for the PostgreSQL
+database:
+
+.. sourcecode:: sql
+
+  CREATE TYPE status AS ENUM ('PENDING', 'RECEIVED', 'COMPLETED')
+
+  CREATE TABLE some_table (
+    id SERIAL NOT NULL,
+    status status NOT NULL,
+    PRIMARY KEY (id)
+  )
+
+The entry used in :paramref:`_orm.registry.type_annotation_map` links the
+base ``enum.Enum`` Python type to the SQLAlchemy :class:`.Enum` SQL
+type, using a special form which indicates to the :class:`.Enum` datatype
+that it should automatically configure itself against an arbitrary enumerated
+type.   This configuration, which is implicit by default, would be indicated
+explicitly as::
+
+    import enum
+    import sqlalchemy
+
+
+    class Base(DeclarativeBase):
+        type_annotation_map = {enum.Enum: sqlalchemy.Enum(enum.Enum)}
+
+The resolution logic within Declarative is able to resolve subclasses
+of ``enum.Enum``, in the above example the custom ``Status`` enumeration,
+to match the ``enum.Enum`` entry in the
+:paramref:`_orm.registry.type_annotation_map` dictionary.  The :class:`.Enum`
+SQL type then knows how to produce a configured version of itself with the
+appropriate settings, including default string length.
+
+In order to modify the configuration of the :class:`.enum.Enum` datatype used
+in this mapping, use the above form, indicating additional arguments. For
+example, to use "non native enumerations" on all backends, the
+:paramref:`.Enum.native_enum` parameter may be set to False for all types::
+
+    import enum
+    import sqlalchemy
+
+
+    class Base(DeclarativeBase):
+        type_annotation_map = {enum.Enum: sqlalchemy.Enum(enum.Enum, native_enum=False)}
+
+To use a specific configuration for a specific ``enum.Enum`` subtype, such
+as setting the string length to 50 when using the example ``Status``
+datatype::
+
+    import enum
+    import sqlalchemy
+
+
+    class Status(enum.Enum):
+        PENDING = "pending"
+        RECEIVED = "received"
+        COMPLETED = "completed"
+
+
+    class Base(DeclarativeBase):
+        type_annotation_map = {
+            Status: sqlalchemy.Enum(Status, length=50, native_enum=False)
+        }
 
 .. _orm_declarative_mapped_column_type_map_pep593:
 
index 09397eb653fca393c3743cd1d6968f1c91153e37..de6c8794b124f4b5d17e837ee77a72e54fd42675 100644 (file)
@@ -14,6 +14,7 @@ import re
 import typing
 from typing import Any
 from typing import Callable
+from typing import cast
 from typing import ClassVar
 from typing import Dict
 from typing import FrozenSet
@@ -23,6 +24,7 @@ from typing import Mapping
 from typing import Optional
 from typing import overload
 from typing import Set
+from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
@@ -62,6 +64,7 @@ from .state import InstanceState
 from .. import exc
 from .. import inspection
 from .. import util
+from ..sql import sqltypes
 from ..sql.base import _NoArg
 from ..sql.elements import SQLCoreOperations
 from ..sql.schema import MetaData
@@ -70,6 +73,7 @@ from ..util import hybridmethod
 from ..util import hybridproperty
 from ..util import typing as compat_typing
 from ..util.typing import CallableReference
+from ..util.typing import is_generic
 from ..util.typing import Literal
 
 if TYPE_CHECKING:
@@ -80,6 +84,7 @@ if TYPE_CHECKING:
     from .interfaces import MapperProperty
     from .state import InstanceState  # noqa
     from ..sql._typing import _TypeEngineArgument
+    from ..util.typing import GenericProtocol
 
 _T = TypeVar("_T", bound=Any)
 
@@ -1018,6 +1023,37 @@ class registry:
             }
         )
 
+    def _resolve_type(
+        self, python_type: Union[GenericProtocol[Any], Type[Any]]
+    ) -> Optional[sqltypes.TypeEngine[Any]]:
+
+        search: Tuple[Union[GenericProtocol[Any], Type[Any]], ...]
+
+        if is_generic(python_type):
+            python_type_type: Type[Any] = python_type.__origin__
+            search = (python_type,)
+        else:
+            # don't know why is_generic() TypeGuard[GenericProtocol[Any]]
+            # check above is not sufficient here
+            python_type_type = cast("Type[Any]", python_type)
+            search = python_type_type.__mro__
+
+        for pt in search:
+            sql_type = self.type_annotation_map.get(pt)
+            if sql_type is None:
+                sql_type = sqltypes._type_map_get(pt)  # type: ignore  # noqa: E501
+
+            if sql_type is not None:
+                sql_type_inst = sqltypes.to_instance(sql_type)  # type: ignore
+
+                resolved_sql_type = sql_type_inst._resolve_for_python_type(
+                    python_type_type, pt
+                )
+                if resolved_sql_type is not None:
+                    return resolved_sql_type
+
+        return None
+
     @property
     def mappers(self) -> FrozenSet[Mapper[Any]]:
         """read only collection of all :class:`_orm.Mapper` objects."""
index 1a5f0bd71d9724ac98912fcaf57b028ffc983369..0b26cb872179c4f4b47be9591f8f8bd7715b9362 100644 (file)
@@ -45,7 +45,6 @@ from .. import log
 from .. import util
 from ..sql import coercions
 from ..sql import roles
-from ..sql import sqltypes
 from ..sql.base import _NoArg
 from ..sql.roles import DDLConstraintColumnRole
 from ..sql.schema import Column
@@ -737,10 +736,7 @@ class MappedColumn(
 
             for check_type in checks:
 
-                if registry.type_annotation_map:
-                    new_sqltype = registry.type_annotation_map.get(check_type)
-                if new_sqltype is None:
-                    new_sqltype = sqltypes._type_map_get(check_type)  # type: ignore  # noqa: E501
+                new_sqltype = registry._resolve_type(check_type)
                 if new_sqltype is not None:
                     break
             else:
@@ -749,4 +745,4 @@ class MappedColumn(
                     f"type for Python type: {our_type}"
                 )
 
-            self.column.type = sqltypes.to_instance(new_sqltype)
+            self.column._set_type(new_sqltype)
index a0f56d8391a32e378c8f77658ccb0e89dcbc31d4..c239a3a6a9e5980c91ae103e2febda83e7bc67d7 100644 (file)
@@ -59,6 +59,7 @@ from .. import util
 from ..engine import processors
 from ..util import langhelpers
 from ..util import OrderedDict
+from ..util.typing import GenericProtocol
 from ..util.typing import Literal
 
 if TYPE_CHECKING:
@@ -1489,6 +1490,28 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
             self.enum_class = None
             return enums, enums
 
+    def _resolve_for_literal(self, value: Any) -> Enum:
+        typ = self._resolve_for_python_type(type(value), type(value))
+        assert typ is not None
+        return typ
+
+    def _resolve_for_python_type(
+        self,
+        python_type: Type[Any],
+        matched_on: Union[GenericProtocol[Any], Type[Any]],
+    ) -> Optional[Enum]:
+        if not issubclass(python_type, enum.Enum):
+            return None
+        return cast(
+            Enum,
+            util.constructor_copy(
+                self,
+                self._generic_type_affinity,
+                python_type,
+                length=NO_ARG if self.length == 0 else self.length,
+            ),
+        )
+
     def _setup_for_values(self, values, objects, kw):
         self.enums = list(values)
 
@@ -3674,6 +3697,7 @@ _type_map: Dict[Type[Any], TypeEngine[Any]] = {
     type(None): NULLTYPE,
     bytes: LargeBinary(),
     str: _STRING,
+    enum.Enum: Enum(enum.Enum),
 }
 
 
index c3768c6c63da85f0d7ed427578cad88cea9e84c4..b395e6796416b5ff2d25405436d1d60e3277366c 100644 (file)
@@ -35,6 +35,7 @@ from .operators import ColumnOperators
 from .visitors import Visitable
 from .. import exc
 from .. import util
+from ..util.typing import flatten_generic
 from ..util.typing import Protocol
 from ..util.typing import TypedDict
 from ..util.typing import TypeGuard
@@ -55,6 +56,7 @@ if typing.TYPE_CHECKING:
     from .sqltypes import STRINGTYPE as STRINGTYPE  # noqa: F401
     from .sqltypes import TABLEVALUE as TABLEVALUE  # noqa: F401
     from ..engine.interfaces import Dialect
+    from ..util.typing import GenericProtocol
 
 _T = TypeVar("_T", bound=Any)
 _T_co = TypeVar("_T_co", bound=Any, covariant=True)
@@ -712,9 +714,66 @@ class TypeEngine(Visitable, Generic[_T]):
 
         .. versionadded:: 1.4.30 or 2.0
 
+        TODO: this should be part of public API
+
+        .. seealso::
+
+            :meth:`.TypeEngine._resolve_for_python_type`
+
         """
         return self
 
+    def _resolve_for_python_type(
+        self: SelfTypeEngine,
+        python_type: Type[Any],
+        matched_on: Union[GenericProtocol[Any], Type[Any]],
+    ) -> Optional[SelfTypeEngine]:
+        """given a Python type (e.g. ``int``, ``str``, etc. ) return an
+        instance of this :class:`.TypeEngine` that's appropriate for this type.
+
+        An additional argument ``matched_on`` is passed, which indicates an
+        entry from the ``__mro__`` of the given ``python_type`` that more
+        specifically matches how the caller located this :class:`.TypeEngine`
+        object.   Such as, if a lookup of some kind links the ``int`` Python
+        type to the :class:`.Integer` SQL type, and the original object
+        was some custom subclass of ``int`` such as ``MyInt(int)``, the
+        arguments passed would be ``(MyInt, int)``.
+
+        If the given Python type does not correspond to this
+        :class:`.TypeEngine`, or the Python type is otherwise ambiguous, the
+        method should return None.
+
+        For simple cases, the method checks that the ``python_type``
+        and ``matched_on`` types are the same (i.e. not a subclass), and
+        returns self; for all other cases, it returns ``None``.
+
+        The initial use case here is for the ORM to link user-defined
+        Python standard library ``enum.Enum`` classes to the SQLAlchemy
+        :class:`.Enum` SQL type when constructing ORM Declarative mappings.
+
+        :param python_type: the Python type we want to use
+        :param matched_on: the Python type that led us to choose this
+         particular :class:`.TypeEngine` class, which would be a supertype
+         of ``python_type``.   By default, the request is rejected if
+         ``python_type`` doesn't match ``matched_on`` (None is returned).
+
+        .. versionadded:: 2.0.0b4
+
+        TODO: this should be part of public API
+
+        .. seealso::
+
+            :meth:`.TypeEngine._resolve_for_literal`
+
+        """
+
+        matched_on = flatten_generic(matched_on)
+
+        if python_type is not matched_on:
+            return None
+
+        return self
+
     @util.ro_memoized_property
     def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]:
         """Return a rudimental 'affinity' value expressing the general class
index 749d042481a03dc381bbc23c1fe75bdb006a5d3a..a75c36776484d35f40276e6fd8d3078cc7da0015 100644 (file)
@@ -180,9 +180,15 @@ def variation(argname, cases):
 
     """
 
+    cases_plus_limitations = [
+        entry
+        if (isinstance(entry, tuple) and len(entry) == 2)
+        else (entry, None)
+        for entry in cases
+    ]
     case_names = [
         argname if c is True else "not_" + argname if c is False else c
-        for c in cases
+        for c, l in cases_plus_limitations
     ]
 
     typ = type(
@@ -195,8 +201,12 @@ def variation(argname, cases):
 
     return combinations(
         *[
-            (casename, typ(casename, argname, case_names))
-            for casename in case_names
+            (casename, typ(casename, argname, case_names), limitation)
+            if limitation is not None
+            else (casename, typ(casename, argname, case_names))
+            for casename, (case, limitation) in zip(
+                case_names, cases_plus_limitations
+            )
         ],
         id_="ia",
         argnames=argname,
index 0c8e5a633447f5bc903a07716e69fb38a3175a6c..b1ef87db1719029e305e0e2770d6d0f8533e9004 100644 (file)
@@ -74,7 +74,7 @@ typing_get_origin = get_origin
 # copied from TypeShed, required in order to implement
 # MutableMapping.update()
 
-_AnnotationScanType = Union[Type[Any], str, ForwardRef]
+_AnnotationScanType = Union[Type[Any], str, ForwardRef, "GenericProtocol[Any]"]
 
 
 class ArgsTypeProcotol(Protocol):
@@ -236,6 +236,15 @@ def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
     return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
 
 
+def flatten_generic(
+    type_: Union[GenericProtocol[Any], Type[Any]]
+) -> Type[Any]:
+    if is_generic(type_):
+        return type_.__origin__
+    else:
+        return cast("Type[Any]", type_)
+
+
 def is_fwd_ref(
     type_: _AnnotationScanType, check_generic: bool = False
 ) -> bool:
index 7358f385db58a420e6b813f5f190ee1d8efac1aa..5d1b6b199e9c3f5b5fc28d614362bdcd10603378 100644 (file)
@@ -10,6 +10,7 @@ from __future__ import annotations
 import dataclasses
 import datetime
 from decimal import Decimal
+import enum
 from typing import Any
 from typing import ClassVar
 from typing import Dict
@@ -53,11 +54,13 @@ from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import MappedAsDataclass
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
 from sqlalchemy.orm import undefer
 from sqlalchemy.orm import WriteOnlyMapped
 from sqlalchemy.orm.collections import attribute_keyed_dict
 from sqlalchemy.orm.collections import KeyFuncDict
 from sqlalchemy.schema import CreateTable
+from sqlalchemy.sql.sqltypes import Enum
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises
@@ -1134,6 +1137,158 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 id: Mapped[int] = mapped_column(primary_key=True)
                 data: Mapped["fake"]  # noqa
 
+    @testing.variation("use_callable", [True, False])
+    @testing.variation("include_generic", [True, False])
+    def test_enum_explicit(self, use_callable, include_generic):
+        global FooEnum
+
+        class FooEnum(enum.Enum):
+            foo = enum.auto()
+            bar = enum.auto()
+
+        if use_callable:
+            tam = {FooEnum: Enum(FooEnum, length=500)}
+        else:
+            tam = {FooEnum: Enum(FooEnum, length=500)}
+        if include_generic:
+            tam[enum.Enum] = Enum(enum.Enum)
+        Base = declarative_base(type_annotation_map=tam)
+
+        class MyClass(Base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[FooEnum]
+
+        is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+        eq_(MyClass.__table__.c.data.type.length, 500)
+        is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+    def test_enum_generic(self):
+        """test for #8859"""
+        global FooEnum
+
+        class FooEnum(enum.Enum):
+            foo = enum.auto()
+            bar = enum.auto()
+
+        Base = declarative_base(
+            type_annotation_map={enum.Enum: Enum(enum.Enum, length=42)}
+        )
+
+        class MyClass(Base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[FooEnum]
+
+        is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+        eq_(MyClass.__table__.c.data.type.length, 42)
+        is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+    def test_enum_default(self, decl_base):
+        """test #8859.
+
+        We now have Enum in the default SQL lookup map, in conjunction with
+        a mechanism that will adapt it for a given enum type.
+
+        This relies on a search through __mro__ for the given type,
+        which in other tests we ensure does not actually function if
+        we aren't dealing with Enum (or some other type that allows for
+        __mro__ lookup)
+
+        """
+        global FooEnum
+
+        class FooEnum(enum.Enum):
+            foo = "foo"
+            bar_value = "bar"
+
+        class MyClass(decl_base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[FooEnum]
+
+        is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+        eq_(MyClass.__table__.c.data.type.length, 9)
+        is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+    def test_type_dont_mis_resolve_on_superclass(self):
+        """test for #8859.
+
+        For subclasses of a type that's in the map, don't resolve this
+        by default, even though we do a search through __mro__.
+
+        """
+        global int_sub
+
+        class int_sub(int):
+            pass
+
+        Base = declarative_base(
+            type_annotation_map={
+                int: Integer,
+            }
+        )
+
+        with expect_raises_message(
+            sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type"
+        ):
+
+            class MyClass(Base):
+                __tablename__ = "mytable"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data: Mapped[int_sub]
+
+    @testing.variation(
+        "dict_key", ["typing", ("plain", testing.requires.python310)]
+    )
+    def test_type_dont_mis_resolve_on_non_generic(self, dict_key):
+        """test for #8859.
+
+        For a specific generic type with arguments, don't do any MRO
+        lookup.
+
+        """
+
+        Base = declarative_base(
+            type_annotation_map={
+                dict: String,
+            }
+        )
+
+        with expect_raises_message(
+            sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type"
+        ):
+
+            class MyClass(Base):
+                __tablename__ = "mytable"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+
+                if dict_key.plain:
+                    data: Mapped[dict[str, str]]
+                elif dict_key.typing:
+                    data: Mapped[Dict[str, str]]
+
+    def test_type_secondary_resolution(self):
+        class MyString(String):
+            def _resolve_for_python_type(self, python_type, matched_type):
+                return String(length=42)
+
+        Base = declarative_base(type_annotation_map={str: MyString})
+
+        class MyClass(Base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str]
+
+        is_true(isinstance(MyClass.__table__.c.data.type, String))
+        eq_(MyClass.__table__.c.data.type.length, 42)
+
 
 class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
@@ -2200,3 +2355,51 @@ class GenericMappingQueryTest(AssertsCompiledSQL, fixtures.TestBase):
             select(typ).where(typ.key == "x"),
             "SELECT xx.id, xx.key, xx.value FROM xx WHERE xx.key = :key_1",
         )
+
+
+class BackendTests(fixtures.TestBase):
+    __backend__ = True
+
+    @testing.variation("native_enum", [True, False])
+    @testing.variation("include_column", [True, False])
+    def test_schema_type_actually_works(
+        self, connection, decl_base, include_column, native_enum
+    ):
+        """test that schema type bindings are set up correctly"""
+
+        global Status
+
+        class Status(enum.Enum):
+            PENDING = "pending"
+            RECEIVED = "received"
+            COMPLETED = "completed"
+
+        if not include_column and not native_enum:
+            decl_base.registry.update_type_annotation_map(
+                {enum.Enum: Enum(enum.Enum, native_enum=False)}
+            )
+
+        class SomeClass(decl_base):
+            __tablename__ = "some_table"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+            if include_column:
+                status: Mapped[Status] = mapped_column(
+                    Enum(Status, native_enum=bool(native_enum))
+                )
+            else:
+                status: Mapped[Status]
+
+        decl_base.metadata.create_all(connection)
+
+        with Session(connection) as sess:
+            sess.add(SomeClass(id=1, status=Status.RECEIVED))
+            sess.commit()
+
+            eq_(
+                sess.scalars(
+                    select(SomeClass.status).where(SomeClass.id == 1)
+                ).first(),
+                Status.RECEIVED,
+            )
index ba099412f3ef204fa656073de458eaa6a805ed66..ffa640d4cd3630ecf7ae4e99881993f6cbe22431 100644 (file)
@@ -1,6 +1,7 @@
 import dataclasses
 import datetime
 from decimal import Decimal
+import enum
 from typing import Any
 from typing import ClassVar
 from typing import Dict
@@ -44,11 +45,13 @@ from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import MappedAsDataclass
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
 from sqlalchemy.orm import undefer
 from sqlalchemy.orm import WriteOnlyMapped
 from sqlalchemy.orm.collections import attribute_keyed_dict
 from sqlalchemy.orm.collections import KeyFuncDict
 from sqlalchemy.schema import CreateTable
+from sqlalchemy.sql.sqltypes import Enum
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises
@@ -1125,6 +1128,158 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 id: Mapped[int] = mapped_column(primary_key=True)
                 data: Mapped["fake"]  # noqa
 
+    @testing.variation("use_callable", [True, False])
+    @testing.variation("include_generic", [True, False])
+    def test_enum_explicit(self, use_callable, include_generic):
+        # anno only: global FooEnum
+
+        class FooEnum(enum.Enum):
+            foo = enum.auto()
+            bar = enum.auto()
+
+        if use_callable:
+            tam = {FooEnum: Enum(FooEnum, length=500)}
+        else:
+            tam = {FooEnum: Enum(FooEnum, length=500)}
+        if include_generic:
+            tam[enum.Enum] = Enum(enum.Enum)
+        Base = declarative_base(type_annotation_map=tam)
+
+        class MyClass(Base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[FooEnum]
+
+        is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+        eq_(MyClass.__table__.c.data.type.length, 500)
+        is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+    def test_enum_generic(self):
+        """test for #8859"""
+        # anno only: global FooEnum
+
+        class FooEnum(enum.Enum):
+            foo = enum.auto()
+            bar = enum.auto()
+
+        Base = declarative_base(
+            type_annotation_map={enum.Enum: Enum(enum.Enum, length=42)}
+        )
+
+        class MyClass(Base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[FooEnum]
+
+        is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+        eq_(MyClass.__table__.c.data.type.length, 42)
+        is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+    def test_enum_default(self, decl_base):
+        """test #8859.
+
+        We now have Enum in the default SQL lookup map, in conjunction with
+        a mechanism that will adapt it for a given enum type.
+
+        This relies on a search through __mro__ for the given type,
+        which in other tests we ensure does not actually function if
+        we aren't dealing with Enum (or some other type that allows for
+        __mro__ lookup)
+
+        """
+        # anno only: global FooEnum
+
+        class FooEnum(enum.Enum):
+            foo = "foo"
+            bar_value = "bar"
+
+        class MyClass(decl_base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[FooEnum]
+
+        is_true(isinstance(MyClass.__table__.c.data.type, Enum))
+        eq_(MyClass.__table__.c.data.type.length, 9)
+        is_(MyClass.__table__.c.data.type.enum_class, FooEnum)
+
+    def test_type_dont_mis_resolve_on_superclass(self):
+        """test for #8859.
+
+        For subclasses of a type that's in the map, don't resolve this
+        by default, even though we do a search through __mro__.
+
+        """
+        # anno only: global int_sub
+
+        class int_sub(int):
+            pass
+
+        Base = declarative_base(
+            type_annotation_map={
+                int: Integer,
+            }
+        )
+
+        with expect_raises_message(
+            sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type"
+        ):
+
+            class MyClass(Base):
+                __tablename__ = "mytable"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data: Mapped[int_sub]
+
+    @testing.variation(
+        "dict_key", ["typing", ("plain", testing.requires.python310)]
+    )
+    def test_type_dont_mis_resolve_on_non_generic(self, dict_key):
+        """test for #8859.
+
+        For a specific generic type with arguments, don't do any MRO
+        lookup.
+
+        """
+
+        Base = declarative_base(
+            type_annotation_map={
+                dict: String,
+            }
+        )
+
+        with expect_raises_message(
+            sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type"
+        ):
+
+            class MyClass(Base):
+                __tablename__ = "mytable"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+
+                if dict_key.plain:
+                    data: Mapped[dict[str, str]]
+                elif dict_key.typing:
+                    data: Mapped[Dict[str, str]]
+
+    def test_type_secondary_resolution(self):
+        class MyString(String):
+            def _resolve_for_python_type(self, python_type, matched_type):
+                return String(length=42)
+
+        Base = declarative_base(type_annotation_map={str: MyString})
+
+        class MyClass(Base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str]
+
+        is_true(isinstance(MyClass.__table__.c.data.type, String))
+        eq_(MyClass.__table__.c.data.type.length, 42)
+
 
 class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
@@ -2191,3 +2346,51 @@ class GenericMappingQueryTest(AssertsCompiledSQL, fixtures.TestBase):
             select(typ).where(typ.key == "x"),
             "SELECT xx.id, xx.key, xx.value FROM xx WHERE xx.key = :key_1",
         )
+
+
+class BackendTests(fixtures.TestBase):
+    __backend__ = True
+
+    @testing.variation("native_enum", [True, False])
+    @testing.variation("include_column", [True, False])
+    def test_schema_type_actually_works(
+        self, connection, decl_base, include_column, native_enum
+    ):
+        """test that schema type bindings are set up correctly"""
+
+        # anno only: global Status
+
+        class Status(enum.Enum):
+            PENDING = "pending"
+            RECEIVED = "received"
+            COMPLETED = "completed"
+
+        if not include_column and not native_enum:
+            decl_base.registry.update_type_annotation_map(
+                {enum.Enum: Enum(enum.Enum, native_enum=False)}
+            )
+
+        class SomeClass(decl_base):
+            __tablename__ = "some_table"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+            if include_column:
+                status: Mapped[Status] = mapped_column(
+                    Enum(Status, native_enum=bool(native_enum))
+                )
+            else:
+                status: Mapped[Status]
+
+        decl_base.metadata.create_all(connection)
+
+        with Session(connection) as sess:
+            sess.add(SomeClass(id=1, status=Status.RECEIVED))
+            sess.commit()
+
+            eq_(
+                sess.scalars(
+                    select(SomeClass.status).where(SomeClass.id == 1)
+                ).first(),
+                Status.RECEIVED,
+            )