From 8e890609eb47f5a273e695154cf143af56807921 Mon Sep 17 00:00:00 2001 From: Frederik Aalund Date: Mon, 30 Jan 2023 11:50:40 -0500 Subject: [PATCH] Add support for typing.Literal in Mapped Added support for :pep:`586` ``Literal`` to be used in the :paramref:`_orm.registry.type_annotation_map` as well as within :class:`.Mapped` constructs. To use custom types such as these, they must appear explicitly within the :paramref:`_orm.registry.type_annotation_map` to be mapped. Pull request courtesy Frederik Aalund. As part of this change, the support for :class:`.sqltypes.Enum` in the :paramref:`_orm.registry.type_annotation_map` has been expanded to include support for ``Literal[]`` types consisting of string values to be used, in addition to ``enum.Enum`` datatypes. If a ``Literal[]`` datatype is used within ``Mapped[]`` that is not linked in :paramref:`_orm.registry.type_annotation_map` to a specific datatype, a :class:`.sqltypes.Enum` will be used by default. Fixed issue involving the use of :class:`.sqltypes.Enum` within the :paramref:`_orm.registry.type_annotation_map` where the :paramref:`_sqltypes.Enum.native_enum` parameter would not be correctly copied to the mapped column datatype, if it were overridden as stated in the documentation to set this parameter to False. Fixes: #9187 Fixes: #9200 Closes: #9191 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9191 Pull-request-sha: 7d13f705307bf62560fc831f6f049a425d411374 Change-Id: Ife3ba2655f4897f806d6a9cf0041c69fd4f39e9d --- doc/build/changelog/unreleased_20/9187.rst | 34 ++ doc/build/orm/declarative_tables.rst | 348 +++++++++++++----- lib/sqlalchemy/orm/decl_api.py | 14 +- lib/sqlalchemy/orm/decl_base.py | 3 +- lib/sqlalchemy/sql/sqltypes.py | 72 +++- lib/sqlalchemy/util/typing.py | 10 +- .../test_tm_future_annotations_sync.py | 347 +++++++++++++---- test/orm/declarative/test_typed_mapping.py | 347 +++++++++++++---- 8 files changed, 923 insertions(+), 252 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/9187.rst diff --git a/doc/build/changelog/unreleased_20/9187.rst b/doc/build/changelog/unreleased_20/9187.rst new file mode 100644 index 0000000000..830bb16009 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9187.rst @@ -0,0 +1,34 @@ +.. change:: + :tags: bug, orm + :tickets: 9187 + + Added support for :pep:`586` ``Literal[]`` to be used in the + :paramref:`_orm.registry.type_annotation_map` as well as within + :class:`.Mapped` constructs. To use custom types such as these, they must + appear explicitly within the :paramref:`_orm.registry.type_annotation_map` + to be mapped. Pull request courtesy Frederik Aalund. + + As part of this change, the support for :class:`.sqltypes.Enum` in the + :paramref:`_orm.registry.type_annotation_map` has been expanded to include + support for ``Literal[]`` types consisting of string values to be used, + in addition to ``enum.Enum`` datatypes. If a ``Literal[]`` datatype + is used within ``Mapped[]`` that is not linked in + :paramref:`_orm.registry.type_annotation_map` to a specific datatype, + a :class:`.sqltypes.Enum` will be used by default. + + .. seealso:: + + :ref:`orm_declarative_mapped_column_enums` + + +.. change:: + :tags: bug, orm + :tickets: 9200 + + Fixed issue involving the use of :class:`.sqltypes.Enum` within the + :paramref:`_orm.registry.type_annotation_map` where the + :paramref:`_sqltypes.Enum.native_enum` parameter would not be correctly + copied to the mapped column datatype, if it were overridden + as stated in the documentation to set this parameter to False. + + diff --git a/doc/build/orm/declarative_tables.rst b/doc/build/orm/declarative_tables.rst index a45fdfd8ed..d9a11087d6 100644 --- a/doc/build/orm/declarative_tables.rst +++ b/doc/build/orm/declarative_tables.rst @@ -369,107 +369,6 @@ while still being able to use succinct annotation-only :func:`_orm.mapped_column configurations. There are two more levels of Python-type configurability available beyond this, described in the next two sections. -.. _orm_declarative_mapped_column_enums: - -Using Python ``Enum`` types in the type map -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. versionadded:: 2.0.0b4 - -User-defined Python types which derive from the Python built-in ``enum.Enum`` -class are automatically linked to the SQLAlchemy :class:`.Enum` datatype -when used in an ORM declarative mapping:: - - import enum - - from sqlalchemy.orm import DeclarativeBase - from sqlalchemy.orm import Mapped - from sqlalchemy.orm import mapped_column - - - class Base(DeclarativeBase): - pass - - - class Status(enum.Enum): - PENDING = "pending" - RECEIVED = "received" - COMPLETED = "completed" - - - class SomeClass(Base): - __tablename__ = "some_table" - - id: Mapped[int] = mapped_column(primary_key=True) - status: Mapped[Status] - -In the above example, the mapped attribute ``SomeClass.status`` will be -linked to a :class:`.Column` with the datatype of ``Enum(Status)``. -We can see this for example in the CREATE TABLE output for the PostgreSQL -database: - -.. sourcecode:: sql - - CREATE TYPE status AS ENUM ('PENDING', 'RECEIVED', 'COMPLETED') - - CREATE TABLE some_table ( - id SERIAL NOT NULL, - status status NOT NULL, - PRIMARY KEY (id) - ) - -The entry used in :paramref:`_orm.registry.type_annotation_map` links the -base ``enum.Enum`` Python type to the SQLAlchemy :class:`.Enum` SQL -type, using a special form which indicates to the :class:`.Enum` datatype -that it should automatically configure itself against an arbitrary enumerated -type. This configuration, which is implicit by default, would be indicated -explicitly as:: - - import enum - import sqlalchemy - - - class Base(DeclarativeBase): - type_annotation_map = {enum.Enum: sqlalchemy.Enum(enum.Enum)} - -The resolution logic within Declarative is able to resolve subclasses -of ``enum.Enum``, in the above example the custom ``Status`` enumeration, -to match the ``enum.Enum`` entry in the -:paramref:`_orm.registry.type_annotation_map` dictionary. The :class:`.Enum` -SQL type then knows how to produce a configured version of itself with the -appropriate settings, including default string length. - -In order to modify the configuration of the :class:`.enum.Enum` datatype used -in this mapping, use the above form, indicating additional arguments. For -example, to use "non native enumerations" on all backends, the -:paramref:`.Enum.native_enum` parameter may be set to False for all types:: - - import enum - import sqlalchemy - - - class Base(DeclarativeBase): - type_annotation_map = {enum.Enum: sqlalchemy.Enum(enum.Enum, native_enum=False)} - -To use a specific configuration for a specific ``enum.Enum`` subtype, such -as setting the string length to 50 when using the example ``Status`` -datatype:: - - import enum - import sqlalchemy - - - class Status(enum.Enum): - PENDING = "pending" - RECEIVED = "received" - COMPLETED = "completed" - - - class Base(DeclarativeBase): - type_annotation_map = { - Status: sqlalchemy.Enum(Status, length=50, native_enum=False) - } - .. _orm_declarative_mapped_column_type_map_pep593: Mapping Multiple Type Configurations to Python Types @@ -739,6 +638,253 @@ adding a ``FOREIGN KEY`` constraint as well as substituting will raise a ``NotImplementedError`` exception at runtime, but may be implemented in future releases. +.. _orm_declarative_mapped_column_enums: + +Using Python ``Enum`` or pep-586 ``Literal`` types in the type map +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 2.0.0b4 - Added ``Enum`` support + +.. versionadded:: 2.0.1 - Added ``Literal`` support + +User-defined Python types which derive from the Python built-in ``enum.Enum`` +as well as the ``typing.Literal`` +class are automatically linked to the SQLAlchemy :class:`.Enum` datatype +when used in an ORM declarative mapping. The example below uses +a custom ``enum.Enum`` within the ``Mapped[]`` constructor:: + + import enum + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + + class Base(DeclarativeBase): + pass + + + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" + + + class SomeClass(Base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[Status] + +In the above example, the mapped attribute ``SomeClass.status`` will be +linked to a :class:`.Column` with the datatype of ``Enum(Status)``. +We can see this for example in the CREATE TABLE output for the PostgreSQL +database: + +.. sourcecode:: sql + + CREATE TYPE status AS ENUM ('PENDING', 'RECEIVED', 'COMPLETED') + + CREATE TABLE some_table ( + id SERIAL NOT NULL, + status status NOT NULL, + PRIMARY KEY (id) + ) + +In a similar way, ``typing.Literal`` may be used instead, using +a ``typing.Literal`` that consists of all strings:: + + + from typing import Literal + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + + class Base(DeclarativeBase): + pass + + + Status = Literal["pending", "received", "completed"] + + + class SomeClass(Base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[Status] + +The entries used in :paramref:`_orm.registry.type_annotation_map` link the base +``enum.Enum`` Python type as well as the ``typing.Literal`` type to the +SQLAlchemy :class:`.Enum` SQL type, using a special form which indicates to the +:class:`.Enum` datatype that it should automatically configure itself against +an arbitrary enumerated type. This configuration, which is implicit by default, +would be indicated explicitly as:: + + import enum + import typing + + import sqlalchemy + from sqlalchemy.orm import DeclarativeBase + + + class Base(DeclarativeBase): + type_annotation_map = { + enum.Enum: sqlalchemy.Enum(enum.Enum), + typing.Literal: sqlalchemy.Enum(enum.Enum), + } + +The resolution logic within Declarative is able to resolve subclasses +of ``enum.Enum`` as well as instances of ``typing.Literal`` to match the +``enum.Enum`` or ``typing.Literal`` entry in the +:paramref:`_orm.registry.type_annotation_map` dictionary. The :class:`.Enum` +SQL type then knows how to produce a configured version of itself with the +appropriate settings, including default string length. If a ``typing.Literal`` +that does not consist of only string values is passed, an informative +error is raised. + +Native Enums and Naming ++++++++++++++++++++++++ + +The :paramref:`.sqltypes.Enum.native_enum` parameter refers to if the +:class:`.sqltypes.Enum` datatype should create a so-called "native" +enum, which on MySQL/MariaDB is the ``ENUM`` datatype and on PostgreSQL is +a new ``TYPE`` object created by ``CREATE TYPE``, or a "non-native" enum, +which means that ``VARCHAR`` will be used to create the datatype. For +backends other than MySQL/MariaDB or PostgreSQL, ``VARCHAR`` is used in +all cases (third party dialects may have their own behaviors). + +Because PostgreSQL's ``CREATE TYPE`` requires that there's an explicit name +for the type to be created, special fallback logic exists when working +with implicitly generated :class:`.sqltypes.Enum` without specifying an +explicit :class:`.sqltypes.Enum` datatype within a mapping: + +1. If the :class:`.sqltypes.Enum` is linked to an ``enum.Enum`` object, + the :paramref:`.sqltypes.Enum.native_enum` parameter defaults to + ``True`` and the name of the enum will be taken from the name of the + ``enum.Enum`` datatype. The PostgreSQL backend will assume ``CREATE TYPE`` + with this name. +2. If the :class:`.sqltypes.Enum` is linked to a ``typing.Literal`` object, + the :paramref:`.sqltypes.Enum.native_enum` parameter defaults to + ``False``; no name is generated and ``VARCHAR`` is assumed. + +To use ``typing.Literal`` with a PostgreSQL ``CREATE TYPE`` type, an +explicit :class:`.sqltypes.Enum` must be used, either within the +type map:: + + import enum + import typing + + import sqlalchemy + from sqlalchemy.orm import DeclarativeBase + + Status = Literal["pending", "received", "completed"] + + + class Base(DeclarativeBase): + type_annotation_map = { + Status: sqlalchemy.Enum("pending", "received", "completed", name="status_enum"), + } + +Or alternatively within :func:`_orm.mapped_column`:: + + import enum + import typing + + import sqlalchemy + from sqlalchemy.orm import DeclarativeBase + + Status = Literal["pending", "received", "completed"] + + + class Base(DeclarativeBase): + pass + + + class SomeClass(Base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[Status] = mapped_column( + sqlalchemy.Enum("pending", "received", "completed", name="status_enum") + ) + +Altering the Configuration of the Default Enum ++++++++++++++++++++++++++++++++++++++++++++++++ + +In order to modify the fixed configuration of the :class:`.enum.Enum` datatype +that's generated implicitly, specify new entries in the +:paramref:`_orm.registry.type_annotation_map`, indicating additional arguments. +For example, to use "non native enumerations" unconditionally, the +:paramref:`.Enum.native_enum` parameter may be set to False for all types:: + + import enum + import typing + import sqlalchemy + from sqlalchemy.orm import DeclarativeBase + + + class Base(DeclarativeBase): + type_annotation_map = { + enum.Enum: sqlalchemy.Enum(enum.Enum, native_enum=False), + typing.Literal: sqlalchemy.Enum(enum.Enum, native_enum=False), + } + +.. versionchanged:: 2.0.1 Implemented support for overriding parameters + such as :paramref:`_sqltypes.Enum.native_enum` within the + :class:`_sqltypes.Enum` datatype when establishing the + :paramref:`_orm.registry.type_annotation_map`. Previously, this + functionality was not working. + +To use a specific configuration for a specific ``enum.Enum`` subtype, such +as setting the string length to 50 when using the example ``Status`` +datatype:: + + import enum + import sqlalchemy + from sqlalchemy.orm import DeclarativeBase + + + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" + + + class Base(DeclarativeBase): + type_annotation_map = { + Status: sqlalchemy.Enum(Status, length=50, native_enum=False) + } + +Linking Specific ``enum.Enum`` or ``typing.Literal`` to other datatypes +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +The above examples feature the use of an :class:`_sqltypes.Enum` that is +automatically configuring itself to the arguments / attributes present on +an ``enum.Enum`` or ``typing.Literal`` type object. For use cases where +specific kinds of ``enum.Enum`` or ``typing.Literal`` should be linked to +other types, these specific types may be placed in the type map also. +In the example below, an entry for ``Literal[]`` that contains non-string +types is linked to the :class:`_sqltypes.JSON` datatype:: + + + from typing import Literal + + from sqlalchemy import JSON + from sqlalchemy.orm import DeclarativeBase + + my_literal = Literal[0, 1, True, False, "true", "false"] + + + class Base(DeclarativeBase): + type_annotation_map = {my_literal: JSON} + +In the above configuration, the ``my_literal`` datatype will resolve to a +:class:`._sqltypes.JSON` instance. Other ``Literal`` variants will continue +to resolve to :class:`_sqltypes.Enum` datatypes. + + Dataclass features in ``mapped_column()`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 4f84438330..9b6c864ff7 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -77,6 +77,7 @@ from ..util import typing as compat_typing from ..util.typing import CallableReference from ..util.typing import flatten_newtype from ..util.typing import is_generic +from ..util.typing import is_literal from ..util.typing import is_newtype from ..util.typing import Literal @@ -1218,10 +1219,19 @@ class registry: ) -> Optional[sqltypes.TypeEngine[Any]]: search: Iterable[Tuple[_MatchedOnType, Type[Any]]] + python_type_type: Type[Any] if is_generic(python_type): - python_type_type: Type[Any] = python_type.__origin__ - search = ((python_type, python_type_type),) + if is_literal(python_type): + python_type_type = cast("Type[Any]", python_type) + + search = ( # type: ignore[assignment] + (python_type, python_type_type), + (Literal, python_type_type), + ) + else: + python_type_type = python_type.__origin__ + search = ((python_type, python_type_type),) elif is_newtype(python_type): python_type_type = flatten_newtype(python_type) search = ((python_type, python_type_type),) diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 0462a89456..a858f12cb9 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -66,6 +66,7 @@ from ..util import topological from ..util.typing import _AnnotationScanType from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref +from ..util.typing import is_literal from ..util.typing import Protocol from ..util.typing import TypedDict from ..util.typing import typing_get_args @@ -1165,7 +1166,7 @@ class _ClassScanMapperConfig(_MapperConfig): extracted_mapped_annotation, mapped_container = extracted - if attr_value is None: + if attr_value is None and not is_literal(extracted_mapped_annotation): for elem in typing_get_args(extracted_mapped_annotation): if isinstance(elem, str) or is_fwd_ref( elem, check_generic=True diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 717e6c0b22..b2dcc9b8a2 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -59,7 +59,9 @@ from .. import util from ..engine import processors from ..util import langhelpers from ..util import OrderedDict +from ..util.typing import is_literal from ..util.typing import Literal +from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument @@ -1263,6 +1265,11 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): .. seealso:: + :ref:`orm_declarative_mapped_column_enums` - background on using + the :class:`_sqltypes.Enum` datatype with the ORM's + :ref:`ORM Annotated Declarative ` + feature. + :class:`_postgresql.ENUM` - PostgreSQL-specific type, which has additional functionality. @@ -1504,16 +1511,54 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): matched_on: _MatchedOnType, matched_on_flattened: Type[Any], ) -> Optional[Enum]: - if not issubclass(python_type, enum.Enum): - return None + + # "generic form" indicates we were placed in a type map + # as ``sqlalchemy.Enum(enum.Enum)`` which indicates we need to + # get enumerated values from the datatype + we_are_generic_form = self._enums_argument == [enum.Enum] + + native_enum = None + + if not we_are_generic_form and python_type is matched_on: + # if we have enumerated values, and the incoming python + # type is exactly the one that matched in the type map, + # then we use these enumerated values and dont try to parse + # what's incoming + enum_args = self._enums_argument + + elif is_literal(python_type): + # for a literal, where we need to get its contents, parse it out. + enum_args = typing_get_args(python_type) + bad_args = [arg for arg in enum_args if not isinstance(arg, str)] + if bad_args: + raise exc.ArgumentError( + f"Can't create string-based Enum datatype from non-string " + f"values: {', '.join(repr(x) for x in bad_args)}. Please " + f"provide an explicit Enum datatype for this Python type" + ) + native_enum = False + elif isinstance(python_type, type) and issubclass( + python_type, enum.Enum + ): + # same for an enum.Enum + enum_args = [python_type] + + else: + enum_args = self._enums_argument + + # make a new Enum that looks like this one. + # pop the "name" so that it gets generated based on the enum + # arguments or other rules + kw = self._make_enum_kw({}) + + kw.pop("name", None) + if native_enum is False: + kw["native_enum"] = False + + kw["length"] = NO_ARG if self.length == 0 else self.length return cast( Enum, - util.constructor_copy( - self, - self._generic_type_affinity, - python_type, - length=NO_ARG if self.length == 0 else self.length, - ), + self._generic_type_affinity(_enums=enum_args, **kw), # type: ignore # noqa: E501 ) def _setup_for_values(self, values, objects, kw): @@ -1622,19 +1667,23 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): self, self._generic_type_affinity, *args, _disable_warnings=True ) - def adapt_to_emulated(self, impltype, **kw): + def _make_enum_kw(self, kw): kw.setdefault("validate_strings", self.validate_strings) kw.setdefault("name", self.name) - kw["_disable_warnings"] = True kw.setdefault("schema", self.schema) kw.setdefault("inherit_schema", self.inherit_schema) kw.setdefault("metadata", self.metadata) - kw.setdefault("_create_events", False) kw.setdefault("native_enum", self.native_enum) kw.setdefault("values_callable", self.values_callable) kw.setdefault("create_constraint", self.create_constraint) kw.setdefault("length", self.length) kw.setdefault("omit_aliases", self._omit_aliases) + return kw + + def adapt_to_emulated(self, impltype, **kw): + self._make_enum_kw(kw) + kw["_disable_warnings"] = True + kw.setdefault("_create_events", False) assert "_enums" in kw return impltype(**kw) @@ -3702,6 +3751,7 @@ _type_map: Dict[Type[Any], TypeEngine[Any]] = { bytes: LargeBinary(), str: _STRING, enum.Enum: Enum(enum.Enum), + Literal: Enum(enum.Enum), # type: ignore[dict-item] } diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 51e95ecfa2..755185c9b7 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -152,7 +152,11 @@ def de_stringify_annotation( annotation = eval_expression(annotation, originating_module) - if include_generic and is_generic(annotation): + if ( + include_generic + and is_generic(annotation) + and not is_literal(annotation) + ): elements = tuple( de_stringify_annotation( cls, @@ -249,6 +253,10 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated +def is_literal(type_: _AnnotationScanType) -> bool: + return get_origin(type_) is Literal + + def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]: return hasattr(type_, "__supertype__") diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index 8d3961ef70..307dbc157a 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -26,6 +26,9 @@ from typing import TypeVar from typing import Union import uuid +from typing_extensions import get_args as get_args +from typing_extensions import Literal as Literal + from sqlalchemy import BIGINT from sqlalchemy import BigInteger from sqlalchemy import Column @@ -73,6 +76,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true +from sqlalchemy.testing import Variation from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat from sqlalchemy.util.typing import Annotated @@ -1307,19 +1311,117 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped["fake"] # noqa + def test_type_dont_mis_resolve_on_superclass(self): + """test for #8859. + + For subclasses of a type that's in the map, don't resolve this + by default, even though we do a search through __mro__. + + """ + global int_sub + + class int_sub(int): + pass + + Base = declarative_base( + type_annotation_map={ + int: Integer, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[int_sub] + + @testing.variation( + "dict_key", ["typing", ("plain", testing.requires.python310)] + ) + def test_type_dont_mis_resolve_on_non_generic(self, dict_key): + """test for #8859. + + For a specific generic type with arguments, don't do any MRO + lookup. + + """ + + Base = declarative_base( + type_annotation_map={ + dict: String, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + + if dict_key.plain: + data: Mapped[dict[str, str]] + elif dict_key.typing: + data: Mapped[Dict[str, str]] + + def test_type_secondary_resolution(self): + class MyString(String): + def _resolve_for_python_type( + self, python_type, matched_type, matched_on_flattened + ): + return String(length=42) + + Base = declarative_base(type_annotation_map={str: MyString}) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + is_true(isinstance(MyClass.__table__.c.data.type, String)) + eq_(MyClass.__table__.c.data.type.length, 42) + + +class EnumOrLiteralTypeMapTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + @testing.variation("use_callable", [True, False]) @testing.variation("include_generic", [True, False]) - def test_enum_explicit(self, use_callable, include_generic): + @testing.variation("set_native_enum", ["none", True, False]) + def test_enum_explicit( + self, use_callable, include_generic, set_native_enum: Variation + ): global FooEnum class FooEnum(enum.Enum): foo = enum.auto() bar = enum.auto() + kw = {"length": 500} + + if set_native_enum.none: + expected_native_enum = True + elif set_native_enum.set_native_enum: + kw["native_enum"] = True + expected_native_enum = True + elif set_native_enum.not_set_native_enum: + kw["native_enum"] = False + expected_native_enum = False + else: + set_native_enum.fail() + if use_callable: - tam = {FooEnum: Enum(FooEnum, length=500)} + tam = {FooEnum: Enum(FooEnum, **kw)} else: - tam = {FooEnum: Enum(FooEnum, length=500)} + tam = {FooEnum: Enum(FooEnum, **kw)} + if include_generic: tam[enum.Enum] = Enum(enum.Enum) Base = declarative_base(type_annotation_map=tam) @@ -1333,8 +1435,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(isinstance(MyClass.__table__.c.data.type, Enum)) eq_(MyClass.__table__.c.data.type.length, 500) is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + is_(MyClass.__table__.c.data.type.native_enum, expected_native_enum) - def test_enum_generic(self): + @testing.variation("set_native_enum", ["none", True, False]) + def test_enum_generic(self, set_native_enum: Variation): """test for #8859""" global FooEnum @@ -1342,8 +1446,21 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): foo = enum.auto() bar = enum.auto() + kw = {"length": 42} + + if set_native_enum.none: + expected_native_enum = True + elif set_native_enum.set_native_enum: + kw["native_enum"] = True + expected_native_enum = True + elif set_native_enum.not_set_native_enum: + kw["native_enum"] = False + expected_native_enum = False + else: + set_native_enum.fail() + Base = declarative_base( - type_annotation_map={enum.Enum: Enum(enum.Enum, length=42)} + type_annotation_map={enum.Enum: Enum(enum.Enum, **kw)} ) class MyClass(Base): @@ -1355,6 +1472,7 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(isinstance(MyClass.__table__.c.data.type, Enum)) eq_(MyClass.__table__.c.data.type.length, 42) is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + is_(MyClass.__table__.c.data.type.native_enum, expected_native_enum) def test_enum_default(self, decl_base): """test #8859. @@ -1384,82 +1502,149 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): eq_(MyClass.__table__.c.data.type.length, 9) is_(MyClass.__table__.c.data.type.enum_class, FooEnum) - def test_type_dont_mis_resolve_on_superclass(self): - """test for #8859. + @testing.variation( + "sqltype", ["custom", "base_enum", "specific_enum", "string"] + ) + @testing.variation("indicate_type_explicitly", [True, False]) + def test_pep586_literal( + self, decl_base, sqltype: Variation, indicate_type_explicitly + ): + """test #9187.""" - For subclasses of a type that's in the map, don't resolve this - by default, even though we do a search through __mro__. + global Status - """ - global int_sub + Status = Literal["to-do", "in-progress", "done"] - class int_sub(int): - pass + if sqltype.custom: - Base = declarative_base( - type_annotation_map={ - int: Integer, - } - ) + class LiteralSqlType(types.TypeDecorator): + impl = types.String + cache_ok = True - with expect_raises_message( - sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" - ): + def __init__(self, literal_type: Any) -> None: + super().__init__() + self._possible_values = get_args(literal_type) - class MyClass(Base): - __tablename__ = "mytable" + our_type = mapped_col_type = LiteralSqlType(Status) + elif sqltype.specific_enum: + our_type = mapped_col_type = Enum( + "to-do", "in-progress", "done", native_enum=False + ) + elif sqltype.base_enum: + our_type = Enum(enum.Enum, native_enum=False) + mapped_col_type = Enum( + "to-do", "in-progress", "done", native_enum=False + ) + elif sqltype.string: + our_type = mapped_col_type = String(50) + else: + sqltype.fail() - id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[int_sub] + decl_base.registry.update_type_annotation_map({Status: our_type}) - @testing.variation( - "dict_key", ["typing", ("plain", testing.requires.python310)] - ) - def test_type_dont_mis_resolve_on_non_generic(self, dict_key): - """test for #8859. + class Foo(decl_base): + __tablename__ = "footable" - For a specific generic type with arguments, don't do any MRO - lookup. + id: Mapped[int] = mapped_column(primary_key=True) - """ + if indicate_type_explicitly: + status: Mapped[Status] = mapped_column(mapped_col_type) + else: + status: Mapped[Status] - Base = declarative_base( - type_annotation_map={ - dict: String, - } + is_true(isinstance(Foo.__table__.c.status.type, type(our_type))) + + if sqltype.custom: + eq_( + Foo.__table__.c.status.type._possible_values, + ("to-do", "in-progress", "done"), + ) + elif sqltype.specific_enum or sqltype.base_enum: + eq_( + Foo.__table__.c.status.type.enums, + ["to-do", "in-progress", "done"], + ) + is_(Foo.__table__.c.status.type.native_enum, False) + + @testing.variation("indicate_type_explicitly", [True, False]) + def test_pep586_literal_defaults_to_enum( + self, decl_base, indicate_type_explicitly + ): + """test #9187.""" + + global Status + + Status = Literal["to-do", "in-progress", "done"] + + if indicate_type_explicitly: + expected_native_enum = True + else: + expected_native_enum = False + + class Foo(decl_base): + __tablename__ = "footable" + + id: Mapped[int] = mapped_column(primary_key=True) + + if indicate_type_explicitly: + status: Mapped[Status] = mapped_column( + Enum("to-do", "in-progress", "done") + ) + else: + status: Mapped[Status] + + is_true(isinstance(Foo.__table__.c.status.type, Enum)) + + eq_( + Foo.__table__.c.status.type.enums, + ["to-do", "in-progress", "done"], ) + is_(Foo.__table__.c.status.type.native_enum, expected_native_enum) - with expect_raises_message( - sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" - ): + @testing.variation("override_in_type_map", [True, False]) + @testing.variation("indicate_type_explicitly", [True, False]) + def test_pep586_literal_checks_the_arguments( + self, decl_base, indicate_type_explicitly, override_in_type_map + ): + """test #9187.""" - class MyClass(Base): - __tablename__ = "mytable" + global NotReallyStrings - id: Mapped[int] = mapped_column(primary_key=True) + NotReallyStrings = Literal["str1", 17, False] - if dict_key.plain: - data: Mapped[dict[str, str]] - elif dict_key.typing: - data: Mapped[Dict[str, str]] + if override_in_type_map: + decl_base.registry.update_type_annotation_map( + {NotReallyStrings: JSON} + ) - def test_type_secondary_resolution(self): - class MyString(String): - def _resolve_for_python_type( - self, python_type, matched_type, matched_on_flattened + if not override_in_type_map and not indicate_type_explicitly: + with expect_raises_message( + ArgumentError, + "Can't create string-based Enum datatype from non-string " + "values: 17, False. Please provide an explicit Enum " + "datatype for this Python type", ): - return String(length=42) - Base = declarative_base(type_annotation_map={str: MyString}) + class Foo(decl_base): + __tablename__ = "footable" - class MyClass(Base): - __tablename__ = "mytable" + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[NotReallyStrings] - id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[str] + else: + # if we override the type in the type_map or mapped_column, + # then we can again use a Literal with non-strings + class Foo(decl_base): + __tablename__ = "footable" - is_true(isinstance(MyClass.__table__.c.data.type, String)) - eq_(MyClass.__table__.c.data.type.length, 42) + id: Mapped[int] = mapped_column(primary_key=True) + + if indicate_type_explicitly: + status: Mapped[NotReallyStrings] = mapped_column(JSON) + else: + status: Mapped[NotReallyStrings] + + is_true(isinstance(Foo.__table__.c.status.type, JSON)) class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): @@ -2622,21 +2807,43 @@ class BackendTests(fixtures.TestBase): @testing.variation("native_enum", [True, False]) @testing.variation("include_column", [True, False]) + @testing.variation("python_type", ["enum", "literal"]) def test_schema_type_actually_works( - self, connection, decl_base, include_column, native_enum + self, + connection, + decl_base, + include_column, + native_enum, + python_type: Variation, ): """test that schema type bindings are set up correctly""" global Status - class Status(enum.Enum): - PENDING = "pending" - RECEIVED = "received" - COMPLETED = "completed" + if python_type.enum: + + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" + + enum_argument = [Status] + test_value = Status.RECEIVED + elif python_type.literal: + Status = Literal[ # type: ignore + "pending", "received", "completed" + ] + enum_argument = ["pending", "received", "completed"] + test_value = "received" + else: + python_type.fail() if not include_column and not native_enum: decl_base.registry.update_type_annotation_map( - {enum.Enum: Enum(enum.Enum, native_enum=False)} + { + enum.Enum: Enum(enum.Enum, native_enum=False), + Literal: Enum(enum.Enum, native_enum=False), + } ) class SomeClass(decl_base): @@ -2646,7 +2853,11 @@ class BackendTests(fixtures.TestBase): if include_column: status: Mapped[Status] = mapped_column( - Enum(Status, native_enum=bool(native_enum)) + Enum( + *enum_argument, + native_enum=bool(native_enum), + name="status", + ) ) else: status: Mapped[Status] @@ -2654,12 +2865,12 @@ class BackendTests(fixtures.TestBase): decl_base.metadata.create_all(connection) with Session(connection) as sess: - sess.add(SomeClass(id=1, status=Status.RECEIVED)) + sess.add(SomeClass(id=1, status=test_value)) sess.commit() eq_( sess.scalars( select(SomeClass.status).where(SomeClass.id == 1) ).first(), - Status.RECEIVED, + test_value, ) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 87fc298629..762c879e6c 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -17,6 +17,9 @@ from typing import TypeVar from typing import Union import uuid +from typing_extensions import get_args as get_args +from typing_extensions import Literal as Literal + from sqlalchemy import BIGINT from sqlalchemy import BigInteger from sqlalchemy import Column @@ -64,6 +67,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true +from sqlalchemy.testing import Variation from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat from sqlalchemy.util.typing import Annotated @@ -1298,19 +1302,117 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped["fake"] # noqa + def test_type_dont_mis_resolve_on_superclass(self): + """test for #8859. + + For subclasses of a type that's in the map, don't resolve this + by default, even though we do a search through __mro__. + + """ + # anno only: global int_sub + + class int_sub(int): + pass + + Base = declarative_base( + type_annotation_map={ + int: Integer, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[int_sub] + + @testing.variation( + "dict_key", ["typing", ("plain", testing.requires.python310)] + ) + def test_type_dont_mis_resolve_on_non_generic(self, dict_key): + """test for #8859. + + For a specific generic type with arguments, don't do any MRO + lookup. + + """ + + Base = declarative_base( + type_annotation_map={ + dict: String, + } + ) + + with expect_raises_message( + sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" + ): + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + + if dict_key.plain: + data: Mapped[dict[str, str]] + elif dict_key.typing: + data: Mapped[Dict[str, str]] + + def test_type_secondary_resolution(self): + class MyString(String): + def _resolve_for_python_type( + self, python_type, matched_type, matched_on_flattened + ): + return String(length=42) + + Base = declarative_base(type_annotation_map={str: MyString}) + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + is_true(isinstance(MyClass.__table__.c.data.type, String)) + eq_(MyClass.__table__.c.data.type.length, 42) + + +class EnumOrLiteralTypeMapTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + @testing.variation("use_callable", [True, False]) @testing.variation("include_generic", [True, False]) - def test_enum_explicit(self, use_callable, include_generic): + @testing.variation("set_native_enum", ["none", True, False]) + def test_enum_explicit( + self, use_callable, include_generic, set_native_enum: Variation + ): # anno only: global FooEnum class FooEnum(enum.Enum): foo = enum.auto() bar = enum.auto() + kw = {"length": 500} + + if set_native_enum.none: + expected_native_enum = True + elif set_native_enum.set_native_enum: + kw["native_enum"] = True + expected_native_enum = True + elif set_native_enum.not_set_native_enum: + kw["native_enum"] = False + expected_native_enum = False + else: + set_native_enum.fail() + if use_callable: - tam = {FooEnum: Enum(FooEnum, length=500)} + tam = {FooEnum: Enum(FooEnum, **kw)} else: - tam = {FooEnum: Enum(FooEnum, length=500)} + tam = {FooEnum: Enum(FooEnum, **kw)} + if include_generic: tam[enum.Enum] = Enum(enum.Enum) Base = declarative_base(type_annotation_map=tam) @@ -1324,8 +1426,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(isinstance(MyClass.__table__.c.data.type, Enum)) eq_(MyClass.__table__.c.data.type.length, 500) is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + is_(MyClass.__table__.c.data.type.native_enum, expected_native_enum) - def test_enum_generic(self): + @testing.variation("set_native_enum", ["none", True, False]) + def test_enum_generic(self, set_native_enum: Variation): """test for #8859""" # anno only: global FooEnum @@ -1333,8 +1437,21 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): foo = enum.auto() bar = enum.auto() + kw = {"length": 42} + + if set_native_enum.none: + expected_native_enum = True + elif set_native_enum.set_native_enum: + kw["native_enum"] = True + expected_native_enum = True + elif set_native_enum.not_set_native_enum: + kw["native_enum"] = False + expected_native_enum = False + else: + set_native_enum.fail() + Base = declarative_base( - type_annotation_map={enum.Enum: Enum(enum.Enum, length=42)} + type_annotation_map={enum.Enum: Enum(enum.Enum, **kw)} ) class MyClass(Base): @@ -1346,6 +1463,7 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(isinstance(MyClass.__table__.c.data.type, Enum)) eq_(MyClass.__table__.c.data.type.length, 42) is_(MyClass.__table__.c.data.type.enum_class, FooEnum) + is_(MyClass.__table__.c.data.type.native_enum, expected_native_enum) def test_enum_default(self, decl_base): """test #8859. @@ -1375,82 +1493,149 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): eq_(MyClass.__table__.c.data.type.length, 9) is_(MyClass.__table__.c.data.type.enum_class, FooEnum) - def test_type_dont_mis_resolve_on_superclass(self): - """test for #8859. + @testing.variation( + "sqltype", ["custom", "base_enum", "specific_enum", "string"] + ) + @testing.variation("indicate_type_explicitly", [True, False]) + def test_pep586_literal( + self, decl_base, sqltype: Variation, indicate_type_explicitly + ): + """test #9187.""" - For subclasses of a type that's in the map, don't resolve this - by default, even though we do a search through __mro__. + # anno only: global Status - """ - # anno only: global int_sub + Status = Literal["to-do", "in-progress", "done"] - class int_sub(int): - pass + if sqltype.custom: - Base = declarative_base( - type_annotation_map={ - int: Integer, - } - ) + class LiteralSqlType(types.TypeDecorator): + impl = types.String + cache_ok = True - with expect_raises_message( - sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" - ): + def __init__(self, literal_type: Any) -> None: + super().__init__() + self._possible_values = get_args(literal_type) - class MyClass(Base): - __tablename__ = "mytable" + our_type = mapped_col_type = LiteralSqlType(Status) + elif sqltype.specific_enum: + our_type = mapped_col_type = Enum( + "to-do", "in-progress", "done", native_enum=False + ) + elif sqltype.base_enum: + our_type = Enum(enum.Enum, native_enum=False) + mapped_col_type = Enum( + "to-do", "in-progress", "done", native_enum=False + ) + elif sqltype.string: + our_type = mapped_col_type = String(50) + else: + sqltype.fail() - id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[int_sub] + decl_base.registry.update_type_annotation_map({Status: our_type}) - @testing.variation( - "dict_key", ["typing", ("plain", testing.requires.python310)] - ) - def test_type_dont_mis_resolve_on_non_generic(self, dict_key): - """test for #8859. + class Foo(decl_base): + __tablename__ = "footable" - For a specific generic type with arguments, don't do any MRO - lookup. + id: Mapped[int] = mapped_column(primary_key=True) - """ + if indicate_type_explicitly: + status: Mapped[Status] = mapped_column(mapped_col_type) + else: + status: Mapped[Status] - Base = declarative_base( - type_annotation_map={ - dict: String, - } + is_true(isinstance(Foo.__table__.c.status.type, type(our_type))) + + if sqltype.custom: + eq_( + Foo.__table__.c.status.type._possible_values, + ("to-do", "in-progress", "done"), + ) + elif sqltype.specific_enum or sqltype.base_enum: + eq_( + Foo.__table__.c.status.type.enums, + ["to-do", "in-progress", "done"], + ) + is_(Foo.__table__.c.status.type.native_enum, False) + + @testing.variation("indicate_type_explicitly", [True, False]) + def test_pep586_literal_defaults_to_enum( + self, decl_base, indicate_type_explicitly + ): + """test #9187.""" + + # anno only: global Status + + Status = Literal["to-do", "in-progress", "done"] + + if indicate_type_explicitly: + expected_native_enum = True + else: + expected_native_enum = False + + class Foo(decl_base): + __tablename__ = "footable" + + id: Mapped[int] = mapped_column(primary_key=True) + + if indicate_type_explicitly: + status: Mapped[Status] = mapped_column( + Enum("to-do", "in-progress", "done") + ) + else: + status: Mapped[Status] + + is_true(isinstance(Foo.__table__.c.status.type, Enum)) + + eq_( + Foo.__table__.c.status.type.enums, + ["to-do", "in-progress", "done"], ) + is_(Foo.__table__.c.status.type.native_enum, expected_native_enum) - with expect_raises_message( - sa_exc.ArgumentError, "Could not locate SQLAlchemy Core type" - ): + @testing.variation("override_in_type_map", [True, False]) + @testing.variation("indicate_type_explicitly", [True, False]) + def test_pep586_literal_checks_the_arguments( + self, decl_base, indicate_type_explicitly, override_in_type_map + ): + """test #9187.""" - class MyClass(Base): - __tablename__ = "mytable" + # anno only: global NotReallyStrings - id: Mapped[int] = mapped_column(primary_key=True) + NotReallyStrings = Literal["str1", 17, False] - if dict_key.plain: - data: Mapped[dict[str, str]] - elif dict_key.typing: - data: Mapped[Dict[str, str]] + if override_in_type_map: + decl_base.registry.update_type_annotation_map( + {NotReallyStrings: JSON} + ) - def test_type_secondary_resolution(self): - class MyString(String): - def _resolve_for_python_type( - self, python_type, matched_type, matched_on_flattened + if not override_in_type_map and not indicate_type_explicitly: + with expect_raises_message( + ArgumentError, + "Can't create string-based Enum datatype from non-string " + "values: 17, False. Please provide an explicit Enum " + "datatype for this Python type", ): - return String(length=42) - Base = declarative_base(type_annotation_map={str: MyString}) + class Foo(decl_base): + __tablename__ = "footable" - class MyClass(Base): - __tablename__ = "mytable" + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[NotReallyStrings] - id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[str] + else: + # if we override the type in the type_map or mapped_column, + # then we can again use a Literal with non-strings + class Foo(decl_base): + __tablename__ = "footable" - is_true(isinstance(MyClass.__table__.c.data.type, String)) - eq_(MyClass.__table__.c.data.type.length, 42) + id: Mapped[int] = mapped_column(primary_key=True) + + if indicate_type_explicitly: + status: Mapped[NotReallyStrings] = mapped_column(JSON) + else: + status: Mapped[NotReallyStrings] + + is_true(isinstance(Foo.__table__.c.status.type, JSON)) class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): @@ -2613,21 +2798,43 @@ class BackendTests(fixtures.TestBase): @testing.variation("native_enum", [True, False]) @testing.variation("include_column", [True, False]) + @testing.variation("python_type", ["enum", "literal"]) def test_schema_type_actually_works( - self, connection, decl_base, include_column, native_enum + self, + connection, + decl_base, + include_column, + native_enum, + python_type: Variation, ): """test that schema type bindings are set up correctly""" # anno only: global Status - class Status(enum.Enum): - PENDING = "pending" - RECEIVED = "received" - COMPLETED = "completed" + if python_type.enum: + + class Status(enum.Enum): + PENDING = "pending" + RECEIVED = "received" + COMPLETED = "completed" + + enum_argument = [Status] + test_value = Status.RECEIVED + elif python_type.literal: + Status = Literal[ # type: ignore + "pending", "received", "completed" + ] + enum_argument = ["pending", "received", "completed"] + test_value = "received" + else: + python_type.fail() if not include_column and not native_enum: decl_base.registry.update_type_annotation_map( - {enum.Enum: Enum(enum.Enum, native_enum=False)} + { + enum.Enum: Enum(enum.Enum, native_enum=False), + Literal: Enum(enum.Enum, native_enum=False), + } ) class SomeClass(decl_base): @@ -2637,7 +2844,11 @@ class BackendTests(fixtures.TestBase): if include_column: status: Mapped[Status] = mapped_column( - Enum(Status, native_enum=bool(native_enum)) + Enum( + *enum_argument, + native_enum=bool(native_enum), + name="status", + ) ) else: status: Mapped[Status] @@ -2645,12 +2856,12 @@ class BackendTests(fixtures.TestBase): decl_base.metadata.create_all(connection) with Session(connection) as sess: - sess.add(SomeClass(id=1, status=Status.RECEIVED)) + sess.add(SomeClass(id=1, status=test_value)) sess.commit() eq_( sess.scalars( select(SomeClass.status).where(SomeClass.id == 1) ).first(), - Status.RECEIVED, + test_value, ) -- 2.47.3