]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
compatibility with typing_extensions 4.13 and type statement
authorDaraan <github.blurry@9ox.net>
Wed, 26 Mar 2025 18:27:46 +0000 (14:27 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 26 Mar 2025 23:36:23 +0000 (00:36 +0100)
Fixed regression caused by ``typing_extension==4.13.0`` that introduced
a different implementation for ``TypeAliasType`` while SQLAlchemy assumed
that it would be equivalent to the ``typing`` version.

Added test regarding generic TypeAliasType

Fixes: #12473
Closes: #12472
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12472
Pull-request-sha: 8861a5acfb8e81663413ff144b41abf64779b6fd

Change-Id: I053019a222546a625ed6d588314ae9f5b34c2f8a

doc/build/changelog/unreleased_20/12473.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/util/typing.py
test/base/test_typing_utils.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/12473.rst b/doc/build/changelog/unreleased_20/12473.rst
new file mode 100644 (file)
index 0000000..5127d92
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 12473
+
+    Fixed regression caused by ``typing_extension==4.13.0`` that introduced
+    a different implementation for ``TypeAliasType`` while SQLAlchemy assumed
+    that it would be equivalent to the ``typing`` version.
index f3cec699b8de7f3c9427c0e5c634e9c06b8592e5..81a6d18ce9deac3cc27fcdf97a1704f30651a9d9 100644 (file)
@@ -1233,7 +1233,7 @@ class registry:
 
                 search = (
                     (python_type, python_type_type),
-                    *((lt, python_type_type) for lt in LITERAL_TYPES),  # type: ignore[arg-type] # noqa: E501
+                    *((lt, python_type_type) for lt in LITERAL_TYPES),
                 )
             else:
                 python_type_type = python_type.__origin__
index a1fb5920b9533196ddf99fe980d514175184128c..dee25a71d0c5648c3ece9770f43a32b7722dd997 100644 (file)
@@ -34,6 +34,8 @@ from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
+import typing_extensions
+
 from . import compat
 
 if True:  # zimports removes the tailing comments
@@ -68,10 +70,6 @@ _VT_co = TypeVar("_VT_co", covariant=True)
 
 TupleAny = Tuple[Any, ...]
 
-# typing_extensions.Literal is different from typing.Literal until
-# Python 3.10.1
-LITERAL_TYPES = frozenset([typing.Literal, Literal])
-
 
 if compat.py310:
     # why they took until py310 to put this in stdlib is beyond me,
@@ -331,7 +329,7 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
 
 
 def is_pep593(type_: Optional[Any]) -> bool:
-    return type_ is not None and get_origin(type_) iAnnotated
+    return type_ is not None and get_origin(type_) in _type_tuples.Annotated
 
 
 def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
@@ -341,14 +339,14 @@ def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
 
 
 def is_literal(type_: Any) -> bool:
-    return get_origin(type_) in LITERAL_TYPES
+    return get_origin(type_) in _type_tuples.Literal
 
 
 def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
     return hasattr(type_, "__supertype__")
     # doesn't work in 3.9, 3.8, 3.7 as it passes a closure, not an
     # object instance
-    # return isinstance(type_, NewType)
+    # isinstance(type, type_instances.NewType)
 
 
 def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
@@ -356,7 +354,13 @@ def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
 
 
 def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
-    return isinstance(type_, TypeAliasType)
+    # NOTE: a generic TAT does not instance check as TypeAliasType outside of
+    # python 3.10. For sqlalchemy use cases it's fine to consider it a TAT
+    # though.
+    # NOTE: things seems to work also without this additional check
+    if is_generic(type_):
+        return is_pep695(type_.__origin__)
+    return isinstance(type_, _type_instances.TypeAliasType)
 
 
 def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
@@ -368,15 +372,15 @@ def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
     """
     _seen = set()
 
-    def recursive_value(type_):
-        if type_ in _seen:
+    def recursive_value(inner_type):
+        if inner_type in _seen:
             # recursion are not supported (at least it's flagged as
             # an error by pyright). Just avoid infinite loop
-            return type_
-        _seen.add(type_)
-        if not is_pep695(type_):
-            return type_
-        value = type_.__value__
+            return inner_type
+        _seen.add(inner_type)
+        if not is_pep695(inner_type):
+            return inner_type
+        value = inner_type.__value__
         if not is_union(value):
             return value
         return [recursive_value(t) for t in value.__args__]
@@ -403,7 +407,7 @@ def is_fwd_ref(
 ) -> TypeGuard[ForwardRef]:
     if check_for_plain_string and isinstance(type_, str):
         return True
-    elif isinstance(type_, ForwardRef):
+    elif isinstance(type_, _type_instances.ForwardRef):
         return True
     elif check_generic and is_generic(type_):
         return any(
@@ -677,3 +681,30 @@ class CallableReference(Generic[_FN]):
         def __set__(self, instance: Any, value: _FN) -> None: ...
 
         def __delete__(self, instance: Any) -> None: ...
+
+
+class _TypingInstances:
+    def __getattr__(self, key: str) -> tuple[type, ...]:
+        types = tuple(
+            {
+                t
+                for t in [
+                    getattr(typing, key, None),
+                    getattr(typing_extensions, key, None),
+                ]
+                if t is not None
+            }
+        )
+        if not types:
+            raise AttributeError(key)
+        self.__dict__[key] = types
+        return types
+
+
+_type_tuples = _TypingInstances()
+if TYPE_CHECKING:
+    _type_instances = typing_extensions
+else:
+    _type_instances = _type_tuples
+
+LITERAL_TYPES = _type_tuples.Literal
index 6cddef6508c03de3042047c60dff3f5f030988e4..7a6aca3c857b2ac033c5a8fccd2ad223aff4455a 100644 (file)
@@ -38,63 +38,144 @@ def null_union_types():
     return res
 
 
+def generic_unions():
+    # remove new-style unions `int | str` that are not generic
+    res = union_types() + null_union_types()
+    if py310:
+        new_ut = type(int | str)
+        res = [t for t in res if not isinstance(t, new_ut)]
+    return res
+
+
 def make_fw_ref(anno: str) -> typing.ForwardRef:
     return typing.Union[anno]
 
 
-TA_int = typing_extensions.TypeAliasType("TA_int", int)
-TA_union = typing_extensions.TypeAliasType("TA_union", typing.Union[int, str])
-TA_null_union = typing_extensions.TypeAliasType(
-    "TA_null_union", typing.Union[int, str, None]
+TypeAliasType = getattr(
+    typing, "TypeAliasType", typing_extensions.TypeAliasType
 )
-TA_null_union2 = typing_extensions.TypeAliasType(
+
+TA_int = TypeAliasType("TA_int", int)
+TAext_int = typing_extensions.TypeAliasType("TAext_int", int)
+TA_union = TypeAliasType("TA_union", typing.Union[int, str])
+TAext_union = typing_extensions.TypeAliasType(
+    "TAext_union", typing.Union[int, str]
+)
+TA_null_union = TypeAliasType("TA_null_union", typing.Union[int, str, None])
+TAext_null_union = typing_extensions.TypeAliasType(
+    "TAext_null_union", typing.Union[int, str, None]
+)
+TA_null_union2 = TypeAliasType(
     "TA_null_union2", typing.Union[int, str, "None"]
 )
-TA_null_union3 = typing_extensions.TypeAliasType(
+TAext_null_union2 = typing_extensions.TypeAliasType(
+    "TAext_null_union2", typing.Union[int, str, "None"]
+)
+TA_null_union3 = TypeAliasType(
     "TA_null_union3", typing.Union[int, "typing.Union[None, bool]"]
 )
-TA_null_union4 = typing_extensions.TypeAliasType(
+TAext_null_union3 = typing_extensions.TypeAliasType(
+    "TAext_null_union3", typing.Union[int, "typing.Union[None, bool]"]
+)
+TA_null_union4 = TypeAliasType(
     "TA_null_union4", typing.Union[int, "TA_null_union2"]
 )
-TA_union_ta = typing_extensions.TypeAliasType(
-    "TA_union_ta", typing.Union[TA_int, str]
+TAext_null_union4 = typing_extensions.TypeAliasType(
+    "TAext_null_union4", typing.Union[int, "TAext_null_union2"]
+)
+TA_union_ta = TypeAliasType("TA_union_ta", typing.Union[TA_int, str])
+TAext_union_ta = typing_extensions.TypeAliasType(
+    "TAext_union_ta", typing.Union[TAext_int, str]
 )
-TA_null_union_ta = typing_extensions.TypeAliasType(
+TA_null_union_ta = TypeAliasType(
     "TA_null_union_ta", typing.Union[TA_null_union, float]
 )
-TA_list = typing_extensions.TypeAliasType(
+TAext_null_union_ta = typing_extensions.TypeAliasType(
+    "TAext_null_union_ta", typing.Union[TAext_null_union, float]
+)
+TA_list = TypeAliasType(
     "TA_list", typing.Union[int, str, typing.List["TA_list"]]
 )
+TAext_list = typing_extensions.TypeAliasType(
+    "TAext_list", typing.Union[int, str, typing.List["TAext_list"]]
+)
 # these below not valid. Verify that it does not cause exceptions in any case
-TA_recursive = typing_extensions.TypeAliasType(
-    "TA_recursive", typing.Union["TA_recursive", str]
+TA_recursive = TypeAliasType("TA_recursive", typing.Union["TA_recursive", str])
+TAext_recursive = typing_extensions.TypeAliasType(
+    "TAext_recursive", typing.Union["TAext_recursive", str]
 )
-TA_null_recursive = typing_extensions.TypeAliasType(
+TA_null_recursive = TypeAliasType(
     "TA_null_recursive", typing.Union[TA_recursive, None]
 )
-TA_recursive_a = typing_extensions.TypeAliasType(
+TAext_null_recursive = typing_extensions.TypeAliasType(
+    "TAext_null_recursive", typing.Union[TAext_recursive, None]
+)
+TA_recursive_a = TypeAliasType(
     "TA_recursive_a", typing.Union["TA_recursive_b", int]
 )
-TA_recursive_b = typing_extensions.TypeAliasType(
+TAext_recursive_a = typing_extensions.TypeAliasType(
+    "TAext_recursive_a", typing.Union["TAext_recursive_b", int]
+)
+TA_recursive_b = TypeAliasType(
     "TA_recursive_b", typing.Union["TA_recursive_a", str]
 )
+TAext_recursive_b = typing_extensions.TypeAliasType(
+    "TAext_recursive_b", typing.Union["TAext_recursive_a", str]
+)
+TA_generic = TypeAliasType("TA_generic", typing.List[TV], type_params=(TV,))
+TAext_generic = typing_extensions.TypeAliasType(
+    "TAext_generic", typing.List[TV], type_params=(TV,)
+)
+TA_generic_typed = TA_generic[int]
+TAext_generic_typed = TAext_generic[int]
+TA_generic_null = TypeAliasType(
+    "TA_generic_null", typing.Union[typing.List[TV], None], type_params=(TV,)
+)
+TAext_generic_null = typing_extensions.TypeAliasType(
+    "TAext_generic_null",
+    typing.Union[typing.List[TV], None],
+    type_params=(TV,),
+)
+TA_generic_null_typed = TA_generic_null[str]
+TAext_generic_null_typed = TAext_generic_null[str]
 
 
 def type_aliases():
     return [
         TA_int,
+        TAext_int,
         TA_union,
+        TAext_union,
         TA_null_union,
+        TAext_null_union,
         TA_null_union2,
+        TAext_null_union2,
         TA_null_union3,
+        TAext_null_union3,
         TA_null_union4,
+        TAext_null_union4,
         TA_union_ta,
+        TAext_union_ta,
         TA_null_union_ta,
+        TAext_null_union_ta,
         TA_list,
+        TAext_list,
         TA_recursive,
+        TAext_recursive,
         TA_null_recursive,
+        TAext_null_recursive,
         TA_recursive_a,
+        TAext_recursive_a,
         TA_recursive_b,
+        TAext_recursive_b,
+        TA_generic,
+        TAext_generic,
+        TA_generic_typed,
+        TAext_generic_typed,
+        TA_generic_null,
+        TAext_generic_null,
+        TA_generic_null_typed,
+        TAext_generic_null_typed,
     ]
 
 
@@ -143,11 +224,14 @@ def exec_code(code: str, *vars: str) -> typing.Any:
 
 class TestTestingThings(fixtures.TestBase):
     def test_unions_are_the_same(self):
+        # the point of this test is to reduce the cases to test since
+        # some symbols are the same in typing and typing_extensions.
+        # If a test starts failing then additional cases should be added,
+        # similar to what it's done for TypeAliasType
+
         # no need to test typing_extensions.Union, typing_extensions.Optional
         is_(typing.Union, typing_extensions.Union)
         is_(typing.Optional, typing_extensions.Optional)
-        if py312:
-            is_(typing.TypeAliasType, typing_extensions.TypeAliasType)
 
     def test_make_union(self):
         v = int, str
@@ -221,8 +305,19 @@ class TestTyping(fixtures.TestBase):
             eq_(sa_typing.is_generic(t), False)
             eq_(sa_typing.is_generic(t[int]), True)
 
+        generics = [
+            TA_generic_typed,
+            TAext_generic_typed,
+            TA_generic_null_typed,
+            TAext_generic_null_typed,
+            *annotated_l(),
+            *generic_unions(),
+        ]
+
         for t in all_types():
-            eq_(sa_typing.is_literal(t), False)
+            # use is since union compare equal between new/old style
+            exp = any(t is k for k in generics)
+            eq_(sa_typing.is_generic(t), exp, t)
 
     def test_is_pep695(self):
         eq_(sa_typing.is_pep695(str), False)
@@ -249,41 +344,100 @@ class TestTyping(fixtures.TestBase):
             sa_typing.pep695_values(typing.Union[int, TA_int]),
             {typing.Union[int, TA_int]},
         )
+        eq_(
+            sa_typing.pep695_values(typing.Union[int, TAext_int]),
+            {typing.Union[int, TAext_int]},
+        )
 
         eq_(sa_typing.pep695_values(TA_int), {int})
+        eq_(sa_typing.pep695_values(TAext_int), {int})
         eq_(sa_typing.pep695_values(TA_union), {int, str})
+        eq_(sa_typing.pep695_values(TAext_union), {int, str})
         eq_(sa_typing.pep695_values(TA_null_union), {int, str, None})
+        eq_(sa_typing.pep695_values(TAext_null_union), {int, str, None})
         eq_(sa_typing.pep695_values(TA_null_union2), {int, str, None})
+        eq_(sa_typing.pep695_values(TAext_null_union2), {int, str, None})
         eq_(
             sa_typing.pep695_values(TA_null_union3),
             {int, typing.ForwardRef("typing.Union[None, bool]")},
         )
+        eq_(
+            sa_typing.pep695_values(TAext_null_union3),
+            {int, typing.ForwardRef("typing.Union[None, bool]")},
+        )
         eq_(
             sa_typing.pep695_values(TA_null_union4),
             {int, typing.ForwardRef("TA_null_union2")},
         )
+        eq_(
+            sa_typing.pep695_values(TAext_null_union4),
+            {int, typing.ForwardRef("TAext_null_union2")},
+        )
         eq_(sa_typing.pep695_values(TA_union_ta), {int, str})
+        eq_(sa_typing.pep695_values(TAext_union_ta), {int, str})
         eq_(sa_typing.pep695_values(TA_null_union_ta), {int, str, None, float})
+        eq_(
+            sa_typing.pep695_values(TAext_null_union_ta),
+            {int, str, None, float},
+        )
         eq_(
             sa_typing.pep695_values(TA_list),
             {int, str, typing.List[typing.ForwardRef("TA_list")]},
         )
+        eq_(
+            sa_typing.pep695_values(TAext_list),
+            {int, str, typing.List[typing.ForwardRef("TAext_list")]},
+        )
         eq_(
             sa_typing.pep695_values(TA_recursive),
             {typing.ForwardRef("TA_recursive"), str},
         )
+        eq_(
+            sa_typing.pep695_values(TAext_recursive),
+            {typing.ForwardRef("TAext_recursive"), str},
+        )
         eq_(
             sa_typing.pep695_values(TA_null_recursive),
             {typing.ForwardRef("TA_recursive"), str, None},
         )
+        eq_(
+            sa_typing.pep695_values(TAext_null_recursive),
+            {typing.ForwardRef("TAext_recursive"), str, None},
+        )
         eq_(
             sa_typing.pep695_values(TA_recursive_a),
             {typing.ForwardRef("TA_recursive_b"), int},
         )
+        eq_(
+            sa_typing.pep695_values(TAext_recursive_a),
+            {typing.ForwardRef("TAext_recursive_b"), int},
+        )
         eq_(
             sa_typing.pep695_values(TA_recursive_b),
             {typing.ForwardRef("TA_recursive_a"), str},
         )
+        eq_(
+            sa_typing.pep695_values(TAext_recursive_b),
+            {typing.ForwardRef("TAext_recursive_a"), str},
+        )
+        # generics
+        eq_(sa_typing.pep695_values(TA_generic), {typing.List[TV]})
+        eq_(sa_typing.pep695_values(TAext_generic), {typing.List[TV]})
+        eq_(sa_typing.pep695_values(TA_generic_typed), {typing.List[TV]})
+        eq_(sa_typing.pep695_values(TAext_generic_typed), {typing.List[TV]})
+        eq_(sa_typing.pep695_values(TA_generic_null), {None, typing.List[TV]})
+        eq_(
+            sa_typing.pep695_values(TAext_generic_null),
+            {None, typing.List[TV]},
+        )
+        eq_(
+            sa_typing.pep695_values(TA_generic_null_typed),
+            {None, typing.List[TV]},
+        )
+        eq_(
+            sa_typing.pep695_values(TAext_generic_null_typed),
+            {None, typing.List[TV]},
+        )
 
     def test_is_fwd_ref(self):
         eq_(sa_typing.is_fwd_ref(int), False)
@@ -346,6 +500,10 @@ class TestTyping(fixtures.TestBase):
             sa_typing.make_union_type(bool, TA_int, NT_str),
             typing.Union[bool, TA_int, NT_str],
         )
+        eq_(
+            sa_typing.make_union_type(bool, TAext_int, NT_str),
+            typing.Union[bool, TAext_int, NT_str],
+        )
 
     def test_includes_none(self):
         eq_(sa_typing.includes_none(None), True)
@@ -359,11 +517,12 @@ class TestTyping(fixtures.TestBase):
             eq_(sa_typing.includes_none(t), True, str(t))
 
         # TODO: these are false negatives
-        false_negative = {
+        false_negatives = {
             TA_null_union4,  # does not evaluate FW ref
+            TAext_null_union4,  # does not evaluate FW ref
         }
         for t in type_aliases() + new_types():
-            if t in false_negative:
+            if t in false_negatives:
                 exp = False
             else:
                 exp = "null" in t.__name__
@@ -378,6 +537,9 @@ class TestTyping(fixtures.TestBase):
         # nested things
         eq_(sa_typing.includes_none(typing.Union[int, "None"]), True)
         eq_(sa_typing.includes_none(typing.Union[bool, TA_null_union]), True)
+        eq_(
+            sa_typing.includes_none(typing.Union[bool, TAext_null_union]), True
+        )
         eq_(sa_typing.includes_none(typing.Union[bool, NT_null]), True)
         # nested fw
         eq_(
@@ -397,6 +559,10 @@ class TestTyping(fixtures.TestBase):
         eq_(
             sa_typing.includes_none(typing.Union[bool, "TA_null_union"]), False
         )
+        eq_(
+            sa_typing.includes_none(typing.Union[bool, "TAext_null_union"]),
+            False,
+        )
         eq_(sa_typing.includes_none(typing.Union[bool, "NT_null"]), False)
 
     def test_is_union(self):
@@ -405,3 +571,26 @@ class TestTyping(fixtures.TestBase):
             eq_(sa_typing.is_union(t), True)
         for t in type_aliases() + new_types() + annotated_l():
             eq_(sa_typing.is_union(t), False)
+
+    def test_TypingInstances(self):
+        is_(sa_typing._type_tuples, sa_typing._type_instances)
+        is_(
+            isinstance(sa_typing._type_instances, sa_typing._TypingInstances),
+            True,
+        )
+
+        # cached
+        is_(
+            sa_typing._type_instances.Literal,
+            sa_typing._type_instances.Literal,
+        )
+
+        for k in ["Literal", "Annotated", "TypeAliasType"]:
+            types = set()
+            ti = getattr(sa_typing._type_instances, k)
+            for lib in [typing, typing_extensions]:
+                lt = getattr(lib, k, None)
+                if lt is not None:
+                    types.add(lt)
+                    is_(lt in ti, True)
+            eq_(len(ti), len(types), k)
index d7d9414661c01071ea3430b2d78d3d7276062fc6..f0b3e81fd755c0874d5ee3fcfadcd80e78fce901 100644 (file)
@@ -105,6 +105,8 @@ from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.util import compat
 from sqlalchemy.util.typing import Annotated
 
+TV = typing.TypeVar("TV")
+
 
 class _SomeDict1(TypedDict):
     type: Literal["1"]
@@ -136,7 +138,16 @@ if compat.py310:
     )
     _JsonPep695 = TypeAliasType("_JsonPep695", _JsonPep604)
 
+TypingTypeAliasType = getattr(typing, "TypeAliasType", TypeAliasType)
+
 _StrPep695 = TypeAliasType("_StrPep695", str)
+_TypingStrPep695 = TypingTypeAliasType("_TypingStrPep695", str)
+_GenericPep695 = TypeAliasType("_GenericPep695", List[TV], type_params=(TV,))
+_TypingGenericPep695 = TypingTypeAliasType(
+    "_TypingGenericPep695", List[TV], type_params=(TV,)
+)
+_GenericPep695Typed = _GenericPep695[int]
+_TypingGenericPep695Typed = _TypingGenericPep695[int]
 _UnionPep695 = TypeAliasType("_UnionPep695", Union[_SomeDict1, _SomeDict2])
 strtypalias_keyword = TypeAliasType(
     "strtypalias_keyword", Annotated[str, mapped_column(info={"hi": "there"})]
@@ -151,6 +162,9 @@ strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})]
 _Literal695 = TypeAliasType(
     "_Literal695", Literal["to-do", "in-progress", "done"]
 )
+_TypingLiteral695 = TypingTypeAliasType(
+    "_TypingLiteral695", Literal["to-do", "in-progress", "done"]
+)
 _RecursiveLiteral695 = TypeAliasType("_RecursiveLiteral695", _Literal695)
 
 
@@ -1093,20 +1107,52 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             ):
                 declare()
 
+    @testing.variation(
+        "type_",
+        [
+            "str_extension",
+            "str_typing",
+            "generic_extension",
+            "generic_typing",
+            "generic_typed_extension",
+            "generic_typed_typing",
+        ],
+    )
     @testing.requires.python312
     def test_pep695_typealias_as_typemap_keys(
-        self, decl_base: Type[DeclarativeBase]
+        self, decl_base: Type[DeclarativeBase], type_
     ):
         """test #10807"""
 
         decl_base.registry.update_type_annotation_map(
-            {_UnionPep695: JSON, _StrPep695: String(30)}
+            {
+                _UnionPep695: JSON,
+                _StrPep695: String(30),
+                _TypingStrPep695: String(30),
+                _GenericPep695: String(30),
+                _TypingGenericPep695: String(30),
+                _GenericPep695Typed: String(30),
+                _TypingGenericPep695Typed: String(30),
+            }
         )
 
         class Test(decl_base):
             __tablename__ = "test"
             id: Mapped[int] = mapped_column(primary_key=True)
-            data: Mapped[_StrPep695]
+            if type_.str_extension:
+                data: Mapped[_StrPep695]
+            elif type_.str_typing:
+                data: Mapped[_TypingStrPep695]
+            elif type_.generic_extension:
+                data: Mapped[_GenericPep695]
+            elif type_.generic_typing:
+                data: Mapped[_TypingGenericPep695]
+            elif type_.generic_typed_extension:
+                data: Mapped[_GenericPep695Typed]
+            elif type_.generic_typed_typing:
+                data: Mapped[_TypingGenericPep695Typed]
+            else:
+                type_.fail()
             structure: Mapped[_UnionPep695]
 
         eq_(Test.__table__.c.data.type._type_affinity, String)
@@ -1163,7 +1209,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             eq_(MyClass.data_one.type.length, None)
 
-    @testing.variation("type_", ["literal", "recursive", "not_literal"])
+    @testing.variation(
+        "type_",
+        [
+            "literal",
+            "literal_typing",
+            "recursive",
+            "not_literal",
+            "not_literal_typing",
+            "generic",
+            "generic_typing",
+            "generic_typed",
+            "generic_typed_typing",
+        ],
+    )
     @testing.combinations(True, False, argnames="in_map")
     @testing.requires.python312
     def test_pep695_literal_defaults_to_enum(self, decl_base, type_, in_map):
@@ -1178,8 +1237,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                     status: Mapped[_RecursiveLiteral695]  # noqa: F821
                 elif type_.literal:
                     status: Mapped[_Literal695]  # noqa: F821
+                elif type_.literal_typing:
+                    status: Mapped[_TypingLiteral695]  # noqa: F821
                 elif type_.not_literal:
                     status: Mapped[_StrPep695]  # noqa: F821
+                elif type_.not_literal_typing:
+                    status: Mapped[_TypingStrPep695]  # noqa: F821
+                elif type_.generic:
+                    status: Mapped[_GenericPep695]  # noqa: F821
+                elif type_.generic_typing:
+                    status: Mapped[_TypingGenericPep695]  # noqa: F821
+                elif type_.generic_typed:
+                    status: Mapped[_GenericPep695Typed]  # noqa: F821
+                elif type_.generic_typed_typing:
+                    status: Mapped[_TypingGenericPep695Typed]  # noqa: F821
                 else:
                     type_.fail()
 
@@ -1189,11 +1260,17 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             decl_base.registry.update_type_annotation_map(
                 {
                     _Literal695: Enum(enum.Enum),  # noqa: F821
+                    _TypingLiteral695: Enum(enum.Enum),  # noqa: F821
                     _RecursiveLiteral695: Enum(enum.Enum),  # noqa: F821
                     _StrPep695: Enum(enum.Enum),  # noqa: F821
+                    _TypingStrPep695: Enum(enum.Enum),  # noqa: F821
+                    _GenericPep695: Enum(enum.Enum),  # noqa: F821
+                    _TypingGenericPep695: Enum(enum.Enum),  # noqa: F821
+                    _GenericPep695Typed: Enum(enum.Enum),  # noqa: F821
+                    _TypingGenericPep695Typed: Enum(enum.Enum),  # noqa: F821
                 }
             )
-            if type_.literal:
+            if type_.literal or type_.literal_typing:
                 Foo = declare()
                 col = Foo.__table__.c.status
                 is_true(isinstance(col.type, Enum))
index cb7712862d03cab2b0da90140d004ee2217fa3c7..748ad03f7ab73305bffa28a28697c681718caf3c 100644 (file)
@@ -96,6 +96,8 @@ from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.util import compat
 from sqlalchemy.util.typing import Annotated
 
+TV = typing.TypeVar("TV")
+
 
 class _SomeDict1(TypedDict):
     type: Literal["1"]
@@ -127,7 +129,16 @@ if compat.py310:
     )
     _JsonPep695 = TypeAliasType("_JsonPep695", _JsonPep604)
 
