From: Mike Bayer Date: Mon, 27 Mar 2023 14:19:32 +0000 (-0400) Subject: check for recursion with container types X-Git-Tag: rel_2_0_8~14^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=2cec50201b8a2ddc0f678be7413ec532616c5c90;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git check for recursion with container types Fixed issue in ORM Annotated Declarative where using a recursive type (e.g. using a nested Dict type) would result in a recursion overflow in the ORM's annotation resolution logic, even if this datatype were not necessary to map the column. Fixes: #9553 Change-Id: Ied99dc0d47276c6e9c23fa9df5fc65f7736d65cf --- diff --git a/doc/build/changelog/unreleased_20/9553.rst b/doc/build/changelog/unreleased_20/9553.rst new file mode 100644 index 0000000000..bdbcfc21b7 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9553.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm + :tickets: 9553 + + Fixed issue in ORM Annotated Declarative where using a recursive type (e.g. + using a nested Dict type) would result in a recursion overflow in the ORM's + annotation resolution logic, even if this datatype were not necessary to + map the column. diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 24d8dd2dc1..9c38ae3443 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -23,6 +23,7 @@ from typing import NewType from typing import NoReturn from typing import Optional from typing import overload +from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING @@ -128,6 +129,7 @@ def de_stringify_annotation( *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, include_generic: bool = False, + _already_seen: Optional[Set[Any]] = None, ) -> Type[Any]: """Resolve annotations that may be string based into real objects. @@ -137,17 +139,15 @@ def de_stringify_annotation( etc. """ - # looked at typing.get_type_hints(), looked at pydantic. We need much # less here, and we here try to not use any private typing internals # or construct ForwardRef objects which is documented as something # that should be avoided. - if ( - is_fwd_ref(annotation) - and not cast(ForwardRef, annotation).__forward_evaluated__ - ): - annotation = cast(ForwardRef, annotation).__forward_arg__ + original_annotation = annotation + + if is_fwd_ref(annotation) and not annotation.__forward_evaluated__: + annotation = annotation.__forward_arg__ if isinstance(annotation, str): if str_cleanup_fn: @@ -162,6 +162,19 @@ def de_stringify_annotation( and is_generic(annotation) and not is_literal(annotation) ): + + if _already_seen is None: + _already_seen = set() + + if annotation in _already_seen: + # only occurs recursively. outermost return type + # will always be Type. + # the element here will be either ForwardRef or + # Optional[ForwardRef] + return original_annotation # type: ignore + else: + _already_seen.add(annotation) + elements = tuple( de_stringify_annotation( cls, @@ -170,6 +183,7 @@ def de_stringify_annotation( locals_, str_cleanup_fn=str_cleanup_fn, include_generic=include_generic, + _already_seen=_already_seen, ) for elem in annotation.__args__ ) @@ -301,7 +315,7 @@ def flatten_newtype(type_: NewType) -> Type[Any]: def is_fwd_ref( type_: _AnnotationScanType, check_generic: bool = False -) -> bool: +) -> TypeGuard[ForwardRef]: if isinstance(type_, ForwardRef): return True elif check_generic and is_generic(type_): @@ -336,7 +350,7 @@ def de_optionalize_union_types( """ if is_fwd_ref(type_): - return de_optionalize_fwd_ref_union_types(cast(ForwardRef, type_)) + return de_optionalize_fwd_ref_union_types(type_) elif is_optional(type_): typ = set(type_.__args__) diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index 307dbc157a..cd6b86e5fd 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -1183,6 +1183,65 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: is_true(A.__table__.c.json.nullable) + @testing.variation("optional", [True, False]) + @testing.variation("provide_type", [True, False]) + @testing.variation("add_to_type_map", [True, False]) + def test_recursive_type( + self, decl_base, optional, provide_type, add_to_type_map + ): + """test #9553""" + + global T + + T = Dict[str, Optional["T"]] + + if not provide_type and not add_to_type_map: + with expect_raises_message( + sa_exc.ArgumentError, + r"Could not locate SQLAlchemy.*" r".*ForwardRef\('T'\).*", + ): + + class TypeTest(decl_base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + if optional: + type_test: Mapped[Optional[T]] = mapped_column() + else: + type_test: Mapped[T] = mapped_column() + + return + + else: + if add_to_type_map: + decl_base.registry.update_type_annotation_map({T: JSON()}) + + class TypeTest(decl_base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + if add_to_type_map: + if optional: + type_test: Mapped[Optional[T]] = mapped_column() + else: + type_test: Mapped[T] = mapped_column() + else: + if optional: + type_test: Mapped[Optional[T]] = mapped_column(JSON()) + else: + type_test: Mapped[T] = mapped_column(JSON()) + + if optional: + is_(TypeTest.__table__.c.type_test.nullable, True) + else: + is_(TypeTest.__table__.c.type_test.nullable, False) + + self.assert_compile( + select(TypeTest), + "SELECT my_table.id, my_table.type_test FROM my_table", + ) + def test_missing_mapped_lhs(self, decl_base): with expect_annotation_syntax_error("User.name"): diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 762c879e6c..98c496a81f 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1174,6 +1174,65 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: is_true(A.__table__.c.json.nullable) + @testing.variation("optional", [True, False]) + @testing.variation("provide_type", [True, False]) + @testing.variation("add_to_type_map", [True, False]) + def test_recursive_type( + self, decl_base, optional, provide_type, add_to_type_map + ): + """test #9553""" + + global T + + T = Dict[str, Optional["T"]] + + if not provide_type and not add_to_type_map: + with expect_raises_message( + sa_exc.ArgumentError, + r"Could not locate SQLAlchemy.*" r".*ForwardRef\('T'\).*", + ): + + class TypeTest(decl_base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + if optional: + type_test: Mapped[Optional[T]] = mapped_column() + else: + type_test: Mapped[T] = mapped_column() + + return + + else: + if add_to_type_map: + decl_base.registry.update_type_annotation_map({T: JSON()}) + + class TypeTest(decl_base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + if add_to_type_map: + if optional: + type_test: Mapped[Optional[T]] = mapped_column() + else: + type_test: Mapped[T] = mapped_column() + else: + if optional: + type_test: Mapped[Optional[T]] = mapped_column(JSON()) + else: + type_test: Mapped[T] = mapped_column(JSON()) + + if optional: + is_(TypeTest.__table__.c.type_test.nullable, True) + else: + is_(TypeTest.__table__.c.type_test.nullable, False) + + self.assert_compile( + select(TypeTest), + "SELECT my_table.id, my_table.type_test FROM my_table", + ) + def test_missing_mapped_lhs(self, decl_base): with expect_annotation_syntax_error("User.name"):