]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dont assume copy_with() on builtins list, dict, etc; improve error msg.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 16 Jan 2023 15:31:39 +0000 (10:31 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 16 Jan 2023 19:23:49 +0000 (14:23 -0500)
Fixed issue where using an ``Annotated`` type in the
``type_annotation_map`` which itself contained a plain container type (e.g.
``list``, ``dict``) generic type as the target type would produce an
internal error where the ORM were trying to interpret the ``Annotated``
instance.

Added an error message when a :func:`_orm.relationship` is mapped against
an abstract container type, such as ``Mapped[Sequence[B]]``, without
providing the :paramref:`_orm.relationship.container_class` parameter which
is necessary when the type is abstract. Previously the the abstract
container would attempt to be instantiated and fail.

Fixes: #9099
Fixes: #9100
Change-Id: I18aa6abd5451c5ac75a9caed8441ff0cd8f44589

doc/build/changelog/unreleased_20/9099.rst [new file with mode: 0644]
doc/build/changelog/unreleased_20/9100.rst [new file with mode: 0644]
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/9099.rst b/doc/build/changelog/unreleased_20/9099.rst
new file mode 100644 (file)
index 0000000..c5e9972
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9099
+
+    Fixed issue where using a pep-593 ``Annotated`` type in the
+    :paramref:`_orm.registry.type_annotation_map` which itself contained a
+    generic plain container or ``collections.abc`` type (e.g. ``list``,
+    ``dict``, ``collections.abc.Sequence``, etc. ) as the target type would
+    produce an internal error when the ORM were trying to interpret the
+    ``Annotated`` instance.
+
+
diff --git a/doc/build/changelog/unreleased_20/9100.rst b/doc/build/changelog/unreleased_20/9100.rst
new file mode 100644 (file)
index 0000000..4c3c65a
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9100
+
+    Added an error message when a :func:`_orm.relationship` is mapped against
+    an abstract container type, such as ``Mapped[Sequence[B]]``, without
+    providing the :paramref:`_orm.relationship.container_class` parameter which
+    is necessary when the type is abstract. Previously the the abstract
+    container would attempt to be instantiated at a later step and fail.
+
+
index bfd39c369433517ca900c40d96355d9233c29455..66d3a60355bf366e4c720533bdc1b1f1528c7764 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import annotations
 import collections
 from collections import abc
 import dataclasses
+import inspect as _py_inspect
 import re
 import typing
 from typing import Any
@@ -1768,7 +1769,18 @@ class RelationshipProperty(
                 arg_origin, abc.Collection
             ):
                 if self.collection_class is None:
+                    if _py_inspect.isabstract(arg_origin):
+                        raise sa_exc.ArgumentError(
+                            f"Collection annotation type {arg_origin} cannot "
+                            "be instantiated; please provide an explicit "
+                            "'collection_class' parameter "
+                            "(e.g. list, set, etc.) to the "
+                            "relationship() function to accompany this "
+                            "annotation"
+                        )
+
                     self.collection_class = arg_origin
+
             elif not is_write_only and not is_dynamic:
                 self.uselist = False
 
index b1ef87db1719029e305e0e2770d6d0f8533e9004..e1670ed21b0c2eac0d399669e3bf1f35646023d9 100644 (file)
@@ -97,8 +97,12 @@ class GenericProtocol(Protocol[_T]):
     __args__: Tuple[_AnnotationScanType, ...]
     __origin__: Type[_T]
 
-    def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
-        ...
+    # Python's builtin _GenericAlias has this method, however builtins like
+    # list, dict, etc. do not, even though they have ``__origin__`` and
+    # ``__args__``
+    #
+    # def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
+    #     ...
 
 
 class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
@@ -158,10 +162,21 @@ def de_stringify_annotation(
             for elem in annotation.__args__
         )
 
-        return annotation.copy_with(elements)
+        return _copy_generic_annotation_with(annotation, elements)
     return annotation  # type: ignore
 
 
+def _copy_generic_annotation_with(
+    annotation: GenericProtocol[_T], elements: Tuple[_AnnotationScanType, ...]
+) -> Type[_T]:
+    if hasattr(annotation, "copy_with"):
+        # List, Dict, etc. real generics
+        return annotation.copy_with(elements)  # type: ignore
+    else:
+        # Python builtins list, dict, etc.
+        return annotation.__origin__[elements]  # type: ignore
+
+
 def eval_expression(expression: str, module_name: str) -> Any:
     try:
         base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
index 787e09aac9711c7682d351a54b93a7fbca0e1454..ae10f7d8e4200395efe7f26a4b5bc62cde7fbac4 100644 (file)
@@ -7,10 +7,12 @@ Do not edit manually, any change will be lost.
 
 from __future__ import annotations
 
+import collections.abc
 import dataclasses
 import datetime
 from decimal import Decimal
 import enum
+import typing
 from typing import Any
 from typing import ClassVar
 from typing import Dict
@@ -703,6 +705,48 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             )
         )
 
