From 7d13f705307bf62560fc831f6f049a425d411374 Mon Sep 17 00:00:00 2001 From: Frederik Aalund Date: Mon, 30 Jan 2023 16:19:52 +0100 Subject: [PATCH] Add support for typing.Literal in Mapped Fixes: #9187 --- lib/sqlalchemy/orm/decl_base.py | 3 +- lib/sqlalchemy/util/typing.py | 12 ++++- test/orm/declarative/test_typed_mapping.py | 55 ++++++++++++++++++++++ 3 files changed, 67 insertions(+), 3 deletions(-) diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 9e8b023597..e3d7a61a89 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -66,6 +66,7 @@ from ..util import topological from ..util.typing import _AnnotationScanType from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref +from ..util.typing import is_literal from ..util.typing import Protocol from ..util.typing import TypedDict from ..util.typing import typing_get_args @@ -1157,7 +1158,7 @@ class _ClassScanMapperConfig(_MapperConfig): extracted_mapped_annotation, mapped_container = extracted - if attr_value is None: + if attr_value is None and not is_literal(extracted_mapped_annotation): for elem in typing_get_args(extracted_mapped_annotation): if isinstance(elem, str) or is_fwd_ref( elem, check_generic=True diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index e1670ed21b..d4ba94fb5a 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -150,7 +150,7 @@ def de_stringify_annotation( annotation = eval_expression(annotation, originating_module) - if include_generic and is_generic(annotation): + if include_generic and is_generic(annotation, include_literal=False): elements = tuple( de_stringify_annotation( cls, @@ -246,8 +246,16 @@ def de_stringify_union_elements( def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated +def is_literal(type_: _AnnotationScanType) -> bool: + return get_origin(type_) is Literal -def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]: +def is_generic( + type_: _AnnotationScanType, + *, + include_literal: bool = True +) -> TypeGuard[GenericProtocol[Any]]: + if not include_literal and is_literal(type_): + return False return hasattr(type_, "__args__") and hasattr(type_, "__origin__") diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 8838afd0ff..ebdb7aa0a4 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -16,6 +16,9 @@ from typing import TypeVar from typing import Union import uuid +from typing_extensions import get_args as get_args # 3.10 +from typing_extensions import Literal as Literal # 3.8 + from sqlalchemy import BIGINT from sqlalchemy import BigInteger from sqlalchemy import Column @@ -1340,6 +1343,58 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(isinstance(MyClass.__table__.c.data.type, String)) eq_(MyClass.__table__.c.data.type.length, 42) + def test_string_literal(self, decl_base): + """test #9187.""" + class LiteralSqlType(types.TypeDecorator): + impl = types.String + cache_ok = True + + def __init__(self, literal_type: Any) -> None: + super().__init__() + self._possible_values = get_args(literal_type) + + def process_bind_param( + self, value: Optional[str], dialect + ) -> Optional[str]: + if value not in self._possible_values: + raise ValueError( + f"Invalid literal value '{value}'. Value must be one of" + f" {self._possible_values}." + ) + return value + + Status = Literal["to-do", "in-progress", "done"] + Base = declarative_base() + BaseWithMap = declarative_base( + type_annotation_map={Status: LiteralSqlType(Status)} + ) + + class Foo(Base): + __tablename__ = "footable" + + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[Status] = mapped_column(LiteralSqlType(Status)) + + class Bar(Base): + __tablename__ = "bartable" + + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[Annotated[Status, mapped_column(LiteralSqlType(Status))]] + + class Baz(BaseWithMap): + __tablename__ = "baztable" + + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[Status] + + is_true(isinstance(Foo.__table__.c.status.type, LiteralSqlType)) + is_true(isinstance(Bar.__table__.c.status.type, LiteralSqlType)) + is_true(isinstance(Baz.__table__.c.status.type, LiteralSqlType)) + eq_(Foo.__table__.c.status.type._possible_values, ("to-do", "in-progress", "done")) + eq_(Bar.__table__.c.status.type._possible_values, ("to-do", "in-progress", "done")) + eq_(Baz.__table__.c.status.type._possible_values, ("to-do", "in-progress", "done")) + + class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" -- 2.47.3