]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix use of typing.Literal on Python 3.8 and 3.9
authorFrazer McLean <frazer@frazermclean.co.uk>
Mon, 2 Sep 2024 16:50:21 +0000 (18:50 +0200)
committerFrazer McLean <frazer@frazermclean.co.uk>
Thu, 5 Sep 2024 10:28:10 +0000 (12:28 +0200)
Fixes: #11820
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

index f4f14e1b56d83d1e330f64d621be50cfa7f750ba..f66bd1d4b6c0b098f91ae78d1ee2dbf5bd53963d 100644 (file)
@@ -12,6 +12,7 @@ import builtins
 import collections.abc as collections_abc
 import re
 import sys
+import typing
 from typing import Any
 from typing import Callable
 from typing import cast
@@ -64,6 +65,10 @@ _VT_co = TypeVar("_VT_co", covariant=True)
 
 TupleAny = Tuple[Any, ...]
 
+# typing_extensions.Literal is different from typing.Literal until
+# Python 3.10.1
+_LITERAL_TYPES = frozenset([typing.Literal, Literal])
+
 
 if compat.py310:
     # why they took until py310 to put this in stdlib is beyond me,
@@ -357,7 +362,7 @@ def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
 
 
 def is_literal(type_: _AnnotationScanType) -> bool:
-    return get_origin(type_) is Literal
+    return get_origin(type_) in _LITERAL_TYPES
 
 
 def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
index e9b74b0d93f33cdf1c1160a86077d75bbbdfdf6d..0fe1d59e158746e738da2c3d2501b002cc3c4242 100644 (file)
@@ -30,6 +30,7 @@ from typing import TypeVar
 from typing import Union
 import uuid
 
+import typing_extensions
 from typing_extensions import get_args as get_args
 from typing_extensions import Literal as Literal
 from typing_extensions import TypeAlias as TypeAlias
@@ -119,6 +120,9 @@ _Recursive695_0: TypeAlias = _Literal695
 _Recursive695_1: TypeAlias = _Recursive695_0
 _Recursive695_2: TypeAlias = _Recursive695_1
 
+_TypingLiteral = typing.Literal["a", "b"]
+_TypingExtensionsLiteral = typing_extensions.Literal["a", "b"]
+
 if compat.py312:
     exec(
         """
@@ -897,6 +901,21 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             eq_(col.type.enums, ["to-do", "in-progress", "done"])
             is_(col.type.native_enum, False)
 
+    def test_typing_literal_identity(self, decl_base):
+        """See issue #11820"""
+
+        class Foo(decl_base):
+            __tablename__ = "footable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            t: Mapped[_TypingLiteral]
+            te: Mapped[_TypingExtensionsLiteral]
+
+        for col in (Foo.__table__.c.t, Foo.__table__.c.te):
+            is_true(isinstance(col.type, Enum))
+            eq_(col.type.enums, ["a", "b"])
+            is_(col.type.native_enum, False)
+
     @testing.requires.python310
     def test_we_got_all_attrs_test_annotated(self):
         argnames = _py_inspect.getfullargspec(mapped_column)
index 5060ac613161565780a3aed48610b54f25258891..a5c2504164ee19b46f36650cd780d287677d2476 100644 (file)
@@ -21,6 +21,7 @@ from typing import TypeVar
 from typing import Union
 import uuid
 
+import typing_extensions
 from typing_extensions import get_args as get_args
 from typing_extensions import Literal as Literal
 from typing_extensions import TypeAlias as TypeAlias
@@ -110,6 +111,9 @@ _Recursive695_0: TypeAlias = _Literal695
 _Recursive695_1: TypeAlias = _Recursive695_0
 _Recursive695_2: TypeAlias = _Recursive695_1
 
+_TypingLiteral = typing.Literal["a", "b"]
+_TypingExtensionsLiteral = typing_extensions.Literal["a", "b"]
+
 if compat.py312:
     exec(
         """
@@ -888,6 +892,21 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             eq_(col.type.enums, ["to-do", "in-progress", "done"])
             is_(col.type.native_enum, False)
 
+    def test_typing_literal_identity(self, decl_base):
+        """See issue #11820"""
+
+        class Foo(decl_base):
+            __tablename__ = "footable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            t: Mapped[_TypingLiteral]
+            te: Mapped[_TypingExtensionsLiteral]
+
+        for col in (Foo.__table__.c.t, Foo.__table__.c.te):
+            is_true(isinstance(col.type, Enum))
+            eq_(col.type.enums, ["a", "b"])
+            is_(col.type.native_enum, False)
+
     @testing.requires.python310
     def test_we_got_all_attrs_test_annotated(self):
         argnames = _py_inspect.getfullargspec(mapped_column)