]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
simplify remote annotation significantly, and also
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 8 Feb 2012 15:14:36 +0000 (10:14 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 8 Feb 2012 15:14:36 +0000 (10:14 -0500)
catch the actual remote columns more accurately.

lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/operators.py
test/orm/test_rel_fn.py
test/orm/test_relationships.py

index 9bab0c2f4a7422b4db37dc6befb4ebc21cab20c2..95343016276134ee10b9722eca95a148a4f61a13 100644 (file)
@@ -949,6 +949,7 @@ class RelationshipProperty(StrategizedProperty):
         assert self.jc.direction is self.direction
         assert self.jc.remote_side == self.remote_side
         assert self.jc.local_remote_pairs == self.local_remote_pairs
+        pass
 
     def _check_conflicts(self):
         """Test that this relationship is legal, warn about 
@@ -1510,6 +1511,7 @@ class RelationshipProperty(StrategizedProperty):
         return strategy.use_get
 
     def _refers_to_parent_table(self):
+        alt = self._alt_refers_to_parent_table()
         pt = self.parent.mapped_table
         mt = self.mapper.mapped_table
         for c, f in self.synchronize_pairs:
@@ -1519,10 +1521,35 @@ class RelationshipProperty(StrategizedProperty):
                 mt.is_derived_from(c.table) and \
                 mt.is_derived_from(f.table)
             ):
+                assert alt
                 return True
         else:
+            assert not alt
             return False
 
+    def _alt_refers_to_parent_table(self):
+        pt = self.parent.mapped_table
+        mt = self.mapper.mapped_table
+        result = [False]
+        def visit_binary(binary):
+            c, f = binary.left, binary.right
+            if (
+                isinstance(c, expression.ColumnClause) and \
+                isinstance(f, expression.ColumnClause) and \
+                pt.is_derived_from(c.table) and \
+                pt.is_derived_from(f.table) and \
+                mt.is_derived_from(c.table) and \
+                mt.is_derived_from(f.table)
+            ):
+                result[0] = True
+
+        visitors.traverse(
+                    self.primaryjoin,
+                    {},
+                    {"binary":visit_binary}
+                )
+        return result[0]
+
     @util.memoized_property
     def _is_self_referential(self):
         return self.mapper.common_parent(self.parent)
index 02eab9c2d765da26f0e0ef1f465d5b6a3d717c79..cb07f234a37c47f33246908d4820479a8d105860 100644 (file)
@@ -19,6 +19,27 @@ from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \
 from sqlalchemy.sql import operators, expression, visitors
 from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, ONETOMANY
 
+def remote(expr):
+    return _annotate_columns(expr, {"remote":True})
+
+def foreign(expr):
+    return _annotate_columns(expr, {"foreign":True})
+
+def remote_foreign(expr):
+    return _annotate_columns(expr, {"foreign":True, 
+                                "remote":True})
+
+def _annotate_columns(element, annotations):
+    def clone(elem):
+        if isinstance(elem, expression.ColumnClause):
+            elem = elem._annotate(annotations.copy())
+        elem._copy_internals(clone=clone)
+        return elem
+
+    if element is not None:
+        element = clone(element)
+    return element
+
 class JoinCondition(object):
     def __init__(self, 
                     parent_selectable, 
@@ -55,7 +76,8 @@ class JoinCondition(object):
         self.support_sync = support_sync
         self.can_be_synced_fn = can_be_synced_fn
         self._determine_joins()
-        self._parse_joins()
+        self._annotate_fks()
+        self._annotate_remote()
         self._determine_direction()
 
     def _determine_joins(self):
@@ -106,13 +128,7 @@ class JoinCondition(object):
                     "'secondaryjoin' is needed as well."
                     % self.prop)
 
-    def _parse_joins(self):
-        """Apply 'remote', 'local' and 'foreign' annotations
-        to the primary and secondary join conditions.
-
-        """
-        parentcols = util.column_set(self.parent_selectable.c)
-        targetcols = util.column_set(self.child_selectable.c)
+    def _annotate_fks(self):
         if self.secondary is not None:
             secondarycols = util.column_set(self.secondary.c)
         else:
@@ -121,20 +137,6 @@ class JoinCondition(object):
         def col_is(a, b):
             return a.compare(b)
 
-        def refers_to_parent_table(binary):
-            pt = self.parent_selectable
-            mt = self.child_selectable
-            c, f = binary.left, binary.right
-            if (
-                pt.is_derived_from(c.table) and \
-                pt.is_derived_from(f.table) and \
-                mt.is_derived_from(c.table) and \
-                mt.is_derived_from(f.table)
-            ):
-                return True
-            else:
-                return False
-
         def is_foreign(a, b):
             if self.consider_as_foreign_keys:
                 if a in self.consider_as_foreign_keys and (
@@ -161,32 +163,19 @@ class JoinCondition(object):
                 elif b in secondarycols and a not in secondarycols:
                     return b
 
-        def _run_w_switch(binary, fn):
-            binary.left, binary.right = fn(binary, binary.left, binary.right)
-            binary.right, binary.left = fn(binary, binary.right, binary.left)
-
         def _annotate_fk(binary, left, right):
             can_be_synced = self.can_be_synced_fn(left)
             left = left._annotate({
-                "equated":binary.operator is operators.eq,
+                #"equated":binary.operator is operators.eq,
                 "can_be_synced":can_be_synced and \
                     binary.operator is operators.eq
             })
             right = right._annotate({
-                "equated":binary.operator is operators.eq,
+                #"equated":binary.operator is operators.eq,
                 "referent":True
             })
             return left, right
 
-        def _annotate_remote(binary, left, right):
-            left = left._annotate(
-                                {"remote":True})
-            if right in parentcols or \
-                right in targetcols:
-                right = right._annotate(
-                                {"local":True})
-            return left, right
-
         def visit_binary(binary):
             if not isinstance(binary.left, sql.ColumnElement) or \
                         not isinstance(binary.right, sql.ColumnElement):
@@ -204,41 +193,12 @@ class JoinCondition(object):
                                             {"foreign":True})
                     # TODO: when the two cols are the same.
 
-            has_foreign = False
             if "foreign" in binary.left._annotations:
                 binary.left, binary.right = _annotate_fk(
                                 binary, binary.left, binary.right)
-                has_foreign = True
             if "foreign" in binary.right._annotations:
                 binary.right, binary.left = _annotate_fk(
                             binary, binary.right, binary.left)
-                has_foreign = True
-
-            if "remote" not in binary.left._annotations and \
-                "remote" not in binary.right._annotations:
-
-                def go(binary, left, right):
-                    if self._local_remote_pairs:
-                        raise NotImplementedError()
-                    elif self._remote_side:
-                        if left in self._remote_side:
-                            return _annotate_remote(binary, left, right)
-                    elif refers_to_parent_table(binary):
-                        # assume one to many - FKs are "remote"
-                        if "foreign" in left._annotations:
-                            return _annotate_remote(binary, left, right)
-                    elif secondarycols:
-                        if left in secondarycols:
-                            return _annotate_remote(binary, left, right)
-                    else:
-                        # TODO: to support the X->Y->Z case 
-                        # we might need to look at parentcols
-                        # and annotate "local" separately...
-                        if left in targetcols and has_foreign \
-                            and right in parentcols or right in secondarycols:
-                            return _annotate_remote(binary, left, right)
-                    return left, right
-                _run_w_switch(binary, go)
 
         self.primaryjoin = visitors.cloned_traverse(
             self.primaryjoin,
@@ -257,11 +217,63 @@ class JoinCondition(object):
             self._check_foreign_cols(
                         self.secondaryjoin, False)
 
+    def _refers_to_parent_table(self):
+        pt = self.parent_selectable
+        mt = self.child_selectable
+        result = [False]
+        def visit_binary(binary):
+            c, f = binary.left, binary.right
+            if (
+                isinstance(c, expression.ColumnClause) and \
+                isinstance(f, expression.ColumnClause) and \
+                pt.is_derived_from(c.table) and \
+                pt.is_derived_from(f.table) and \
+                mt.is_derived_from(c.table) and \
+                mt.is_derived_from(f.table)
+            ):
+                result[0] = True
+
+        visitors.traverse(
+                    self.primaryjoin,
+                    {},
+                    {"binary":visit_binary}
+                )
+        return result[0]
+
+    def _annotate_remote(self):
+        for col in visitors.iterate(self.primaryjoin, {}):
+            if "remote" in col._annotations:
+                return
+
+        if self._local_remote_pairs:
+            raise NotImplementedError()
+        elif self._remote_side:
+            def repl(element):
+                if element in self._remote_side:
+                    return element._annotate({"remote":True})
+        elif self.secondary is not None:
+            def repl(element):
+                if self.secondary.c.contains_column(element):
+                    return element._annotate({"remote":True})
+        elif self._refers_to_parent_table():
+            def repl(element):
+                # assume one to many - FKs are "remote"
+                if "foreign" in element._annotations:
+                    return element._annotate({"remote":True})
+        else:
+            def repl(element):
+                if self.child_selectable.c.contains_column(element):
+                    return element._annotate({"remote":True})
+
+        self.primaryjoin = visitors.replacement_traverse(
+                                        self.primaryjoin, {},  repl)
+        if self.secondaryjoin is not None:
+            self.secondaryjoin = visitors.replacement_traverse(
+                                        self.secondaryjoin, {}, repl)
+
 
     def _check_foreign_cols(self, join_condition, primary):
         """Check the foreign key columns collected and emit error messages."""
-        # TODO: don't worry, we can simplify this once we
-        # encourage configuration via direct annotation
 
         can_sync = False
 
@@ -284,66 +296,30 @@ class JoinCondition(object):
         # to report.  Check for a join condition using any operator 
         # (not just ==), perhaps they need to turn on "viewonly=True".
         if self.support_sync and has_foreign and not can_sync:
-
-            err = "Could not locate any "\
-                    "foreign-key-equated, locally mapped column "\
-                    "pairs for %s "\
-                    "condition '%s' on relationship %s." % (
+            err = "Could not locate any simple equality expressions "\
+                    "involving foreign key columns for %s join condition "\
+                    "'%s' on relationship %s." % (
                         primary and 'primaryjoin' or 'secondaryjoin', 
                         join_condition, 
                         self.prop
                     )
-
-            # TODO: this needs to be changed to detect that
-            # annotations were present and whatnot.   the future
-            # foreignkey(col) annotation will cover establishing
-            # the col as foreign to it's mate
-            if not self.consider_as_foreign_keys:
-                err += "  Ensure that the "\
-                        "referencing Column objects have a "\
-                        "ForeignKey present, or are otherwise part "\
-                        "of a ForeignKeyConstraint on their parent "\
-                        "Table, or specify the foreign_keys parameter "\
-                        "to this relationship."
-
-            err += "  For more "\
-                    "relaxed rules on join conditions, the "\
-                    "relationship may be marked as viewonly=True."
+            err += "  Ensure that referencing columns are associated with a "\
+                    "ForeignKey or ForeignKeyConstraint, or are annotated "\
+                    "in the join condition with the foreign() annotation. "\
+                    "To allow comparison operators other than '==', "\
+                    "the relationship can be marked as viewonly=True."
 
             raise sa_exc.ArgumentError(err)
         else:
-            if self.consider_as_foreign_keys:
-                raise sa_exc.ArgumentError("Could not determine "
-                        "relationship direction for %s condition "
-                        "'%s', on relationship %s, using manual "
-                        "'foreign_keys' setting.  Do the columns "
-                        "in 'foreign_keys' represent all, and "
-                        "only, the 'foreign' columns in this join "
-                        "condition?  Does the %s Table already "
-                        "have adequate ForeignKey and/or "
-                        "ForeignKeyConstraint objects established "
-                        "(in which case 'foreign_keys' is usually "
-                        "unnecessary)?" 
-                        % (
-                            primary and 'primaryjoin' or 'secondaryjoin',
-                            join_condition, 
-                            self.prop,
-                            primary and 'mapped' or 'secondary'
-                        ))
-            else:
-                raise sa_exc.ArgumentError("Could not determine "
-                        "relationship direction for %s condition "
-                        "'%s', on relationship %s. Ensure that the "
-                        "referencing Column objects have a "
-                        "ForeignKey present, or are otherwise part "
-                        "of a ForeignKeyConstraint on their parent "
-                        "Table, or specify the foreign_keys parameter " 
-                        "to this relationship."
-                        % (
-                            primary and 'primaryjoin' or 'secondaryjoin', 
-                            join_condition, 
-                            self.prop
-                        ))
+            err = "Could not locate any relevant foreign key columns "\
+                    "for %s join condition '%s' on relationship %s." % (
+                        primary and 'primaryjoin' or 'secondaryjoin', 
+                        join_condition, 
+                        self.prop
+                    )
+            err += "Ensure that referencing columns are associated with a "\
+                    "a ForeignKey or ForeignKeyConstraint, or are annotated "\
+                    "in the join condition with the foreign() annotation."
 
     def _determine_direction(self):
         """Determine if this relationship is one to many, many to one, 
@@ -399,14 +375,21 @@ class JoinCondition(object):
                         "nor the child's mapped tables" % self.prop)
 
     @util.memoized_property
