]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added use callable for relationship.back_populates and ForeignKey.column 10260/head
authorPriyanshu Parikh <parikhpriyanshu7@gmail.com>
Wed, 11 Oct 2023 12:48:03 +0000 (18:18 +0530)
committerPriyanshu Parikh <parikhpriyanshu7@gmail.com>
Wed, 11 Oct 2023 16:29:38 +0000 (21:59 +0530)
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/sql/schema.py
test/orm/test_relationships.py
test/sql/test_metadata.py

index df36c38641677b7f57334cc270fcdcf305361a09..376607ace0b379f77ffa2f1bb9071d5db3eedac7 100644 (file)
@@ -28,6 +28,7 @@ from .properties import MappedColumn
 from .properties import MappedSQLExpression
 from .query import AliasOption
 from .relationships import _RelationshipArgumentType
+from .relationships import _RelationshipBackPopulatesArgument
 from .relationships import _RelationshipSecondaryArgument
 from .relationships import Relationship
 from .relationships import RelationshipProperty
@@ -917,7 +918,7 @@ def relationship(
     ] = None,
     primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
     secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
-    back_populates: Optional[str] = None,
+    back_populates: Optional[_RelationshipBackPopulatesArgument] = None,
     order_by: _ORMOrderByArgument = False,
     backref: Optional[ORMBackrefArgument] = None,
     overlaps: Optional[str] = None,
index 7ea30d7b180ae0c46dc089dab4e3ac93afc2b455..d63f182bfb34caa04db3479d27cd12b296c6b6ac 100644 (file)
@@ -176,6 +176,11 @@ _ORMOrderByArgument = Union[
     Callable[[], Iterable[_ColumnExpressionArgument[Any]]],
     Iterable[Union[str, _ColumnExpressionArgument[Any]]],
 ]
