From: Mike Bayer Date: Sat, 31 Aug 2024 16:56:00 +0000 (-0400) Subject: re-process args for builtin generic types X-Git-Tag: rel_2_0_33~6 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=42ec1f70138d51dd7e61578453faa0f4d47f6ec3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git re-process args for builtin generic types Improvements to the ORM annotated declarative type map lookup dealing with composed types such as ``dict[str, Any]`` linking to JSON (or others) with or without "future annotations" mode. There's apparently a big incompatiblity in types from typing vs. Python builtins in the way they genericize. The typing library makes it very difficult to distinguish between the two sets of types. This patch is a bit slash and burn to work around all this. These should likely be reported as bugs in the Python standard library if they aren't already. Fixes: #11814 Change-Id: I56a62701d5e883be04df7f45fd9429bb9c1c9a6f (cherry picked from commit f746fd78e303352d426a15c1f76ee835ce399d44) --- diff --git a/doc/build/changelog/unreleased_20/11814.rst b/doc/build/changelog/unreleased_20/11814.rst new file mode 100644 index 0000000000..a9feecb28c --- /dev/null +++ b/doc/build/changelog/unreleased_20/11814.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm + :tickets: 11814 + + Improvements to the ORM annotated declarative type map lookup dealing with + composed types such as ``dict[str, Any]`` linking to JSON (or others) with + or without "future annotations" mode. + + diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 1203c9cb36..d43fbffc57 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -431,7 +431,7 @@ class _ImperativeMapperConfig(_MapperConfig): class _CollectedAnnotation(NamedTuple): raw_annotation: _AnnotationScanType mapped_container: Optional[Type[Mapped[Any]]] - extracted_mapped_annotation: Union[Type[Any], str] + extracted_mapped_annotation: Union[_AnnotationScanType, str] is_dataclass: bool attr_value: Any originating_module: str diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 0b4ad88ed8..2b4ac3c9d7 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -89,6 +89,7 @@ from ..util.typing import ( de_stringify_union_elements as _de_stringify_union_elements, ) from ..util.typing import eval_name_only as _eval_name_only +from ..util.typing import fixup_container_fwd_refs from ..util.typing import is_origin_of_cls from ..util.typing import Literal from ..util.typing import Protocol @@ -2321,7 +2322,7 @@ def _extract_mapped_subtype( is_dataclass_field: bool, expect_mapped: bool = True, raiseerr: bool = True, -) -> Optional[Tuple[Union[type, str], Optional[type]]]: +) -> Optional[Tuple[Union[_AnnotationScanType, str], Optional[type]]]: """given an annotation, figure out if it's ``Mapped[something]`` and if so, return the ``something`` part. @@ -2407,7 +2408,11 @@ def _extract_mapped_subtype( "Expected sub-type for Mapped[] annotation" ) - return annotated.__args__[0], annotated.__origin__ + return ( + # fix dict/list/set args to be ForwardRef, see #11814 + fixup_container_fwd_refs(annotated.__args__[0]), + annotated.__origin__, + ) def _mapper_property_as_plain_name(prop: Type[Any]) -> str: diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 64619957a6..81e77f629f 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -188,9 +188,50 @@ def de_stringify_annotation( ) return _copy_generic_annotation_with(annotation, elements) + return annotation # type: ignore +def fixup_container_fwd_refs( + type_: _AnnotationScanType, +) -> _AnnotationScanType: + """Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')] + and similar for list, set + + """ + if ( + is_generic(type_) + and type_.__origin__ + in ( + dict, + set, + list, + collections_abc.MutableSet, + collections_abc.MutableMapping, + collections_abc.MutableSequence, + collections_abc.Mapping, + collections_abc.Sequence, + ) + # fight, kick and scream to struggle to tell the difference between + # dict[] and typing.Dict[] which DO NOT compare the same and DO NOT + # behave the same yet there is NO WAY to distinguish between which type + # it is using public attributes + and not re.match( + "typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_) + ) + ): + # compat with py3.10 and earlier + return type_.__origin__.__class_getitem__( # type: ignore + tuple( + [ + ForwardRef(elem) if isinstance(elem, str) else elem + for elem in type_.__args__ + ] + ) + ) + return type_ + + def _copy_generic_annotation_with( annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...] ) -> Type[_T]: diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index 2bdf340d4c..765318cfa2 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -1420,21 +1420,47 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): (str, str), ), id_="sa", + argnames="container_typ,args", ) - def test_extract_generic_from_pep593(self, container_typ, args): - """test #9099""" + @testing.variation("style", ["pep593", "alias", "direct"]) + def test_extract_composed(self, container_typ, args, style): + """test #9099 (pep593) + + test #11814 + + """ global TestType - TestType = Annotated[container_typ[args], 0] + + if style.pep593: + TestType = Annotated[container_typ[args], 0] + elif style.alias: + TestType = container_typ[args] + elif style.direct: + TestType = container_typ + double_strings = args == (str, str) class Base(DeclarativeBase): - type_annotation_map = {TestType: JSON()} + if style.direct: + if double_strings: + type_annotation_map = {TestType[str, str]: JSON()} + else: + type_annotation_map = {TestType[str]: JSON()} + else: + type_annotation_map = {TestType: JSON()} class MyClass(Base): __tablename__ = "my_table" id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[TestType] = mapped_column() + + if style.direct: + if double_strings: + data: Mapped[TestType[str, str]] = mapped_column() + else: + data: Mapped[TestType[str]] = mapped_column() + else: + data: Mapped[TestType] = mapped_column() is_(MyClass.__table__.c.data.type._type_affinity, JSON) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 6fb792b0ba..8b10118f4c 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1411,21 +1411,47 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): (str, str), ), id_="sa", + argnames="container_typ,args", ) - def test_extract_generic_from_pep593(self, container_typ, args): - """test #9099""" + @testing.variation("style", ["pep593", "alias", "direct"]) + def test_extract_composed(self, container_typ, args, style): + """test #9099 (pep593) + + test #11814 + + """ global TestType - TestType = Annotated[container_typ[args], 0] + + if style.pep593: + TestType = Annotated[container_typ[args], 0] + elif style.alias: + TestType = container_typ[args] + elif style.direct: + TestType = container_typ + double_strings = args == (str, str) class Base(DeclarativeBase): - type_annotation_map = {TestType: JSON()} + if style.direct: + if double_strings: + type_annotation_map = {TestType[str, str]: JSON()} + else: + type_annotation_map = {TestType[str]: JSON()} + else: + type_annotation_map = {TestType: JSON()} class MyClass(Base): __tablename__ = "my_table" id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[TestType] = mapped_column() + + if style.direct: + if double_strings: + data: Mapped[TestType[str, str]] = mapped_column() + else: + data: Mapped[TestType[str]] = mapped_column() + else: + data: Mapped[TestType] = mapped_column() is_(MyClass.__table__.c.data.type._type_affinity, JSON)