-    def remote_columns(self):
+    def liberal_remote_columns(self):
+        # this is temporary until we figure out 
+        # which version of "remote" to use
         return self._gather_join_annotations("remote")
 
+    @util.memoized_property
+    def remote_columns(self):
+        return set([r for l, r in self.local_remote_pairs])
+        #return self._gather_join_annotations("remote")
+
     remote_side = remote_columns
 
     @util.memoized_property
     def local_columns(self):
-        return self._gather_join_annotations("local")
+        return set([l for l, r in self.local_remote_pairs])
 
     @util.memoized_property
     def foreign_key_columns(self):
@@ -440,10 +423,14 @@ class JoinCondition(object):
         lrp = util.OrderedSet()
         def visit_binary(binary):
             if "remote" in binary.right._annotations and \
-                "local" in binary.left._annotations:
+                "remote" not in binary.left._annotations and \
+                isinstance(binary.left, expression.ColumnClause) and \
+                self.can_be_synced_fn(binary.left):
                 lrp.add((binary.left, binary.right))
             elif "remote" in binary.left._annotations and \
-                "local" in binary.right._annotations:
+                "remote" not in binary.right._annotations and \
+                isinstance(binary.right, expression.ColumnClause) and \
+                self.can_be_synced_fn(binary.right):
                 lrp.add((binary.right, binary.left))
         visitors.traverse(self.primaryjoin, {}, {"binary":visit_binary})
         if self.secondaryjoin is not None:
