From: Priyanshu Parikh Date: Sun, 15 Oct 2023 14:34:25 +0000 (-0400) Subject: allow callable for relationship.back_populates X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=95464adc3c81827bd1c072674dc8c4e17463d8cb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git allow callable for relationship.back_populates The :paramref:`_orm.relationship.back_populates` argument to :func:`_orm.relationship` may now be passed as a Python callable, which resolves to either the direct linked ORM attribute, or a string value as before. ORM attributes are also accepted directly by :paramref:`_orm.relationship.back_populates`. This change allows type checkers and IDEs to confirm the argument for :paramref:`_orm.relationship.back_populates` is valid. Thanks to Priyanshu Parikh for the help on suggesting and helping to implement this feature. An attempt was made to do this for ForeignKey as well, however this is not feasible since there is no "deferred configuration" step for Table objects; Table objects set each other up on ForeignKey as they are created, such as setting the type of a column in a referencing Column when the referenced table is created. We have no way to know which Table a foreign key intends to represent when it's a callable whereas when it's a string, we do know, and we actually make a lot of use of that string to match it to the target table as that target is created (see _fk_memos). However the commit keeps a little bit of the cleanup to ForeignKey intact. Co-authored-by: Mike Bayer Fixes: #10050 Closes: #10260 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10260 Pull-request-sha: 6f21002e1d5bbe291111655f33b19e4eb4b3cb84 Change-Id: I8e0a40c9898ec91d44f2df06dcc22f33b06745c3 --- diff --git a/doc/build/changelog/migration_21.rst b/doc/build/changelog/migration_21.rst index 0795a3fe9f..8edea83839 100644 --- a/doc/build/changelog/migration_21.rst +++ b/doc/build/changelog/migration_21.rst @@ -34,3 +34,33 @@ need to be aware of this extra installation dependency. :ticket:`10197` + +.. _change_10050: + +ORM Relationship allows callable for back_populates +--------------------------------------------------- + +To help produce code that is more amenable to IDE-level linting and type +checking, the :paramref:`_orm.relationship.back_populates` parameter now +accepts both direct references to a class-bound attribute as well as +lambdas which do the same:: + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + # use a lambda: to link to B.a directly when it exists + bs: Mapped[list[B]] = relationship(back_populates=lambda: B.a) + + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + + # A.bs already exists, so can link directly + a: Mapped[A] = relationship(back_populates=A.bs) + +:ticket:`10050` + diff --git a/doc/build/changelog/unreleased_21/10050.rst b/doc/build/changelog/unreleased_21/10050.rst new file mode 100644 index 0000000000..a1c1753a1c --- /dev/null +++ b/doc/build/changelog/unreleased_21/10050.rst @@ -0,0 +1,17 @@ +.. change:: + :tags: feature, orm + :tickets: 10050 + + The :paramref:`_orm.relationship.back_populates` argument to + :func:`_orm.relationship` may now be passed as a Python callable, which + resolves to either the direct linked ORM attribute, or a string value as + before. ORM attributes are also accepted directly by + :paramref:`_orm.relationship.back_populates`. This change allows type + checkers and IDEs to confirm the argument for + :paramref:`_orm.relationship.back_populates` is valid. Thanks to Priyanshu + Parikh for the help on suggesting and helping to implement this feature. + + .. seealso:: + + :ref:`change_10050` + diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index ba9bb516f8..3a7f826e1d 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -28,6 +28,7 @@ from .properties import MappedColumn from .properties import MappedSQLExpression from .query import AliasOption from .relationships import _RelationshipArgumentType +from .relationships import _RelationshipBackPopulatesArgument from .relationships import _RelationshipSecondaryArgument from .relationships import Relationship from .relationships import RelationshipProperty @@ -922,7 +923,7 @@ def relationship( ] = None, primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, - back_populates: Optional[str] = None, + back_populates: Optional[_RelationshipBackPopulatesArgument] = None, order_by: _ORMOrderByArgument = False, backref: Optional[ORMBackrefArgument] = None, overlaps: Optional[str] = None, diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 30cbec96a1..58b413bed9 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -176,6 +176,13 @@ _ORMOrderByArgument = Union[ Callable[[], Iterable[_ColumnExpressionArgument[Any]]], Iterable[Union[str, _ColumnExpressionArgument[Any]]], ] +_RelationshipBackPopulatesArgument = Union[ + str, + PropComparator[Any], + Callable[[], Union[str, PropComparator[Any]]], +] + + ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] _ORMColCollectionElement = Union[ @@ -273,10 +280,32 @@ class _RelationshipArg(Generic[_T1, _T2]): else: self.resolved = attr_value + def effective_value(self) -> Any: + if self.resolved is not None: + return self.resolved + else: + return self.argument + _RelationshipOrderByArg = Union[Literal[False], Tuple[ColumnElement[Any], ...]] +@dataclasses.dataclass +class _StringRelationshipArg(_RelationshipArg[_T1, _T2]): + def _resolve_against_registry( + self, clsregistry_resolver: Callable[[str, bool], _class_resolver] + ) -> None: + attr_value = self.argument + + if callable(attr_value): + attr_value = attr_value() + + if isinstance(attr_value, attributes.QueryableAttribute): + attr_value = attr_value.key # type: ignore + + self.resolved = attr_value + + class _RelationshipArgs(NamedTuple): """stores user-passed parameters that are resolved at mapper configuration time. @@ -302,6 +331,9 @@ class _RelationshipArgs(NamedTuple): remote_side: _RelationshipArg[ Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]] ] + back_populates: _StringRelationshipArg[ + Optional[_RelationshipBackPopulatesArgument], str + ] @log.class_logger @@ -372,7 +404,7 @@ class RelationshipProperty( ] = None, primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, - back_populates: Optional[str] = None, + back_populates: Optional[_RelationshipBackPopulatesArgument] = None, order_by: _ORMOrderByArgument = False, backref: Optional[ORMBackrefArgument] = None, overlaps: Optional[str] = None, @@ -417,6 +449,7 @@ class RelationshipProperty( _RelationshipArg("order_by", order_by, None), _RelationshipArg("foreign_keys", foreign_keys, None), _RelationshipArg("remote_side", remote_side, None), + _StringRelationshipArg("back_populates", back_populates, None), ) self.post_update = post_update @@ -487,9 +520,7 @@ class RelationshipProperty( # mypy ignoring the @property setter self.cascade = cascade # type: ignore - self.back_populates = back_populates - - if self.back_populates: + if back_populates: if backref: raise sa_exc.ArgumentError( "backref and back_populates keyword arguments " @@ -499,6 +530,14 @@ class RelationshipProperty( else: self.backref = backref + @property + def back_populates(self) -> str: + return self._init_args.back_populates.effective_value() # type: ignore + + @back_populates.setter + def back_populates(self, value: str) -> None: + self._init_args.back_populates.argument = value + def _warn_for_persistence_only_flags(self, **kw: Any) -> None: for k, v in kw.items(): if v != self._persistence_only[k]: @@ -1672,6 +1711,7 @@ class RelationshipProperty( "secondary", "foreign_keys", "remote_side", + "back_populates", ): rel_arg = getattr(init_args, attr) @@ -2054,7 +2094,10 @@ class RelationshipProperty( if self.parent.non_primary: return - if self.backref is not None and not self.back_populates: + + resolve_back_populates = self._init_args.back_populates.resolved + + if self.backref is not None and not resolve_back_populates: kwargs: Dict[str, Any] if isinstance(self.backref, str): backref_key, kwargs = self.backref, {} @@ -2125,8 +2168,18 @@ class RelationshipProperty( backref_key, relationship, warn_for_existing=True ) - if self.back_populates: - self._add_reverse_property(self.back_populates) + if resolve_back_populates: + if isinstance(resolve_back_populates, PropComparator): + back_populates = resolve_back_populates.prop.key + elif isinstance(resolve_back_populates, str): + back_populates = resolve_back_populates + else: + # need test coverage for this case as well + raise sa_exc.ArgumentError( + f"Invalid back_populates value: {resolve_back_populates!r}" + ) + + self._add_reverse_property(back_populates) @util.preload_module("sqlalchemy.orm.dependency") def _post_init(self) -> None: diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index f5f6fb1775..7c3e58b4bc 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -266,6 +266,8 @@ used for :class:`.PrimaryKeyConstraint`, :class:`.UniqueConstraint`, etc. """ +_DDLColumnReferenceArgument = _DDLColumnArgument + _DMLTableArgument = Union[ "TableClause", "Join", diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 78586937b1..7d3d1f521e 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -92,6 +92,7 @@ from ..util.typing import TypeGuard if typing.TYPE_CHECKING: from ._typing import _AutoIncrementType from ._typing import _DDLColumnArgument + from ._typing import _DDLColumnReferenceArgument from ._typing import _InfoType from ._typing import _TextCoercedExpressionArgument from ._typing import _TypeEngineArgument @@ -2768,9 +2769,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): _table_column: Optional[Column[Any]] + _colspec: Union[str, Column[Any]] + def __init__( self, - column: _DDLColumnArgument, + column: _DDLColumnReferenceArgument, _constraint: Optional[ForeignKeyConstraint] = None, use_alter: bool = False, name: _ConstraintNameArgument = None, @@ -2856,21 +2859,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): """ - self._colspec = coercions.expect(roles.DDLReferredColumnRole, column) self._unresolvable = _unresolvable - if isinstance(self._colspec, str): - self._table_column = None - else: - self._table_column = self._colspec - - if not isinstance( - self._table_column.table, (type(None), TableClause) - ): - raise exc.ArgumentError( - "ForeignKey received Column not bound " - "to a Table, got: %r" % self._table_column.table - ) + self._colspec, self._table_column = self._parse_colspec_argument( + column + ) # the linked ForeignKeyConstraint. # ForeignKey will create this when parent Column @@ -2895,6 +2888,33 @@ class ForeignKey(DialectKWArgs, SchemaItem): self.info = info self._unvalidated_dialect_kw = dialect_kw + def _resolve_colspec_argument( + self, + ) -> Tuple[Union[str, Column[Any]], Optional[Column[Any]],]: + argument = self._colspec + + return self._parse_colspec_argument(argument) + + def _parse_colspec_argument( + self, + argument: _DDLColumnArgument, + ) -> Tuple[Union[str, Column[Any]], Optional[Column[Any]],]: + _colspec = coercions.expect(roles.DDLReferredColumnRole, argument) + + if isinstance(_colspec, str): + _table_column = None + else: + assert isinstance(_colspec, ColumnClause) + _table_column = _colspec + + if not isinstance(_table_column.table, (type(None), TableClause)): + raise exc.ArgumentError( + "ForeignKey received Column not bound " + "to a Table, got: %r" % _table_column.table + ) + + return _colspec, _table_column + def __repr__(self) -> str: return "ForeignKey(%r)" % self._get_colspec() @@ -2954,6 +2974,9 @@ class ForeignKey(DialectKWArgs, SchemaItem): argument first passed to the object's constructor. """ + + _colspec, effective_table_column = self._resolve_colspec_argument() + if schema not in (None, RETAIN_SCHEMA): _schema, tname, colname = self._column_tokens if table_name is not None: @@ -2968,28 +2991,30 @@ class ForeignKey(DialectKWArgs, SchemaItem): return "%s.%s.%s" % (schema, table_name, colname) else: return "%s.%s" % (table_name, colname) - elif self._table_column is not None: - if self._table_column.table is None: + elif effective_table_column is not None: + if effective_table_column.table is None: if _is_copy: raise exc.InvalidRequestError( f"Can't copy ForeignKey object which refers to " - f"non-table bound Column {self._table_column!r}" + f"non-table bound Column {effective_table_column!r}" ) else: - return self._table_column.key + return effective_table_column.key return "%s.%s" % ( - self._table_column.table.fullname, - self._table_column.key, + effective_table_column.table.fullname, + effective_table_column.key, ) else: - assert isinstance(self._colspec, str) - return self._colspec + assert isinstance(_colspec, str) + return _colspec @property def _referred_schema(self) -> Optional[str]: return self._column_tokens[0] - def _table_key(self) -> Any: + def _table_key_within_construction(self) -> Any: + """get the table key but only safely""" + if self._table_column is not None: if self._table_column.table is None: return None @@ -3028,10 +3053,6 @@ class ForeignKey(DialectKWArgs, SchemaItem): """parse a string-based _colspec into its component parts.""" m = self._get_colspec().split(".") - if m is None: - raise exc.ArgumentError( - f"Invalid foreign key column specification: {self._colspec}" - ) if len(m) == 1: tname = m.pop() colname = None @@ -3121,7 +3142,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): if _column is None: raise exc.NoReferencedColumnError( "Could not initialize target column " - f"for ForeignKey '{self._colspec}' " + f"for ForeignKey '{self._get_colspec()}' " f"on table '{parenttable.name}': " f"table '{table.name}' has no column named '{key}'", table.name, @@ -3157,7 +3178,6 @@ class ForeignKey(DialectKWArgs, SchemaItem): is raised. """ - return self._resolve_column() @overload @@ -3175,7 +3195,9 @@ class ForeignKey(DialectKWArgs, SchemaItem): ) -> Optional[Column[Any]]: _column: Column[Any] - if isinstance(self._colspec, str): + _colspec, effective_table_column = self._resolve_colspec_argument() + + if isinstance(_colspec, str): parenttable, tablekey, colname = self._resolve_col_tokens() if self._unresolvable or tablekey not in parenttable.metadata: @@ -3201,11 +3223,12 @@ class ForeignKey(DialectKWArgs, SchemaItem): parenttable, table, colname ) - elif hasattr(self._colspec, "__clause_element__"): - _column = self._colspec.__clause_element__() + elif hasattr(_colspec, "__clause_element__"): + _column = _colspec.__clause_element__() return _column else: - _column = self._colspec + assert isinstance(_colspec, Column) + _column = _colspec return _column def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: @@ -3257,7 +3280,9 @@ class ForeignKey(DialectKWArgs, SchemaItem): table.foreign_keys.add(self) # set up remote ".column" attribute, or a note to pick it # up when the other Table/Column shows up - if isinstance(self._colspec, str): + + _colspec, _ = self._resolve_colspec_argument() + if isinstance(_colspec, str): parenttable, table_key, colname = self._resolve_col_tokens() fk_key = (table_key, colname) if table_key in parenttable.metadata.tables: @@ -3273,12 +3298,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): self._set_target_column(_column) parenttable.metadata._fk_memos[fk_key].append(self) - elif hasattr(self._colspec, "__clause_element__"): - _column = self._colspec.__clause_element__() + elif hasattr(_colspec, "__clause_element__"): + _column = _colspec.__clause_element__() self._set_target_column(_column) else: - _column = self._colspec - self._set_target_column(_column) + self._set_target_column(_colspec) if TYPE_CHECKING: @@ -4603,7 +4627,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): def __init__( self, columns: _typing_Sequence[_DDLColumnArgument], - refcolumns: _typing_Sequence[_DDLColumnArgument], + refcolumns: _typing_Sequence[_DDLColumnReferenceArgument], name: _ConstraintNameArgument = None, onupdate: Optional[str] = None, ondelete: Optional[str] = None, @@ -4789,7 +4813,9 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): return self.elements[0].column.table def _validate_dest_table(self, table: Table) -> None: - table_keys = {elem._table_key() for elem in self.elements} + table_keys = { + elem._table_key_within_construction() for elem in self.elements + } if None not in table_keys and len(table_keys) > 1: elem0, elem1 = sorted(table_keys)[0:2] raise exc.ArgumentError( @@ -4862,7 +4888,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): schema=schema, table_name=target_table.name if target_table is not None - and x._table_key() == x.parent.table.key + and x._table_key_within_construction() + == x.parent.table.key else None, _is_copy=True, ) diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 969196ad8c..d644d26793 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -2297,6 +2297,98 @@ class ManualBackrefTest(_fixtures.FixtureTest): assert a1.user is u1 assert a1 in u1.addresses + @testing.variation( + "argtype", ["str", "callable_str", "prop", "callable_prop"] + ) + def test_o2m_with_callable(self, argtype): + """test #10050""" + + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + if argtype.str: + abp, ubp = "user", "addresses" + elif argtype.callable_str: + abp, ubp = lambda: "user", lambda: "addresses" + elif argtype.prop: + abp, ubp = lambda: "user", lambda: "addresses" + elif argtype.callable_prop: + abp, ubp = lambda: Address.user, lambda: User.addresses + else: + argtype.fail() + + self.mapper_registry.map_imperatively( + User, + users, + properties={ + "addresses": relationship(Address, back_populates=abp) + }, + ) + + if argtype.prop: + ubp = User.addresses + + self.mapper_registry.map_imperatively( + Address, + addresses, + properties={"user": relationship(User, back_populates=ubp)}, + ) + + sess = fixture_session() + + u1 = User(name="u1") + a1 = Address(email_address="foo") + u1.addresses.append(a1) + assert a1.user is u1 + + sess.add(u1) + sess.flush() + sess.expire_all() + assert sess.query(Address).one() is a1 + assert a1.user is u1 + assert a1 in u1.addresses + + @testing.variation("argtype", ["plain", "callable"]) + def test_invalid_backref_type(self, argtype): + """test #10050""" + + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + if argtype.plain: + abp, ubp = object(), "addresses" + elif argtype.callable: + abp, ubp = lambda: object(), lambda: "addresses" + else: + argtype.fail() + + self.mapper_registry.map_imperatively( + User, + users, + properties={ + "addresses": relationship(Address, back_populates=abp) + }, + ) + + self.mapper_registry.map_imperatively( + Address, + addresses, + properties={"user": relationship(User, back_populates=ubp)}, + ) + + with expect_raises_message( + exc.ArgumentError, r"Invalid back_populates value: