From: Daraan Date: Wed, 26 Mar 2025 18:27:46 +0000 (-0400) Subject: compatibility with typing_extensions 4.13 and type statement X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=690e754b653b79db847458ebf500cc7a34f4c62f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git compatibility with typing_extensions 4.13 and type statement Fixed regression caused by ``typing_extension==4.13.0`` that introduced a different implementation for ``TypeAliasType`` while SQLAlchemy assumed that it would be equivalent to the ``typing`` version. Added test regarding generic TypeAliasType Fixes: #12473 Closes: #12472 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12472 Pull-request-sha: 8861a5acfb8e81663413ff144b41abf64779b6fd Change-Id: I053019a222546a625ed6d588314ae9f5b34c2f8a --- diff --git a/doc/build/changelog/unreleased_20/12473.rst b/doc/build/changelog/unreleased_20/12473.rst new file mode 100644 index 0000000000..5127d92dd2 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12473.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, typing + :tickets: 12473 + + Fixed regression caused by ``typing_extension==4.13.0`` that introduced + a different implementation for ``TypeAliasType`` while SQLAlchemy assumed + that it would be equivalent to the ``typing`` version. diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index f3cec699b8..81a6d18ce9 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -1233,7 +1233,7 @@ class registry: search = ( (python_type, python_type_type), - *((lt, python_type_type) for lt in LITERAL_TYPES), # type: ignore[arg-type] # noqa: E501 + *((lt, python_type_type) for lt in LITERAL_TYPES), ) else: python_type_type = python_type.__origin__ diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index a1fb5920b9..dee25a71d0 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -34,6 +34,8 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +import typing_extensions + from . import compat if True: # zimports removes the tailing comments @@ -68,10 +70,6 @@ _VT_co = TypeVar("_VT_co", covariant=True) TupleAny = Tuple[Any, ...] -# typing_extensions.Literal is different from typing.Literal until -# Python 3.10.1 -LITERAL_TYPES = frozenset([typing.Literal, Literal]) - if compat.py310: # why they took until py310 to put this in stdlib is beyond me, @@ -331,7 +329,7 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str: def is_pep593(type_: Optional[Any]) -> bool: - return type_ is not None and get_origin(type_) is Annotated + return type_ is not None and get_origin(type_) in _type_tuples.Annotated def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]: @@ -341,14 +339,14 @@ def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]: def is_literal(type_: Any) -> bool: - return get_origin(type_) in LITERAL_TYPES + return get_origin(type_) in _type_tuples.Literal def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]: return hasattr(type_, "__supertype__") # doesn't work in 3.9, 3.8, 3.7 as it passes a closure, not an # object instance - # return isinstance(type_, NewType) + # isinstance(type, type_instances.NewType) def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]: @@ -356,7 +354,13 @@ def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]: def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]: - return isinstance(type_, TypeAliasType) + # NOTE: a generic TAT does not instance check as TypeAliasType outside of + # python 3.10. For sqlalchemy use cases it's fine to consider it a TAT + # though. + # NOTE: things seems to work also without this additional check + if is_generic(type_): + return is_pep695(type_.__origin__) + return isinstance(type_, _type_instances.TypeAliasType) def pep695_values(type_: _AnnotationScanType) -> Set[Any]: @@ -368,15 +372,15 @@ def pep695_values(type_: _AnnotationScanType) -> Set[Any]: """ _seen = set() - def recursive_value(type_): - if type_ in _seen: + def recursive_value(inner_type): + if inner_type in _seen: # recursion are not supported (at least it's flagged as # an error by pyright). Just avoid infinite loop - return type_ - _seen.add(type_) - if not is_pep695(type_): - return type_ - value = type_.__value__ + return inner_type + _seen.add(inner_type) + if not is_pep695(inner_type): + return inner_type + value = inner_type.__value__ if not is_union(value): return value return [recursive_value(t) for t in value.__args__] @@ -403,7 +407,7 @@ def is_fwd_ref( ) -> TypeGuard[ForwardRef]: if check_for_plain_string and isinstance(type_, str): return True - elif isinstance(type_, ForwardRef): + elif isinstance(type_, _type_instances.ForwardRef): return True elif check_generic and is_generic(type_): return any( @@ -677,3 +681,30 @@ class CallableReference(Generic[_FN]): def __set__(self, instance: Any, value: _FN) -> None: ... def __delete__(self, instance: Any) -> None: ... + + +class _TypingInstances: + def __getattr__(self, key: str) -> tuple[type, ...]: + types = tuple( + { + t + for t in [ + getattr(typing, key, None), + getattr(typing_extensions, key, None), + ] + if t is not None + } + ) + if not types: + raise AttributeError(key) + self.__dict__[key] = types + return types + + +_type_tuples = _TypingInstances() +if TYPE_CHECKING: + _type_instances = typing_extensions +else: + _type_instances = _type_tuples + +LITERAL_TYPES = _type_tuples.Literal diff --git a/test/base/test_typing_utils.py b/test/base/test_typing_utils.py index 6cddef6508..7a6aca3c85 100644 --- a/test/base/test_typing_utils.py +++ b/test/base/test_typing_utils.py @@ -38,63 +38,144 @@ def null_union_types(): return res +def generic_unions(): + # remove new-style unions `int | str` that are not generic + res = union_types() + null_union_types() + if py310: + new_ut = type(int | str) + res = [t for t in res if not isinstance(t, new_ut)] + return res + + def make_fw_ref(anno: str) -> typing.ForwardRef: return typing.Union[anno] -TA_int = typing_extensions.TypeAliasType("TA_int", int) -TA_union = typing_extensions.TypeAliasType("TA_union", typing.Union[int, str]) -TA_null_union = typing_extensions.TypeAliasType( - "TA_null_union", typing.Union[int, str, None] +TypeAliasType = getattr( + typing, "TypeAliasType", typing_extensions.TypeAliasType ) -TA_null_union2 = typing_extensions.TypeAliasType( + +TA_int = TypeAliasType("TA_int", int) +TAext_int = typing_extensions.TypeAliasType("TAext_int", int) +TA_union = TypeAliasType("TA_union", typing.Union[int, str]) +TAext_union = typing_extensions.TypeAliasType( + "TAext_union", typing.Union[int, str] +) +TA_null_union = TypeAliasType("TA_null_union", typing.Union[int, str, None]) +TAext_null_union = typing_extensions.TypeAliasType( + "TAext_null_union", typing.Union[int, str, None] +) +TA_null_union2 = TypeAliasType( "TA_null_union2", typing.Union[int, str, "None"] ) -TA_null_union3 = typing_extensions.TypeAliasType( +TAext_null_union2 = typing_extensions.TypeAliasType( + "TAext_null_union2", typing.Union[int, str, "None"] +) +TA_null_union3 = TypeAliasType( "TA_null_union3", typing.Union[int, "typing.Union[None, bool]"] ) -TA_null_union4 = typing_extensions.TypeAliasType( +TAext_null_union3 = typing_extensions.TypeAliasType( + "TAext_null_union3", typing.Union[int, "typing.Union[None, bool]"] +) +TA_null_union4 = TypeAliasType( "TA_null_union4", typing.Union[int, "TA_null_union2"] ) -TA_union_ta = typing_extensions.TypeAliasType( - "TA_union_ta", typing.Union[TA_int, str] +TAext_null_union4 = typing_extensions.TypeAliasType( + "TAext_null_union4", typing.Union[int, "TAext_null_union2"] +) +TA_union_ta = TypeAliasType("TA_union_ta", typing.Union[TA_int, str]) +TAext_union_ta = typing_extensions.TypeAliasType( + "TAext_union_ta", typing.Union[TAext_int, str] ) -TA_null_union_ta = typing_extensions.TypeAliasType( +TA_null_union_ta = TypeAliasType( "TA_null_union_ta", typing.Union[TA_null_union, float] ) -TA_list = typing_extensions.TypeAliasType( +TAext_null_union_ta = typing_extensions.TypeAliasType( + "TAext_null_union_ta", typing.Union[TAext_null_union, float] +) +TA_list = TypeAliasType( "TA_list", typing.Union[int, str, typing.List["TA_list"]] ) +TAext_list = typing_extensions.TypeAliasType( + "TAext_list", typing.Union[int, str, typing.List["TAext_list"]] +) # these below not valid. Verify that it does not cause exceptions in any case -TA_recursive = typing_extensions.TypeAliasType( - "TA_recursive", typing.Union["TA_recursive", str] +TA_recursive = TypeAliasType("TA_recursive", typing.Union["TA_recursive", str]) +TAext_recursive = typing_extensions.TypeAliasType( + "TAext_recursive", typing.Union["TAext_recursive", str] ) -TA_null_recursive = typing_extensions.TypeAliasType( +TA_null_recursive = TypeAliasType( "TA_null_recursive", typing.Union[TA_recursive, None] ) -TA_recursive_a = typing_extensions.TypeAliasType( +TAext_null_recursive = typing_extensions.TypeAliasType( + "TAext_null_recursive", typing.Union[TAext_recursive, None] +) +TA_recursive_a = TypeAliasType( "TA_recursive_a", typing.Union["TA_recursive_b", int] ) -TA_recursive_b = typing_extensions.TypeAliasType( +TAext_recursive_a = typing_extensions.TypeAliasType( + "TAext_recursive_a", typing.Union["TAext_recursive_b", int] +) +TA_recursive_b = TypeAliasType( "TA_recursive_b", typing.Union["TA_recursive_a", str] ) +TAext_recursive_b = typing_extensions.TypeAliasType( + "TAext_recursive_b", typing.Union["TAext_recursive_a", str] +) +TA_generic = TypeAliasType("TA_generic", typing.List[TV], type_params=(TV,)) +TAext_generic = typing_extensions.TypeAliasType( + "TAext_generic", typing.List[TV], type_params=(TV,) +) +TA_generic_typed = TA_generic[int] +TAext_generic_typed = TAext_generic[int] +TA_generic_null = TypeAliasType( + "TA_generic_null", typing.Union[typing.List[TV], None], type_params=(TV,) +) +TAext_generic_null = typing_extensions.TypeAliasType( + "TAext_generic_null", + typing.Union[typing.List[TV], None], + type_params=(TV,), +) +TA_generic_null_typed = TA_generic_null[str] +TAext_generic_null_typed = TAext_generic_null[str] def type_aliases(): return [ TA_int, + TAext_int, TA_union, + TAext_union, TA_null_union, + TAext_null_union, TA_null_union2, + TAext_null_union2, TA_null_union3, + TAext_null_union3, TA_null_union4, + TAext_null_union4, TA_union_ta, + TAext_union_ta, TA_null_union_ta, + TAext_null_union_ta, TA_list, + TAext_list, TA_recursive, + TAext_recursive, TA_null_recursive, + TAext_null_recursive, TA_recursive_a, + TAext_recursive_a, TA_recursive_b, + TAext_recursive_b, + TA_generic, + TAext_generic, + TA_generic_typed, + TAext_generic_typed, + TA_generic_null, + TAext_generic_null, + TA_generic_null_typed, + TAext_generic_null_typed, ] @@ -143,11 +224,14 @@ def exec_code(code: str, *vars: str) -> typing.Any: class TestTestingThings(fixtures.TestBase): def test_unions_are_the_same(self): + # the point of this test is to reduce the cases to test since + # some symbols are the same in typing and typing_extensions. + # If a test starts failing then additional cases should be added, + # similar to what it's done for TypeAliasType + # no need to test typing_extensions.Union, typing_extensions.Optional is_(typing.Union, typing_extensions.Union) is_(typing.Optional, typing_extensions.Optional) - if py312: - is_(typing.TypeAliasType, typing_extensions.TypeAliasType) def test_make_union(self): v = int, str @@ -221,8 +305,19 @@ class TestTyping(fixtures.TestBase): eq_(sa_typing.is_generic(t), False) eq_(sa_typing.is_generic(t[int]), True) + generics = [ + TA_generic_typed, + TAext_generic_typed, + TA_generic_null_typed, + TAext_generic_null_typed, + *annotated_l(), + *generic_unions(), + ] + for t in all_types(): - eq_(sa_typing.is_literal(t), False) + # use is since union compare equal between new/old style + exp = any(t is k for k in generics) + eq_(sa_typing.is_generic(t), exp, t) def test_is_pep695(self): eq_(sa_typing.is_pep695(str), False) @@ -249,41 +344,100 @@ class TestTyping(fixtures.TestBase): sa_typing.pep695_values(typing.Union[int, TA_int]), {typing.Union[int, TA_int]}, ) + eq_( + sa_typing.pep695_values(typing.Union[int, TAext_int]), + {typing.Union[int, TAext_int]}, + ) eq_(sa_typing.pep695_values(TA_int), {int}) + eq_(sa_typing.pep695_values(TAext_int), {int}) eq_(sa_typing.pep695_values(TA_union), {int, str}) + eq_(sa_typing.pep695_values(TAext_union), {int, str}) eq_(sa_typing.pep695_values(TA_null_union), {int, str, None}) + eq_(sa_typing.pep695_values(TAext_null_union), {int, str, None}) eq_(sa_typing.pep695_values(TA_null_union2), {int, str, None}) + eq_(sa_typing.pep695_values(TAext_null_union2), {int, str, None}) eq_( sa_typing.pep695_values(TA_null_union3), {int, typing.ForwardRef("typing.Union[None, bool]")}, ) + eq_( + sa_typing.pep695_values(TAext_null_union3), + {int, typing.ForwardRef("typing.Union[None, bool]")}, + ) eq_( sa_typing.pep695_values(TA_null_union4), {int, typing.ForwardRef("TA_null_union2")}, ) + eq_( + sa_typing.pep695_values(TAext_null_union4), + {int, typing.ForwardRef("TAext_null_union2")}, + ) eq_(sa_typing.pep695_values(TA_union_ta), {int, str}) + eq_(sa_typing.pep695_values(TAext_union_ta), {int, str}) eq_(sa_typing.pep695_values(TA_null_union_ta), {int, str, None, float}) + eq_( + sa_typing.pep695_values(TAext_null_union_ta), + {int, str, None, float}, + ) eq_( sa_typing.pep695_values(TA_list), {int, str, typing.List[typing.ForwardRef("TA_list")]}, ) + eq_( + sa_typing.pep695_values(TAext_list), + {int, str, typing.List[typing.ForwardRef("TAext_list")]}, + ) eq_( sa_typing.pep695_values(TA_recursive), {typing.ForwardRef("TA_recursive"), str}, ) + eq_( + sa_typing.pep695_values(TAext_recursive), + {typing.ForwardRef("TAext_recursive"), str}, + ) eq_( sa_typing.pep695_values(TA_null_recursive), {typing.ForwardRef("TA_recursive"), str, None}, ) + eq_( + sa_typing.pep695_values(TAext_null_recursive), + {typing.ForwardRef("TAext_recursive"), str, None}, + ) eq_( sa_typing.pep695_values(TA_recursive_a), {typing.ForwardRef("TA_recursive_b"), int}, ) + eq_( + sa_typing.pep695_values(TAext_recursive_a), + {typing.ForwardRef("TAext_recursive_b"), int}, + ) eq_( sa_typing.pep695_values(TA_recursive_b), {typing.ForwardRef("TA_recursive_a"), str}, ) + eq_( + sa_typing.pep695_values(TAext_recursive_b), + {typing.ForwardRef("TAext_recursive_a"), str}, + ) + # generics + eq_(sa_typing.pep695_values(TA_generic), {typing.List[TV]}) + eq_(sa_typing.pep695_values(TAext_generic), {typing.List[TV]}) + eq_(sa_typing.pep695_values(TA_generic_typed), {typing.List[TV]}) + eq_(sa_typing.pep695_values(TAext_generic_typed), {typing.List[TV]}) + eq_(sa_typing.pep695_values(TA_generic_null), {None, typing.List[TV]}) + eq_( + sa_typing.pep695_values(TAext_generic_null), + {None, typing.List[TV]}, + ) + eq_( + sa_typing.pep695_values(TA_generic_null_typed), + {None, typing.List[TV]}, + ) + eq_( + sa_typing.pep695_values(TAext_generic_null_typed), + {None, typing.List[TV]}, + ) def test_is_fwd_ref(self): eq_(sa_typing.is_fwd_ref(int), False) @@ -346,6 +500,10 @@ class TestTyping(fixtures.TestBase): sa_typing.make_union_type(bool, TA_int, NT_str), typing.Union[bool, TA_int, NT_str], ) + eq_( + sa_typing.make_union_type(bool, TAext_int, NT_str), + typing.Union[bool, TAext_int, NT_str], + ) def test_includes_none(self): eq_(sa_typing.includes_none(None), True) @@ -359,11 +517,12 @@ class TestTyping(fixtures.TestBase): eq_(sa_typing.includes_none(t), True, str(t)) # TODO: these are false negatives - false_negative = { + false_negatives = { TA_null_union4, # does not evaluate FW ref + TAext_null_union4, # does not evaluate FW ref } for t in type_aliases() + new_types(): - if t in false_negative: + if t in false_negatives: exp = False else: exp = "null" in t.__name__ @@ -378,6 +537,9 @@ class TestTyping(fixtures.TestBase): # nested things eq_(sa_typing.includes_none(typing.Union[int, "None"]), True) eq_(sa_typing.includes_none(typing.Union[bool, TA_null_union]), True) + eq_( + sa_typing.includes_none(typing.Union[bool, TAext_null_union]), True + ) eq_(sa_typing.includes_none(typing.Union[bool, NT_null]), True) # nested fw eq_( @@ -397,6 +559,10 @@ class TestTyping(fixtures.TestBase): eq_( sa_typing.includes_none(typing.Union[bool, "TA_null_union"]), False ) + eq_( + sa_typing.includes_none(typing.Union[bool, "TAext_null_union"]), + False, + ) eq_(sa_typing.includes_none(typing.Union[bool, "NT_null"]), False) def test_is_union(self): @@ -405,3 +571,26 @@ class TestTyping(fixtures.TestBase): eq_(sa_typing.is_union(t), True) for t in type_aliases() + new_types() + annotated_l(): eq_(sa_typing.is_union(t), False) + + def test_TypingInstances(self): + is_(sa_typing._type_tuples, sa_typing._type_instances) + is_( + isinstance(sa_typing._type_instances, sa_typing._TypingInstances), + True, + ) + + # cached + is_( + sa_typing._type_instances.Literal, + sa_typing._type_instances.Literal, + ) + + for k in ["Literal", "Annotated", "TypeAliasType"]: + types = set() + ti = getattr(sa_typing._type_instances, k) + for lib in [typing, typing_extensions]: + lt = getattr(lib, k, None) + if lt is not None: + types.add(lt) + is_(lt in ti, True) + eq_(len(ti), len(types), k) diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index d7d9414661..f0b3e81fd7 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -105,6 +105,8 @@ from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat from sqlalchemy.util.typing import Annotated +TV = typing.TypeVar("TV") + class _SomeDict1(TypedDict): type: Literal["1"] @@ -136,7 +138,16 @@ if compat.py310: ) _JsonPep695 = TypeAliasType("_JsonPep695", _JsonPep604) +TypingTypeAliasType = getattr(typing, "TypeAliasType", TypeAliasType) + _StrPep695 = TypeAliasType("_StrPep695", str) +_TypingStrPep695 = TypingTypeAliasType("_TypingStrPep695", str) +_GenericPep695 = TypeAliasType("_GenericPep695", List[TV], type_params=(TV,)) +_TypingGenericPep695 = TypingTypeAliasType( + "_TypingGenericPep695", List[TV], type_params=(TV,) +) +_GenericPep695Typed = _GenericPep695[int] +_TypingGenericPep695Typed = _TypingGenericPep695[int] _UnionPep695 = TypeAliasType("_UnionPep695", Union[_SomeDict1, _SomeDict2]) strtypalias_keyword = TypeAliasType( "strtypalias_keyword", Annotated[str, mapped_column(info={"hi": "there"})] @@ -151,6 +162,9 @@ strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})] _Literal695 = TypeAliasType( "_Literal695", Literal["to-do", "in-progress", "done"] ) +_TypingLiteral695 = TypingTypeAliasType( + "_TypingLiteral695", Literal["to-do", "in-progress", "done"] +) _RecursiveLiteral695 = TypeAliasType("_RecursiveLiteral695", _Literal695) @@ -1093,20 +1107,52 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): ): declare() + @testing.variation( + "type_", + [ + "str_extension", + "str_typing", + "generic_extension", + "generic_typing", + "generic_typed_extension", + "generic_typed_typing", + ], + ) @testing.requires.python312 def test_pep695_typealias_as_typemap_keys( - self, decl_base: Type[DeclarativeBase] + self, decl_base: Type[DeclarativeBase], type_ ): """test #10807""" decl_base.registry.update_type_annotation_map( - {_UnionPep695: JSON, _StrPep695: String(30)} + { + _UnionPep695: JSON, + _StrPep695: String(30), + _TypingStrPep695: String(30), + _GenericPep695: String(30), + _TypingGenericPep695: String(30), + _GenericPep695Typed: String(30), + _TypingGenericPep695Typed: String(30), + } ) class Test(decl_base): __tablename__ = "test" id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[_StrPep695] + if type_.str_extension: + data: Mapped[_StrPep695] + elif type_.str_typing: + data: Mapped[_TypingStrPep695] + elif type_.generic_extension: + data: Mapped[_GenericPep695] + elif type_.generic_typing: + data: Mapped[_TypingGenericPep695] + elif type_.generic_typed_extension: + data: Mapped[_GenericPep695Typed] + elif type_.generic_typed_typing: + data: Mapped[_TypingGenericPep695Typed] + else: + type_.fail() structure: Mapped[_UnionPep695] eq_(Test.__table__.c.data.type._type_affinity, String) @@ -1163,7 +1209,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: eq_(MyClass.data_one.type.length, None) - @testing.variation("type_", ["literal", "recursive", "not_literal"]) + @testing.variation( + "type_", + [ + "literal", + "literal_typing", + "recursive", + "not_literal", + "not_literal_typing", + "generic", + "generic_typing", + "generic_typed", + "generic_typed_typing", + ], + ) @testing.combinations(True, False, argnames="in_map") @testing.requires.python312 def test_pep695_literal_defaults_to_enum(self, decl_base, type_, in_map): @@ -1178,8 +1237,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): status: Mapped[_RecursiveLiteral695] # noqa: F821 elif type_.literal: status: Mapped[_Literal695] # noqa: F821 + elif type_.literal_typing: + status: Mapped[_TypingLiteral695] # noqa: F821 elif type_.not_literal: status: Mapped[_StrPep695] # noqa: F821 + elif type_.not_literal_typing: + status: Mapped[_TypingStrPep695] # noqa: F821 + elif type_.generic: + status: Mapped[_GenericPep695] # noqa: F821 + elif type_.generic_typing: + status: Mapped[_TypingGenericPep695] # noqa: F821 + elif type_.generic_typed: + status: Mapped[_GenericPep695Typed] # noqa: F821 + elif type_.generic_typed_typing: + status: Mapped[_TypingGenericPep695Typed] # noqa: F821 else: type_.fail() @@ -1189,11 +1260,17 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): decl_base.registry.update_type_annotation_map( { _Literal695: Enum(enum.Enum), # noqa: F821 + _TypingLiteral695: Enum(enum.Enum), # noqa: F821 _RecursiveLiteral695: Enum(enum.Enum), # noqa: F821 _StrPep695: Enum(enum.Enum), # noqa: F821 + _TypingStrPep695: Enum(enum.Enum), # noqa: F821 + _GenericPep695: Enum(enum.Enum), # noqa: F821 + _TypingGenericPep695: Enum(enum.Enum), # noqa: F821 + _GenericPep695Typed: Enum(enum.Enum), # noqa: F821 + _TypingGenericPep695Typed: Enum(enum.Enum), # noqa: F821 } ) - if type_.literal: + if type_.literal or type_.literal_typing: Foo = declare() col = Foo.__table__.c.status is_true(isinstance(col.type, Enum)) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index cb7712862d..748ad03f7a 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -96,6 +96,8 @@ from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat from sqlalchemy.util.typing import Annotated +TV = typing.TypeVar("TV") + class _SomeDict1(TypedDict): type: Literal["1"] @@ -127,7 +129,16 @@ if compat.py310: ) _JsonPep695 = TypeAliasType("_JsonPep695", _JsonPep604) +TypingTypeAliasType = getattr(typing, "TypeAliasType", TypeAliasType) + _StrPep695 = TypeAliasType("_StrPep695", str) +_TypingStrPep695 = TypingTypeAliasType("_TypingStrPep695", str) +_GenericPep695 = TypeAliasType("_GenericPep695", List[TV], type_params=(TV,)) +_TypingGenericPep695 = TypingTypeAliasType( + "_TypingGenericPep695", List[TV], type_params=(TV,) +) +_GenericPep695Typed = _GenericPep695[int] +_TypingGenericPep695Typed = _TypingGenericPep695[int] _UnionPep695 = TypeAliasType("_UnionPep695", Union[_SomeDict1, _SomeDict2]) strtypalias_keyword = TypeAliasType( "strtypalias_keyword", Annotated[str, mapped_column(info={"hi": "there"})] @@ -142,6 +153,9 @@ strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})] _Literal695 = TypeAliasType( "_Literal695", Literal["to-do", "in-progress", "done"] ) +_TypingLiteral695 = TypingTypeAliasType( + "_TypingLiteral695", Literal["to-do", "in-progress", "done"] +) _RecursiveLiteral695 = TypeAliasType("_RecursiveLiteral695", _Literal695) @@ -1084,20 +1098,52 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): ): declare() + @testing.variation( + "type_", + [ + "str_extension", + "str_typing", + "generic_extension", + "generic_typing", + "generic_typed_extension", + "generic_typed_typing", + ], + ) @testing.requires.python312 def test_pep695_typealias_as_typemap_keys( - self, decl_base: Type[DeclarativeBase] + self, decl_base: Type[DeclarativeBase], type_ ): """test #10807""" decl_base.registry.update_type_annotation_map( - {_UnionPep695: JSON, _StrPep695: String(30)} + { + _UnionPep695: JSON, + _StrPep695: String(30), + _TypingStrPep695: String(30), + _GenericPep695: String(30), + _TypingGenericPep695: String(30), + _GenericPep695Typed: String(30), + _TypingGenericPep695Typed: String(30), + } ) class Test(decl_base): __tablename__ = "test" id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[_StrPep695] + if type_.str_extension: + data: Mapped[_StrPep695] + elif type_.str_typing: + data: Mapped[_TypingStrPep695] + elif type_.generic_extension: + data: Mapped[_GenericPep695] + elif type_.generic_typing: + data: Mapped[_TypingGenericPep695] + elif type_.generic_typed_extension: + data: Mapped[_GenericPep695Typed] + elif type_.generic_typed_typing: + data: Mapped[_TypingGenericPep695Typed] + else: + type_.fail() structure: Mapped[_UnionPep695] eq_(Test.__table__.c.data.type._type_affinity, String) @@ -1154,7 +1200,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: eq_(MyClass.data_one.type.length, None) - @testing.variation("type_", ["literal", "recursive", "not_literal"]) + @testing.variation( + "type_", + [ + "literal", + "literal_typing", + "recursive", + "not_literal", + "not_literal_typing", + "generic", + "generic_typing", + "generic_typed", + "generic_typed_typing", + ], + ) @testing.combinations(True, False, argnames="in_map") @testing.requires.python312 def test_pep695_literal_defaults_to_enum(self, decl_base, type_, in_map): @@ -1169,8 +1228,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): status: Mapped[_RecursiveLiteral695] # noqa: F821 elif type_.literal: status: Mapped[_Literal695] # noqa: F821 + elif type_.literal_typing: + status: Mapped[_TypingLiteral695] # noqa: F821 elif type_.not_literal: status: Mapped[_StrPep695] # noqa: F821 + elif type_.not_literal_typing: + status: Mapped[_TypingStrPep695] # noqa: F821 + elif type_.generic: + status: Mapped[_GenericPep695] # noqa: F821 + elif type_.generic_typing: + status: Mapped[_TypingGenericPep695] # noqa: F821 + elif type_.generic_typed: + status: Mapped[_GenericPep695Typed] # noqa: F821 + elif type_.generic_typed_typing: + status: Mapped[_TypingGenericPep695Typed] # noqa: F821 else: type_.fail() @@ -1180,11 +1251,17 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): decl_base.registry.update_type_annotation_map( { _Literal695: Enum(enum.Enum), # noqa: F821 + _TypingLiteral695: Enum(enum.Enum), # noqa: F821 _RecursiveLiteral695: Enum(enum.Enum), # noqa: F821 _StrPep695: Enum(enum.Enum), # noqa: F821 + _TypingStrPep695: Enum(enum.Enum), # noqa: F821 + _GenericPep695: Enum(enum.Enum), # noqa: F821 + _TypingGenericPep695: Enum(enum.Enum), # noqa: F821 + _GenericPep695Typed: Enum(enum.Enum), # noqa: F821 + _TypingGenericPep695Typed: Enum(enum.Enum), # noqa: F821 } ) - if type_.literal: + if type_.literal or type_.literal_typing: Foo = declare() col = Foo.__table__.c.status is_true(isinstance(col.type, Enum))