]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
local_remote_pairs/remote_side are comparing against existing
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Feb 2012 01:47:18 +0000 (20:47 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Feb 2012 01:47:18 +0000 (20:47 -0500)
100%

lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/relationships.py
test/orm/test_rel_fn.py

index 5b883a8f5750ad8615ec9bcb9955c6917ba00d30..9bab0c2f4a7422b4db37dc6befb4ebc21cab20c2 100644 (file)
@@ -943,7 +943,6 @@ class RelationshipProperty(StrategizedProperty):
                     prop=self,
                     support_sync=not self.viewonly,
                     can_be_synced_fn=self._columns_are_mapped
-
         )
 
     def _test_new_thing(self):
index 9aebc9f8ae9703e99b1f527a8af8b14a4f770e25..02eab9c2d765da26f0e0ef1f465d5b6a3d717c79 100644 (file)
@@ -147,23 +147,25 @@ class JoinCondition(object):
                         a not in self.consider_as_foreign_keys
                     ):
                     return b
-            elif isinstance(a, schema.Column) and \
+
+            if isinstance(a, schema.Column) and \
                         isinstance(b, schema.Column):
                 if a.references(b):
                     return a
                 elif b.references(a):
                     return b
-            elif secondarycols:
+
+            if secondarycols:
                 if a in secondarycols and b not in secondarycols:
                     return a
                 elif b in secondarycols and a not in secondarycols:
                     return b
 
-        def _annotate_fk(binary, switch):
-            if switch:
-                right, left = binary.left, binary.right
-            else:
-                left, right = binary.left, binary.right
+        def _run_w_switch(binary, fn):
+            binary.left, binary.right = fn(binary, binary.left, binary.right)
+            binary.right, binary.left = fn(binary, binary.right, binary.left)
+
+        def _annotate_fk(binary, left, right):
             can_be_synced = self.can_be_synced_fn(left)
             left = left._annotate({
                 "equated":binary.operator is operators.eq,
@@ -174,26 +176,16 @@ class JoinCondition(object):
                 "equated":binary.operator is operators.eq,
                 "referent":True
             })
-            if switch:
-                binary.right, binary.left = left, right
-            else:
-                binary.left, binary.right = left, right
+            return left, right
 
-        def _annotate_remote(binary, switch):
-            if switch:
-                right, left = binary.left, binary.right
-            else:
-                left, right = binary.left, binary.right
+        def _annotate_remote(binary, left, right):
             left = left._annotate(
                                 {"remote":True})
             if right in parentcols or \
-                secondarycols and right in targetcols:
+                right in targetcols:
                 right = right._annotate(
                                 {"local":True})
-            if switch:
-                binary.right, binary.left = left, right
-            else:
-                binary.left, binary.right = left, right
+            return left, right
 
         def visit_binary(binary):
             if not isinstance(binary.left, sql.ColumnElement) or \
@@ -214,37 +206,39 @@ class JoinCondition(object):
 
             has_foreign = False
             if "foreign" in binary.left._annotations:
-                _annotate_fk(binary, False)
+                binary.left, binary.right = _annotate_fk(
+                                binary, binary.left, binary.right)
                 has_foreign = True
             if "foreign" in binary.right._annotations:
-                _annotate_fk(binary, True)
+                binary.right, binary.left = _annotate_fk(
+                            binary, binary.right, binary.left)
                 has_foreign = True
 
             if "remote" not in binary.left._annotations and \
                 "remote" not in binary.right._annotations:
