From: Mike Bayer Date: Wed, 8 Feb 2012 15:14:36 +0000 (-0500) Subject: simplify remote annotation significantly, and also X-Git-Tag: rel_0_8_0b1~477^2~15 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=d1414ad20524c421aa78272c03dce5f839a0aab6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git simplify remote annotation significantly, and also catch the actual remote columns more accurately. --- diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 9bab0c2f4a..9534301627 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -949,6 +949,7 @@ class RelationshipProperty(StrategizedProperty): assert self.jc.direction is self.direction assert self.jc.remote_side == self.remote_side assert self.jc.local_remote_pairs == self.local_remote_pairs + pass def _check_conflicts(self): """Test that this relationship is legal, warn about @@ -1510,6 +1511,7 @@ class RelationshipProperty(StrategizedProperty): return strategy.use_get def _refers_to_parent_table(self): + alt = self._alt_refers_to_parent_table() pt = self.parent.mapped_table mt = self.mapper.mapped_table for c, f in self.synchronize_pairs: @@ -1519,10 +1521,35 @@ class RelationshipProperty(StrategizedProperty): mt.is_derived_from(c.table) and \ mt.is_derived_from(f.table) ): + assert alt return True else: + assert not alt return False + def _alt_refers_to_parent_table(self): + pt = self.parent.mapped_table + mt = self.mapper.mapped_table + result = [False] + def visit_binary(binary): + c, f = binary.left, binary.right + if ( + isinstance(c, expression.ColumnClause) and \ + isinstance(f, expression.ColumnClause) and \ + pt.is_derived_from(c.table) and \ + pt.is_derived_from(f.table) and \ + mt.is_derived_from(c.table) and \ + mt.is_derived_from(f.table) + ): + result[0] = True + + visitors.traverse( + self.primaryjoin, + {}, + {"binary":visit_binary} + ) + return result[0] + @util.memoized_property def _is_self_referential(self): return self.mapper.common_parent(self.parent) diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 02eab9c2d7..cb07f234a3 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -19,6 +19,27 @@ from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \ from sqlalchemy.sql import operators, expression, visitors from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, ONETOMANY +def remote(expr): + return _annotate_columns(expr, {"remote":True}) + +def foreign(expr): + return _annotate_columns(expr, {"foreign":True}) + +def remote_foreign(expr): + return _annotate_columns(expr, {"foreign":True, + "remote":True}) + +def _annotate_columns(element, annotations): + def clone(elem): + if isinstance(elem, expression.ColumnClause): + elem = elem._annotate(annotations.copy()) + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + return element + class JoinCondition(object): def __init__(self, parent_selectable, @@ -55,7 +76,8 @@ class JoinCondition(object): self.support_sync = support_sync self.can_be_synced_fn = can_be_synced_fn self._determine_joins() - self._parse_joins() + self._annotate_fks() + self._annotate_remote() self._determine_direction() def _determine_joins(self): @@ -106,13 +128,7 @@ class JoinCondition(object): "'secondaryjoin' is needed as well." % self.prop) - def _parse_joins(self): - """Apply 'remote', 'local' and 'foreign' annotations - to the primary and secondary join conditions. - - """ - parentcols = util.column_set(self.parent_selectable.c) - targetcols = util.column_set(self.child_selectable.c) + def _annotate_fks(self): if self.secondary is not None: secondarycols = util.column_set(self.secondary.c) else: @@ -121,20 +137,6 @@ class JoinCondition(object): def col_is(a, b): return a.compare(b) - def refers_to_parent_table(binary): - pt = self.parent_selectable - mt = self.child_selectable - c, f = binary.left, binary.right - if ( - pt.is_derived_from(c.table) and \ - pt.is_derived_from(f.table) and \ - mt.is_derived_from(c.table) and \ - mt.is_derived_from(f.table) - ): - return True - else: - return False - def is_foreign(a, b): if self.consider_as_foreign_keys: if a in self.consider_as_foreign_keys and ( @@ -161,32 +163,19 @@ class JoinCondition(object): elif b in secondarycols and a not in secondarycols: return b - def _run_w_switch(binary, fn): - binary.left, binary.right = fn(binary, binary.left, binary.right) - binary.right, binary.left = fn(binary, binary.right, binary.left) - def _annotate_fk(binary, left, right): can_be_synced = self.can_be_synced_fn(left) left = left._annotate({ - "equated":binary.operator is operators.eq, + #"equated":binary.operator is operators.eq, "can_be_synced":can_be_synced and \ binary.operator is operators.eq }) right = right._annotate({ - "equated":binary.operator is operators.eq, + #"equated":binary.operator is operators.eq, "referent":True }) return left, right - def _annotate_remote(binary, left, right): - left = left._annotate( - {"remote":True}) - if right in parentcols or \ - right in targetcols: - right = right._annotate( - {"local":True}) - return left, right - def visit_binary(binary): if not isinstance(binary.left, sql.ColumnElement) or \ not isinstance(binary.right, sql.ColumnElement): @@ -204,41 +193,12 @@ class JoinCondition(object): {"foreign":True}) # TODO: when the two cols are the same. - has_foreign = False if "foreign" in binary.left._annotations: binary.left, binary.right = _annotate_fk( binary, binary.left, binary.right) - has_foreign = True if "foreign" in binary.right._annotations: binary.right, binary.left = _annotate_fk( binary, binary.right, binary.left) - has_foreign = True - - if "remote" not in binary.left._annotations and \ - "remote" not in binary.right._annotations: - - def go(binary, left, right): - if self._local_remote_pairs: - raise NotImplementedError() - elif self._remote_side: - if left in self._remote_side: - return _annotate_remote(binary, left, right) - elif refers_to_parent_table(binary): - # assume one to many - FKs are "remote" - if "foreign" in left._annotations: - return _annotate_remote(binary, left, right) - elif secondarycols: - if left in secondarycols: - return _annotate_remote(binary, left, right) - else: - # TODO: to support the X->Y->Z case - # we might need to look at parentcols - # and annotate "local" separately... - if left in targetcols and has_foreign \ - and right in parentcols or right in secondarycols: - return _annotate_remote(binary, left, right) - return left, right - _run_w_switch(binary, go) self.primaryjoin = visitors.cloned_traverse( self.primaryjoin, @@ -257,11 +217,63 @@ class JoinCondition(object): self._check_foreign_cols( self.secondaryjoin, False) + def _refers_to_parent_table(self): + pt = self.parent_selectable + mt = self.child_selectable + result = [False] + def visit_binary(binary): + c, f = binary.left, binary.right + if ( + isinstance(c, expression.ColumnClause) and \ + isinstance(f, expression.ColumnClause) and \ + pt.is_derived_from(c.table) and \ + pt.is_derived_from(f.table) and \ + mt.is_derived_from(c.table) and \ + mt.is_derived_from(f.table) + ): + result[0] = True + + visitors.traverse( + self.primaryjoin, + {}, + {"binary":visit_binary} + ) + return result[0] + + def _annotate_remote(self): + for col in visitors.iterate(self.primaryjoin, {}): + if "remote" in col._annotations: + return + + if self._local_remote_pairs: + raise NotImplementedError() + elif self._remote_side: + def repl(element): + if element in self._remote_side: + return element._annotate({"remote":True}) + elif self.secondary is not None: + def repl(element): + if self.secondary.c.contains_column(element): + return element._annotate({"remote":True}) + elif self._refers_to_parent_table(): + def repl(element): + # assume one to many - FKs are "remote" + if "foreign" in element._annotations: + return element._annotate({"remote":True}) + else: + def repl(element): + if self.child_selectable.c.contains_column(element): + return element._annotate({"remote":True}) + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, {}, repl) + def _check_foreign_cols(self, join_condition, primary): """Check the foreign key columns collected and emit error messages.""" - # TODO: don't worry, we can simplify this once we - # encourage configuration via direct annotation can_sync = False @@ -284,66 +296,30 @@ class JoinCondition(object): # to report. Check for a join condition using any operator # (not just ==), perhaps they need to turn on "viewonly=True". if self.support_sync and has_foreign and not can_sync: - - err = "Could not locate any "\ - "foreign-key-equated, locally mapped column "\ - "pairs for %s "\ - "condition '%s' on relationship %s." % ( + err = "Could not locate any simple equality expressions "\ + "involving foreign key columns for %s join condition "\ + "'%s' on relationship %s." % ( primary and 'primaryjoin' or 'secondaryjoin', join_condition, self.prop ) - - # TODO: this needs to be changed to detect that - # annotations were present and whatnot. the future - # foreignkey(col) annotation will cover establishing - # the col as foreign to it's mate - if not self.consider_as_foreign_keys: - err += " Ensure that the "\ - "referencing Column objects have a "\ - "ForeignKey present, or are otherwise part "\ - "of a ForeignKeyConstraint on their parent "\ - "Table, or specify the foreign_keys parameter "\ - "to this relationship." - - err += " For more "\ - "relaxed rules on join conditions, the "\ - "relationship may be marked as viewonly=True." + err += " Ensure that referencing columns are associated with a "\ + "ForeignKey or ForeignKeyConstraint, or are annotated "\ + "in the join condition with the foreign() annotation. "\ + "To allow comparison operators other than '==', "\ + "the relationship can be marked as viewonly=True." raise sa_exc.ArgumentError(err) else: - if self.consider_as_foreign_keys: - raise sa_exc.ArgumentError("Could not determine " - "relationship direction for %s condition " - "'%s', on relationship %s, using manual " - "'foreign_keys' setting. Do the columns " - "in 'foreign_keys' represent all, and " - "only, the 'foreign' columns in this join " - "condition? Does the %s Table already " - "have adequate ForeignKey and/or " - "ForeignKeyConstraint objects established " - "(in which case 'foreign_keys' is usually " - "unnecessary)?" - % ( - primary and 'primaryjoin' or 'secondaryjoin', - join_condition, - self.prop, - primary and 'mapped' or 'secondary' - )) - else: - raise sa_exc.ArgumentError("Could not determine " - "relationship direction for %s condition " - "'%s', on relationship %s. Ensure that the " - "referencing Column objects have a " - "ForeignKey present, or are otherwise part " - "of a ForeignKeyConstraint on their parent " - "Table, or specify the foreign_keys parameter " - "to this relationship." - % ( - primary and 'primaryjoin' or 'secondaryjoin', - join_condition, - self.prop - )) + err = "Could not locate any relevant foreign key columns "\ + "for %s join condition '%s' on relationship %s." % ( + primary and 'primaryjoin' or 'secondaryjoin', + join_condition, + self.prop + ) + err += "Ensure that referencing columns are associated with a "\ + "a ForeignKey or ForeignKeyConstraint, or are annotated "\ + "in the join condition with the foreign() annotation." def _determine_direction(self): """Determine if this relationship is one to many, many to one, @@ -399,14 +375,21 @@ class JoinCondition(object): "nor the child's mapped tables" % self.prop) @util.memoized_property - def remote_columns(self): + def liberal_remote_columns(self): + # this is temporary until we figure out + # which version of "remote" to use return self._gather_join_annotations("remote") + @util.memoized_property + def remote_columns(self): + return set([r for l, r in self.local_remote_pairs]) + #return self._gather_join_annotations("remote") + remote_side = remote_columns @util.memoized_property def local_columns(self): - return self._gather_join_annotations("local") + return set([l for l, r in self.local_remote_pairs]) @util.memoized_property def foreign_key_columns(self): @@ -440,10 +423,14 @@ class JoinCondition(object): lrp = util.OrderedSet() def visit_binary(binary): if "remote" in binary.right._annotations and \ - "local" in binary.left._annotations: + "remote" not in binary.left._annotations and \ + isinstance(binary.left, expression.ColumnClause) and \ + self.can_be_synced_fn(binary.left): lrp.add((binary.left, binary.right)) elif "remote" in binary.left._annotations and \ - "local" in binary.right._annotations: + "remote" not in binary.right._annotations and \ + isinstance(binary.right, expression.ColumnClause) and \ + self.can_be_synced_fn(binary.right): lrp.add((binary.right, binary.left)) visitors.traverse(self.primaryjoin, {}, {"binary":visit_binary}) if self.secondaryjoin is not None: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 30e19bc686..72099a5f5e 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -3384,6 +3384,10 @@ class _BinaryExpression(ColumnElement): except: raise TypeError("Boolean value of this clause is not defined") + @property + def is_comparison(self): + return operators.is_comparison(self.operator) + @property def _from_objects(self): return self.left._from_objects + self.right._from_objects diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 89f0aaee13..b86b50db44 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -521,6 +521,11 @@ def nullslast_op(a): _commutative = set([eq, ne, add, mul]) +_comparison = set([eq, ne, lt, gt, ge, le]) + +def is_comparison(op): + return op in _comparison + def is_commutative(op): return op in _commutative diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index 6ce89d604e..862149bc13 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -1,7 +1,9 @@ -from test.lib.testing import assert_raises, assert_raises_message, eq_, AssertsCompiledSQL, is_ +from test.lib.testing import assert_raises, assert_raises_message, eq_, \ + AssertsCompiledSQL, is_ from test.lib import fixtures from sqlalchemy.orm import relationships -from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, select, ForeignKeyConstraint +from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, \ + select, ForeignKeyConstraint, exc from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): @@ -12,10 +14,14 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): m = MetaData() cls.left = Table('lft', m, Column('id', Integer, primary_key=True), + Column('x', Integer), + Column('y', Integer), ) cls.right = Table('rgt', m, Column('id', Integer, primary_key=True), - Column('lid', Integer, ForeignKey('lft.id')) + Column('lid', Integer, ForeignKey('lft.id')), + Column('x', Integer), + Column('y', Integer), ) cls.selfref = Table('selfref', m, Column('id', Integer, primary_key=True), @@ -88,6 +94,88 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): **kw ) + def _join_fixture_compound_expression_1(self, **kw): + return relationships.JoinCondition( + self.left, + self.right, + self.left, + self.right, + primaryjoin=(self.left.c.x + self.left.c.y) == \ + relationships.remote_foreign( + self.right.c.x * self.right.c.y + ), + **kw + ) + + def _join_fixture_compound_expression_2(self, **kw): + return relationships.JoinCondition( + self.left, + self.right, + self.left, + self.right, + primaryjoin=(self.left.c.x + self.left.c.y) == \ + relationships.foreign( + self.right.c.x * self.right.c.y + ), + **kw + ) + + def test_determine_remote_side_compound_1(self): + joincond = self._join_fixture_compound_expression_1( + support_sync=False) + eq_( + joincond.liberal_remote_columns, + set([self.right.c.x, self.right.c.y]) + ) + + def test_determine_local_remote_compound_1(self): + joincond = self._join_fixture_compound_expression_1( + support_sync=False) + eq_( + joincond.local_remote_pairs, + [] + ) + + def test_err_local_remote_compound_1(self): + assert_raises_message( + exc.ArgumentError, + r"Could not locate any simple equality " + "expressions involving foreign key " + "columns for primaryjoin join " + r"condition 'lft.x \+ lft.y = rgt.x \* rgt.y' " + "on relationship None. Ensure that referencing " + "columns are associated with a ForeignKey or " + "ForeignKeyConstraint, or are annotated in the " + r"join condition with the foreign\(\) annotation. " + "To allow comparison operators other " + "than '==', the relationship can be marked as viewonly=True.", + self._join_fixture_compound_expression_1 + ) + + def test_determine_remote_side_compound_2(self): + joincond = self._join_fixture_compound_expression_2( + support_sync=False) + eq_( + joincond.remote_side, + set([]) + ) + + def test_determine_local_remote_compound_2(self): + joincond = self._join_fixture_compound_expression_2( + support_sync=False) + eq_( + joincond.local_remote_pairs, + [] + ) + + def test_determine_direction_compound_2(self): + joincond = self._join_fixture_compound_expression_2( + support_sync=False) + is_( + joincond.direction, + ONETOMANY + ) + def test_determine_join_o2m(self): joincond = self._join_fixture_o2m() self.assert_compile( @@ -177,7 +265,7 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL): def test_determine_remote_side_m2o_composite_selfref(self): joincond = self._join_fixture_m2o_composite_selfref() eq_( - joincond.remote_side, + joincond.liberal_remote_columns, set([self.composite_selfref.c.id, self.composite_selfref.c.group_id]) ) diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 2049088aff..4031a1251c 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -365,6 +365,9 @@ class ComplexPostUpdateTest(fixtures.MappedTest): backref=backref('pages', cascade="all, delete-orphan", order_by=pages.c.pagename)), + # TODO: the remote side + lazyclause isn't + # even coming out correctly here. currentversion/version + # aren't being considered at all. 'currentversion': relationship( PageVersion, uselist=False,