]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
check for recursion with container types
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Mar 2023 14:19:32 +0000 (10:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Mar 2023 20:01:16 +0000 (16:01 -0400)
Fixed issue in ORM Annotated Declarative where using a recursive type (e.g.
using a nested Dict type) would result in a recursion overflow in the ORM's
annotation resolution logic, even if this datatype were not necessary to
map the column.

Fixes: #9553
Change-Id: Ied99dc0d47276c6e9c23fa9df5fc65f7736d65cf

doc/build/changelog/unreleased_20/9553.rst [new file with mode: 0644]
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/9553.rst b/doc/build/changelog/unreleased_20/9553.rst
new file mode 100644 (file)
index 0000000..bdbcfc2
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9553
+
+    Fixed issue in ORM Annotated Declarative where using a recursive type (e.g.
+    using a nested Dict type) would result in a recursion overflow in the ORM's
+    annotation resolution logic, even if this datatype were not necessary to
+    map the column.
index 24d8dd2dc11e1d8c921fb49718a67b096c252566..9c38ae34435fa38abe7e1c93feba16c4c868d79f 100644 (file)
@@ -23,6 +23,7 @@ from typing import NewType
 from typing import NoReturn
 from typing import Optional
 from typing import overload
+from typing import Set
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
@@ -128,6 +129,7 @@ def de_stringify_annotation(
     *,
     str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
     include_generic: bool = False,
+    _already_seen: Optional[Set[Any]] = None,
 ) -> Type[Any]:
     """Resolve annotations that may be string based into real objects.
 
@@ -137,17 +139,15 @@ def de_stringify_annotation(
     etc.
 
     """
-
     # looked at typing.get_type_hints(), looked at pydantic.  We need much
     # less here, and we here try to not use any private typing internals
     # or construct ForwardRef objects which is documented as something
     # that should be avoided.
 
-    if (
-        is_fwd_ref(annotation)
-        and not cast(ForwardRef, annotation).__forward_evaluated__
-    ):
-        annotation = cast(ForwardRef, annotation).__forward_arg__
+    original_annotation = annotation
+
+    if is_fwd_ref(annotation) and not annotation.__forward_evaluated__:
+        annotation = annotation.__forward_arg__
 
     if isinstance(annotation, str):
         if str_cleanup_fn:
@@ -162,6 +162,19 @@ def de_stringify_annotation(
         and is_generic(annotation)
         and not is_literal(annotation)
     ):
+
+        if _already_seen is None:
+            _already_seen = set()
+
+        if annotation in _already_seen:
+            # only occurs recursively.  outermost return type
+            # will always be Type.
+            # the element here will be either ForwardRef or
+            # Optional[ForwardRef]
+            return original_annotation  # type: ignore
+        else:
+            _already_seen.add(annotation)
+
         elements = tuple(
             de_stringify_annotation(
                 cls,
@@ -170,6 +183,7 @@ def de_stringify_annotation(
                 locals_,
                 str_cleanup_fn=str_cleanup_fn,
                 include_generic=include_generic,
+                _already_seen=_already_seen,
             )
             for elem in annotation.__args__
         )
@@ -301,7 +315,7 @@ def flatten_newtype(type_: NewType) -> Type[Any]:
 
 def is_fwd_ref(
     type_: _AnnotationScanType, check_generic: bool = False
-) -> bool:
+) -> TypeGuard[ForwardRef]:
     if isinstance(type_, ForwardRef):
         return True
     elif check_generic and is_generic(type_):
@@ -336,7 +350,7 @@ def de_optionalize_union_types(
     """
 
     if is_fwd_ref(type_):
-        return de_optionalize_fwd_ref_union_types(cast(ForwardRef, type_))
+        return de_optionalize_fwd_ref_union_types(type_)
 
     elif is_optional(type_):
         typ = set(type_.__args__)
index 307dbc157a1ad9e76ba78fbc30d73cdcc92f8ecc..cd6b86e5fd611cd0db4ddfcc5f815840ea80b791 100644 (file)
@@ -1183,6 +1183,65 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             is_true(A.__table__.c.json.nullable)
 
+    @testing.variation("optional", [True, False])
+    @testing.variation("provide_type", [True, False])
+    @testing.variation("add_to_type_map", [True, False])
+    def test_recursive_type(
+        self, decl_base, optional, provide_type, add_to_type_map
+    ):
+        """test #9553"""
+
+        global T
+
+        T = Dict[str, Optional["T"]]
+
+        if not provide_type and not add_to_type_map:
+            with expect_raises_message(
+                sa_exc.ArgumentError,
+                r"Could not locate SQLAlchemy.*" r".*ForwardRef\('T'\).*",
+            ):
+
+                class TypeTest(decl_base):
+                    __tablename__ = "my_table"
+
+                    id: Mapped[int] = mapped_column(primary_key=True)
+                    if optional:
+                        type_test: Mapped[Optional[T]] = mapped_column()
+                    else:
+                        type_test: Mapped[T] = mapped_column()
+
+            return
+
+        else:
+            if add_to_type_map:
+                decl_base.registry.update_type_annotation_map({T: JSON()})
+
+            class TypeTest(decl_base):
+                __tablename__ = "my_table"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+
+                if add_to_type_map:
+                    if optional:
+                        type_test: Mapped[Optional[T]] = mapped_column()
+                    else:
+                        type_test: Mapped[T] = mapped_column()
+                else:
+                    if optional:
+                        type_test: Mapped[Optional[T]] = mapped_column(JSON())
+                    else:
+                        type_test: Mapped[T] = mapped_column(JSON())
+
+        if optional:
+            is_(TypeTest.__table__.c.type_test.nullable, True)
+        else:
+            is_(TypeTest.__table__.c.type_test.nullable, False)
+
+        self.assert_compile(
+            select(TypeTest),
+            "SELECT my_table.id, my_table.type_test FROM my_table",
+        )
+
     def test_missing_mapped_lhs(self, decl_base):
         with expect_annotation_syntax_error("User.name"):
 
index 762c879e6cc63235b8f7aaf8209577f3c7436ef3..98c496a81f979fd3134769f73d59440e63c85639 100644 (file)
@@ -1174,6 +1174,65 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             is_true(A.__table__.c.json.nullable)
 
+    @testing.variation("optional", [True, False])
+    @testing.variation("provide_type", [True, False])
+    @testing.variation("add_to_type_map", [True, False])
+    def test_recursive_type(
+        self, decl_base, optional, provide_type, add_to_type_map
+    ):
+        """test #9553"""
+
+        global T
+
+        T = Dict[str, Optional["T"]]
+
+        if not provide_type and not add_to_type_map:
+            with expect_raises_message(
+                sa_exc.ArgumentError,
+                r"Could not locate SQLAlchemy.*" r".*ForwardRef\('T'\).*",
+            ):
+
+                class TypeTest(decl_base):
+                    __tablename__ = "my_table"
+
+                    id: Mapped[int] = mapped_column(primary_key=True)
+                    if optional:
+                        type_test: Mapped[Optional[T]] = mapped_column()
+                    else:
+                        type_test: Mapped[T] = mapped_column()
+
+            return
+
+        else:
+            if add_to_type_map:
+                decl_base.registry.update_type_annotation_map({T: JSON()})
+
+            class TypeTest(decl_base):
+                __tablename__ = "my_table"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+
+                if add_to_type_map:
+                    if optional:
+                        type_test: Mapped[Optional[T]] = mapped_column()
+                    else:
+                        type_test: Mapped[T] = mapped_column()
+                else:
+                    if optional:
+                        type_test: Mapped[Optional[T]] = mapped_column(JSON())
+                    else:
+                        type_test: Mapped[T] = mapped_column(JSON())
+
+        if optional:
+            is_(TypeTest.__table__.c.type_test.nullable, True)
+        else:
+            is_(TypeTest.__table__.c.type_test.nullable, False)
+
+        self.assert_compile(
+            select(TypeTest),
+            "SELECT my_table.id, my_table.type_test FROM my_table",
+        )
+
     def test_missing_mapped_lhs(self, decl_base):
         with expect_annotation_syntax_error("User.name"):