+    @testing.combinations(
+        (collections.abc.Sequence, (str,), testing.requires.python310),
+        (collections.abc.MutableSequence, (str,), testing.requires.python310),
+        (collections.abc.Mapping, (str, str), testing.requires.python310),
+        (
+            collections.abc.MutableMapping,
+            (str, str),
+            testing.requires.python310,
+        ),
+        (typing.Mapping, (str, str), testing.requires.python310),
+        (typing.MutableMapping, (str, str), testing.requires.python310),
+        (typing.Sequence, (str,)),
+        (typing.MutableSequence, (str,)),
+        (list, (str,), testing.requires.python310),
+        (
+            List,
+            (str,),
+        ),
+        (dict, (str, str), testing.requires.python310),
+        (
+            Dict,
+            (str, str),
+        ),
+        id_="sa",
+    )
+    def test_extract_generic_from_pep593(self, container_typ, args):
+        """test #9099"""
+
+        global TestType
+        TestType = Annotated[container_typ[args], 0]
+
+        class Base(DeclarativeBase):
+            type_annotation_map = {TestType: JSON()}
+
+        class MyClass(Base):
+            __tablename__ = "my_table"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[TestType] = mapped_column()
+
+        is_(MyClass.__table__.c.data.type._type_affinity, JSON)
+
     @testing.combinations(
         ("default", lambda ctx: 10),
         ("default", func.foo()),
@@ -1516,6 +1560,94 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         ):
             registry.configure()
 
+    @testing.variation(
+        "datatype",
+        [
+            "typing_sequence",
+            ("collections_sequence", testing.requires.python310),
+            "typing_mutable_sequence",
+            ("collections_mutable_sequence", testing.requires.python310),
+        ],
+    )
+    @testing.variation("include_explicit", [True, False])
+    def test_relationship_abstract_cls_error(
+        self, decl_base, datatype, include_explicit
+    ):
+        """test #9100"""
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+            data: Mapped[str]
+
+        if include_explicit:
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+
+                # note this can be done more succinctly by assigning to
+                # an interim type, however making it explicit here
+                # allows us to further test de-stringifying of these
+                # collection types
+                if datatype.typing_sequence:
+                    bs: Mapped[typing.Sequence[B]] = relationship(
+                        collection_class=list
+                    )
+                elif datatype.collections_sequence:
+                    bs: Mapped[collections.abc.Sequence[B]] = relationship(
+                        collection_class=list
+                    )
+                elif datatype.typing_mutable_sequence:
+                    bs: Mapped[typing.MutableSequence[B]] = relationship(
+                        collection_class=list
+                    )
+                elif datatype.collections_mutable_sequence:
+                    bs: Mapped[
+                        collections.abc.MutableSequence[B]
+                    ] = relationship(collection_class=list)
+                else:
+                    datatype.fail()
+
+            decl_base.registry.configure()
+            self.assert_compile(
+                select(A).join(A.bs),
+                "SELECT a.id FROM a JOIN b ON a.id = b.a_id",
+            )
+        else:
+            with expect_raises_message(
+                sa_exc.ArgumentError,
+                r"Collection annotation type "
+                r".*Sequence.* cannot be "
+                r"instantiated; please provide an explicit "
+                r"'collection_class' parameter \(e.g. list, set, etc.\) to "
+                r"the relationship\(\) function to accompany this annotation",
+            ):
+
+                class A(decl_base):
+                    __tablename__ = "a"
+
+                    id: Mapped[int] = mapped_column(primary_key=True)
+
+                    if datatype.typing_sequence:
+                        bs: Mapped[typing.Sequence[B]] = relationship()
+                    elif datatype.collections_sequence:
+                        bs: Mapped[
+                            collections.abc.Sequence[B]
+                        ] = relationship()
+                    elif datatype.typing_mutable_sequence:
+                        bs: Mapped[typing.MutableSequence[B]] = relationship()
+                    elif datatype.collections_mutable_sequence:
+                        bs: Mapped[
+                            collections.abc.MutableSequence[B]
+                        ] = relationship()
+                    else:
+                        datatype.fail()
+
+                decl_base.registry.configure()
+
     def test_14_style_anno_accepted_w_allow_unmapped(self):
         """test for #8692"""
 