+TypingTypeAliasType = getattr(typing, "TypeAliasType", TypeAliasType)
+
 _StrPep695 = TypeAliasType("_StrPep695", str)
+_TypingStrPep695 = TypingTypeAliasType("_TypingStrPep695", str)
+_GenericPep695 = TypeAliasType("_GenericPep695", List[TV], type_params=(TV,))
+_TypingGenericPep695 = TypingTypeAliasType(
+    "_TypingGenericPep695", List[TV], type_params=(TV,)
+)
+_GenericPep695Typed = _GenericPep695[int]
+_TypingGenericPep695Typed = _TypingGenericPep695[int]
 _UnionPep695 = TypeAliasType("_UnionPep695", Union[_SomeDict1, _SomeDict2])
 strtypalias_keyword = TypeAliasType(
     "strtypalias_keyword", Annotated[str, mapped_column(info={"hi": "there"})]
@@ -142,6 +153,9 @@ strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})]
 _Literal695 = TypeAliasType(
     "_Literal695", Literal["to-do", "in-progress", "done"]
 )
+_TypingLiteral695 = TypingTypeAliasType(
+    "_TypingLiteral695", Literal["to-do", "in-progress", "done"]
+)
 _RecursiveLiteral695 = TypeAliasType("_RecursiveLiteral695", _Literal695)
 
 
