From: Mike Bayer Date: Mon, 16 Jan 2023 15:31:39 +0000 (-0500) Subject: dont assume copy_with() on builtins list, dict, etc; improve error msg. X-Git-Tag: rel_2_0_0rc3~8^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=046272e06aa3284a87e0dd1f90d2242fb434de10;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git dont assume copy_with() on builtins list, dict, etc; improve error msg. Fixed issue where using an ``Annotated`` type in the ``type_annotation_map`` which itself contained a plain container type (e.g. ``list``, ``dict``) generic type as the target type would produce an internal error where the ORM were trying to interpret the ``Annotated`` instance. Added an error message when a :func:`_orm.relationship` is mapped against an abstract container type, such as ``Mapped[Sequence[B]]``, without providing the :paramref:`_orm.relationship.container_class` parameter which is necessary when the type is abstract. Previously the the abstract container would attempt to be instantiated and fail. Fixes: #9099 Fixes: #9100 Change-Id: I18aa6abd5451c5ac75a9caed8441ff0cd8f44589 --- diff --git a/doc/build/changelog/unreleased_20/9099.rst b/doc/build/changelog/unreleased_20/9099.rst new file mode 100644 index 0000000000..c5e997235b --- /dev/null +++ b/doc/build/changelog/unreleased_20/9099.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, orm + :tickets: 9099 + + Fixed issue where using a pep-593 ``Annotated`` type in the + :paramref:`_orm.registry.type_annotation_map` which itself contained a + generic plain container or ``collections.abc`` type (e.g. ``list``, + ``dict``, ``collections.abc.Sequence``, etc. ) as the target type would + produce an internal error when the ORM were trying to interpret the + ``Annotated`` instance. + + diff --git a/doc/build/changelog/unreleased_20/9100.rst b/doc/build/changelog/unreleased_20/9100.rst new file mode 100644 index 0000000000..4c3c65ad59 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9100.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, orm + :tickets: 9100 + + Added an error message when a :func:`_orm.relationship` is mapped against + an abstract container type, such as ``Mapped[Sequence[B]]``, without + providing the :paramref:`_orm.relationship.container_class` parameter which + is necessary when the type is abstract. Previously the the abstract + container would attempt to be instantiated at a later step and fail. + + diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index bfd39c3694..66d3a60355 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -18,6 +18,7 @@ from __future__ import annotations import collections from collections import abc import dataclasses +import inspect as _py_inspect import re import typing from typing import Any @@ -1768,7 +1769,18 @@ class RelationshipProperty( arg_origin, abc.Collection ): if self.collection_class is None: + if _py_inspect.isabstract(arg_origin): + raise sa_exc.ArgumentError( + f"Collection annotation type {arg_origin} cannot " + "be instantiated; please provide an explicit " + "'collection_class' parameter " + "(e.g. list, set, etc.) to the " + "relationship() function to accompany this " + "annotation" + ) + self.collection_class = arg_origin + elif not is_write_only and not is_dynamic: self.uselist = False diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index b1ef87db17..e1670ed21b 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -97,8 +97,12 @@ class GenericProtocol(Protocol[_T]): __args__: Tuple[_AnnotationScanType, ...] __origin__: Type[_T] - def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]: - ... + # Python's builtin _GenericAlias has this method, however builtins like + # list, dict, etc. do not, even though they have ``__origin__`` and + # ``__args__`` + # + # def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]: + # ... class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): @@ -158,10 +162,21 @@ def de_stringify_annotation( for elem in annotation.__args__ ) - return annotation.copy_with(elements) + return _copy_generic_annotation_with(annotation, elements) return annotation # type: ignore +def _copy_generic_annotation_with( + annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...] +) -> Type[_T]: + if hasattr(annotation, "copy_with"): + # List, Dict, etc. real generics + return annotation.copy_with(elements) # type: ignore + else: + # Python builtins list, dict, etc. + return annotation.__origin__[elements] # type: ignore + + def eval_expression(expression: str, module_name: str) -> Any: try: base_globals: Dict[str, Any] = sys.modules[module_name].__dict__ diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index 787e09aac9..ae10f7d8e4 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -7,10 +7,12 @@ Do not edit manually, any change will be lost. from __future__ import annotations +import collections.abc import dataclasses import datetime from decimal import Decimal import enum +import typing from typing import Any from typing import ClassVar from typing import Dict @@ -703,6 +705,48 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) ) + @testing.combinations( + (collections.abc.Sequence, (str,), testing.requires.python310), + (collections.abc.MutableSequence, (str,), testing.requires.python310), + (collections.abc.Mapping, (str, str), testing.requires.python310), + ( + collections.abc.MutableMapping, + (str, str), + testing.requires.python310, + ), + (typing.Mapping, (str, str), testing.requires.python310), + (typing.MutableMapping, (str, str), testing.requires.python310), + (typing.Sequence, (str,)), + (typing.MutableSequence, (str,)), + (list, (str,), testing.requires.python310), + ( + List, + (str,), + ), + (dict, (str, str), testing.requires.python310), + ( + Dict, + (str, str), + ), + id_="sa", + ) + def test_extract_generic_from_pep593(self, container_typ, args): + """test #9099""" + + global TestType + TestType = Annotated[container_typ[args], 0] + + class Base(DeclarativeBase): + type_annotation_map = {TestType: JSON()} + + class MyClass(Base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[TestType] = mapped_column() + + is_(MyClass.__table__.c.data.type._type_affinity, JSON) + @testing.combinations( ("default", lambda ctx: 10), ("default", func.foo()), @@ -1516,6 +1560,94 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): ): registry.configure() + @testing.variation( + "datatype", + [ + "typing_sequence", + ("collections_sequence", testing.requires.python310), + "typing_mutable_sequence", + ("collections_mutable_sequence", testing.requires.python310), + ], + ) + @testing.variation("include_explicit", [True, False]) + def test_relationship_abstract_cls_error( + self, decl_base, datatype, include_explicit + ): + """test #9100""" + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + data: Mapped[str] + + if include_explicit: + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + # note this can be done more succinctly by assigning to + # an interim type, however making it explicit here + # allows us to further test de-stringifying of these + # collection types + if datatype.typing_sequence: + bs: Mapped[typing.Sequence[B]] = relationship( + collection_class=list + ) + elif datatype.collections_sequence: + bs: Mapped[collections.abc.Sequence[B]] = relationship( + collection_class=list + ) + elif datatype.typing_mutable_sequence: + bs: Mapped[typing.MutableSequence[B]] = relationship( + collection_class=list + ) + elif datatype.collections_mutable_sequence: + bs: Mapped[ + collections.abc.MutableSequence[B] + ] = relationship(collection_class=list) + else: + datatype.fail() + + decl_base.registry.configure() + self.assert_compile( + select(A).join(A.bs), + "SELECT a.id FROM a JOIN b ON a.id = b.a_id", + ) + else: + with expect_raises_message( + sa_exc.ArgumentError, + r"Collection annotation type " + r".*Sequence.* cannot be " + r"instantiated; please provide an explicit " + r"'collection_class' parameter \(e.g. list, set, etc.\) to " + r"the relationship\(\) function to accompany this annotation", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + if datatype.typing_sequence: + bs: Mapped[typing.Sequence[B]] = relationship() + elif datatype.collections_sequence: + bs: Mapped[ + collections.abc.Sequence[B] + ] = relationship() + elif datatype.typing_mutable_sequence: + bs: Mapped[typing.MutableSequence[B]] = relationship() + elif datatype.collections_mutable_sequence: + bs: Mapped[ + collections.abc.MutableSequence[B] + ] = relationship() + else: + datatype.fail() + + decl_base.registry.configure() + def test_14_style_anno_accepted_w_allow_unmapped(self): """test for #8692""" diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index da73104e26..8838afd0ff 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1,7 +1,9 @@ +import collections.abc import dataclasses import datetime from decimal import Decimal import enum +import typing from typing import Any from typing import ClassVar from typing import Dict @@ -694,6 +696,48 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) ) + @testing.combinations( + (collections.abc.Sequence, (str,), testing.requires.python310), + (collections.abc.MutableSequence, (str,), testing.requires.python310), + (collections.abc.Mapping, (str, str), testing.requires.python310), + ( + collections.abc.MutableMapping, + (str, str), + testing.requires.python310, + ), + (typing.Mapping, (str, str), testing.requires.python310), + (typing.MutableMapping, (str, str), testing.requires.python310), + (typing.Sequence, (str,)), + (typing.MutableSequence, (str,)), + (list, (str,), testing.requires.python310), + ( + List, + (str,), + ), + (dict, (str, str), testing.requires.python310), + ( + Dict, + (str, str), + ), + id_="sa", + ) + def test_extract_generic_from_pep593(self, container_typ, args): + """test #9099""" + + global TestType + TestType = Annotated[container_typ[args], 0] + + class Base(DeclarativeBase): + type_annotation_map = {TestType: JSON()} + + class MyClass(Base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[TestType] = mapped_column() + + is_(MyClass.__table__.c.data.type._type_affinity, JSON) + @testing.combinations( ("default", lambda ctx: 10), ("default", func.foo()), @@ -1507,6 +1551,94 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): ): registry.configure() + @testing.variation( + "datatype", + [ + "typing_sequence", + ("collections_sequence", testing.requires.python310), + "typing_mutable_sequence", + ("collections_mutable_sequence", testing.requires.python310), + ], + ) + @testing.variation("include_explicit", [True, False]) + def test_relationship_abstract_cls_error( + self, decl_base, datatype, include_explicit + ): + """test #9100""" + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + data: Mapped[str] + + if include_explicit: + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + # note this can be done more succinctly by assigning to + # an interim type, however making it explicit here + # allows us to further test de-stringifying of these + # collection types + if datatype.typing_sequence: + bs: Mapped[typing.Sequence[B]] = relationship( + collection_class=list + ) + elif datatype.collections_sequence: + bs: Mapped[collections.abc.Sequence[B]] = relationship( + collection_class=list + ) + elif datatype.typing_mutable_sequence: + bs: Mapped[typing.MutableSequence[B]] = relationship( + collection_class=list + ) + elif datatype.collections_mutable_sequence: + bs: Mapped[ + collections.abc.MutableSequence[B] + ] = relationship(collection_class=list) + else: + datatype.fail() + + decl_base.registry.configure() + self.assert_compile( + select(A).join(A.bs), + "SELECT a.id FROM a JOIN b ON a.id = b.a_id", + ) + else: + with expect_raises_message( + sa_exc.ArgumentError, + r"Collection annotation type " + r".*Sequence.* cannot be " + r"instantiated; please provide an explicit " + r"'collection_class' parameter \(e.g. list, set, etc.\) to " + r"the relationship\(\) function to accompany this annotation", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + if datatype.typing_sequence: + bs: Mapped[typing.Sequence[B]] = relationship() + elif datatype.collections_sequence: + bs: Mapped[ + collections.abc.Sequence[B] + ] = relationship() + elif datatype.typing_mutable_sequence: + bs: Mapped[typing.MutableSequence[B]] = relationship() + elif datatype.collections_mutable_sequence: + bs: Mapped[ + collections.abc.MutableSequence[B] + ] = relationship() + else: + datatype.fail() + + decl_base.registry.configure() + def test_14_style_anno_accepted_w_allow_unmapped(self): """test for #8692"""