+_RelationshipBackPopulatesArgument = Union[
+    str,
+    PropComparator[Any],
+    Callable[[], "_RelationshipBackPopulatesArgument"],
+]
 ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
 
 _ORMColCollectionElement = Union[
@@ -274,6 +279,18 @@ class _RelationshipArg(Generic[_T1, _T2]):
 _RelationshipOrderByArg = Union[Literal[False], Tuple[ColumnElement[Any], ...]]
 
 
+@dataclasses.dataclass
+class _StringRelationshipArg(_RelationshipArg[_T1, _T2]):
+    def _resolve_against_registry(
+        self, clsregistry_resolver: Callable[[str, bool], _class_resolver]
+    ) -> None:
+        attr_value = self.argument
+
+        if callable(attr_value):
+            self.resolved = attr_value()
+        else:
+            self.resolved = attr_value
+
 class _RelationshipArgs(NamedTuple):
     """stores user-passed parameters that are resolved at mapper configuration
     time.
@@ -299,6 +316,9 @@ class _RelationshipArgs(NamedTuple):
     remote_side: _RelationshipArg[
         Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]]
     ]
+    back_populates: _StringRelationshipArg[
+        Optional[_RelationshipBackPopulatesArgument], str
+    ]
 
 
 @log.class_logger
@@ -369,7 +389,7 @@ class RelationshipProperty(
         ] = None,
         primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
         secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
-        back_populates: Optional[str] = None,
+        back_populates: Optional[_RelationshipBackPopulatesArgument] = None,
         order_by: _ORMOrderByArgument = False,
         backref: Optional[ORMBackrefArgument] = None,
         overlaps: Optional[str] = None,
@@ -414,6 +434,7 @@ class RelationshipProperty(
             _RelationshipArg("order_by", order_by, None),
             _RelationshipArg("foreign_keys", foreign_keys, None),
             _RelationshipArg("remote_side", remote_side, None),
+            _StringRelationshipArg("back_populates", back_populates, None),
         )
 
         self.post_update = post_update
@@ -1669,6 +1690,7 @@ class RelationshipProperty(
             "secondary",
             "foreign_keys",
             "remote_side",
+            "back_populates",
         ):
             rel_arg = getattr(init_args, attr)
 
@@ -2054,7 +2076,10 @@ class RelationshipProperty(
 
         if self.parent.non_primary:
             return
-        if self.backref is not None and not self.back_populates:
+
+        resolve_back_populates = self._init_args.back_populates.resolved
+
+        if self.backref is not None and not resolve_back_populates:
             kwargs: Dict[str, Any]
             if isinstance(self.backref, str):
                 backref_key, kwargs = self.backref, {}
@@ -2125,8 +2150,18 @@ class RelationshipProperty(
                 backref_key, relationship, warn_for_existing=True
             )
 
-        if self.back_populates:
-            self._add_reverse_property(self.back_populates)
+        if resolve_back_populates:
+            if isinstance(resolve_back_populates, PropComparator):
+                back_populates = resolve_back_populates.prop.key
+            elif isinstance(resolve_back_populates, str):
+                back_populates = resolve_back_populates
+            else:
+                # need test coverage for this case as well
+                raise sa_exc.ArgumentError(
+                    r"Invalid back_populates value: {resolve_back_populates!r}"
+                )
+
+            self._add_reverse_property(back_populates)
 
     @util.preload_module("sqlalchemy.orm.dependency")
     def _post_init(self) -> None:
index ca389a9a71a8bd1b7b1217f7cf09743c529dfba7..91885991e8c7a20fccd060c2d5779610c9aa076d 100644 (file)
@@ -2766,9 +2766,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
 
     _table_column: Optional[Column[Any]]
 
+    _colspec: Union[str, Column[Any], Callable[[], _DDLColumnArgument]]
+
     def __init__(
         self,
-        column: _DDLColumnArgument,
+        column: Union[_DDLColumnArgument, Callable[[], _DDLColumnArgument]],
         _constraint: Optional[ForeignKeyConstraint] = None,
         use_alter: bool = False,
         name: _ConstraintNameArgument = None,
@@ -2854,12 +2856,18 @@ class ForeignKey(DialectKWArgs, SchemaItem):
 
         """
 
-        self._colspec = coercions.expect(roles.DDLReferredColumnRole, column)
+        if not callable(column):
+            self._colspec = coercions.expect(
+                roles.DDLReferredColumnRole, column
+            )
+        else:
+            self._colspec = column
         self._unresolvable = _unresolvable
 
-        if isinstance(self._colspec, str):
+        if isinstance(self._colspec, str) or callable(self._colspec):
             self._table_column = None
         else:
+            assert isinstance(self._colspec, ColumnClause)
             self._table_column = self._colspec
 
             if not isinstance(
@@ -2952,6 +2960,16 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         argument first passed to the object's constructor.
 
         """
+
+        effective_table_column: Optional[Column[Any]]
+
+        if callable(self._colspec):
+            resolved = self._colspec()
+            column = coercions.expect(roles.DDLReferredColumnRole, resolved)
+            effective_table_column = column
+        else:
+            effective_table_column = self._table_column
+
         if schema not in (None, RETAIN_SCHEMA):
             _schema, tname, colname = self._column_tokens
             if table_name is not None:
@@ -2966,18 +2984,18 @@ class ForeignKey(DialectKWArgs, SchemaItem):
                 return "%s.%s.%s" % (schema, table_name, colname)
             else:
                 return "%s.%s" % (table_name, colname)
-        elif self._table_column is not None:
-            if self._table_column.table is None:
+        elif effective_table_column is not None:
+            if effective_table_column.table is None:
                 if _is_copy:
                     raise exc.InvalidRequestError(
                         f"Can't copy ForeignKey object which refers to "
-                        f"non-table bound Column {self._table_column!r}"
+                        f"non-table bound Column {effective_table_column!r}"
                     )
                 else:
-                    return self._table_column.key
+                    return effective_table_column.key
             return "%s.%s" % (
-                self._table_column.table.fullname,
-                self._table_column.key,
+                effective_table_column.table.fullname,
+                effective_table_column.key,
             )
         else:
             assert isinstance(self._colspec, str)
@@ -3203,6 +3221,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
             _column = self._colspec.__clause_element__()
             return _column
         else:
+            assert isinstance(self._colspec, Column)
             _column = self._colspec
             return _column
 
@@ -3255,7 +3274,9 @@ class ForeignKey(DialectKWArgs, SchemaItem):
         table.foreign_keys.add(self)
         # set up remote ".column" attribute, or a note to pick it
         # up when the other Table/Column shows up
-        if isinstance(self._colspec, str):
+        colspec = self._get_colspec()
+
+        if isinstance(colspec, str):
             parenttable, table_key, colname = self._resolve_col_tokens()
             fk_key = (table_key, colname)
             if table_key in parenttable.metadata.tables:
@@ -3271,11 +3292,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
                     self._set_target_column(_column)
 
             parenttable.metadata._fk_memos[fk_key].append(self)
-        elif hasattr(self._colspec, "__clause_element__"):
-            _column = self._colspec.__clause_element__()
+        elif hasattr(colspec, "__clause_element__"):
+            _column = colspec.__clause_element__()
             self._set_target_column(_column)
         else:
-            _column = self._colspec
+            _column = colspec
             self._set_target_column(_column)
 
 
index d6b886be151b1dc246f3306160940b08cfbdd1ab..d04027373ff93bfe13d6ffc98a08e09f5e890ec4 100644 (file)
@@ -2296,6 +2296,48 @@ class ManualBackrefTest(_fixtures.FixtureTest):
         assert a1.user is u1
         assert a1 in u1.addresses
 
+    def test_o2m_with_callable(self):
+        users, Address, addresses, User = (
+            self.tables.users,
+            self.classes.Address,
+            self.tables.addresses,
+            self.classes.User,
+        )
+
+        self.mapper_registry.map_imperatively(
+            User,
+            users,
+            properties={
+                "addresses": relationship(
+                    Address, back_populates=lambda: Address.user
+                )
+            },
+        )
+
+        self.mapper_registry.map_imperatively(
+            Address,
+            addresses,
+            properties={
+                "user": relationship(
+                    User, back_populates=lambda: User.addresses
+                )
+            },
+        )
+
+        sess = fixture_session()
+
+        u1 = User(name="u1")
+        a1 = Address(email_address="foo")
+        u1.addresses.append(a1)
+        assert a1.user is u1
+
+        sess.add(u1)
+        sess.flush()
+        sess.expire_all()
+        assert sess.query(Address).one() is a1
+        assert a1.user is u1
+        assert a1 in u1.addresses
+
     def test_invalid_key(self):
         users, Address, addresses, User = (
             self.tables.users,
index 0b35adc1ccc9194647d858521abebc6b7784dd3b..5776210ef0a60c7eeba647204cd6ad9699325c76 100644 (file)
@@ -469,6 +469,24 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
             ["b.a", "b.b"],
         )
 
+    def test_fk_callable(self):
+        meta = MetaData()
+
+        a = Table(
+            "a",
+            meta,
+            Column("id", Integer, primary_key=True),
+        )
+
+        b = Table(
+            "b",
+            meta,
+            Column("id", Integer, primary_key=True),
+            Column("a_id", ForeignKey(lambda: a.c.id), nullable=False),
+        )
+
+        assert b.c.a_id.references(a.c.id)
+
     def test_pickle_metadata_sequence_restated(self):
         m1 = MetaData()
         Table(