index 30e19bc6864237998756369e2e92d342f61f63fb..72099a5f5eb0361f213eddfc9f09b8b11f3a4c41 100644 (file)
@@ -3384,6 +3384,10 @@ class _BinaryExpression(ColumnElement):
         except:
             raise TypeError("Boolean value of this clause is not defined")
 
+    @property
+    def is_comparison(self):
+        return operators.is_comparison(self.operator)
+
     @property
     def _from_objects(self):
         return self.left._from_objects + self.right._from_objects
index 89f0aaee13e753d7faf5a3986d9a008508ec9b77..b86b50db445daadad4b41c863a1f73b802c033cb 100644 (file)
@@ -521,6 +521,11 @@ def nullslast_op(a):
 
 _commutative = set([eq, ne, add, mul])
 
+_comparison = set([eq, ne, lt, gt, ge, le])
+
+def is_comparison(op):
+    return op in _comparison
+
 def is_commutative(op):
     return op in _commutative
 
index 6ce89d604e43b36f5b56fd2545f8800f55af499b..862149bc13622429112088a0764810e978a3f818 100644 (file)
@@ -1,7 +1,9 @@
-from test.lib.testing import assert_raises, assert_raises_message, eq_, AssertsCompiledSQL, is_
+from test.lib.testing import assert_raises, assert_raises_message, eq_, \
+    AssertsCompiledSQL, is_
 from test.lib import fixtures
 from sqlalchemy.orm import relationships
