From: Federico Caselli Date: Sun, 27 Nov 2022 17:11:34 +0000 (+0100) Subject: Improve support for enum in mapped classes X-Git-Tag: rel_2_0_0b4~23^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9c9fd31bcea3beaed6d14fde639e65f6b43bea09;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve support for enum in mapped classes Add a new system by which TypeEngine objects have some say in how the declarative type registry interprets them. The Enum datatype is the primary target for this but it is hoped the system may be useful for other types as well. Fixes: #8859 Change-Id: I15ac3daee770408b5795746f47c1bbd931b7d26d --- diff --git a/doc/build/changelog/unreleased_20/8859.rst b/doc/build/changelog/unreleased_20/8859.rst new file mode 100644 index 0000000000..85e4be4224 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8859.rst @@ -0,0 +1,16 @@ +.. change:: + :tags: usecase, orm + :tickets: 8859 + + Added support custom user-defined types which extend the Python + ``enum.Enum`` base class to be resolved automatically + to SQLAlchemy :class:`.Enum` SQL types, when using the Annotated + Declarative Table feature. The feature is made possible through new + lookup features added to the ORM type map feature, and includes support + for changing the arguments of the :class:`.Enum` that's generated by + default as well as setting up specific ``enum.Enum`` types within + the map with specific arguments. + + .. seealso:: + + :ref:`orm_declarative_mapped_column_enums` diff --git a/doc/build/orm/declarative_tables.rst b/doc/build/orm/declarative_tables.rst index 475813f819..806a6897f2 100644 --- a/doc/build/orm/declarative_tables.rst +++ b/doc/build/orm/declarative_tables.rst @@ -369,6 +369,106 @@ while still being able to use succinct annotation-only :func:`_orm.mapped_column configurations. There are two more levels of Python-type configurability available beyond this, described in the next two sections. +.. _orm_declarative_mapped_column_enums: + +Using Python ``Enum`` types in the type map +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 2.0.0b4 + +User-defined Python types which derive from the Python built-in ``enum.Enum`` +class are automatically linked to the SQLAlchemy :class:`.Enum` datatype +when used in an ORM declarative mapping:: + + import enum + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + + class Base(DeclarativeBase): + pass + + + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" + + + class SomeClass(Base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[Status] + +In the above example, the mapped attribute ``SomeClass.status`` will be +linked to a :class:`.Column` with the datatype of ``Enum(Status)``. +We can see this for example in the CREATE TABLE output for the PostgreSQL +database: + +.. sourcecode:: sql + + CREATE TYPE status AS ENUM ('PENDING', 'RECEIVED', 'COMPLETED') + + CREATE TABLE some_table ( + id SERIAL NOT NULL, + status status NOT NULL, + PRIMARY KEY (id) + ) + +The entry used in :paramref:`_orm.registry.type_annotation_map` links the +base ``enum.Enum`` Python type to the SQLAlchemy :class:`.Enum` SQL +type, using a special form which indicates to the :class:`.Enum` datatype +that it should automatically configure itself against an arbitrary enumerated +type. This configuration, which is implicit by default, would be indicated +explicitly as:: + + import enum + import sqlalchemy + + + class Base(DeclarativeBase): + type_annotation_map = {enum.Enum: sqlalchemy.Enum(enum.Enum)} + +The resolution logic within Declarative is able to resolve subclasses +of ``enum.Enum``, in the above example the custom ``Status`` enumeration, +to match the ``enum.Enum`` entry in the +:paramref:`_orm.registry.type_annotation_map` dictionary. The :class:`.Enum` +SQL type then knows how to produce a configured version of itself with the +appropriate settings, including default string length. + +In order to modify the configuration of the :class:`.enum.Enum` datatype used +in this mapping, use the above form, indicating additional arguments. For +example, to use "non native enumerations" on all backends, the +:paramref:`.Enum.native_enum` parameter may be set to False for all types:: + + import enum + import sqlalchemy + + + class Base(DeclarativeBase): + type_annotation_map = {enum.Enum: sqlalchemy.Enum(enum.Enum, native_enum=False)} + +To use a specific configuration for a specific ``enum.Enum`` subtype, such +as setting the string length to 50 when using the example ``Status`` +datatype:: + + import enum + import sqlalchemy + + + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" + + + class Base(DeclarativeBase): + type_annotation_map = { + Status: sqlalchemy.Enum(Status, length=50, native_enum=False) + } .. _orm_declarative_mapped_column_type_map_pep593: diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 09397eb653..de6c8794b1 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -14,6 +14,7 @@ import re import typing from typing import Any from typing import Callable +from typing import cast from typing import ClassVar from typing import Dict from typing import FrozenSet @@ -23,6 +24,7 @@ from typing import Mapping from typing import Optional from typing import overload from typing import Set +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -62,6 +64,7 @@ from .state import InstanceState from .. import exc from .. import inspection from .. import util +from ..sql import sqltypes from ..sql.base import _NoArg from ..sql.elements import SQLCoreOperations from ..sql.schema import MetaData @@ -70,6 +73,7 @@ from ..util import hybridmethod from ..util import hybridproperty from ..util import typing as compat_typing from ..util.typing import CallableReference +from ..util.typing import is_generic from ..util.typing import Literal if TYPE_CHECKING: @@ -80,6 +84,7 @@ if TYPE_CHECKING: from .interfaces import MapperProperty from .state import InstanceState # noqa from ..sql._typing import _TypeEngineArgument + from ..util.typing import GenericProtocol _T = TypeVar("_T", bound=Any) @@ -1018,6 +1023,37 @@ class registry: } ) + def _resolve_type( + self, python_type: Union[GenericProtocol[Any], Type[Any]] + ) -> Optional[sqltypes.TypeEngine[Any]]: + + search: Tuple[Union[GenericProtocol[Any], Type[Any]], ...] + + if is_generic(python_type): + python_type_type: Type[Any] = python_type.__origin__ + search = (python_type,) + else: + # don't know why is_generic() TypeGuard[GenericProtocol[Any]] + # check above is not sufficient here + python_type_type = cast("Type[Any]", python_type) + search = python_type_type.__mro__ + + for pt in search: + sql_type = self.type_annotation_map.get(pt) + if sql_type is None: + sql_type = sqltypes._type_map_get(pt) # type: ignore # noqa: E501 + + if sql_type is not None: + sql_type_inst = sqltypes.to_instance(sql_type) # type: ignore + + resolved_sql_type = sql_type_inst._resolve_for_python_type( + python_type_type, pt + ) + if resolved_sql_type is not None: + return resolved_sql_type + + return None + @property def mappers(self) -> FrozenSet[Mapper[Any]]: """read only collection of all :class:`_orm.Mapper` objects.""" diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 1a5f0bd71d..0b26cb8721 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -45,7 +45,6 @@ from .. import log from .. import util from ..sql import coercions from ..sql import roles -from ..sql import sqltypes from ..sql.base import _NoArg from ..sql.roles import DDLConstraintColumnRole from ..sql.schema import Column @@ -737,10 +736,7 @@ class MappedColumn( for check_type in checks: - if registry.type_annotation_map: - new_sqltype = registry.type_annotation_map.get(check_type) - if new_sqltype is None: - new_sqltype = sqltypes._type_map_get(check_type) # type: ignore # noqa: E501 + new_sqltype = registry._resolve_type(check_type) if new_sqltype is not None: break else: @@ -749,4 +745,4 @@ class MappedColumn( f"type for Python type: {our_type}" ) - self.column.type = sqltypes.to_instance(new_sqltype) + self.column._set_type(new_sqltype) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index a0f56d8391..c239a3a6a9 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -59,6 +59,7 @@ from .. import util from ..engine import processors from ..util import langhelpers from ..util import OrderedDict +from ..util.typing import GenericProtocol from ..util.typing import Literal if TYPE_CHECKING: @@ -1489,6 +1490,28 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): self.enum_class = None return enums, enums + def _resolve_for_literal(self, value: Any) -> Enum: + typ = self._resolve_for_python_type(type(value), type(value)) + assert typ is not None + return typ + + def _resolve_for_python_type( + self, + python_type: Type[Any], + matched_on: Union[GenericProtocol[Any], Type[Any]], + ) -> Optional[Enum]: + if not issubclass(python_type, enum.Enum): + return None + return cast( + Enum, + util.constructor_copy( + self, + self._generic_type_affinity, + python_type, + length=NO_ARG if self.length == 0 else self.length, + ), + ) + def _setup_for_values(self, values, objects, kw): self.enums = list(values) @@ -3674,6 +3697,7 @@ _type_map: Dict[Type[Any], TypeEngine[Any]] = { type(None): NULLTYPE, bytes: LargeBinary(), str: _STRING, + enum.Enum: Enum(enum.Enum), } diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index c3768c6c63..b395e67964 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -35,6 +35,7 @@ from .operators import ColumnOperators from .visitors import Visitable from .. import exc from .. import util +from ..util.typing import flatten_generic from ..util.typing import Protocol from ..util.typing import TypedDict from ..util.typing import TypeGuard @@ -55,6 +56,7 @@ if typing.TYPE_CHECKING: from .sqltypes import STRINGTYPE as STRINGTYPE # noqa: F401 from .sqltypes import TABLEVALUE as TABLEVALUE # noqa: F401 from ..engine.interfaces import Dialect + from ..util.typing import GenericProtocol _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) @@ -712,9 +714,66 @@ class TypeEngine(Visitable, Generic[_T]): .. versionadded:: 1.4.30 or 2.0 + TODO: this should be part of public API + + .. seealso:: + + :meth:`.TypeEngine._resolve_for_python_type` + """ return self + def _resolve_for_python_type( + self: SelfTypeEngine, + python_type: Type[Any], + matched_on: Union[GenericProtocol[Any], Type[Any]], + ) -> Optional[SelfTypeEngine]: + """given a Python type (e.g. ``int``, ``str``, etc. ) return an + instance of this :class:`.TypeEngine` that's appropriate for this type. + + An additional argument ``matched_on`` is passed, which indicates an + entry from the ``__mro__`` of the given ``python_type`` that more + specifically matches how the caller located this :class:`.TypeEngine` + object. Such as, if a lookup of some kind links the ``int`` Python + type to the :class:`.Integer` SQL type, and the original object + was some custom subclass of ``int`` such as ``MyInt(int)``, the + arguments passed would be ``(MyInt, int)``. + + If the given Python type does not correspond to this + :class:`.TypeEngine`, or the Python type is otherwise ambiguous, the + method should return None. + + For simple cases, the method checks that the ``python_type`` + and ``matched_on`` types are the same (i.e. not a subclass), and + returns self; for all other cases, it returns ``None``. + + The initial use case here is for the ORM to link user-defined + Python standard library ``enum.Enum`` classes to the SQLAlchemy + :class:`.Enum` SQL type when constructing ORM Declarative mappings. + + :param python_type: the Python type we want to use + :param matched_on: the Python type that led us to choose this + particular :class:`.TypeEngine` class, which would be a supertype + of ``python_type``. By default, the request is rejected if + ``python_type`` doesn't match ``matched_on`` (None is returned). + + .. versionadded:: 2.0.0b4 + + TODO: this should be part of public API + + .. seealso:: + + :meth:`.TypeEngine._resolve_for_literal` + + """ + + matched_on = flatten_generic(matched_on) + + if python_type is not matched_on: + return None + + return self + @util.ro_memoized_property def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]: """Return a rudimental 'affinity' value expressing the general class diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 749d042481..a75c367764 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -180,9 +180,15 @@ def variation(argname, cases): """ + cases_plus_limitations = [ + entry + if (isinstance(entry, tuple) and len(entry) == 2) + else (entry, None) + for entry in cases + ] case_names = [ argname if c is True else "not_" + argname if c is False else c - for c in cases + for c, l in cases_plus_limitations ] typ = type( @@ -195,8 +201,12 @@ def variation(argname, cases): return combinations( *[ - (casename, typ(casename, argname, case_names)) - for casename in case_names + (casename, typ(casename, argname, case_names), limitation) + if limitation is not None + else (casename, typ(casename, argname, case_names)) + for casename, (case, limitation) in zip( + case_names, cases_plus_limitations + ) ], id_="ia", argnames=argname, diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 0c8e5a6334..b1ef87db17 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -74,7 +74,7 @@ typing_get_origin = get_origin # copied from TypeShed, required in order to implement # MutableMapping.update() -_AnnotationScanType = Union[Type[Any], str, ForwardRef] +_AnnotationScanType = Union[Type[Any], str, ForwardRef, "GenericProtocol[Any]"] class ArgsTypeProcotol(Protocol): @@ -236,6 +236,15 @@ def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]: return hasattr(type_, "__args__") and hasattr(type_, "__origin__") +def flatten_generic( + type_: Union[GenericProtocol[Any], Type[Any]] +) -> Type[Any]: + if is_generic(type_): + return type_.__origin__ + else: + return cast("Type[Any]", type_) + + def is_fwd_ref( type_: _AnnotationScanType, check_generic: bool = False ) -> bool: diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index 7358f385db..5d1b6b199e 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -10,6 +10,7 @@ from __future__ import annotations import dataclasses import datetime from decimal import Decimal +import enum from typing import Any from typing import ClassVar from typing import Dict @@ -53,11 +54,13 @@ from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedAsDataclass from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.orm import undefer from sqlalchemy.orm import WriteOnlyMapped from sqlalchemy.orm.collections import attribute_keyed_dict from sqlalchemy.orm.collections import KeyFuncDict from sqlalchemy.schema import CreateTable +from sqlalchemy.sql.sqltypes import Enum from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises @@ -1134,6 +1137,158 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped["fake"] # noqa + @testing.variation("use_callable", [True, False]) + @testing.variation("include_generic", [True, False]) + def test_enum_explicit(self, use_callable, include_generic): + global FooEnum + + class FooEnum(enum.Enum): + foo = enum.auto() + bar = enum.auto() + + if use_callable: + tam = {FooEnum: Enum(FooEnum, length=500)} + else: + tam = {FooEnum: Enum(FooEnum, length=500)} + if include_generic: + tam[enum.Enum] = Enum(enum.Enum) + Base = declarative_base(type_annotation_map=tam) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 500) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_enum_generic(self): + """test for #8859""" + global FooEnum + + class FooEnum(enum.Enum): + foo = enum.auto() + bar = enum.auto() + + Base = declarative_base( + type_annotation_map={enum.Enum: Enum(enum.Enum, length=42)} + ) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 42) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_enum_default(self, decl_base): + """test #8859. + + We now have Enum in the default SQL lookup map, in conjunction with + a mechanism that will adapt it for a given enum type. + + This relies on a search through __mro__ for the given type, + which in other tests we ensure does not actually function if + we aren't dealing with Enum (or some other type that allows for + __mro__ lookup) + + """ + global FooEnum + + class FooEnum(enum.Enum): + foo = "foo" + bar_value = "bar" + + class MyClass(decl_base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 9) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_type_dont_mis_resolve_on_superclass(self): + """test for #8859. + + For subclasses of a type that's in the map, don't resolve this + by default, even though we do a search through __mro__. + + """ + global int_sub + + class int_sub(int): + pass + + Base = declarative_base( + type_annotation_map={ + int: Integer, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[int_sub] + + @testing.variation( + "dict_key", ["typing", ("plain", testing.requires.python310)] + ) + def test_type_dont_mis_resolve_on_non_generic(self, dict_key): + """test for #8859. + + For a specific generic type with arguments, don't do any MRO + lookup. + + """ + + Base = declarative_base( + type_annotation_map={ + dict: String, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + + if dict_key.plain: + data: Mapped[dict[str, str]] + elif dict_key.typing: + data: Mapped[Dict[str, str]] + + def test_type_secondary_resolution(self): + class MyString(String): + def _resolve_for_python_type(self, python_type, matched_type): + return String(length=42) + + Base = declarative_base(type_annotation_map={str: MyString}) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + is_true(isinstance(MyClass.__table__.c.data.type, String)) + eq_(MyClass.__table__.c.data.type.length, 42) + class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -2200,3 +2355,51 @@ class GenericMappingQueryTest(AssertsCompiledSQL, fixtures.TestBase): select(typ).where(typ.key == "x"), "SELECT xx.id, xx.key, xx.value FROM xx WHERE xx.key = :key_1", ) + + +class BackendTests(fixtures.TestBase): + __backend__ = True + + @testing.variation("native_enum", [True, False]) + @testing.variation("include_column", [True, False]) + def test_schema_type_actually_works( + self, connection, decl_base, include_column, native_enum + ): + """test that schema type bindings are set up correctly""" + + global Status + + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" + + if not include_column and not native_enum: + decl_base.registry.update_type_annotation_map( + {enum.Enum: Enum(enum.Enum, native_enum=False)} + ) + + class SomeClass(decl_base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + if include_column: + status: Mapped[Status] = mapped_column( + Enum(Status, native_enum=bool(native_enum)) + ) + else: + status: Mapped[Status] + + decl_base.metadata.create_all(connection) + + with Session(connection) as sess: + sess.add(SomeClass(id=1, status=Status.RECEIVED)) + sess.commit() + + eq_( + sess.scalars( + select(SomeClass.status).where(SomeClass.id == 1) + ).first(), + Status.RECEIVED, + ) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index ba099412f3..ffa640d4cd 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1,6 +1,7 @@ import dataclasses import datetime from decimal import Decimal +import enum from typing import Any from typing import ClassVar from typing import Dict @@ -44,11 +45,13 @@ from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedAsDataclass from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.orm import undefer from sqlalchemy.orm import WriteOnlyMapped from sqlalchemy.orm.collections import attribute_keyed_dict from sqlalchemy.orm.collections import KeyFuncDict from sqlalchemy.schema import CreateTable +from sqlalchemy.sql.sqltypes import Enum from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises @@ -1125,6 +1128,158 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped["fake"] # noqa + @testing.variation("use_callable", [True, False]) + @testing.variation("include_generic", [True, False]) + def test_enum_explicit(self, use_callable, include_generic): + # anno only: global FooEnum + + class FooEnum(enum.Enum): + foo = enum.auto() + bar = enum.auto() + + if use_callable: + tam = {FooEnum: Enum(FooEnum, length=500)} + else: + tam = {FooEnum: Enum(FooEnum, length=500)} + if include_generic: + tam[enum.Enum] = Enum(enum.Enum) + Base = declarative_base(type_annotation_map=tam) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 500) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_enum_generic(self): + """test for #8859""" + # anno only: global FooEnum + + class FooEnum(enum.Enum): + foo = enum.auto() + bar = enum.auto() + + Base = declarative_base( + type_annotation_map={enum.Enum: Enum(enum.Enum, length=42)} + ) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 42) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_enum_default(self, decl_base): + """test #8859. + + We now have Enum in the default SQL lookup map, in conjunction with + a mechanism that will adapt it for a given enum type. + + This relies on a search through __mro__ for the given type, + which in other tests we ensure does not actually function if + we aren't dealing with Enum (or some other type that allows for + __mro__ lookup) + + """ + # anno only: global FooEnum + + class FooEnum(enum.Enum): + foo = "foo" + bar_value = "bar" + + class MyClass(decl_base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[FooEnum] + + is_true(isinstance(MyClass.__table__.c.data.type, Enum)) + eq_(MyClass.__table__.c.data.type.length, 9) + is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + + def test_type_dont_mis_resolve_on_superclass(self): + """test for #8859. + + For subclasses of a type that's in the map, don't resolve this + by default, even though we do a search through __mro__. + + """ + # anno only: global int_sub + + class int_sub(int): + pass + + Base = declarative_base( + type_annotation_map={ + int: Integer, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[int_sub] + + @testing.variation( + "dict_key", ["typing", ("plain", testing.requires.python310)] + ) + def test_type_dont_mis_resolve_on_non_generic(self, dict_key): + """test for #8859. + + For a specific generic type with arguments, don't do any MRO + lookup. + + """ + + Base = declarative_base( + type_annotation_map={ + dict: String, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + + if dict_key.plain: + data: Mapped[dict[str, str]] + elif dict_key.typing: + data: Mapped[Dict[str, str]] + + def test_type_secondary_resolution(self): + class MyString(String): + def _resolve_for_python_type(self, python_type, matched_type): + return String(length=42) + + Base = declarative_base(type_annotation_map={str: MyString}) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + is_true(isinstance(MyClass.__table__.c.data.type, String)) + eq_(MyClass.__table__.c.data.type.length, 42) + class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -2191,3 +2346,51 @@ class GenericMappingQueryTest(AssertsCompiledSQL, fixtures.TestBase): select(typ).where(typ.key == "x"), "SELECT xx.id, xx.key, xx.value FROM xx WHERE xx.key = :key_1", ) + + +class BackendTests(fixtures.TestBase): + __backend__ = True + + @testing.variation("native_enum", [True, False]) + @testing.variation("include_column", [True, False]) + def test_schema_type_actually_works( + self, connection, decl_base, include_column, native_enum + ): + """test that schema type bindings are set up correctly""" + + # anno only: global Status + + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" + + if not include_column and not native_enum: + decl_base.registry.update_type_annotation_map( + {enum.Enum: Enum(enum.Enum, native_enum=False)} + ) + + class SomeClass(decl_base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + if include_column: + status: Mapped[Status] = mapped_column( + Enum(Status, native_enum=bool(native_enum)) + ) + else: + status: Mapped[Status] + + decl_base.metadata.create_all(connection) + + with Session(connection) as sess: + sess.add(SomeClass(id=1, status=Status.RECEIVED)) + sess.commit() + + eq_( + sess.scalars( + select(SomeClass.status).where(SomeClass.id == 1) + ).first(), + Status.RECEIVED, + )