-                if self._local_remote_pairs:
-                    raise NotImplementedError()
-                elif self._remote_side:
-                    if binary.left in self._remote_side:
-                        _annotate_remote(binary, False)
-                    elif binary.right in self._remote_side:
-                        _annotate_remote(binary, True)
-                elif refers_to_parent_table(binary):
-                    # assume one to many - FKs are "remote"
-                    if "foreign" in binary.left._annotations:
-                        _annotate_remote(binary, False)
-                    elif "foreign" in binary.right._annotations:
-                        _annotate_remote(binary, True)
-                elif secondarycols:
-                    if binary.left in secondarycols:
-                        _annotate_remote(binary, False)
-                    elif binary.right in secondarycols:
-                        _annotate_remote(binary, True)
-                else:
-                    if binary.left in targetcols and has_foreign:
-                        _annotate_remote(binary, False)
-                    elif binary.right in targetcols and has_foreign:
-                        _annotate_remote(binary, True)
+
+                def go(binary, left, right):
+                    if self._local_remote_pairs:
+                        raise NotImplementedError()
+                    elif self._remote_side:
+                        if left in self._remote_side:
+                            return _annotate_remote(binary, left, right)
+                    elif refers_to_parent_table(binary):
+                        # assume one to many - FKs are "remote"
+                        if "foreign" in left._annotations:
+                            return _annotate_remote(binary, left, right)
+                    elif secondarycols:
+                        if left in secondarycols:
+                            return _annotate_remote(binary, left, right)
+                    else:
+                        # TODO: to support the X->Y->Z case 
+                        # we might need to look at parentcols
+                        # and annotate "local" separately...
+                        if left in targetcols and has_foreign \
+                            and right in parentcols or right in secondarycols:
+                            return _annotate_remote(binary, left, right)
+                    return left, right
+                _run_w_switch(binary, go)
 
         self.primaryjoin = visitors.cloned_traverse(
             self.primaryjoin,
@@ -374,8 +368,15 @@ class JoinCondition(object):
             if onetomany_fk and manytoone_fk:
                 # fks on both sides.  test for overlap of local/remote
                 # with foreign key
-                onetomany_local = self.remote_side.intersection(self.foreign_key_columns)
-                manytoone_local = self.local_columns.intersection(self.foreign_key_columns)
+                self_equated = self.remote_columns.intersection(
+                                        self.local_columns
+                                    )
+                onetomany_local = self.remote_columns.\
+                                    intersection(self.foreign_key_columns).\
+                                    difference(self_equated)
+                manytoone_local = self.local_columns.\
+                                    intersection(self.foreign_key_columns).\
+                                    difference(self_equated)
                 if onetomany_local and not manytoone_local:
                     self.direction = ONETOMANY
                 elif manytoone_local and not onetomany_local:
@@ -436,18 +437,18 @@ class JoinCondition(object):
 
     @util.memoized_property
     def local_remote_pairs(self):
-        lrp = []
+        lrp = util.OrderedSet()
         def visit_binary(binary):
             if "remote" in binary.right._annotations and \
                 "local" in binary.left._annotations:
-                lrp.append((binary.left, binary.right))
+                lrp.add((binary.left, binary.right))
             elif "remote" in binary.left._annotations and \
                 "local" in binary.right._annotations:
-                lrp.append((binary.right, binary.left))
+                lrp.add((binary.right, binary.left))
         visitors.traverse(self.primaryjoin, {}, {"binary":visit_binary})
         if self.secondaryjoin is not None:
             visitors.traverse(self.secondaryjoin, {}, {"binary":visit_binary})
-        return lrp
+        return list(lrp)
 
     def join_targets(self, source_selectable, 
                             dest_selectable,
@@ -563,121 +564,6 @@ def _create_lazy_clause(cls, prop, reverse_direction=False):
 
     return lazywhere, bind_to_col, equated_columns
 
-def _sync_pairs_from_join(self, join_condition, primary):
-    """Determine a list of "source"/"destination" column pairs
-    based on the given join condition, as well as the
-    foreign keys argument.
-
-    "source" would be a column referenced by a foreign key,
-    and "destination" would be the column who has a foreign key
-    reference to "source".
-
-    """
-
-    fks = self._user_defined_foreign_keys
-    # locate pairs
-    eq_pairs = criterion_as_pairs(join_condition,
-            consider_as_foreign_keys=fks,
-            any_operator=self.viewonly)
-
-    # couldn't find any fks, but we have 
-    # "secondary" - assume the "secondary" columns
-    # are the fks
-    if not eq_pairs and \
-            self.secondary is not None and \
-            not fks:
-        fks = set(self.secondary.c)
-        eq_pairs = criterion_as_pairs(join_condition,
-                consider_as_foreign_keys=fks,
-                any_operator=self.viewonly)
-
-        if eq_pairs:
-            util.warn("No ForeignKey objects were present "
-                        "in secondary table '%s'.  Assumed referenced "
-                        "foreign key columns %s for join condition '%s' "
-                        "on relationship %s" % (
-                        self.secondary.description,
-                        ", ".join(sorted(["'%s'" % col for col in fks])),
-                        join_condition,
-                        self
-                    ))
-
-    # Filter out just to columns that are mapped.
-    # If viewonly, allow pairs where the FK col
-    # was part of "foreign keys" - the column it references
-    # may be in an un-mapped table - see 
-    # test.orm.test_relationships.ViewOnlyComplexJoin.test_basic
-    # for an example of this.
-    eq_pairs = [(l, r) for (l, r) in eq_pairs
-                if self._columns_are_mapped(l, r)
-                or self.viewonly and 
-                r in fks]
-
-    if eq_pairs:
-        return eq_pairs
-
-    # from here below is just determining the best error message
-    # to report.  Check for a join condition using any operator 
-    # (not just ==), perhaps they need to turn on "viewonly=True".
-    if not self.viewonly and criterion_as_pairs(join_condition,
-            consider_as_foreign_keys=self._user_defined_foreign_keys,
-            any_operator=True):
-
-        err = "Could not locate any "\
-                "foreign-key-equated, locally mapped column "\
-                "pairs for %s "\
-                "condition '%s' on relationship %s." % (
-                    primary and 'primaryjoin' or 'secondaryjoin', 
-                    join_condition, 
-                    self
-                )
-
-        if not self._user_defined_foreign_keys:
-            err += "  Ensure that the "\
-                    "referencing Column objects have a "\
-                    "ForeignKey present, or are otherwise part "\
-                    "of a ForeignKeyConstraint on their parent "\
-                    "Table, or specify the foreign_keys parameter "\
-                    "to this relationship."
-
-        err += "  For more "\
-                "relaxed rules on join conditions, the "\
-                "relationship may be marked as viewonly=True."
-
-        raise sa_exc.ArgumentError(err)
-    else:
-        if self._user_defined_foreign_keys:
-            raise sa_exc.ArgumentError("Could not determine "
-                    "relationship direction for %s condition "
-                    "'%s', on relationship %s, using manual "
-                    "'foreign_keys' setting.  Do the columns "
-                    "in 'foreign_keys' represent all, and "
-                    "only, the 'foreign' columns in this join "
-                    "condition?  Does the %s Table already "
-                    "have adequate ForeignKey and/or "
-                    "ForeignKeyConstraint objects established "
-                    "(in which case 'foreign_keys' is usually "
-                    "unnecessary)?" 
-                    % (
-                        primary and 'primaryjoin' or 'secondaryjoin',
-                        join_condition, 
-                        self,
-                        primary and 'mapped' or 'secondary'
-                    ))
-        else:
-            raise sa_exc.ArgumentError("Could not determine "
-                    "relationship direction for %s condition "
-                    "'%s', on relationship %s. Ensure that the "
-                    "referencing Column objects have a "
-                    "ForeignKey present, or are otherwise part "
-                    "of a ForeignKeyConstraint on their parent "
-                    "Table, or specify the foreign_keys parameter " 
-                    "to this relationship."
-                    % (
-                        primary and 'primaryjoin' or 'secondaryjoin', 
-                        join_condition, 
-                        self
-                    ))
 
 def _determine_synchronize_pairs(self):
     """Resolve 'primary'/foreign' column pairs from the primaryjoin
@@ -714,87 +600,6 @@ def _determine_synchronize_pairs(self):
     else:
         self.secondary_synchronize_pairs = None
 
-def _determine_direction(self):
-    """Determine if this relationship is one to many, many to one, 
-    many to many.
-
-    This is derived from the primaryjoin, presence of "secondary",
-    and in the case of self-referential the "remote side".
-
-    """
-    if self.secondaryjoin is not None:
-        self.direction = MANYTOMANY
-    elif self._refers_to_parent_table():
-
-        # self referential defaults to ONETOMANY unless the "remote"
-        # side is present and does not reference any foreign key
-        # columns
-
-        if self.local_remote_pairs:
-            remote = [r for (l, r) in self.local_remote_pairs]
-        elif self.remote_side:
-            remote = self.remote_side
-        else:
-            remote = None
-        if not remote or self._calculated_foreign_keys.difference(l for (l,
-                r) in self.synchronize_pairs).intersection(remote):
-            self.direction = ONETOMANY
-        else:
-            self.direction = MANYTOONE
-    else:
-        parentcols = util.column_set(self.parent.mapped_table.c)
-        targetcols = util.column_set(self.mapper.mapped_table.c)
-
-        # fk collection which suggests ONETOMANY.
-        onetomany_fk = targetcols.intersection(
-                        self._calculated_foreign_keys)
-
-        # fk collection which suggests MANYTOONE.
-
-        manytoone_fk = parentcols.intersection(
-                        self._calculated_foreign_keys)
-
-        if onetomany_fk and manytoone_fk:
-            # fks on both sides.  do the same test only based on the
-            # local side.
-            referents = [c for (c, f) in self.synchronize_pairs]
-            onetomany_local = parentcols.intersection(referents)
-            manytoone_local = targetcols.intersection(referents)
-
-            if onetomany_local and not manytoone_local:
-                self.direction = ONETOMANY
-            elif manytoone_local and not onetomany_local:
-                self.direction = MANYTOONE
-            else:
-                raise sa_exc.ArgumentError(
-                        "Can't determine relationship"
-                        " direction for relationship '%s' - foreign "
-                        "key columns are present in both the parent "
-                        "and the child's mapped tables.  Specify "
-                        "'foreign_keys' argument." % self)
-        elif onetomany_fk:
-            self.direction = ONETOMANY
-        elif manytoone_fk:
-            self.direction = MANYTOONE
-        else:
-            raise sa_exc.ArgumentError("Can't determine relationship "
-                    "direction for relationship '%s' - foreign "
-                    "key columns are present in neither the parent "
-                    "nor the child's mapped tables" % self)
-
-    if self.cascade.delete_orphan and not self.single_parent \
-        and (self.direction is MANYTOMANY or self.direction
-             is MANYTOONE):
-        util.warn('On %s, delete-orphan cascade is not supported '
-                  'on a many-to-many or many-to-one relationship '
-                  'when single_parent is not set.   Set '
-                  'single_parent=True on the relationship().'
-                  % self)
-    if self.direction is MANYTOONE and self.passive_deletes:
-        util.warn("On %s, 'passive_deletes' is normally configured "
-                  "on one-to-many, one-to-one, many-to-many "
-                  "relationships only."
-                   % self)
 
 def _determine_local_remote_pairs(self):
     """Determine pairs of columns representing "local" to 
index 48743a2ebd13eb024a20590b72b8974572a12e57..6ce89d604e43b36f5b56fd2545f8800f55af499b 100644 (file)
@@ -1,7 +1,7 @@
 from test.lib.testing import assert_raises, assert_raises_message, eq_, AssertsCompiledSQL, is_
 from test.lib import fixtures
 from sqlalchemy.orm import relationships
-from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, select
+from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, select, ForeignKeyConstraint
 from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
 
 class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -21,6 +21,15 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
             Column('id', Integer, primary_key=True),
             Column('sid', Integer, ForeignKey('selfref.id'))
         )
+        cls.composite_selfref = Table('composite_selfref', m,
+            Column('id', Integer, primary_key=True),
+            Column('group_id', Integer, primary_key=True),
+            Column('parent_id', Integer),
+            ForeignKeyConstraint(
+                ['parent_id', 'group_id'],
+                ['composite_selfref.id', 'composite_selfref.group_id']
+            )
+        )
 
     def _join_fixture_o2m(self, **kw):
         return relationships.JoinCondition(
@@ -59,6 +68,26 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
             **kw
         )
 
+    def _join_fixture_o2m_composite_selfref(self, **kw):
+        return relationships.JoinCondition(
+            self.composite_selfref,
+            self.composite_selfref,
+            self.composite_selfref,
+            self.composite_selfref,
+            **kw
+        )
+
+    def _join_fixture_m2o_composite_selfref(self, **kw):
+        return relationships.JoinCondition(
+            self.composite_selfref,
+            self.composite_selfref,
+            self.composite_selfref,
+            self.composite_selfref,
+            remote_side=set([self.composite_selfref.c.id, 
+                            self.composite_selfref.c.group_id]),
+            **kw
+        )
+
     def test_determine_join_o2m(self):
         joincond = self._join_fixture_o2m()
         self.assert_compile(
@@ -113,6 +142,46 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
             set([self.selfref.c.id])
         )
 
+    def test_determine_join_o2m_composite_selfref(self):
+        joincond = self._join_fixture_o2m_composite_selfref()
+        self.assert_compile(
+                joincond.primaryjoin,
+                "composite_selfref.group_id = composite_selfref.group_id "
+                "AND composite_selfref.id = composite_selfref.parent_id"
+        )
+
+    def test_determine_direction_o2m_composite_selfref(self):
+        joincond = self._join_fixture_o2m_composite_selfref()
+        is_(joincond.direction, ONETOMANY)
+
+    def test_determine_remote_side_o2m_composite_selfref(self):
+        joincond = self._join_fixture_o2m_composite_selfref()
+        eq_(
+            joincond.remote_side,
+            set([self.composite_selfref.c.parent_id, 
+                self.composite_selfref.c.group_id])
+        )
+
+    def test_determine_join_m2o_composite_selfref(self):
+        joincond = self._join_fixture_m2o_composite_selfref()
+        self.assert_compile(
+                joincond.primaryjoin,
+                "composite_selfref.group_id = composite_selfref.group_id "
+                "AND composite_selfref.id = composite_selfref.parent_id"
+        )
+
+    def test_determine_direction_m2o_composite_selfref(self):
+        joincond = self._join_fixture_m2o_composite_selfref()
+        is_(joincond.direction, MANYTOONE)
+
+    def test_determine_remote_side_m2o_composite_selfref(self):
+        joincond = self._join_fixture_m2o_composite_selfref()
+        eq_(
+            joincond.remote_side,
+            set([self.composite_selfref.c.id, 
+                self.composite_selfref.c.group_id])
+        )
+
     def test_determine_join_m2o(self):
         joincond = self._join_fixture_m2o()
         self.assert_compile(