]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
re-process args for builtin generic types
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 31 Aug 2024 16:56:00 +0000 (12:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 1 Sep 2024 20:42:52 +0000 (16:42 -0400)
Improvements to the ORM annotated declarative type map lookup dealing with
composed types such as ``dict[str, Any]`` linking to JSON (or others) with
or without "future annotations" mode.

There's apparently a big incompatiblity in types from typing vs.
Python builtins in the way they genericize.   The typing library makes
it very difficult to distinguish between the two sets of types.  This
patch is a bit slash and burn to work around all this.   These should
likely be reported as bugs in the Python standard library if they
aren't already.

Fixes: #11814
Change-Id: I56a62701d5e883be04df7f45fd9429bb9c1c9a6f
(cherry picked from commit f746fd78e303352d426a15c1f76ee835ce399d44)

doc/build/changelog/unreleased_20/11814.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/11814.rst b/doc/build/changelog/unreleased_20/11814.rst
new file mode 100644 (file)
index 0000000..a9feecb
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11814
+
+    Improvements to the ORM annotated declarative type map lookup dealing with
+    composed types such as ``dict[str, Any]`` linking to JSON (or others) with
+    or without "future annotations" mode.
+
+
index 1203c9cb36a7c97e80416c44d7ada119b313c1b5..d43fbffc576c9c47419bc0007136857151b9f9bd 100644 (file)
@@ -431,7 +431,7 @@ class _ImperativeMapperConfig(_MapperConfig):
 class _CollectedAnnotation(NamedTuple):
     raw_annotation: _AnnotationScanType
     mapped_container: Optional[Type[Mapped[Any]]]
-    extracted_mapped_annotation: Union[Type[Any], str]
+    extracted_mapped_annotation: Union[_AnnotationScanType, str]
     is_dataclass: bool
     attr_value: Any
     originating_module: str
index 0b4ad88ed8b92e7f076bf9b176b651b58baea4cb..2b4ac3c9d7c9dd6c792f398551f36a18bd3ba652 100644 (file)
@@ -89,6 +89,7 @@ from ..util.typing import (
     de_stringify_union_elements as _de_stringify_union_elements,
 )
 from ..util.typing import eval_name_only as _eval_name_only
+from ..util.typing import fixup_container_fwd_refs
 from ..util.typing import is_origin_of_cls
 from ..util.typing import Literal
 from ..util.typing import Protocol
@@ -2321,7 +2322,7 @@ def _extract_mapped_subtype(
     is_dataclass_field: bool,
     expect_mapped: bool = True,
     raiseerr: bool = True,
-) -> Optional[Tuple[Union[type, str], Optional[type]]]:
+) -> Optional[Tuple[Union[_AnnotationScanType, str], Optional[type]]]:
     """given an annotation, figure out if it's ``Mapped[something]`` and if
     so, return the ``something`` part.
 
@@ -2407,7 +2408,11 @@ def _extract_mapped_subtype(
                 "Expected sub-type for Mapped[] annotation"
             )
 
-        return annotated.__args__[0], annotated.__origin__
+        return (
+            # fix dict/list/set args to be ForwardRef, see #11814
+            fixup_container_fwd_refs(annotated.__args__[0]),
+            annotated.__origin__,
+        )
 
 
 def _mapper_property_as_plain_name(prop: Type[Any]) -> str:
index 64619957a6bbebc531a34f716323220f967d6dfb..81e77f629f7b7ad88d51ba0d697988aa1f59cff7 100644 (file)
@@ -188,9 +188,50 @@ def de_stringify_annotation(
         )
 
         return _copy_generic_annotation_with(annotation, elements)
+
     return annotation  # type: ignore
 
 
+def fixup_container_fwd_refs(
+    type_: _AnnotationScanType,
+) -> _AnnotationScanType:
+    """Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')]
+    and similar for list, set
+
+    """
+    if (
+        is_generic(type_)
+        and type_.__origin__
+        in (
+            dict,
+            set,
+            list,
+            collections_abc.MutableSet,
+            collections_abc.MutableMapping,
+            collections_abc.MutableSequence,
+            collections_abc.Mapping,
+            collections_abc.Sequence,
+        )
+        # fight, kick and scream to struggle to tell the difference between
+        # dict[] and typing.Dict[] which DO NOT compare the same and DO NOT
+        # behave the same yet there is NO WAY to distinguish between which type
+        # it is using public attributes
+        and not re.match(
+            "typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_)
+        )
+    ):
+        # compat with py3.10 and earlier
+        return type_.__origin__.__class_getitem__(  # type: ignore
+            tuple(
+                [
+                    ForwardRef(elem) if isinstance(elem, str) else elem
+                    for elem in type_.__args__
+                ]
+            )
+        )
+    return type_
+
+
 def _copy_generic_annotation_with(
     annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
 ) -> Type[_T]:
index 2bdf340d4c077f7c4edfc5c58bb1e7bdd830ac02..765318cfa28ecfc50f9cc5d5c4211fbf1f1d9b67 100644 (file)
@@ -1420,21 +1420,47 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             (str, str),
         ),
         id_="sa",
+        argnames="container_typ,args",
     )
-    def test_extract_generic_from_pep593(self, container_typ, args):
-        """test #9099"""
+    @testing.variation("style", ["pep593", "alias", "direct"])
+    def test_extract_composed(self, container_typ, args, style):
+        """test #9099 (pep593)
+
+        test #11814
+
+        """
 
         global TestType
