]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improve pep-695 inference including Enum support
authorAlc-Alc <alc@localhost>
Thu, 25 Apr 2024 19:42:34 +0000 (15:42 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 2 May 2024 22:45:47 +0000 (18:45 -0400)
Fixed issue in ORM Annotated Declarative where typing issue where literals
defined using :pep:`695` type aliases would not work with inference of
:class:`.Enum` datatypes. Pull request courtesy of Alc-Alc.

Fixes: #11305
Closes: #11313
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11313
Pull-request-sha: 090f0d865c4129cffffbce6a6ce3db9b91602460

Change-Id: Iac63302ad74fd7018a34a50c80ec3aeb87dc94a4

doc/build/changelog/unreleased_20/11305.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/11305.rst b/doc/build/changelog/unreleased_20/11305.rst
new file mode 100644 (file)
index 0000000..0a022c0
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11305
+
+    Fixed issue in ORM Annotated Declarative where typing issue where literals
+    defined using :pep:`695` type aliases would not work with inference of
+    :class:`.Enum` datatypes. Pull request courtesy of Alc-Alc.
index 72dded0e0932b9526b9e07eb5adb89cbea732fa1..3c26a17036af8aa5c1ba9fbdb2b653bce52592c5 100644 (file)
@@ -1232,31 +1232,39 @@ class registry:
     def _resolve_type(
         self, python_type: _MatchedOnType
     ) -> Optional[sqltypes.TypeEngine[Any]]:
-        search: Iterable[Tuple[_MatchedOnType, Type[Any]]]
+
+        python_type_to_check = python_type
+        while is_pep695(python_type_to_check):
+            python_type_to_check = python_type_to_check.__value__
+
+        check_is_pt = python_type is python_type_to_check
+
         python_type_type: Type[Any]
+        search: Iterable[Tuple[_MatchedOnType, Type[Any]]]
 
-        if is_generic(python_type):
-            if is_literal(python_type):
-                python_type_type = cast("Type[Any]", python_type)
+        if is_generic(python_type_to_check):
+            if is_literal(python_type_to_check):
+                python_type_type = cast("Type[Any]", python_type_to_check)
 
                 search = (  # type: ignore[assignment]
                     (python_type, python_type_type),
                     (Literal, python_type_type),
                 )
             else:
-                python_type_type = python_type.__origin__
+                python_type_type = python_type_to_check.__origin__
                 search = ((python_type, python_type_type),)
-        elif is_newtype(python_type):
-            python_type_type = flatten_newtype(python_type)
-            search = ((python_type, python_type_type),)
-        elif is_pep695(python_type):
-            python_type_type = python_type.__value__
-            flattened = None
+        elif is_newtype(python_type_to_check):
+            python_type_type = flatten_newtype(python_type_to_check)
             search = ((python_type, python_type_type),)
+        elif isinstance(python_type_to_check, type):
+            python_type_type = python_type_to_check
+            search = (
+                (pt if check_is_pt else python_type, pt)
+                for pt in python_type_type.__mro__
+            )
         else:
-            python_type_type = cast("Type[Any]", python_type)
-            flattened = None
-            search = ((pt, pt) for pt in python_type_type.__mro__)
+            python_type_type = python_type_to_check  # type: ignore[assignment]
+            search = ((python_type, python_type_type),)
 
         for pt, flattened in search:
             # we search through full __mro__ for types.  however...
index 60f71947e0d6649cc3fb944cd7562e9b57b18bfd..5dca5e246c35e1d2dbacd4c08a1a1fd2114c68f8 100644 (file)
@@ -111,8 +111,13 @@ _UnionTypeAlias: TypeAlias = Union[_SomeDict1, _SomeDict2]
 
 _StrTypeAlias: TypeAlias = str
 
-_StrPep695: TypeAlias = Union[_SomeDict1, _SomeDict2]
-_UnionPep695: TypeAlias = str
+_StrPep695: TypeAlias = str
+_UnionPep695: TypeAlias = Union[_SomeDict1, _SomeDict2]
+
+_Literal695: TypeAlias = Literal["to-do", "in-progress", "done"]
+_Recursive695_0: TypeAlias = _Literal695
+_Recursive695_1: TypeAlias = _Recursive695_0
+_Recursive695_2: TypeAlias = _Recursive695_1
 
 if compat.py312:
     exec(
@@ -126,6 +131,11 @@ strtypalias_tat: typing.TypeAliasType = Annotated[
     str, mapped_column(info={"hi": "there"})]
 
 strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})]