index da73104e26a5bdbcee3f88abc124b17fa3121138..8838afd0ff3c7e8e2b997dc99fe7b1096956bd3b 100644 (file)
@@ -1,7 +1,9 @@
+import collections.abc
 import dataclasses
 import datetime
 from decimal import Decimal
 import enum
+import typing
 from typing import Any
 from typing import ClassVar
 from typing import Dict
@@ -694,6 +696,48 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             )
         )
 
+    @testing.combinations(
+        (collections.abc.Sequence, (str,), testing.requires.python310),
+        (collections.abc.MutableSequence, (str,), testing.requires.python310),
+        (collections.abc.Mapping, (str, str), testing.requires.python310),
+        (
+            collections.abc.MutableMapping,
+            (str, str),
+            testing.requires.python310,
+        ),
+        (typing.Mapping, (str, str), testing.requires.python310),
+        (typing.MutableMapping, (str, str), testing.requires.python310),
+        (typing.Sequence, (str,)),
+        (typing.MutableSequence, (str,)),
+        (list, (str,), testing.requires.python310),
+        (
+            List,
+            (str,),
+        ),
+        (dict, (str, str), testing.requires.python310),
+        (
+            Dict,
+            (str, str),
+        ),
+        id_="sa",
+    )
+    def test_extract_generic_from_pep593(self, container_typ, args):
+        """test #9099"""
+
+        global TestType
+        TestType = Annotated[container_typ[args], 0]
+
+        class Base(DeclarativeBase):
+            type_annotation_map = {TestType: JSON()}
+
+        class MyClass(Base):
+            __tablename__ = "my_table"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[TestType] = mapped_column()
+
+        is_(MyClass.__table__.c.data.type._type_affinity, JSON)
+
     @testing.combinations(
         ("default", lambda ctx: 10),
         ("default", func.foo()),
@@ -1507,6 +1551,94 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         ):
             registry.configure()
 
+    @testing.variation(
+        "datatype",
+        [
+            "typing_sequence",
+            ("collections_sequence", testing.requires.python310),
+            "typing_mutable_sequence",
+            ("collections_mutable_sequence", testing.requires.python310),
+        ],
+    )
+    @testing.variation("include_explicit", [True, False])
+    def test_relationship_abstract_cls_error(
+        self, decl_base, datatype, include_explicit
+    ):
+        """test #9100"""
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+            data: Mapped[str]
+
+        if include_explicit:
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+
+                # note this can be done more succinctly by assigning to
+                # an interim type, however making it explicit here
+                # allows us to further test de-stringifying of these
+                # collection types
+                if datatype.typing_sequence:
+                    bs: Mapped[typing.Sequence[B]] = relationship(
+                        collection_class=list
+                    )
+                elif datatype.collections_sequence:
+                    bs: Mapped[collections.abc.Sequence[B]] = relationship(
+                        collection_class=list
+                    )
+                elif datatype.typing_mutable_sequence:
+                    bs: Mapped[typing.MutableSequence[B]] = relationship(
+                        collection_class=list
+                    )
+                elif datatype.collections_mutable_sequence:
+                    bs: Mapped[
+                        collections.abc.MutableSequence[B]
+                    ] = relationship(collection_class=list)
+                else:
+                    datatype.fail()
+
+            decl_base.registry.configure()
+            self.assert_compile(
+                select(A).join(A.bs),
+                "SELECT a.id FROM a JOIN b ON a.id = b.a_id",
+            )
+        else:
+            with expect_raises_message(
+                sa_exc.ArgumentError,
+                r"Collection annotation type "
+                r".*Sequence.* cannot be "
+                r"instantiated; please provide an explicit "
+                r"'collection_class' parameter \(e.g. list, set, etc.\) to "
+                r"the relationship\(\) function to accompany this annotation",
+            ):
+
+                class A(decl_base):
+                    __tablename__ = "a"
+
+                    id: Mapped[int] = mapped_column(primary_key=True)
+
+                    if datatype.typing_sequence:
+                        bs: Mapped[typing.Sequence[B]] = relationship()
+                    elif datatype.collections_sequence:
+                        bs: Mapped[
+                            collections.abc.Sequence[B]
+                        ] = relationship()
+                    elif datatype.typing_mutable_sequence:
+                        bs: Mapped[typing.MutableSequence[B]] = relationship()
+                    elif datatype.collections_mutable_sequence:
+                        bs: Mapped[
+                            collections.abc.MutableSequence[B]
+                        ] = relationship()
+                    else:
+                        datatype.fail()
+
+                decl_base.registry.configure()
+
     def test_14_style_anno_accepted_w_allow_unmapped(self):
         """test for #8692"""