From 6f21002e1d5bbe291111655f33b19e4eb4b3cb84 Mon Sep 17 00:00:00 2001 From: Priyanshu Parikh Date: Wed, 11 Oct 2023 18:18:03 +0530 Subject: [PATCH] Added use callable for relationship.back_populates and ForeignKey.column --- lib/sqlalchemy/orm/_orm_constructors.py | 3 +- lib/sqlalchemy/orm/relationships.py | 43 +++++++++++++++++++--- lib/sqlalchemy/sql/schema.py | 47 ++++++++++++++++++------- test/orm/test_relationships.py | 42 ++++++++++++++++++++++ test/sql/test_metadata.py | 18 ++++++++++ 5 files changed, 135 insertions(+), 18 deletions(-) diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index df36c38641..376607ace0 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -28,6 +28,7 @@ from .properties import MappedColumn from .properties import MappedSQLExpression from .query import AliasOption from .relationships import _RelationshipArgumentType +from .relationships import _RelationshipBackPopulatesArgument from .relationships import _RelationshipSecondaryArgument from .relationships import Relationship from .relationships import RelationshipProperty @@ -917,7 +918,7 @@ def relationship( ] = None, primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, - back_populates: Optional[str] = None, + back_populates: Optional[_RelationshipBackPopulatesArgument] = None, order_by: _ORMOrderByArgument = False, backref: Optional[ORMBackrefArgument] = None, overlaps: Optional[str] = None, diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 7ea30d7b18..d63f182bfb 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -176,6 +176,11 @@ _ORMOrderByArgument = Union[ Callable[[], Iterable[_ColumnExpressionArgument[Any]]], Iterable[Union[str, _ColumnExpressionArgument[Any]]], ] +_RelationshipBackPopulatesArgument = Union[ + str, + PropComparator[Any], + Callable[[], "_RelationshipBackPopulatesArgument"], +] ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] _ORMColCollectionElement = Union[ @@ -274,6 +279,18 @@ class _RelationshipArg(Generic[_T1, _T2]): _RelationshipOrderByArg = Union[Literal[False], Tuple[ColumnElement[Any], ...]] +@dataclasses.dataclass +class _StringRelationshipArg(_RelationshipArg[_T1, _T2]): + def _resolve_against_registry( + self, clsregistry_resolver: Callable[[str, bool], _class_resolver] + ) -> None: + attr_value = self.argument + + if callable(attr_value): + self.resolved = attr_value() + else: + self.resolved = attr_value + class _RelationshipArgs(NamedTuple): """stores user-passed parameters that are resolved at mapper configuration time. @@ -299,6 +316,9 @@ class _RelationshipArgs(NamedTuple): remote_side: _RelationshipArg[ Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]] ] + back_populates: _StringRelationshipArg[ + Optional[_RelationshipBackPopulatesArgument], str + ] @log.class_logger @@ -369,7 +389,7 @@ class RelationshipProperty( ] = None, primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, - back_populates: Optional[str] = None, + back_populates: Optional[_RelationshipBackPopulatesArgument] = None, order_by: _ORMOrderByArgument = False, backref: Optional[ORMBackrefArgument] = None, overlaps: Optional[str] = None, @@ -414,6 +434,7 @@ class RelationshipProperty( _RelationshipArg("order_by", order_by, None), _RelationshipArg("foreign_keys", foreign_keys, None), _RelationshipArg("remote_side", remote_side, None), + _StringRelationshipArg("back_populates", back_populates, None), ) self.post_update = post_update @@ -1669,6 +1690,7 @@ class RelationshipProperty( "secondary", "foreign_keys", "remote_side", + "back_populates", ): rel_arg = getattr(init_args, attr) @@ -2054,7 +2076,10 @@ class RelationshipProperty( if self.parent.non_primary: return - if self.backref is not None and not self.back_populates: + + resolve_back_populates = self._init_args.back_populates.resolved + + if self.backref is not None and not resolve_back_populates: kwargs: Dict[str, Any] if isinstance(self.backref, str): backref_key, kwargs = self.backref, {} @@ -2125,8 +2150,18 @@ class RelationshipProperty( backref_key, relationship, warn_for_existing=True ) - if self.back_populates: - self._add_reverse_property(self.back_populates) + if resolve_back_populates: + if isinstance(resolve_back_populates, PropComparator): + back_populates = resolve_back_populates.prop.key + elif isinstance(resolve_back_populates, str): + back_populates = resolve_back_populates + else: + # need test coverage for this case as well + raise sa_exc.ArgumentError( + r"Invalid back_populates value: {resolve_back_populates!r}" + ) + + self._add_reverse_property(back_populates) @util.preload_module("sqlalchemy.orm.dependency") def _post_init(self) -> None: diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index ca389a9a71..91885991e8 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2766,9 +2766,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): _table_column: Optional[Column[Any]] + _colspec: Union[str, Column[Any], Callable[[], _DDLColumnArgument]] + def __init__( self, - column: _DDLColumnArgument, + column: Union[_DDLColumnArgument, Callable[[], _DDLColumnArgument]], _constraint: Optional[ForeignKeyConstraint] = None, use_alter: bool = False, name: _ConstraintNameArgument = None, @@ -2854,12 +2856,18 @@ class ForeignKey(DialectKWArgs, SchemaItem): """ - self._colspec = coercions.expect(roles.DDLReferredColumnRole, column) + if not callable(column): + self._colspec = coercions.expect( + roles.DDLReferredColumnRole, column + ) + else: + self._colspec = column self._unresolvable = _unresolvable - if isinstance(self._colspec, str): + if isinstance(self._colspec, str) or callable(self._colspec): self._table_column = None else: + assert isinstance(self._colspec, ColumnClause) self._table_column = self._colspec if not isinstance( @@ -2952,6 +2960,16 @@ class ForeignKey(DialectKWArgs, SchemaItem): argument first passed to the object's constructor. """ + + effective_table_column: Optional[Column[Any]] + + if callable(self._colspec): + resolved = self._colspec() + column = coercions.expect(roles.DDLReferredColumnRole, resolved) + effective_table_column = column + else: + effective_table_column = self._table_column + if schema not in (None, RETAIN_SCHEMA): _schema, tname, colname = self._column_tokens if table_name is not None: @@ -2966,18 +2984,18 @@ class ForeignKey(DialectKWArgs, SchemaItem): return "%s.%s.%s" % (schema, table_name, colname) else: return "%s.%s" % (table_name, colname) - elif self._table_column is not None: - if self._table_column.table is None: + elif effective_table_column is not None: + if effective_table_column.table is None: if _is_copy: raise exc.InvalidRequestError( f"Can't copy ForeignKey object which refers to " - f"non-table bound Column {self._table_column!r}" + f"non-table bound Column {effective_table_column!r}" ) else: - return self._table_column.key + return effective_table_column.key return "%s.%s" % ( - self._table_column.table.fullname, - self._table_column.key, + effective_table_column.table.fullname, + effective_table_column.key, ) else: assert isinstance(self._colspec, str) @@ -3203,6 +3221,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): _column = self._colspec.__clause_element__() return _column else: + assert isinstance(self._colspec, Column) _column = self._colspec return _column @@ -3255,7 +3274,9 @@ class ForeignKey(DialectKWArgs, SchemaItem): table.foreign_keys.add(self) # set up remote ".column" attribute, or a note to pick it # up when the other Table/Column shows up - if isinstance(self._colspec, str): + colspec = self._get_colspec() + + if isinstance(colspec, str): parenttable, table_key, colname = self._resolve_col_tokens() fk_key = (table_key, colname) if table_key in parenttable.metadata.tables: @@ -3271,11 +3292,11 @@ class ForeignKey(DialectKWArgs, SchemaItem): self._set_target_column(_column) parenttable.metadata._fk_memos[fk_key].append(self) - elif hasattr(self._colspec, "__clause_element__"): - _column = self._colspec.__clause_element__() + elif hasattr(colspec, "__clause_element__"): + _column = colspec.__clause_element__() self._set_target_column(_column) else: - _column = self._colspec + _column = colspec self._set_target_column(_column) diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index d6b886be15..d04027373f 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -2296,6 +2296,48 @@ class ManualBackrefTest(_fixtures.FixtureTest): assert a1.user is u1 assert a1 in u1.addresses + def test_o2m_with_callable(self): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + self.mapper_registry.map_imperatively( + User, + users, + properties={ + "addresses": relationship( + Address, back_populates=lambda: Address.user + ) + }, + ) + + self.mapper_registry.map_imperatively( + Address, + addresses, + properties={ + "user": relationship( + User, back_populates=lambda: User.addresses + ) + }, + ) + + sess = fixture_session() + + u1 = User(name="u1") + a1 = Address(email_address="foo") + u1.addresses.append(a1) + assert a1.user is u1 + + sess.add(u1) + sess.flush() + sess.expire_all() + assert sess.query(Address).one() is a1 + assert a1.user is u1 + assert a1 in u1.addresses + def test_invalid_key(self): users, Address, addresses, User = ( self.tables.users, diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 0b35adc1cc..5776210ef0 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -469,6 +469,24 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): ["b.a", "b.b"], ) + def test_fk_callable(self): + meta = MetaData() + + a = Table( + "a", + meta, + Column("id", Integer, primary_key=True), + ) + + b = Table( + "b", + meta, + Column("id", Integer, primary_key=True), + Column("a_id", ForeignKey(lambda: a.c.id), nullable=False), + ) + + assert b.c.a_id.references(a.c.id) + def test_pickle_metadata_sequence_restated(self): m1 = MetaData() Table( -- 2.47.3