]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
allow callable for relationship.back_populates
authorPriyanshu Parikh <parikhpriyanshu7@gmail.com>
Sun, 15 Oct 2023 14:34:25 +0000 (10:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 17 Jan 2024 16:14:20 +0000 (11:14 -0500)
The :paramref:`_orm.relationship.back_populates` argument to
:func:`_orm.relationship` may now be passed as a Python callable, which
resolves to either the direct linked ORM attribute, or a string value as
before.  ORM attributes are also accepted directly by
:paramref:`_orm.relationship.back_populates`.   This change allows type
checkers and IDEs to confirm the argument for
:paramref:`_orm.relationship.back_populates` is valid. Thanks to Priyanshu
Parikh for the help on suggesting and helping to implement this feature.

An attempt was made to do this for ForeignKey as well, however this
is not feasible since there is no "deferred configuration" step
for Table objects; Table objects set each other up on ForeignKey
as they are created, such as setting the type of a column
in a referencing Column when the referenced table is created.  We
have no way to know which Table a foreign key intends to represent
when it's a callable whereas when it's a string, we do know, and
we actually make a lot of use of that string to match it to the
target table as that target is created (see _fk_memos).  However
the commit keeps a little bit of the cleanup to ForeignKey intact.

Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Fixes: #10050
Closes: #10260
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10260
Pull-request-sha: 6f21002e1d5bbe291111655f33b19e4eb4b3cb84

Change-Id: I8e0a40c9898ec91d44f2df06dcc22f33b06745c3

doc/build/changelog/migration_21.rst
doc/build/changelog/unreleased_21/10050.rst [new file with mode: 0644]
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/schema.py
test/orm/test_relationships.py

index 0795a3fe9fd9f59db612e88cd2fe16751a61e9dc..8edea83839908800aa181ab568a7421afc1c6610 100644 (file)
@@ -34,3 +34,33 @@ need to be aware of this extra installation dependency.
 
 :ticket:`10197`
 
+
+.. _change_10050:
+
+ORM Relationship allows callable for back_populates
+---------------------------------------------------
+
+To help produce code that is more amenable to IDE-level linting and type
+checking, the :paramref:`_orm.relationship.back_populates` parameter now
+accepts both direct references to a class-bound attribute as well as
+lambdas which do the same::
+
+    class A(Base):
+        __tablename__ = "a"
+
+        id: Mapped[int] = mapped_column(primary_key=True)
+
+        # use a lambda: to link to B.a directly when it exists
+        bs: Mapped[list[B]] = relationship(back_populates=lambda: B.a)
+
+
+    class B(Base):
+        __tablename__ = "b"
+        id: Mapped[int] = mapped_column(primary_key=True)
+        a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+        # A.bs already exists, so can link directly
+        a: Mapped[A] = relationship(back_populates=A.bs)
+
+:ticket:`10050`
+
diff --git a/doc/build/changelog/unreleased_21/10050.rst b/doc/build/changelog/unreleased_21/10050.rst
new file mode 100644 (file)
index 0000000..a1c1753
--- /dev/null
@@ -0,0 +1,17 @@
+.. change::
+    :tags: feature, orm
+    :tickets: 10050
+
+    The :paramref:`_orm.relationship.back_populates` argument to
+    :func:`_orm.relationship` may now be passed as a Python callable, which
+    resolves to either the direct linked ORM attribute, or a string value as
+    before.  ORM attributes are also accepted directly by
+    :paramref:`_orm.relationship.back_populates`.   This change allows type
+    checkers and IDEs to confirm the argument for
+    :paramref:`_orm.relationship.back_populates` is valid. Thanks to Priyanshu
+    Parikh for the help on suggesting and helping to implement this feature.
+
+    .. seealso::
+
+        :ref:`change_10050`
+
index ba9bb516f842106676c1edbe3bdd96a15b865531..3a7f826e1d19c4a5e1ecdbe401fd7bd73f95a643 100644 (file)
@@ -28,6 +28,7 @@ from .properties import MappedColumn
 from .properties import MappedSQLExpression
 from .query import AliasOption
 from .relationships import _RelationshipArgumentType
+from .relationships import _RelationshipBackPopulatesArgument
 from .relationships import _RelationshipSecondaryArgument
 from .relationships import Relationship
 from .relationships import RelationshipProperty
@@ -922,7 +923,7 @@ def relationship(
     ] = None,
     primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
     secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
-    back_populates: Optional[str] = None,
+    back_populates: Optional[_RelationshipBackPopulatesArgument] = None,
     order_by: _ORMOrderByArgument = False,
     backref: Optional[ORMBackrefArgument] = None,
     overlaps: Optional[str] = None,
index 30cbec96a1ac29c73dcafb95c285163c0fade2de..58b413bed93c43dcab1e2d3339776913a0278035 100644 (file)
@@ -176,6 +176,13 @@ _ORMOrderByArgument = Union[
     Callable[[], Iterable[_ColumnExpressionArgument[Any]]],
     Iterable[Union[str, _ColumnExpressionArgument[Any]]],
 ]
+_RelationshipBackPopulatesArgument = Union[
+    str,
+    PropComparator[Any],
+    Callable[[], Union[str, PropComparator[Any]]],
+]
+
+
 ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
 
 _ORMColCollectionElement = Union[
@@ -273,10 +280,32 @@ class _RelationshipArg(Generic[_T1, _T2]):
         else:
             self.resolved = attr_value
 
+    def effective_value(self) -> Any:
+        if self.resolved is not None:
+            return self.resolved
+        else:
+            return self.argument
+
 
 _RelationshipOrderByArg = Union[Literal[False], Tuple[ColumnElement[Any], ...]]
 
 
+@dataclasses.dataclass
+class _StringRelationshipArg(_RelationshipArg[_T1, _T2]):
+    def _resolve_against_registry(
+        self, clsregistry_resolver: Callable[[str, bool], _class_resolver]
+    ) -> None:
+        attr_value = self.argument
+
+        if callable(attr_value):
+            attr_value = attr_value()
+
+        if isinstance(attr_value, attributes.QueryableAttribute):
+            attr_value = attr_value.key  # type: ignore
+
+        self.resolved = attr_value
+
+
 class _RelationshipArgs(NamedTuple):
     """stores user-passed parameters that are resolved at mapper configuration
     time.
@@ -302,6 +331,9 @@ class _RelationshipArgs(NamedTuple):
     remote_side: _RelationshipArg[
         Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
     ]
+    back_populates: _StringRelationshipArg[
+        Optional[_RelationshipBackPopulatesArgument], str
+    ]
 
 
 @log.class_logger
@@ -372,7 +404,7 @@ class RelationshipProperty(
         ] = None,
         primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
         secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
-        back_populates: Optional[str] = None,
+        back_populates: Optional[_RelationshipBackPopulatesArgument] = None,
         order_by: _ORMOrderByArgument = False,
         backref: Optional[ORMBackrefArgument] = None,
         overlaps: Optional[str] = None,
@@ -417,6 +449,7 @@ class RelationshipProperty(
             _RelationshipArg("order_by", order_by, None),
             _RelationshipArg("foreign_keys", foreign_keys, None),
             _RelationshipArg("remote_side", remote_side, None),
+            _StringRelationshipArg("back_populates", back_populates, None),
         )
 
         self.post_update = post_update
@@ -487,9 +520,7 @@ class RelationshipProperty(
         # mypy ignoring the @property setter
         self.cascade = cascade  # type: ignore
 
-        self.back_populates = back_populates
-
-        if self.back_populates:
+        if back_populates:
             if backref:
                 raise sa_exc.ArgumentError(
                     "backref and back_populates keyword arguments "
@@ -499,6 +530,14 @@ class RelationshipProperty(
         else:
             self.backref = backref
 
+    @property
+    def back_populates(self) -> str:
+        return self._init_args.back_populates.effective_value()  # type: ignore
+
+    @back_populates.setter
+    def back_populates(self, value: str) -> None:
+        self._init_args.back_populates.argument = value
+
     def _warn_for_persistence_only_flags(self, **kw: Any) -> None:
         for k, v in kw.items():
             if v != self._persistence_only[k]:
@@ -1672,6 +1711,7 @@ class RelationshipProperty(
             "secondary",
             "foreign_keys",
             "remote_side",
+            "back_populates",
         ):
             rel_arg = getattr(init_args, attr)
 
@@ -2054,7 +2094,10 @@ class RelationshipProperty(
 
         if self.parent.non_primary:
             return
-        if self.backref is not None and not self.back_populates:
+
+        resolve_back_populates = self._init_args.back_populates.resolved
+
+        if self.backref is not None and not resolve_back_populates:
             kwargs: Dict[str, Any]
             if isinstance(self.backref, str):
                 backref_key, kwargs = self.backref, {}
@@ -2125,8 +2168,18 @@ class RelationshipProperty(
                 backref_key, relationship, warn_for_existing=True
             )
 
-        if self.back_populates:
-            self._add_reverse_property(self.back_populates)
+        if resolve_back_populates:
+            if isinstance(resolve_back_populates, PropComparator):
+                back_populates = resolve_back_populates.prop.key
+            elif isinstance(resolve_back_populates, str):
+                back_populates = resolve_back_populates
+            else:
+                # need test coverage for this case as well
+                raise sa_exc.ArgumentError(
+                    f"Invalid back_populates value: {resolve_back_populates!r}"
+                )
+
+            self._add_reverse_property(back_populates)
 
     @util.preload_module("sqlalchemy.orm.dependency")
     def _post_init(self) -> None:
index f5f6fb1775b7ae9854bd0e7f9b8c7d01edb63541..7c3e58b4bca575ae2282ef2782b5edda188b3856 100644 (file)
@@ -266,6 +266,8 @@ used for :class:`.PrimaryKeyConstraint`, :class:`.UniqueConstraint`, etc.
 
 """
 
+_DDLColumnReferenceArgument = _DDLColumnArgument
+
 _DMLTableArgument = Union[
     "TableClause",
     "Join",
index 78586937b14ccd04ed3866aef0548975bc391ebb..7d3d1f521ed9aaeb3626664ac7a0588c58ec3c20 100644 (file)
@@ -92,6 +92,7 @@ from ..util.typing import TypeGuard
 if typing.TYPE_CHECKING:
     from ._typing import _AutoIncrementType
     from ._typing import _DDLColumnArgument
+    from ._typing import _DDLColumnReferenceArgument
     from ._typing import _InfoType
     from ._typing import _TextCoercedExpressionArgument
     from ._typing import _TypeEngineArgument
@@ -2768,9 +2769,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
 
     _table_column: Optional[Column[Any]]
 
+    _colspec: Union[str, Column[Any]]
+
     def __init__(
         self,
-        column: _DDLColumnArgument,
+        column: _DDLColumnReferenceArgument,
         _constraint: Optional[ForeignKeyConstraint] = None,
         use_alter: bool = False,
         name: _ConstraintNameArgument = None,
@@ -2856,21 +2859,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
 
         """
 
-        self._colspec = coercions.expect(roles.DDLReferredColumnRole, column)
         self._unresolvable = _unresolvable
 
-        if isinstance(self._colspec, str):
-            self._table_column = None
-        else:
-            self._table_column = self._colspec
-
-            if not isinstance(
-                self._table_column.table, (type(None), TableClause)
-            ):
-                raise exc.ArgumentError(
-                    "ForeignKey received Column not bound "
-                    "to a Table, got: %r" % self._table_column.table
-                )
+        self._colspec, self._table_column = self._parse_colspec_argument(
+            column
+        )
 
         # the linked ForeignKeyConstraint.
         # ForeignKey will create this when parent Column
@@ -2895,6 +2888,33 @@ class ForeignKey(DialectKWArgs, SchemaItem):
             self.info = info
         self._unvalidated_dialect_kw = dialect_kw
 
+    def _resolve_colspec_argument(
+        self,
+    ) -> Tuple[Union[str, Column[Any]], Optional[Column[Any]],]:
+        argument = self._colspec
+
+        return self._parse_colspec_argument(argument)
+
+    def _parse_colspec_argument(
+        self,
+        argument: _DDLColumnArgument,
+    ) -> Tuple[Union[str, Column[Any]], Optional[Column[Any]],]:
+        _colspec = coercions.expect(roles.DDLReferredColumnRole, argument)
+
+        if isinstance(_colspec, str):
+            _table_column = None
+        else:
+            assert isinstance(_colspec, ColumnClause)
+            _table_column = _colspec
+
+            if not isinstance(_table_column.table, (type(None), TableClause)):
+                raise exc.ArgumentError(
+                    "ForeignKey received Column not bound "
+                    "to a Table, got: %r" % _table_column.table
+                )
+
+        return _colspec, _table_column
+
     def __repr__(self) -> str:
         return "ForeignKey(%r)" % self._get_colspec()
 
@@ -2954,6 +2974,9 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         argument first passed to the object's constructor.
 
         """
+
+        _colspec, effective_table_column = self._resolve_colspec_argument()
+
         if schema not in (None, RETAIN_SCHEMA):
             _schema, tname, colname = self._column_tokens
             if table_name is not None:
@@ -2968,28 +2991,30 @@ class ForeignKey(DialectKWArgs, SchemaItem):
                 return "%s.%s.%s" % (schema, table_name, colname)
             else:
                 return "%s.%s" % (table_name, colname)
-        elif self._table_column is not None:
-            if self._table_column.table is None:
+        elif effective_table_column is not None:
+            if effective_table_column.table is None:
                 if _is_copy:
                     raise exc.InvalidRequestError(
                         f"Can't copy ForeignKey object which refers to "
-                        f"non-table bound Column {self._table_column!r}"
+                        f"non-table bound Column {effective_table_column!r}"
                     )
                 else:
-                    return self._table_column.key
+                    return effective_table_column.key
             return "%s.%s" % (
-                self._table_column.table.fullname,
-                self._table_column.key,
+                effective_table_column.table.fullname,
+                effective_table_column.key,
             )
         else:
-            assert isinstance(self._colspec, str)
-            return self._colspec
+            assert isinstance(_colspec, str)
+            return _colspec
 
     @property
     def _referred_schema(self) -> Optional[str]:
         return self._column_tokens[0]
 
-    def _table_key(self) -> Any:
+    def _table_key_within_construction(self) -> Any:
+        """get the table key but only safely"""
+
         if self._table_column is not None:
             if self._table_column.table is None:
                 return None
@@ -3028,10 +3053,6 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         """parse a string-based _colspec into its component parts."""
 
         m = self._get_colspec().split(".")
-        if m is None:
-            raise exc.ArgumentError(
-                f"Invalid foreign key column specification: {self._colspec}"
-            )
         if len(m) == 1:
             tname = m.pop()
             colname = None
@@ -3121,7 +3142,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         if _column is None:
             raise exc.NoReferencedColumnError(
                 "Could not initialize target column "
-                f"for ForeignKey '{self._colspec}' "
+                f"for ForeignKey '{self._get_colspec()}' "
                 f"on table '{parenttable.name}': "
                 f"table '{table.name}' has no column named '{key}'",
                 table.name,
@@ -3157,7 +3178,6 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         is raised.
 
         """
-
         return self._resolve_column()
 
     @overload
@@ -3175,7 +3195,9 @@ class ForeignKey(DialectKWArgs, SchemaItem):
     ) -> Optional[Column[Any]]:
         _column: Column[Any]
 
-        if isinstance(self._colspec, str):
+        _colspec, effective_table_column = self._resolve_colspec_argument()
+
+        if isinstance(_colspec, str):
             parenttable, tablekey, colname = self._resolve_col_tokens()
 
             if self._unresolvable or tablekey not in parenttable.metadata:
@@ -3201,11 +3223,12 @@ class ForeignKey(DialectKWArgs, SchemaItem):
                     parenttable, table, colname
                 )
 
-        elif hasattr(self._colspec, "__clause_element__"):
-            _column = self._colspec.__clause_element__()
+        elif hasattr(_colspec, "__clause_element__"):
+            _column = _colspec.__clause_element__()
             return _column
         else:
-            _column = self._colspec
+            assert isinstance(_colspec, Column)
+            _column = _colspec
             return _column
 
     def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
@@ -3257,7 +3280,9 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         table.foreign_keys.add(self)
         # set up remote ".column" attribute, or a note to pick it
         # up when the other Table/Column shows up
-        if isinstance(self._colspec, str):
+
+        _colspec, _ = self._resolve_colspec_argument()
+        if isinstance(_colspec, str):
             parenttable, table_key, colname = self._resolve_col_tokens()
             fk_key = (table_key, colname)
             if table_key in parenttable.metadata.tables:
@@ -3273,12 +3298,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
                     self._set_target_column(_column)
 
             parenttable.metadata._fk_memos[fk_key].append(self)
-        elif hasattr(self._colspec, "__clause_element__"):
-            _column = self._colspec.__clause_element__()
+        elif hasattr(_colspec, "__clause_element__"):
+            _column = _colspec.__clause_element__()
             self._set_target_column(_column)
         else:
-            _column = self._colspec
-            self._set_target_column(_column)
+            self._set_target_column(_colspec)
 
 
 if TYPE_CHECKING:
@@ -4603,7 +4627,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
     def __init__(
         self,
         columns: _typing_Sequence[_DDLColumnArgument],
-        refcolumns: _typing_Sequence[_DDLColumnArgument],
+        refcolumns: _typing_Sequence[_DDLColumnReferenceArgument],
         name: _ConstraintNameArgument = None,
         onupdate: Optional[str] = None,
         ondelete: Optional[str] = None,
@@ -4789,7 +4813,9 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
         return self.elements[0].column.table
 
     def _validate_dest_table(self, table: Table) -> None:
-        table_keys = {elem._table_key() for elem in self.elements}
+        table_keys = {
+            elem._table_key_within_construction() for elem in self.elements
+        }
         if None not in table_keys and len(table_keys) > 1:
             elem0, elem1 = sorted(table_keys)[0:2]
             raise exc.ArgumentError(
@@ -4862,7 +4888,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
                     schema=schema,
                     table_name=target_table.name
                     if target_table is not None
-                    and x._table_key() == x.parent.table.key
+                    and x._table_key_within_construction()
+                    == x.parent.table.key
                     else None,
                     _is_copy=True,
                 )
index 969196ad8ca8a853074bf0088f6714b54d21d78c..d644d26793b0004d88ea0651a05d9bcd30bdf850 100644 (file)
@@ -2297,6 +2297,98 @@ class ManualBackrefTest(_fixtures.FixtureTest):
         assert a1.user is u1
         assert a1 in u1.addresses
 
+    @testing.variation(
+        "argtype", ["str", "callable_str", "prop", "callable_prop"]
+    )
+    def test_o2m_with_callable(self, argtype):
+        """test #10050"""
+
+        users, Address, addresses, User = (
+            self.tables.users,
+            self.classes.Address,
+            self.tables.addresses,
+            self.classes.User,
+        )
+
+        if argtype.str:
+            abp, ubp = "user", "addresses"
+        elif argtype.callable_str:
+            abp, ubp = lambda: "user", lambda: "addresses"
+        elif argtype.prop:
+            abp, ubp = lambda: "user", lambda: "addresses"
+        elif argtype.callable_prop:
+            abp, ubp = lambda: Address.user, lambda: User.addresses
+        else:
+            argtype.fail()
+
+        self.mapper_registry.map_imperatively(
+            User,
+            users,
+            properties={
+                "addresses": relationship(Address, back_populates=abp)
+            },
+        )
+
+        if argtype.prop:
+            ubp = User.addresses
+
+        self.mapper_registry.map_imperatively(
+            Address,
+            addresses,
+            properties={"user": relationship(User, back_populates=ubp)},
+        )
+
+        sess = fixture_session()
+
+        u1 = User(name="u1")
+        a1 = Address(email_address="foo")
+        u1.addresses.append(a1)
+        assert a1.user is u1
+
+        sess.add(u1)
+        sess.flush()
+        sess.expire_all()
+        assert sess.query(Address).one() is a1
+        assert a1.user is u1
+        assert a1 in u1.addresses
+
+    @testing.variation("argtype", ["plain", "callable"])
+    def test_invalid_backref_type(self, argtype):
+        """test #10050"""
+
+        users, Address, addresses, User = (
+            self.tables.users,
+            self.classes.Address,
+            self.tables.addresses,
+            self.classes.User,
+        )
+
+        if argtype.plain:
+            abp, ubp = object(), "addresses"
+        elif argtype.callable:
+            abp, ubp = lambda: object(), lambda: "addresses"
+        else:
+            argtype.fail()
+
+        self.mapper_registry.map_imperatively(
+            User,
+            users,
+            properties={
+                "addresses": relationship(Address, back_populates=abp)
+            },
+        )
+
+        self.mapper_registry.map_imperatively(
+            Address,
+            addresses,
+            properties={"user": relationship(User, back_populates=ubp)},
+        )
+
+        with expect_raises_message(
+            exc.ArgumentError, r"Invalid back_populates value: <object"
+        ):
+            self.mapper_registry.configure()
+
     def test_invalid_key(self):
         users, Address, addresses, User = (
             self.tables.users,