From: Mike Bayer Date: Sun, 3 Jul 2022 20:25:15 +0000 (-0400) Subject: runtime annotation fixes for relationship X-Git-Tag: rel_2_0_0b1~191 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=148711cb8515a19b6177dc07655cc6e652de0553;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git runtime annotation fixes for relationship * derive uselist=False when fwd ref passed to relationship This case needs to work whether or not the class name is a forward ref. we dont allow the colleciton to be a forward ref so this will work. * fix issues with MappedCollection When using string annotations or __future__.annotations, we need to do more parsing in order to get the target collection properly Change-Id: I9e5a1358b62d060a8815826f98190801a9cc0b68 --- diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 4f19ba9460..539cf2600a 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -87,6 +87,10 @@ from .interfaces import PropComparator as PropComparator from .interfaces import UserDefinedOption as UserDefinedOption from .loading import merge_frozen_result as merge_frozen_result from .loading import merge_result as merge_result +from .mapped_collection import attribute_mapped_collection +from .mapped_collection import column_mapped_collection +from .mapped_collection import mapped_collection +from .mapped_collection import MappedCollection from .mapper import configure_mappers as configure_mappers from .mapper import Mapper as Mapper from .mapper import reconstructor as reconstructor diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index b3fcd29ea3..dd79eb1d09 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -463,6 +463,7 @@ class _class_resolver: generic_match = re.match(r"(.+)\[(.+)\]", name) if generic_match: + clsarg = generic_match.group(2).strip("'") raise exc.InvalidRequestError( f"When initializing mapper {self.prop.parent}, " f'expression "relationship({self.arg!r})" seems to be ' @@ -470,7 +471,7 @@ class _class_resolver: "please state the generic argument " "using an annotation, e.g. " f'"{self.prop.key}: Mapped[{generic_match.group(1)}' - f'[{generic_match.group(2)}]] = relationship()"' + f"['{clsarg}']] = relationship()\"" ) from err else: raise exc.InvalidRequestError( diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 630f6898fa..77a95a195b 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1724,11 +1724,12 @@ class Relationship( self.collection_class = collection_class else: self.uselist = False + if argument.__args__: # type: ignore if issubclass( argument.__origin__, typing.Mapping # type: ignore ): - type_arg = argument.__args__[1] # type: ignore + type_arg = argument.__args__[-1] # type: ignore else: type_arg = argument.__args__[0] # type: ignore if hasattr(type_arg, "__forward_arg__"): @@ -1743,6 +1744,12 @@ class Relationship( elif hasattr(argument, "__forward_arg__"): argument = argument.__forward_arg__ # type: ignore + # we don't allow the collection class to be a + # __forward_arg__ right now, so if we see a forward arg here, + # we know there was no collection class either + if self.collection_class is None: + self.uselist = False + self.argument = argument @util.preload_module("sqlalchemy.orm.mapper") diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 317abe2b47..02080a27f9 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1958,8 +1958,12 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any: def _is_mapped_annotation( raw_annotation: _AnnotationScanType, cls: Type[Any] ) -> bool: - annotated = de_stringify_annotation(cls, raw_annotation) - return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") + try: + annotated = de_stringify_annotation(cls, raw_annotation) + except NameError: + return False + else: + return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") def _cleanup_mapped_str_annotation(annotation: str) -> str: @@ -1984,7 +1988,10 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str: # stack: ['Mapped', 'List', 'Address'] if not re.match(r"""^["'].*["']$""", stack[-1]): - stack[-1] = f'"{stack[-1]}"' + stripchars = "\"' " + stack[-1] = ", ".join( + f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",") + ) # stack: ['Mapped', 'List', '"Address"'] annotation = "[".join(stack) + ("]" * (len(stack) - 1)) @@ -2007,6 +2014,7 @@ def _extract_mapped_subtype( Includes error raise scenarios and other options. """ + if raw_annotation is None: if required: @@ -2017,9 +2025,19 @@ def _extract_mapped_subtype( ) return None - annotated = de_stringify_annotation( - cls, raw_annotation, _cleanup_mapped_str_annotation - ) + try: + annotated = de_stringify_annotation( + cls, raw_annotation, _cleanup_mapped_str_annotation + ) + except NameError as ne: + if raiseerr and "Mapped[" in raw_annotation: # type: ignore + raise sa_exc.ArgumentError( + f"Could not interpret annotation {raw_annotation}. " + "Check that it's not using names that might not be imported " + "at the module level. See chained stack trace for more hints." + ) from ne + + annotated = raw_annotation # type: ignore if is_dataclass_field: return annotated diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 653301f1f2..45fe63765b 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -113,8 +113,10 @@ def de_stringify_annotation( try: annotation = eval(annotation, base_globals, None) - except NameError: - pass + except NameError as err: + raise NameError( + f"Could not de-stringify annotation {annotation}" + ) from err return annotation # type: ignore diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index 44976b5d88..f5111bfc79 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -38,6 +38,7 @@ from sqlalchemy.testing import eq_regex from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import ne_ @@ -547,6 +548,37 @@ class RelationshipDefaultFactoryTest(fixtures.TestBase): ): A() + def test_one_to_one_example(self, dc_decl_base: Type[MappedAsDataclass]): + """test example in the relationship docs will derive uselist=False + correctly""" + + class Parent(dc_decl_base): + __tablename__ = "parent" + + id: Mapped[int] = mapped_column(init=False, primary_key=True) + child: Mapped["Child"] = relationship( # noqa: F821 + back_populates="parent", default=None + ) + + class Child(dc_decl_base): + __tablename__ = "child" + + id: Mapped[int] = mapped_column(init=False, primary_key=True) + parent_id: Mapped[int] = mapped_column( + ForeignKey("parent.id"), init=False + ) + parent: Mapped["Parent"] = relationship( + back_populates="child", default=None + ) + + c1 = Child() + p1 = Parent(child=c1) + is_(p1.child, c1) + is_(c1.parent, p1) + + p2 = Parent() + is_(p2.child, None) + def test_replace_operation_works_w_history_etc( self, registry: _RegistryType ): diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index f8abd686a0..74cbebb4da 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -1,13 +1,21 @@ from __future__ import annotations from typing import List +from typing import Set +from typing import TypeVar +from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import Integer +from sqlalchemy.orm import attribute_mapped_collection from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import MappedCollection from sqlalchemy.orm import relationship +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true from .test_typed_mapping import MappedColumnTest # noqa from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest @@ -17,6 +25,13 @@ having ``from __future__ import annotations`` in effect. """ +_R = TypeVar("_R") + + +class MappedOneArg(MappedCollection[str, _R]): + pass + + class RelationshipLHSTest(_RelationshipLHSTest): def test_bidirectional_literal_annotations(self, decl_base): """test the 'string cleanup' function in orm/util.py, where @@ -54,3 +69,142 @@ class RelationshipLHSTest(_RelationshipLHSTest): b1 = B() a1.bs.append(b1) is_(a1, b1.a) + + def test_collection_class_uselist_implicit_fwd(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + bs_list: Mapped[List[B]] = relationship( # noqa: F821 + viewonly=True + ) + bs_set: Mapped[Set[B]] = relationship(viewonly=True) # noqa: F821 + bs_list_warg: Mapped[List[B]] = relationship( # noqa: F821 + "B", viewonly=True + ) + bs_set_warg: Mapped[Set[B]] = relationship( # noqa: F821 + "B", viewonly=True + ) + + b_one_to_one: Mapped[B] = relationship(viewonly=True) # noqa: F821 + + b_one_to_one_warg: Mapped[B] = relationship( # noqa: F821 + "B", viewonly=True + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + a: Mapped[A] = relationship(viewonly=True) + a_warg: Mapped[A] = relationship("A", viewonly=True) + + is_(A.__mapper__.attrs["bs_list"].collection_class, list) + is_(A.__mapper__.attrs["bs_set"].collection_class, set) + is_(A.__mapper__.attrs["bs_list_warg"].collection_class, list) + is_(A.__mapper__.attrs["bs_set_warg"].collection_class, set) + is_true(A.__mapper__.attrs["bs_list"].uselist) + is_true(A.__mapper__.attrs["bs_set"].uselist) + is_true(A.__mapper__.attrs["bs_list_warg"].uselist) + is_true(A.__mapper__.attrs["bs_set_warg"].uselist) + + is_false(A.__mapper__.attrs["b_one_to_one"].uselist) + is_false(A.__mapper__.attrs["b_one_to_one_warg"].uselist) + + is_false(B.__mapper__.attrs["a"].uselist) + is_false(B.__mapper__.attrs["a_warg"].uselist) + + def test_collection_class_dict_attr_mapped_collection_literal_annotations( + self, decl_base + ): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + + bs: Mapped[MappedCollection[str, B]] = relationship( # noqa: F821 + collection_class=attribute_mapped_collection("name") + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + name: Mapped[str] = mapped_column() + + self._assert_dict(A, B) + + def test_collection_cls_attr_mapped_collection_dbl_literal_annotations( + self, decl_base + ): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + + bs: Mapped[ + MappedCollection[str, "B"] + ] = relationship( # noqa: F821 + collection_class=attribute_mapped_collection("name") + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + name: Mapped[str] = mapped_column() + + self._assert_dict(A, B) + + def test_collection_cls_not_locatable(self, decl_base): + class MyCollection(MappedCollection): + pass + + with expect_raises_message( + exc.ArgumentError, + r"Could not interpret annotation Mapped\[MyCollection\['B'\]\].", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + + bs: Mapped[MyCollection["B"]] = relationship( # noqa: F821 + collection_class=attribute_mapped_collection("name") + ) + + def test_collection_cls_one_arg(self, decl_base): + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + + bs: Mapped[MappedOneArg["B"]] = relationship( # noqa: F821 + collection_class=attribute_mapped_collection("name") + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + name: Mapped[str] = mapped_column() + + self._assert_dict(A, B) + + def _assert_dict(self, A, B): + A.registry.configure() + + a1 = A() + b1 = B(name="foo") + + # collection appender on MappedCollection + a1.bs.set(b1) + + is_(a1.bs["foo"], b1) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index beb5d783bf..3bf3a01823 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1077,6 +1077,15 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): "B", viewonly=True ) + # note this is string annotation + b_one_to_one: Mapped["B"] = relationship( # noqa: F821 + viewonly=True + ) + + b_one_to_one_warg: Mapped["B"] = relationship( # noqa: F821 + "B", viewonly=True + ) + class B(decl_base): __tablename__ = "b" id: Mapped[int] = mapped_column(Integer, primary_key=True) @@ -1094,9 +1103,36 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(A.__mapper__.attrs["bs_list_warg"].uselist) is_true(A.__mapper__.attrs["bs_set_warg"].uselist) + is_false(A.__mapper__.attrs["b_one_to_one"].uselist) + is_false(A.__mapper__.attrs["b_one_to_one_warg"].uselist) + is_false(B.__mapper__.attrs["a"].uselist) is_false(B.__mapper__.attrs["a_warg"].uselist) + def test_one_to_one_example(self, decl_base: Type[DeclarativeBase]): + """test example in the relationship docs will derive uselist=False + correctly""" + + class Parent(decl_base): + __tablename__ = "parent" + + id: Mapped[int] = mapped_column(primary_key=True) + child: Mapped["Child"] = relationship( # noqa: F821 + back_populates="parent" + ) + + class Child(decl_base): + __tablename__ = "child" + + id: Mapped[int] = mapped_column(primary_key=True) + parent_id: Mapped[int] = mapped_column(ForeignKey("parent.id")) + parent: Mapped["Parent"] = relationship(back_populates="child") + + c1 = Child() + p1 = Parent(child=c1) + is_(p1.child, c1) + is_(c1.parent, p1) + def test_collection_class_dict_no_collection(self, decl_base): class A(decl_base): __tablename__ = "a" diff --git a/test/orm/test_deferred.py b/test/orm/test_deferred.py index 14c0e81ee6..0dda9f52f5 100644 --- a/test/orm/test_deferred.py +++ b/test/orm/test_deferred.py @@ -10,6 +10,7 @@ from sqlalchemy import util from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes from sqlalchemy.orm import contains_eager +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import defaultload from sqlalchemy.orm import defer from sqlalchemy.orm import deferred @@ -18,6 +19,8 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import lazyload from sqlalchemy.orm import Load from sqlalchemy.orm import load_only +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import query_expression from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload @@ -86,6 +89,47 @@ class DeferredTest(AssertsCompiledSQL, _fixtures.FixtureTest): ], ) + def test_basic_w_new_style(self): + """sanity check that mapped_column(deferred=True) works""" + + class Base(DeclarativeBase): + pass + + class Order(Base): + __tablename__ = "orders" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] + address_id: Mapped[int] + isopen: Mapped[bool] + description: Mapped[str] = mapped_column(deferred=True) + + q = fixture_session().query(Order).order_by(Order.id) + + def go(): + result = q.all() + o2 = result[2] + o2.description + + self.sql_eq_( + go, + [ + ( + "SELECT orders.id AS orders_id, " + "orders.user_id AS orders_user_id, " + "orders.address_id AS orders_address_id, " + "orders.isopen AS orders_isopen " + "FROM orders ORDER BY orders.id", + {}, + ), + ( + "SELECT orders.description AS orders_description " + "FROM orders WHERE orders.id = :pk_1", + {"pk_1": 3}, + ), + ], + ) + def test_defer_primary_key(self): """what happens when we try to defer the primary key?"""