+
+type _Literal695 = Literal["to-do", "in-progress", "done"]
+type _Recursive695_0 = _Literal695
+type _Recursive695_1 = _Recursive695_0
+type _Recursive695_2 = _Recursive695_1
 """,
         globals(),
     )
@@ -838,9 +848,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         class Test(decl_base):
             __tablename__ = "test"
             id: Mapped[int] = mapped_column(primary_key=True)
-            data: Mapped[_StrPep695]  # type: ignore
-            structure: Mapped[_UnionPep695]  # type: ignore
+            data: Mapped[_StrPep695]
+            structure: Mapped[_UnionPep695]
 
+        eq_(Test.__table__.c.data.type._type_affinity, String)
         eq_(Test.__table__.c.data.type.length, 30)
         is_(Test.__table__.c.structure.type._type_affinity, JSON)
 
@@ -870,6 +881,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         eq_(MyClass.data_one.expression.info, {"hi": "there"})
 
+    @testing.requires.python312
+    def test_pep695_literal_defaults_to_enum(self, decl_base):
+        """test #11305."""
+
+        class Foo(decl_base):
+            __tablename__ = "footable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            status: Mapped[_Literal695]
+            r2: Mapped[_Recursive695_2]
+
+        for col in (Foo.__table__.c.status, Foo.__table__.c.r2):
+            is_true(isinstance(col.type, Enum))
+            eq_(col.type.enums, ["to-do", "in-progress", "done"])
+            is_(col.type.native_enum, False)
+
     @testing.requires.python310
     def test_we_got_all_attrs_test_annotated(self):
         argnames = _py_inspect.getfullargspec(mapped_column)
index a1af50cbadb4c9acfc5dfb5093833389b12886c1..25200514dc3cce94f603cbd042e20128b8a319b9 100644 (file)
@@ -102,8 +102,13 @@ _UnionTypeAlias: TypeAlias = Union[_SomeDict1, _SomeDict2]
 
 _StrTypeAlias: TypeAlias = str
 
-_StrPep695: TypeAlias = Union[_SomeDict1, _SomeDict2]
-_UnionPep695: TypeAlias = str
+_StrPep695: TypeAlias = str
+_UnionPep695: TypeAlias = Union[_SomeDict1, _SomeDict2]
+
+_Literal695: TypeAlias = Literal["to-do", "in-progress", "done"]
+_Recursive695_0: TypeAlias = _Literal695
+_Recursive695_1: TypeAlias = _Recursive695_0
+_Recursive695_2: TypeAlias = _Recursive695_1
 
 if compat.py312:
     exec(
@@ -117,6 +122,11 @@ strtypalias_tat: typing.TypeAliasType = Annotated[
     str, mapped_column(info={"hi": "there"})]
 
 strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})]
+
+type _Literal695 = Literal["to-do", "in-progress", "done"]
+type _Recursive695_0 = _Literal695
+type _Recursive695_1 = _Recursive695_0
+type _Recursive695_2 = _Recursive695_1
 """,
         globals(),
     )
@@ -829,9 +839,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         class Test(decl_base):
             __tablename__ = "test"
             id: Mapped[int] = mapped_column(primary_key=True)
-            data: Mapped[_StrPep695]  # type: ignore
-            structure: Mapped[_UnionPep695]  # type: ignore
+            data: Mapped[_StrPep695]
+            structure: Mapped[_UnionPep695]
 
+        eq_(Test.__table__.c.data.type._type_affinity, String)
         eq_(Test.__table__.c.data.type.length, 30)
         is_(Test.__table__.c.structure.type._type_affinity, JSON)
 
@@ -861,6 +872,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         eq_(MyClass.data_one.expression.info, {"hi": "there"})
 
+    @testing.requires.python312
+    def test_pep695_literal_defaults_to_enum(self, decl_base):
+        """test #11305."""
+
+        class Foo(decl_base):
+            __tablename__ = "footable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            status: Mapped[_Literal695]
+            r2: Mapped[_Recursive695_2]
+
+        for col in (Foo.__table__.c.status, Foo.__table__.c.r2):
+            is_true(isinstance(col.type, Enum))
+            eq_(col.type.enums, ["to-do", "in-progress", "done"])
+            is_(col.type.native_enum, False)
+
     @testing.requires.python310
     def test_we_got_all_attrs_test_annotated(self):
         argnames = _py_inspect.getfullargspec(mapped_column)