]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix use of typing.Literal on Python 3.8 and 3.9
authorFrazer McLean <frazer@frazermclean.co.uk>
Thu, 5 Sep 2024 11:29:47 +0000 (07:29 -0400)
committerMichael Bayer <mike_mp@zzzcomputing.com>
Thu, 12 Sep 2024 12:26:14 +0000 (12:26 +0000)
Fixed issue where it was not possible to use ``typing.Literal`` with
``Mapped[]`` on Python 3.8 and 3.9.  Pull request courtesy Frazer McLean.

Fixes: #11820
Closes: #11825
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11825
Pull-request-sha: e1e50a97d2a6e0e9ef7ba8dc1a5f07d252e79fa4

Change-Id: Idf04326abcba45813ad555127e81d581a0353587

doc/build/changelog/unreleased_20/11820.rst [new file with mode: 0644]
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/11820.rst b/doc/build/changelog/unreleased_20/11820.rst
new file mode 100644 (file)
index 0000000..ae03040
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, orm, typing
+    :tickets: 11814
+
+    Fixed issue where it was not possible to use ``typing.Literal`` with
+    ``Mapped[]`` on Python 3.8 and 3.9.  Pull request courtesy Frazer McLean.
index 7be6589e03d78c55d9249886aee2f8f37faa6878..3366fca4993f61f5956b123e300931e7dec26e26 100644 (file)
@@ -12,6 +12,7 @@ import builtins
 import collections.abc as collections_abc
 import re
 import sys
+import typing
 from typing import Any
 from typing import Callable
 from typing import cast
@@ -64,6 +65,10 @@ _VT_co = TypeVar("_VT_co", covariant=True)
 
 TupleAny = Tuple[Any, ...]
 
+# typing_extensions.Literal is different from typing.Literal until
+# Python 3.10.1
+_LITERAL_TYPES = frozenset([typing.Literal, Literal])
+
 
 if compat.py310:
     # why they took until py310 to put this in stdlib is beyond me,
@@ -358,7 +363,7 @@ def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
 
 
 def is_literal(type_: _AnnotationScanType) -> bool:
-    return get_origin(type_) is Literal
+    return get_origin(type_) in _LITERAL_TYPES
 
 
 def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
index eb1e605d10e0af094f0312463b05f1545026f558..e473245b82f78ac678b8376bf902f5e364582ffa 100644 (file)
@@ -30,6 +30,7 @@ from typing import TypeVar
 from typing import Union
 import uuid
 
+import typing_extensions
 from typing_extensions import get_args as get_args
 from typing_extensions import Literal as Literal
 from typing_extensions import TypeAlias as TypeAlias
@@ -119,6 +120,9 @@ _Recursive695_0: TypeAlias = _Literal695
 _Recursive695_1: TypeAlias = _Recursive695_0
 _Recursive695_2: TypeAlias = _Recursive695_1
 
+_TypingLiteral = typing.Literal["a", "b"]
+_TypingExtensionsLiteral = typing_extensions.Literal["a", "b"]
+
 if compat.py312:
     exec(
         """
@@ -897,6 +901,21 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             eq_(col.type.enums, ["to-do", "in-progress", "done"])
             is_(col.type.native_enum, False)
 
+    def test_typing_literal_identity(self, decl_base):
+        """See issue #11820"""
+
+        class Foo(decl_base):
+            __tablename__ = "footable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            t: Mapped[_TypingLiteral]
+            te: Mapped[_TypingExtensionsLiteral]
+
+        for col in (Foo.__table__.c.t, Foo.__table__.c.te):
+            is_true(isinstance(col.type, Enum))
+            eq_(col.type.enums, ["a", "b"])
+            is_(col.type.native_enum, False)
+
     @testing.requires.python310
     def test_we_got_all_attrs_test_annotated(self):
         argnames = _py_inspect.getfullargspec(mapped_column)
index c9eacbae7da81343ee048dee23ddff1238a9c7c9..36adbd197dbdf90cfe9cb5e023d1d706dd19321f 100644 (file)
@@ -21,6 +21,7 @@ from typing import TypeVar
 from typing import Union
 import uuid
 
+import typing_extensions
 from typing_extensions import get_args as get_args
 from typing_extensions import Literal as Literal
 from typing_extensions import TypeAlias as TypeAlias
@@ -110,6 +111,9 @@ _Recursive695_0: TypeAlias = _Literal695
 _Recursive695_1: TypeAlias = _Recursive695_0
 _Recursive695_2: TypeAlias = _Recursive695_1
 
+_TypingLiteral = typing.Literal["a", "b"]
+_TypingExtensionsLiteral = typing_extensions.Literal["a", "b"]
+
 if compat.py312:
     exec(
         """
@@ -888,6 +892,21 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             eq_(col.type.enums, ["to-do", "in-progress", "done"])
             is_(col.type.native_enum, False)
 
+    def test_typing_literal_identity(self, decl_base):
+        """See issue #11820"""
+
+        class Foo(decl_base):
+            __tablename__ = "footable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            t: Mapped[_TypingLiteral]
+            te: Mapped[_TypingExtensionsLiteral]
+
+        for col in (Foo.__table__.c.t, Foo.__table__.c.te):
+            is_true(isinstance(col.type, Enum))
+            eq_(col.type.enums, ["a", "b"])
+            is_(col.type.native_enum, False)
+
     @testing.requires.python310
     def test_we_got_all_attrs_test_annotated(self):
         argnames = _py_inspect.getfullargspec(mapped_column)