:ticket:`10197`
+
+.. _change_10050:
+
+ORM Relationship allows callable for back_populates
+---------------------------------------------------
+
+To help produce code that is more amenable to IDE-level linting and type
+checking, the :paramref:`_orm.relationship.back_populates` parameter now
+accepts both direct references to a class-bound attribute as well as
+lambdas which do the same::
+
+ class A(Base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ # use a lambda: to link to B.a directly when it exists
+ bs: Mapped[list[B]] = relationship(back_populates=lambda: B.a)
+
+
+ class B(Base):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(primary_key=True)
+ a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+ # A.bs already exists, so can link directly
+ a: Mapped[A] = relationship(back_populates=A.bs)
+
+:ticket:`10050`
+
--- /dev/null
+.. change::
+ :tags: feature, orm
+ :tickets: 10050
+
+ The :paramref:`_orm.relationship.back_populates` argument to
+ :func:`_orm.relationship` may now be passed as a Python callable, which
+ resolves to either the direct linked ORM attribute, or a string value as
+ before. ORM attributes are also accepted directly by
+ :paramref:`_orm.relationship.back_populates`. This change allows type
+ checkers and IDEs to confirm the argument for
+ :paramref:`_orm.relationship.back_populates` is valid. Thanks to Priyanshu
+ Parikh for the help on suggesting and helping to implement this feature.
+
+ .. seealso::
+
+ :ref:`change_10050`
+
from .properties import MappedSQLExpression
from .query import AliasOption
from .relationships import _RelationshipArgumentType
+from .relationships import _RelationshipBackPopulatesArgument
from .relationships import _RelationshipSecondaryArgument
from .relationships import Relationship
from .relationships import RelationshipProperty
] = None,
primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
- back_populates: Optional[str] = None,
+ back_populates: Optional[_RelationshipBackPopulatesArgument] = None,
order_by: _ORMOrderByArgument = False,
backref: Optional[ORMBackrefArgument] = None,
overlaps: Optional[str] = None,
Callable[[], Iterable[_ColumnExpressionArgument[Any]]],
Iterable[Union[str, _ColumnExpressionArgument[Any]]],
]
+_RelationshipBackPopulatesArgument = Union[
+ str,
+ PropComparator[Any],
+ Callable[[], Union[str, PropComparator[Any]]],
+]
+
+
ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
_ORMColCollectionElement = Union[
else:
self.resolved = attr_value
+ def effective_value(self) -> Any:
+ if self.resolved is not None:
+ return self.resolved
+ else:
+ return self.argument
+
_RelationshipOrderByArg = Union[Literal[False], Tuple[ColumnElement[Any], ...]]
+@dataclasses.dataclass
+class _StringRelationshipArg(_RelationshipArg[_T1, _T2]):
+ def _resolve_against_registry(
+ self, clsregistry_resolver: Callable[[str, bool], _class_resolver]
+ ) -> None:
+ attr_value = self.argument
+
+ if callable(attr_value):
+ attr_value = attr_value()
+
+ if isinstance(attr_value, attributes.QueryableAttribute):
+ attr_value = attr_value.key # type: ignore
+
+ self.resolved = attr_value
+
+
class _RelationshipArgs(NamedTuple):
"""stores user-passed parameters that are resolved at mapper configuration
time.
remote_side: _RelationshipArg[
Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
]
+ back_populates: _StringRelationshipArg[
+ Optional[_RelationshipBackPopulatesArgument], str
+ ]
@log.class_logger
] = None,
primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
- back_populates: Optional[str] = None,
+ back_populates: Optional[_RelationshipBackPopulatesArgument] = None,
order_by: _ORMOrderByArgument = False,
backref: Optional[ORMBackrefArgument] = None,
overlaps: Optional[str] = None,
_RelationshipArg("order_by", order_by, None),
_RelationshipArg("foreign_keys", foreign_keys, None),
_RelationshipArg("remote_side", remote_side, None),
+ _StringRelationshipArg("back_populates", back_populates, None),
)
self.post_update = post_update
# mypy ignoring the @property setter
self.cascade = cascade # type: ignore
- self.back_populates = back_populates
-
- if self.back_populates:
+ if back_populates:
if backref:
raise sa_exc.ArgumentError(
"backref and back_populates keyword arguments "
else:
self.backref = backref
+ @property
+ def back_populates(self) -> str:
+ return self._init_args.back_populates.effective_value() # type: ignore
+
+ @back_populates.setter
+ def back_populates(self, value: str) -> None:
+ self._init_args.back_populates.argument = value
+
def _warn_for_persistence_only_flags(self, **kw: Any) -> None:
for k, v in kw.items():
if v != self._persistence_only[k]:
"secondary",
"foreign_keys",
"remote_side",
+ "back_populates",
):
rel_arg = getattr(init_args, attr)
if self.parent.non_primary:
return
- if self.backref is not None and not self.back_populates:
+
+ resolve_back_populates = self._init_args.back_populates.resolved
+
+ if self.backref is not None and not resolve_back_populates:
kwargs: Dict[str, Any]
if isinstance(self.backref, str):
backref_key, kwargs = self.backref, {}
backref_key, relationship, warn_for_existing=True
)
- if self.back_populates:
- self._add_reverse_property(self.back_populates)
+ if resolve_back_populates:
+ if isinstance(resolve_back_populates, PropComparator):
+ back_populates = resolve_back_populates.prop.key
+ elif isinstance(resolve_back_populates, str):
+ back_populates = resolve_back_populates
+ else:
+ # need test coverage for this case as well
+ raise sa_exc.ArgumentError(
+ f"Invalid back_populates value: {resolve_back_populates!r}"
+ )
+
+ self._add_reverse_property(back_populates)
@util.preload_module("sqlalchemy.orm.dependency")
def _post_init(self) -> None:
"""
+_DDLColumnReferenceArgument = _DDLColumnArgument
+
_DMLTableArgument = Union[
"TableClause",
"Join",
if typing.TYPE_CHECKING:
from ._typing import _AutoIncrementType
from ._typing import _DDLColumnArgument
+ from ._typing import _DDLColumnReferenceArgument
from ._typing import _InfoType
from ._typing import _TextCoercedExpressionArgument
from ._typing import _TypeEngineArgument
_table_column: Optional[Column[Any]]
+ _colspec: Union[str, Column[Any]]
+
def __init__(
self,
- column: _DDLColumnArgument,
+ column: _DDLColumnReferenceArgument,
_constraint: Optional[ForeignKeyConstraint] = None,
use_alter: bool = False,
name: _ConstraintNameArgument = None,
"""
- self._colspec = coercions.expect(roles.DDLReferredColumnRole, column)
self._unresolvable = _unresolvable
- if isinstance(self._colspec, str):
- self._table_column = None
- else:
- self._table_column = self._colspec
-
- if not isinstance(
- self._table_column.table, (type(None), TableClause)
- ):
- raise exc.ArgumentError(
- "ForeignKey received Column not bound "
- "to a Table, got: %r" % self._table_column.table
- )
+ self._colspec, self._table_column = self._parse_colspec_argument(
+ column
+ )
# the linked ForeignKeyConstraint.
# ForeignKey will create this when parent Column
self.info = info
self._unvalidated_dialect_kw = dialect_kw
+ def _resolve_colspec_argument(
+ self,
+ ) -> Tuple[Union[str, Column[Any]], Optional[Column[Any]],]:
+ argument = self._colspec
+
+ return self._parse_colspec_argument(argument)
+
+ def _parse_colspec_argument(
+ self,
+ argument: _DDLColumnArgument,
+ ) -> Tuple[Union[str, Column[Any]], Optional[Column[Any]],]:
+ _colspec = coercions.expect(roles.DDLReferredColumnRole, argument)
+
+ if isinstance(_colspec, str):
+ _table_column = None
+ else:
+ assert isinstance(_colspec, ColumnClause)
+ _table_column = _colspec
+
+ if not isinstance(_table_column.table, (type(None), TableClause)):
+ raise exc.ArgumentError(
+ "ForeignKey received Column not bound "
+ "to a Table, got: %r" % _table_column.table
+ )
+
+ return _colspec, _table_column
+
def __repr__(self) -> str:
return "ForeignKey(%r)" % self._get_colspec()
argument first passed to the object's constructor.
"""
+
+ _colspec, effective_table_column = self._resolve_colspec_argument()
+
if schema not in (None, RETAIN_SCHEMA):
_schema, tname, colname = self._column_tokens
if table_name is not None:
return "%s.%s.%s" % (schema, table_name, colname)
else:
return "%s.%s" % (table_name, colname)
- elif self._table_column is not None:
- if self._table_column.table is None:
+ elif effective_table_column is not None:
+ if effective_table_column.table is None:
if _is_copy:
raise exc.InvalidRequestError(
f"Can't copy ForeignKey object which refers to "
- f"non-table bound Column {self._table_column!r}"
+ f"non-table bound Column {effective_table_column!r}"
)
else:
- return self._table_column.key
+ return effective_table_column.key
return "%s.%s" % (
- self._table_column.table.fullname,
- self._table_column.key,
+ effective_table_column.table.fullname,
+ effective_table_column.key,
)
else:
- assert isinstance(self._colspec, str)
- return self._colspec
+ assert isinstance(_colspec, str)
+ return _colspec
@property
def _referred_schema(self) -> Optional[str]:
return self._column_tokens[0]
- def _table_key(self) -> Any:
+ def _table_key_within_construction(self) -> Any:
+ """get the table key but only safely"""
+
if self._table_column is not None:
if self._table_column.table is None:
return None
"""parse a string-based _colspec into its component parts."""
m = self._get_colspec().split(".")
- if m is None:
- raise exc.ArgumentError(
- f"Invalid foreign key column specification: {self._colspec}"
- )
if len(m) == 1:
tname = m.pop()
colname = None
if _column is None:
raise exc.NoReferencedColumnError(
"Could not initialize target column "
- f"for ForeignKey '{self._colspec}' "
+ f"for ForeignKey '{self._get_colspec()}' "
f"on table '{parenttable.name}': "
f"table '{table.name}' has no column named '{key}'",
table.name,
is raised.
"""
-
return self._resolve_column()
@overload
) -> Optional[Column[Any]]:
_column: Column[Any]
- if isinstance(self._colspec, str):
+ _colspec, effective_table_column = self._resolve_colspec_argument()
+
+ if isinstance(_colspec, str):
parenttable, tablekey, colname = self._resolve_col_tokens()
if self._unresolvable or tablekey not in parenttable.metadata:
parenttable, table, colname
)
- elif hasattr(self._colspec, "__clause_element__"):
- _column = self._colspec.__clause_element__()
+ elif hasattr(_colspec, "__clause_element__"):
+ _column = _colspec.__clause_element__()
return _column
else:
- _column = self._colspec
+ assert isinstance(_colspec, Column)
+ _column = _colspec
return _column
def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
table.foreign_keys.add(self)
# set up remote ".column" attribute, or a note to pick it
# up when the other Table/Column shows up
- if isinstance(self._colspec, str):
+
+ _colspec, _ = self._resolve_colspec_argument()
+ if isinstance(_colspec, str):
parenttable, table_key, colname = self._resolve_col_tokens()
fk_key = (table_key, colname)
if table_key in parenttable.metadata.tables:
self._set_target_column(_column)
parenttable.metadata._fk_memos[fk_key].append(self)
- elif hasattr(self._colspec, "__clause_element__"):
- _column = self._colspec.__clause_element__()
+ elif hasattr(_colspec, "__clause_element__"):
+ _column = _colspec.__clause_element__()
self._set_target_column(_column)
else:
- _column = self._colspec
- self._set_target_column(_column)
+ self._set_target_column(_colspec)
if TYPE_CHECKING:
def __init__(
self,
columns: _typing_Sequence[_DDLColumnArgument],
- refcolumns: _typing_Sequence[_DDLColumnArgument],
+ refcolumns: _typing_Sequence[_DDLColumnReferenceArgument],
name: _ConstraintNameArgument = None,
onupdate: Optional[str] = None,
ondelete: Optional[str] = None,
return self.elements[0].column.table
def _validate_dest_table(self, table: Table) -> None:
- table_keys = {elem._table_key() for elem in self.elements}
+ table_keys = {
+ elem._table_key_within_construction() for elem in self.elements
+ }
if None not in table_keys and len(table_keys) > 1:
elem0, elem1 = sorted(table_keys)[0:2]
raise exc.ArgumentError(
schema=schema,
table_name=target_table.name
if target_table is not None
- and x._table_key() == x.parent.table.key
+ and x._table_key_within_construction()
+ == x.parent.table.key
else None,
_is_copy=True,
)
assert a1.user is u1
assert a1 in u1.addresses
+ @testing.variation(
+ "argtype", ["str", "callable_str", "prop", "callable_prop"]
+ )
+ def test_o2m_with_callable(self, argtype):
+ """test #10050"""
+
+ users, Address, addresses, User = (
+ self.tables.users,
+ self.classes.Address,
+ self.tables.addresses,
+ self.classes.User,
+ )
+
+ if argtype.str:
+ abp, ubp = "user", "addresses"
+ elif argtype.callable_str:
+ abp, ubp = lambda: "user", lambda: "addresses"
+ elif argtype.prop:
+ abp, ubp = lambda: "user", lambda: "addresses"
+ elif argtype.callable_prop:
+ abp, ubp = lambda: Address.user, lambda: User.addresses
+ else:
+ argtype.fail()
+
+ self.mapper_registry.map_imperatively(
+ User,
+ users,
+ properties={
+ "addresses": relationship(Address, back_populates=abp)
+ },
+ )
+
+ if argtype.prop:
+ ubp = User.addresses
+
+ self.mapper_registry.map_imperatively(
+ Address,
+ addresses,
+ properties={"user": relationship(User, back_populates=ubp)},
+ )
+
+ sess = fixture_session()
+
+ u1 = User(name="u1")
+ a1 = Address(email_address="foo")
+ u1.addresses.append(a1)
+ assert a1.user is u1
+
+ sess.add(u1)
+ sess.flush()
+ sess.expire_all()
+ assert sess.query(Address).one() is a1
+ assert a1.user is u1
+ assert a1 in u1.addresses
+
+ @testing.variation("argtype", ["plain", "callable"])
+ def test_invalid_backref_type(self, argtype):
+ """test #10050"""
+
+ users, Address, addresses, User = (
+ self.tables.users,
+ self.classes.Address,
+ self.tables.addresses,
+ self.classes.User,
+ )
+
+ if argtype.plain:
+ abp, ubp = object(), "addresses"
+ elif argtype.callable:
+ abp, ubp = lambda: object(), lambda: "addresses"
+ else:
+ argtype.fail()
+
+ self.mapper_registry.map_imperatively(
+ User,
+ users,
+ properties={
+ "addresses": relationship(Address, back_populates=abp)
+ },
+ )
+
+ self.mapper_registry.map_imperatively(
+ Address,
+ addresses,
+ properties={"user": relationship(User, back_populates=ubp)},
+ )
+
+ with expect_raises_message(
+ exc.ArgumentError, r"Invalid back_populates value: <object"
+ ):
+ self.mapper_registry.configure()
+
def test_invalid_key(self):
users, Address, addresses, User = (
self.tables.users,