From 232a92c81f09a44432c267bc8ebd3db9aca78994 Mon Sep 17 00:00:00 2001 From: Frazer McLean Date: Mon, 2 Sep 2024 18:50:21 +0200 Subject: [PATCH] Fix use of typing.Literal on Python 3.8 and 3.9 Fixes: #11820 --- lib/sqlalchemy/util/typing.py | 7 ++++++- .../test_tm_future_annotations_sync.py | 19 +++++++++++++++++++ test/orm/declarative/test_typed_mapping.py | 19 +++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index f4f14e1b56..f66bd1d4b6 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -12,6 +12,7 @@ import builtins import collections.abc as collections_abc import re import sys +import typing from typing import Any from typing import Callable from typing import cast @@ -64,6 +65,10 @@ _VT_co = TypeVar("_VT_co", covariant=True) TupleAny = Tuple[Any, ...] +# typing_extensions.Literal is different from typing.Literal until +# Python 3.10.1 +_LITERAL_TYPES = frozenset([typing.Literal, Literal]) + if compat.py310: # why they took until py310 to put this in stdlib is beyond me, @@ -357,7 +362,7 @@ def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]: def is_literal(type_: _AnnotationScanType) -> bool: - return get_origin(type_) is Literal + return get_origin(type_) in _LITERAL_TYPES def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]: diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index e9b74b0d93..0fe1d59e15 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -30,6 +30,7 @@ from typing import TypeVar from typing import Union import uuid +import typing_extensions from typing_extensions import get_args as get_args from typing_extensions import Literal as Literal from typing_extensions import TypeAlias as TypeAlias @@ -119,6 +120,9 @@ _Recursive695_0: TypeAlias = _Literal695 _Recursive695_1: TypeAlias = _Recursive695_0 _Recursive695_2: TypeAlias = _Recursive695_1 +_TypingLiteral = typing.Literal["a", "b"] +_TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] + if compat.py312: exec( """ @@ -897,6 +901,21 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): eq_(col.type.enums, ["to-do", "in-progress", "done"]) is_(col.type.native_enum, False) + def test_typing_literal_identity(self, decl_base): + """See issue #11820""" + + class Foo(decl_base): + __tablename__ = "footable" + + id: Mapped[int] = mapped_column(primary_key=True) + t: Mapped[_TypingLiteral] + te: Mapped[_TypingExtensionsLiteral] + + for col in (Foo.__table__.c.t, Foo.__table__.c.te): + is_true(isinstance(col.type, Enum)) + eq_(col.type.enums, ["a", "b"]) + is_(col.type.native_enum, False) + @testing.requires.python310 def test_we_got_all_attrs_test_annotated(self): argnames = _py_inspect.getfullargspec(mapped_column) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 5060ac6131..a5c2504164 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -21,6 +21,7 @@ from typing import TypeVar from typing import Union import uuid +import typing_extensions from typing_extensions import get_args as get_args from typing_extensions import Literal as Literal from typing_extensions import TypeAlias as TypeAlias @@ -110,6 +111,9 @@ _Recursive695_0: TypeAlias = _Literal695 _Recursive695_1: TypeAlias = _Recursive695_0 _Recursive695_2: TypeAlias = _Recursive695_1 +_TypingLiteral = typing.Literal["a", "b"] +_TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] + if compat.py312: exec( """ @@ -888,6 +892,21 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): eq_(col.type.enums, ["to-do", "in-progress", "done"]) is_(col.type.native_enum, False) + def test_typing_literal_identity(self, decl_base): + """See issue #11820""" + + class Foo(decl_base): + __tablename__ = "footable" + + id: Mapped[int] = mapped_column(primary_key=True) + t: Mapped[_TypingLiteral] + te: Mapped[_TypingExtensionsLiteral] + + for col in (Foo.__table__.c.t, Foo.__table__.c.te): + is_true(isinstance(col.type, Enum)) + eq_(col.type.enums, ["a", "b"]) + is_(col.type.native_enum, False) + @testing.requires.python310 def test_we_got_all_attrs_test_annotated(self): argnames = _py_inspect.getfullargspec(mapped_column) -- 2.47.3