From: Mike Bayer Date: Sat, 31 Aug 2024 16:56:00 +0000 (-0400) Subject: re-process args for builtin generic types X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f746fd78e303352d426a15c1f76ee835ce399d44;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git re-process args for builtin generic types Improvements to the ORM annotated declarative type map lookup dealing with composed types such as ``dict[str, Any]`` linking to JSON (or others) with or without "future annotations" mode. There's apparently a big incompatiblity in types from typing vs. Python builtins in the way they genericize. The typing library makes it very difficult to distinguish between the two sets of types. This patch is a bit slash and burn to work around all this. These should likely be reported as bugs in the Python standard library if they aren't already. Fixes: #11814 Change-Id: I56a62701d5e883be04df7f45fd9429bb9c1c9a6f --- diff --git a/doc/build/changelog/unreleased_20/11814.rst b/doc/build/changelog/unreleased_20/11814.rst new file mode 100644 index 0000000000..a9feecb28c --- /dev/null +++ b/doc/build/changelog/unreleased_20/11814.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm + :tickets: 11814 + + Improvements to the ORM annotated declarative type map lookup dealing with + composed types such as ``dict[str, Any]`` linking to JSON (or others) with + or without "future annotations" mode. + + diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 90396403c2..271c61a8b6 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -431,7 +431,7 @@ class _ImperativeMapperConfig(_MapperConfig): class _CollectedAnnotation(NamedTuple): raw_annotation: _AnnotationScanType mapped_container: Optional[Type[Mapped[Any]]] - extracted_mapped_annotation: Union[Type[Any], str] + extracted_mapped_annotation: Union[_AnnotationScanType, str] is_dataclass: bool attr_value: Any originating_module: str diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index d4dff11e45..6d6fc14715 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -90,6 +90,7 @@ from ..util.typing import ( de_stringify_union_elements as _de_stringify_union_elements, ) from ..util.typing import eval_name_only as _eval_name_only +from ..util.typing import fixup_container_fwd_refs from ..util.typing import is_origin_of_cls from ..util.typing import Literal from ..util.typing import TupleAny @@ -2323,7 +2324,7 @@ def _extract_mapped_subtype( is_dataclass_field: bool, expect_mapped: bool = True, raiseerr: bool = True, -) -> Optional[Tuple[Union[type, str], Optional[type]]]: +) -> Optional[Tuple[Union[_AnnotationScanType, str], Optional[type]]]: """given an annotation, figure out if it's ``Mapped[something]`` and if so, return the ``something`` part. @@ -2409,7 +2410,11 @@ def _extract_mapped_subtype( "Expected sub-type for Mapped[] annotation" ) - return annotated.__args__[0], annotated.__origin__ + return ( + # fix dict/list/set args to be ForwardRef, see #11814 + fixup_container_fwd_refs(annotated.__args__[0]), + annotated.__origin__, + ) def _mapper_property_as_plain_name(prop: Type[Any]) -> str: diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index cfc3a26a97..f4f14e1b56 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -190,9 +190,50 @@ def de_stringify_annotation( ) return _copy_generic_annotation_with(annotation, elements) + return annotation # type: ignore +def fixup_container_fwd_refs( + type_: _AnnotationScanType, +) -> _AnnotationScanType: + """Correct dict['x', 'y'] into dict[ForwardRef('x'), ForwardRef('y')] + and similar for list, set + + """ + if ( + is_generic(type_) + and type_.__origin__ + in ( + dict, + set, + list, + collections_abc.MutableSet, + collections_abc.MutableMapping, + collections_abc.MutableSequence, + collections_abc.Mapping, + collections_abc.Sequence, + ) + # fight, kick and scream to struggle to tell the difference between + # dict[] and typing.Dict[] which DO NOT compare the same and DO NOT + # behave the same yet there is NO WAY to distinguish between which type + # it is using public attributes + and not re.match( + "typing.(?:Dict|List|Set|.*Mapping|.*Sequence|.*Set)", repr(type_) + ) + ): + # compat with py3.10 and earlier + return type_.__origin__.__class_getitem__( # type: ignore + tuple( + [ + ForwardRef(elem) if isinstance(elem, str) else elem + for elem in type_.__args__ + ] + ) + ) + return type_ + + def _copy_generic_annotation_with( annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...] ) -> Type[_T]: diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index 5dca5e246c..e9b74b0d93 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -1420,21 +1420,47 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): (str, str), ), id_="sa", + argnames="container_typ,args", ) - def test_extract_generic_from_pep593(self, container_typ, args): - """test #9099""" + @testing.variation("style", ["pep593", "alias", "direct"]) + def test_extract_composed(self, container_typ, args, style): + """test #9099 (pep593) + + test #11814 + + """ global TestType - TestType = Annotated[container_typ[args], 0] + + if style.pep593: + TestType = Annotated[container_typ[args], 0] + elif style.alias: + TestType = container_typ[args] + elif style.direct: + TestType = container_typ + double_strings = args == (str, str) class Base(DeclarativeBase): - type_annotation_map = {TestType: JSON()} + if style.direct: + if double_strings: + type_annotation_map = {TestType[str, str]: JSON()} + else: + type_annotation_map = {TestType[str]: JSON()} + else: + type_annotation_map = {TestType: JSON()} class MyClass(Base): __tablename__ = "my_table" id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[TestType] = mapped_column() + + if style.direct: + if double_strings: + data: Mapped[TestType[str, str]] = mapped_column() + else: + data: Mapped[TestType[str]] = mapped_column() + else: + data: Mapped[TestType] = mapped_column() is_(MyClass.__table__.c.data.type._type_affinity, JSON) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 25200514dc..5060ac6131 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1411,21 +1411,47 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): (str, str), ), id_="sa", + argnames="container_typ,args", ) - def test_extract_generic_from_pep593(self, container_typ, args): - """test #9099""" + @testing.variation("style", ["pep593", "alias", "direct"]) + def test_extract_composed(self, container_typ, args, style): + """test #9099 (pep593) + + test #11814 + + """ global TestType - TestType = Annotated[container_typ[args], 0] + + if style.pep593: + TestType = Annotated[container_typ[args], 0] + elif style.alias: + TestType = container_typ[args] + elif style.direct: + TestType = container_typ + double_strings = args == (str, str) class Base(DeclarativeBase): - type_annotation_map = {TestType: JSON()} + if style.direct: + if double_strings: + type_annotation_map = {TestType[str, str]: JSON()} + else: + type_annotation_map = {TestType[str]: JSON()} + else: + type_annotation_map = {TestType: JSON()} class MyClass(Base): __tablename__ = "my_table" id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[TestType] = mapped_column() + + if style.direct: + if double_strings: + data: Mapped[TestType[str, str]] = mapped_column() + else: + data: Mapped[TestType[str]] = mapped_column() + else: + data: Mapped[TestType] = mapped_column() is_(MyClass.__table__.c.data.type._type_affinity, JSON)