-        TestType = Annotated[container_typ[args], 0]
+
+        if style.pep593:
+            TestType = Annotated[container_typ[args], 0]
+        elif style.alias:
+            TestType = container_typ[args]
+        elif style.direct:
+            TestType = container_typ
+            double_strings = args == (str, str)
 
         class Base(DeclarativeBase):
-            type_annotation_map = {TestType: JSON()}
+            if style.direct:
+                if double_strings:
+                    type_annotation_map = {TestType[str, str]: JSON()}
+                else:
+                    type_annotation_map = {TestType[str]: JSON()}
+            else:
+                type_annotation_map = {TestType: JSON()}
 
         class MyClass(Base):
             __tablename__ = "my_table"
 
             id: Mapped[int] = mapped_column(primary_key=True)
-            data: Mapped[TestType] = mapped_column()
+
+            if style.direct:
+                if double_strings:
+                    data: Mapped[TestType[str, str]] = mapped_column()
+                else:
+                    data: Mapped[TestType[str]] = mapped_column()
+            else:
+                data: Mapped[TestType] = mapped_column()
 
         is_(MyClass.__table__.c.data.type._type_affinity, JSON)
 
index 6fb792b0ba0541d29226ce465a0dfe430042161b..8b10118f4c92c2d088ee016d34f9b2c16106ec61 100644 (file)
@@ -1411,21 +1411,47 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             (str, str),
         ),
         id_="sa",
+        argnames="container_typ,args",
     )
-    def test_extract_generic_from_pep593(self, container_typ, args):
-        """test #9099"""
+    @testing.variation("style", ["pep593", "alias", "direct"])
+    def test_extract_composed(self, container_typ, args, style):
+        """test #9099 (pep593)
+
+        test #11814
+
+        """
 
         global TestType
-        TestType = Annotated[container_typ[args], 0]
+
+        if style.pep593:
+            TestType = Annotated[container_typ[args], 0]
+        elif style.alias:
+            TestType = container_typ[args]
+        elif style.direct:
+            TestType = container_typ
+            double_strings = args == (str, str)
 
         class Base(DeclarativeBase):
-            type_annotation_map = {TestType: JSON()}
+            if style.direct:
+                if double_strings:
+                    type_annotation_map = {TestType[str, str]: JSON()}
+                else:
+                    type_annotation_map = {TestType[str]: JSON()}
+            else:
+                type_annotation_map = {TestType: JSON()}
 
         class MyClass(Base):
             __tablename__ = "my_table"
 
             id: Mapped[int] = mapped_column(primary_key=True)
-            data: Mapped[TestType] = mapped_column()
+
+            if style.direct:
+                if double_strings:
+                    data: Mapped[TestType[str, str]] = mapped_column()
+                else:
+                    data: Mapped[TestType[str]] = mapped_column()
+            else:
+                data: Mapped[TestType] = mapped_column()
 
         is_(MyClass.__table__.c.data.type._type_affinity, JSON)