]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
runtime annotation fixes for relationship
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 3 Jul 2022 20:25:15 +0000 (16:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Jul 2022 02:33:48 +0000 (22:33 -0400)
* derive uselist=False when fwd ref passed to relationship

  This case needs to work whether or not the class name
  is a forward ref.  we dont allow the colleciton to be a
  forward ref so this will work.

* fix issues with MappedCollection

  When using string annotations or __future__.annotations,
  we need to do more parsing in order to get the target
  collection properly

Change-Id: I9e5a1358b62d060a8815826f98190801a9cc0b68

lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/clsregistry.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_dc_transforms.py
test/orm/declarative/test_tm_future_annotations.py
test/orm/declarative/test_typed_mapping.py
test/orm/test_deferred.py

index 4f19ba946013dcd952c83c808b48303c3026ea2a..539cf2600a9215843f0c9646afb7df6e4b0ba86c 100644 (file)
@@ -87,6 +87,10 @@ from .interfaces import PropComparator as PropComparator
 from .interfaces import UserDefinedOption as UserDefinedOption
 from .loading import merge_frozen_result as merge_frozen_result
 from .loading import merge_result as merge_result
+from .mapped_collection import attribute_mapped_collection
+from .mapped_collection import column_mapped_collection
+from .mapped_collection import mapped_collection
+from .mapped_collection import MappedCollection
 from .mapper import configure_mappers as configure_mappers
 from .mapper import Mapper as Mapper
 from .mapper import reconstructor as reconstructor
index b3fcd29ea38268a1b49dfc53a582edc44811c98f..dd79eb1d096d13cf45f932d40028069bb0531f46 100644 (file)
@@ -463,6 +463,7 @@ class _class_resolver:
         generic_match = re.match(r"(.+)\[(.+)\]", name)
 
         if generic_match:
+            clsarg = generic_match.group(2).strip("'")
             raise exc.InvalidRequestError(
                 f"When initializing mapper {self.prop.parent}, "
                 f'expression "relationship({self.arg!r})" seems to be '
@@ -470,7 +471,7 @@ class _class_resolver:
                 "please state the generic argument "
                 "using an annotation, e.g. "
                 f'"{self.prop.key}: Mapped[{generic_match.group(1)}'
-                f'[{generic_match.group(2)}]] = relationship()"'
+                f"['{clsarg}']] = relationship()\""
             ) from err
         else:
             raise exc.InvalidRequestError(
index 630f6898faa93a556283af6139a08677e9b8adf8..77a95a195b6b1bb0b9a6bfede141e369d0b89e3a 100644 (file)
@@ -1724,11 +1724,12 @@ class Relationship(
                     self.collection_class = collection_class
             else:
                 self.uselist = False
+
             if argument.__args__:  # type: ignore
                 if issubclass(
                     argument.__origin__, typing.Mapping  # type: ignore
                 ):
-                    type_arg = argument.__args__[1]  # type: ignore
+                    type_arg = argument.__args__[-1]  # type: ignore
                 else:
                     type_arg = argument.__args__[0]  # type: ignore
                 if hasattr(type_arg, "__forward_arg__"):
@@ -1743,6 +1744,12 @@ class Relationship(
         elif hasattr(argument, "__forward_arg__"):
             argument = argument.__forward_arg__  # type: ignore
 
+            # we don't allow the collection class to be a
+            # __forward_arg__ right now, so if we see a forward arg here,
+            # we know there was no collection class either
+            if self.collection_class is None:
+                self.uselist = False
+
         self.argument = argument
 
     @util.preload_module("sqlalchemy.orm.mapper")
index 317abe2b479fdc21cac3465ada318a6b213b4aca..02080a27f9a16a33ceb9750d935555d88b8da88d 100644 (file)
@@ -1958,8 +1958,12 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any:
 def _is_mapped_annotation(
     raw_annotation: _AnnotationScanType, cls: Type[Any]
 ) -> bool:
-    annotated = de_stringify_annotation(cls, raw_annotation)
-    return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
+    try:
+        annotated = de_stringify_annotation(cls, raw_annotation)
+    except NameError:
+        return False
+    else:
+        return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
 
 
 def _cleanup_mapped_str_annotation(annotation: str) -> str:
@@ -1984,7 +1988,10 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str:
 
         # stack: ['Mapped', 'List', 'Address']
         if not re.match(r"""^["'].*["']$""", stack[-1]):
-            stack[-1] = f'"{stack[-1]}"'
+            stripchars = "\"' "
+            stack[-1] = ", ".join(
+                f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",")
+            )
             # stack: ['Mapped', 'List', '"Address"']
 
             annotation = "[".join(stack) + ("]" * (len(stack) - 1))
@@ -2007,6 +2014,7 @@ def _extract_mapped_subtype(
     Includes error raise scenarios and other options.
 
     """
+
     if raw_annotation is None:
 
         if required:
@@ -2017,9 +2025,19 @@ def _extract_mapped_subtype(
             )
         return None
 
-    annotated = de_stringify_annotation(
-        cls, raw_annotation, _cleanup_mapped_str_annotation
-    )
+    try:
+        annotated = de_stringify_annotation(
+            cls, raw_annotation, _cleanup_mapped_str_annotation
+        )
+    except NameError as ne:
+        if raiseerr and "Mapped[" in raw_annotation:  # type: ignore
+            raise sa_exc.ArgumentError(
+                f"Could not interpret annotation {raw_annotation}.  "
+                "Check that it's not using names that might not be imported "
+                "at the module level.  See chained stack trace for more hints."
+            ) from ne
+
+        annotated = raw_annotation  # type: ignore
 
     if is_dataclass_field:
         return annotated
index 653301f1f2bf62a7ed00af5a9a63e63536b290ea..45fe63765b6ae1325ab0e145f8cf261f86352470 100644 (file)
@@ -113,8 +113,10 @@ def de_stringify_annotation(
 
         try:
             annotation = eval(annotation, base_globals, None)
-        except NameError:
-            pass
+        except NameError as err:
+            raise NameError(
+                f"Could not de-stringify annotation {annotation}"
+            ) from err
     return annotation  # type: ignore
 
 
index 44976b5d88175526ecec50cc2ca15548ad055b5d..f5111bfc79915fc1125d136d9fa9e4ac5243871c 100644 (file)
@@ -38,6 +38,7 @@ from sqlalchemy.testing import eq_regex
 from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import ne_
@@ -547,6 +548,37 @@ class RelationshipDefaultFactoryTest(fixtures.TestBase):
         ):
             A()
 
+    def test_one_to_one_example(self, dc_decl_base: Type[MappedAsDataclass]):
+        """test example in the relationship docs will derive uselist=False
+        correctly"""
+
+        class Parent(dc_decl_base):
+            __tablename__ = "parent"
+
+            id: Mapped[int] = mapped_column(init=False, primary_key=True)
+            child: Mapped["Child"] = relationship(  # noqa: F821
+                back_populates="parent", default=None
+            )
+
+        class Child(dc_decl_base):
+            __tablename__ = "child"
+
+            id: Mapped[int] = mapped_column(init=False, primary_key=True)
+            parent_id: Mapped[int] = mapped_column(
+                ForeignKey("parent.id"), init=False
+            )
+            parent: Mapped["Parent"] = relationship(
+                back_populates="child", default=None
+            )
+
+        c1 = Child()
+        p1 = Parent(child=c1)
+        is_(p1.child, c1)
+        is_(c1.parent, p1)
+
+        p2 = Parent()
+        is_(p2.child, None)
+
     def test_replace_operation_works_w_history_etc(
         self, registry: _RegistryType
     ):
index f8abd686a0a21ff6ed8a06e237be1c50cac1f2b6..74cbebb4da7221d85e2ce22cb67949ff27a3e61a 100644 (file)
@@ -1,13 +1,21 @@
 from __future__ import annotations
 
 from typing import List
+from typing import Set
+from typing import TypeVar
 
+from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
+from sqlalchemy.orm import attribute_mapped_collection
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import MappedCollection
 from sqlalchemy.orm import relationship
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
 from .test_typed_mapping import MappedColumnTest  # noqa
 from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
 
@@ -17,6 +25,13 @@ having ``from __future__ import annotations`` in effect.
 """
 
 
+_R = TypeVar("_R")
+
+
+class MappedOneArg(MappedCollection[str, _R]):
+    pass
+
+
 class RelationshipLHSTest(_RelationshipLHSTest):
     def test_bidirectional_literal_annotations(self, decl_base):
         """test the 'string cleanup' function in orm/util.py, where
@@ -54,3 +69,142 @@ class RelationshipLHSTest(_RelationshipLHSTest):
         b1 = B()
         a1.bs.append(b1)
         is_(a1, b1.a)
+
+    def test_collection_class_uselist_implicit_fwd(self, decl_base):
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+            bs_list: Mapped[List[B]] = relationship(  # noqa: F821
+                viewonly=True
+            )
+            bs_set: Mapped[Set[B]] = relationship(viewonly=True)  # noqa: F821
+            bs_list_warg: Mapped[List[B]] = relationship(  # noqa: F821
+                "B", viewonly=True
+            )
+            bs_set_warg: Mapped[Set[B]] = relationship(  # noqa: F821
+                "B", viewonly=True
+            )
+
+            b_one_to_one: Mapped[B] = relationship(viewonly=True)  # noqa: F821
+
+            b_one_to_one_warg: Mapped[B] = relationship(  # noqa: F821
+                "B", viewonly=True
+            )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+            a: Mapped[A] = relationship(viewonly=True)
+            a_warg: Mapped[A] = relationship("A", viewonly=True)
+
+        is_(A.__mapper__.attrs["bs_list"].collection_class, list)
+        is_(A.__mapper__.attrs["bs_set"].collection_class, set)
+        is_(A.__mapper__.attrs["bs_list_warg"].collection_class, list)
+        is_(A.__mapper__.attrs["bs_set_warg"].collection_class, set)
+        is_true(A.__mapper__.attrs["bs_list"].uselist)
+        is_true(A.__mapper__.attrs["bs_set"].uselist)
+        is_true(A.__mapper__.attrs["bs_list_warg"].uselist)
+        is_true(A.__mapper__.attrs["bs_set_warg"].uselist)
+
+        is_false(A.__mapper__.attrs["b_one_to_one"].uselist)
+        is_false(A.__mapper__.attrs["b_one_to_one_warg"].uselist)
+
+        is_false(B.__mapper__.attrs["a"].uselist)
+        is_false(B.__mapper__.attrs["a_warg"].uselist)
+
+    def test_collection_class_dict_attr_mapped_collection_literal_annotations(
+        self, decl_base
+    ):
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+
+            bs: Mapped[MappedCollection[str, B]] = relationship(  # noqa: F821
+                collection_class=attribute_mapped_collection("name")
+            )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+            name: Mapped[str] = mapped_column()
+
+        self._assert_dict(A, B)
+
+    def test_collection_cls_attr_mapped_collection_dbl_literal_annotations(
+        self, decl_base
+    ):
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+
+            bs: Mapped[
+                MappedCollection[str, "B"]
+            ] = relationship(  # noqa: F821
+                collection_class=attribute_mapped_collection("name")
+            )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+            name: Mapped[str] = mapped_column()
+
+        self._assert_dict(A, B)
+
+    def test_collection_cls_not_locatable(self, decl_base):
+        class MyCollection(MappedCollection):
+            pass
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            r"Could not interpret annotation Mapped\[MyCollection\['B'\]\].",
+        ):
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data: Mapped[str] = mapped_column()
+
+                bs: Mapped[MyCollection["B"]] = relationship(  # noqa: F821
+                    collection_class=attribute_mapped_collection("name")
+                )
+
+    def test_collection_cls_one_arg(self, decl_base):
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+
+            bs: Mapped[MappedOneArg["B"]] = relationship(  # noqa: F821
+                collection_class=attribute_mapped_collection("name")
+            )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+            name: Mapped[str] = mapped_column()
+
+        self._assert_dict(A, B)
+
+    def _assert_dict(self, A, B):
+        A.registry.configure()
+
+        a1 = A()
+        b1 = B(name="foo")
+
+        # collection appender on MappedCollection
+        a1.bs.set(b1)
+
+        is_(a1.bs["foo"], b1)
index beb5d783bf4ede070a7bc5fb9fb09a08ea540f1b..3bf3a01823cc0de09b5ee037d4a2726f5379f5a5 100644 (file)
@@ -1077,6 +1077,15 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 "B", viewonly=True
             )
 
+            # note this is string annotation
+            b_one_to_one: Mapped["B"] = relationship(  # noqa: F821
+                viewonly=True
+            )
+
+            b_one_to_one_warg: Mapped["B"] = relationship(  # noqa: F821
+                "B", viewonly=True
+            )
+
         class B(decl_base):
             __tablename__ = "b"
             id: Mapped[int] = mapped_column(Integer, primary_key=True)
@@ -1094,9 +1103,36 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(A.__mapper__.attrs["bs_list_warg"].uselist)
         is_true(A.__mapper__.attrs["bs_set_warg"].uselist)
 
+        is_false(A.__mapper__.attrs["b_one_to_one"].uselist)
+        is_false(A.__mapper__.attrs["b_one_to_one_warg"].uselist)
+
         is_false(B.__mapper__.attrs["a"].uselist)
         is_false(B.__mapper__.attrs["a_warg"].uselist)
 
+    def test_one_to_one_example(self, decl_base: Type[DeclarativeBase]):
+        """test example in the relationship docs will derive uselist=False
+        correctly"""
+
+        class Parent(decl_base):
+            __tablename__ = "parent"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            child: Mapped["Child"] = relationship(  # noqa: F821
+                back_populates="parent"
+            )
+
+        class Child(decl_base):
+            __tablename__ = "child"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id"))
+            parent: Mapped["Parent"] = relationship(back_populates="child")
+
+        c1 = Child()
+        p1 = Parent(child=c1)
+        is_(p1.child, c1)
+        is_(c1.parent, p1)
+
     def test_collection_class_dict_no_collection(self, decl_base):
         class A(decl_base):
             __tablename__ = "a"
index 14c0e81ee6ac8116ee12691684bfe58089399760..0dda9f52f56d04418a23f99c013bcc92d162971f 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import util
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import attributes
 from sqlalchemy.orm import contains_eager
+from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import defaultload
 from sqlalchemy.orm import defer
 from sqlalchemy.orm import deferred
@@ -18,6 +19,8 @@ from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import lazyload
 from sqlalchemy.orm import Load
 from sqlalchemy.orm import load_only
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import query_expression
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
@@ -86,6 +89,47 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest):
             ],
         )
 
+    def test_basic_w_new_style(self):
+        """sanity check that mapped_column(deferred=True) works"""
+
+        class Base(DeclarativeBase):
+            pass
+
+        class Order(Base):
+            __tablename__ = "orders"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            user_id: Mapped[int]
+            address_id: Mapped[int]
+            isopen: Mapped[bool]
+            description: Mapped[str] = mapped_column(deferred=True)
+
+        q = fixture_session().query(Order).order_by(Order.id)
+
+        def go():
+            result = q.all()
+            o2 = result[2]
+            o2.description
+
+        self.sql_eq_(
+            go,
+            [
+                (
+                    "SELECT orders.id AS orders_id, "
+                    "orders.user_id AS orders_user_id, "
+                    "orders.address_id AS orders_address_id, "
+                    "orders.isopen AS orders_isopen "
+                    "FROM orders ORDER BY orders.id",
+                    {},
+                ),
+                (
+                    "SELECT orders.description AS orders_description "
+                    "FROM orders WHERE orders.id = :pk_1",
+                    {"pk_1": 3},
+                ),
+            ],
+        )
+
     def test_defer_primary_key(self):
         """what happens when we try to defer the primary key?"""