@@ -1084,20 +1098,52 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             ):
                 declare()
 
+    @testing.variation(
+        "type_",
+        [
+            "str_extension",
+            "str_typing",
+            "generic_extension",
+            "generic_typing",
+            "generic_typed_extension",
+            "generic_typed_typing",
+        ],
+    )
     @testing.requires.python312
     def test_pep695_typealias_as_typemap_keys(
-        self, decl_base: Type[DeclarativeBase]
+        self, decl_base: Type[DeclarativeBase], type_
     ):
         """test #10807"""
 
         decl_base.registry.update_type_annotation_map(
-            {_UnionPep695: JSON, _StrPep695: String(30)}
+            {
+                _UnionPep695: JSON,
+                _StrPep695: String(30),
+                _TypingStrPep695: String(30),
+                _GenericPep695: String(30),
+                _TypingGenericPep695: String(30),
+                _GenericPep695Typed: String(30),
+                _TypingGenericPep695Typed: String(30),
+            }
         )
 
         class Test(decl_base):
             __tablename__ = "test"
             id: Mapped[int] = mapped_column(primary_key=True)
-            data: Mapped[_StrPep695]
+            if type_.str_extension:
+                data: Mapped[_StrPep695]
+            elif type_.str_typing:
+                data: Mapped[_TypingStrPep695]
+            elif type_.generic_extension:
+                data: Mapped[_GenericPep695]
+            elif type_.generic_typing:
+                data: Mapped[_TypingGenericPep695]
+            elif type_.generic_typed_extension:
+                data: Mapped[_GenericPep695Typed]
+            elif type_.generic_typed_typing:
+                data: Mapped[_TypingGenericPep695Typed]
+            else:
+                type_.fail()
             structure: Mapped[_UnionPep695]
 
         eq_(Test.__table__.c.data.type._type_affinity, String)