-from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, select, ForeignKeyConstraint
+from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, \
+    select, ForeignKeyConstraint, exc
 from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
 
 class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -12,10 +14,14 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
         m = MetaData()
         cls.left = Table('lft', m,
             Column('id', Integer, primary_key=True),
+            Column('x', Integer),
+            Column('y', Integer),
         )
         cls.right = Table('rgt', m,
             Column('id', Integer, primary_key=True),
-            Column('lid', Integer, ForeignKey('lft.id'))
+            Column('lid', Integer, ForeignKey('lft.id')),
+            Column('x', Integer),
+            Column('y', Integer),
         )
         cls.selfref = Table('selfref', m,
             Column('id', Integer, primary_key=True),
@@ -88,6 +94,88 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
             **kw
         )
 
+    def _join_fixture_compound_expression_1(self, **kw):
+        return relationships.JoinCondition(
+            self.left,
+            self.right,
+            self.left,
+            self.right,
+            primaryjoin=(self.left.c.x + self.left.c.y) == \
+                            relationships.remote_foreign(
+                                self.right.c.x * self.right.c.y
+                            ),
+            **kw
+        )
+
+    def _join_fixture_compound_expression_2(self, **kw):
+        return relationships.JoinCondition(
+            self.left,
+            self.right,
+            self.left,
+            self.right,
+            primaryjoin=(self.left.c.x + self.left.c.y) == \
+                            relationships.foreign(
+                                self.right.c.x * self.right.c.y
+                            ),
+            **kw
+        )
+
+    def test_determine_remote_side_compound_1(self):
+        joincond = self._join_fixture_compound_expression_1(
+                                support_sync=False)
+        eq_(
+            joincond.liberal_remote_columns,
+            set([self.right.c.x, self.right.c.y])
+        )
+
+    def test_determine_local_remote_compound_1(self):
+        joincond = self._join_fixture_compound_expression_1(
+                                support_sync=False)
+        eq_(
+            joincond.local_remote_pairs,
+            []
+        )
+
+    def test_err_local_remote_compound_1(self):
+        assert_raises_message(
+            exc.ArgumentError,
+            r"Could not locate any simple equality "
+            "expressions involving foreign key "
+            "columns for primaryjoin join "
+            r"condition 'lft.x \+ lft.y = rgt.x \* rgt.y' "
+            "on relationship None.  Ensure that referencing "
+            "columns are associated with a ForeignKey or "
+            "ForeignKeyConstraint, or are annotated in the "
+            r"join condition with the foreign\(\) annotation. "
+            "To allow comparison operators other "
+            "than '==', the relationship can be marked as viewonly=True.",
+            self._join_fixture_compound_expression_1
+        )
+
+    def test_determine_remote_side_compound_2(self):
+        joincond = self._join_fixture_compound_expression_2(
+                                support_sync=False)
+        eq_(
+            joincond.remote_side,
+            set([])
+        )
+
+    def test_determine_local_remote_compound_2(self):
+        joincond = self._join_fixture_compound_expression_2(
+                                support_sync=False)
+        eq_(
+            joincond.local_remote_pairs,
+            []
+        )
+
+    def test_determine_direction_compound_2(self):
+        joincond = self._join_fixture_compound_expression_2(
+                                support_sync=False)
+        is_(
+            joincond.direction,
+            ONETOMANY
+        )
+
     def test_determine_join_o2m(self):
         joincond = self._join_fixture_o2m()
         self.assert_compile(
@@ -177,7 +265,7 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_determine_remote_side_m2o_composite_selfref(self):
         joincond = self._join_fixture_m2o_composite_selfref()
         eq_(
-            joincond.remote_side,
+            joincond.liberal_remote_columns,
             set([self.composite_selfref.c.id, 
                 self.composite_selfref.c.group_id])
         )
index 2049088aff29680bb0c9af1efd668bebd8664341..4031a1251cb9767ae199aff5e37f7f1f38ef5331 100644 (file)
@@ -365,6 +365,9 @@ class ComplexPostUpdateTest(fixtures.MappedTest):
                  backref=backref('pages',
                                  cascade="all, delete-orphan",
                                  order_by=pages.c.pagename)),
+            # TODO: the remote side + lazyclause isn't 
+            # even coming out correctly here.   currentversion/version
+            # aren't being considered at all.
             'currentversion': relationship(
                  PageVersion,
                  uselist=False,