@@ -1154,7 +1200,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             eq_(MyClass.data_one.type.length, None)
 
-    @testing.variation("type_", ["literal", "recursive", "not_literal"])
+    @testing.variation(
+        "type_",
+        [
+            "literal",
+            "literal_typing",
+            "recursive",
+            "not_literal",
+            "not_literal_typing",
+            "generic",
+            "generic_typing",
+            "generic_typed",
+            "generic_typed_typing",
+        ],
+    )
     @testing.combinations(True, False, argnames="in_map")
     @testing.requires.python312
     def test_pep695_literal_defaults_to_enum(self, decl_base, type_, in_map):
@@ -1169,8 +1228,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                     status: Mapped[_RecursiveLiteral695]  # noqa: F821
                 elif type_.literal:
                     status: Mapped[_Literal695]  # noqa: F821
+                elif type_.literal_typing:
+                    status: Mapped[_TypingLiteral695]  # noqa: F821
                 elif type_.not_literal:
                     status: Mapped[_StrPep695]  # noqa: F821
+                elif type_.not_literal_typing:
+                    status: Mapped[_TypingStrPep695]  # noqa: F821
+                elif type_.generic:
+                    status: Mapped[_GenericPep695]  # noqa: F821
+                elif type_.generic_typing:
+                    status: Mapped[_TypingGenericPep695]  # noqa: F821
+                elif type_.generic_typed:
+                    status: Mapped[_GenericPep695Typed]  # noqa: F821
+                elif type_.generic_typed_typing:
+                    status: Mapped[_TypingGenericPep695Typed]  # noqa: F821
                 else:
                     type_.fail()
 
@@ -1180,11 +1251,17 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             decl_base.registry.update_type_annotation_map(
                 {
                     _Literal695: Enum(enum.Enum),  # noqa: F821
+                    _TypingLiteral695: Enum(enum.Enum),  # noqa: F821
                     _RecursiveLiteral695: Enum(enum.Enum),  # noqa: F821
                     _StrPep695: Enum(enum.Enum),  # noqa: F821
+                    _TypingStrPep695: Enum(enum.Enum),  # noqa: F821
+                    _GenericPep695: Enum(enum.Enum),  # noqa: F821
+                    _TypingGenericPep695: Enum(enum.Enum),  # noqa: F821
+                    _GenericPep695Typed: Enum(enum.Enum),  # noqa: F821
+                    _TypingGenericPep695Typed: Enum(enum.Enum),  # noqa: F821
                 }
             )
-            if type_.literal:
+            if type_.literal or type_.literal_typing:
                 Foo = declare()
                 col = Foo.__table__.c.status
                 is_true(isinstance(col.type, Enum))