]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- got m2m, local_remote_pairs, etc. working
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Feb 2012 02:16:53 +0000 (21:16 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Feb 2012 02:16:53 +0000 (21:16 -0500)
- using new traversal that returns the product of both sides
of a binary, starting to work with (a+b) == (c+d) types of joins.
primaryjoins on functions working
- annotations working, including reversing local/remote when
doing backref

lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/orm/test_rel_fn.py
test/orm/test_relationships.py
test/sql/test_generative.py
test/sql/test_selectable.py

index 9fd969e3bfd88e0bd6d1ca1e3ca912fe50c057b2..13bd18f088b91f7ad6c1de989278e45fabf043ed 100644 (file)
@@ -44,6 +44,11 @@ from sqlalchemy.orm.properties import (
      PropertyLoader,
      SynonymProperty,
      )
+from sqlalchemy.orm.relationships import (
+    foreign,
+    remote,
+    remote_foreign
+)
 from sqlalchemy.orm import mapper as mapperlib
 from sqlalchemy.orm.mapper import reconstructor, validates
 from sqlalchemy.orm import strategies
@@ -81,6 +86,7 @@ __all__ = (
     'dynamic_loader',
     'eagerload',
     'eagerload_all',
+    'foreign',
     'immediateload',
     'join',
     'joinedload',
@@ -96,6 +102,8 @@ __all__ = (
     'reconstructor',
     'relationship',
     'relation',
+    'remote',
+    'remote_foreign',
     'scoped_session',
     'sessionmaker',
     'subqueryload',
index 77da33b5ff518d7b5a16cfc0eec560211e1a812e..38237b2d4b213fdd270f0ccfb5705246e01df4bd 100644 (file)
@@ -16,7 +16,7 @@ from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \
     join_condition, _shallow_annotate
 from sqlalchemy.sql import operators, expression, visitors
 from sqlalchemy.orm import attributes, dependency, mapper, \
-    object_mapper, strategies, configure_mappers
+    object_mapper, strategies, configure_mappers, relationships
 from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \
     _orm_annotate, _orm_deannotate
 
@@ -915,13 +915,12 @@ class RelationshipProperty(StrategizedProperty):
         self._check_conflicts()
         self._process_dependent_arguments()
         self._setup_join_conditions()
-        self._extra_determine_direction()
+        self._check_cascade_settings()
         self._post_init()
         self._generate_backref()
         super(RelationshipProperty, self).do_init()
 
     def _setup_join_conditions(self):
-        import relationships
         self._join_condition = jc = relationships.JoinCondition(
                     parent_selectable=self.parent.mapped_table,
                     child_selectable=self.mapper.mapped_table,
@@ -946,8 +945,8 @@ class RelationshipProperty(StrategizedProperty):
         self.local_remote_pairs = jc.local_remote_pairs
         self.remote_side = jc.remote_columns
         self.synchronize_pairs = jc.synchronize_pairs
-        self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
         self._calculated_foreign_keys = jc.foreign_key_columns
+        self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
 
     def _check_conflicts(self):
         """Test that this relationship is legal, warn about 
@@ -1035,7 +1034,7 @@ class RelationshipProperty(StrategizedProperty):
                             (self.key, self.parent.class_)
                         )
 
-    def _extra_determine_direction(self):
+    def _check_cascade_settings(self):
         if self.cascade.delete_orphan and not self.single_parent \
             and (self.direction is MANYTOMANY or self.direction
                  is MANYTOONE):
@@ -1064,7 +1063,6 @@ class RelationshipProperty(StrategizedProperty):
                 return False
         return True
 
-
     def _generate_backref(self):
         if not self.is_primary():
             return
@@ -1083,13 +1081,15 @@ class RelationshipProperty(StrategizedProperty):
                 pj = kwargs.pop('primaryjoin', self.secondaryjoin)
                 sj = kwargs.pop('secondaryjoin', self.primaryjoin)
             else:
-                pj = kwargs.pop('primaryjoin', self.primaryjoin)
+                pj = kwargs.pop('primaryjoin', 
+                        self._join_condition.primaryjoin_reverse_remote)
                 sj = kwargs.pop('secondaryjoin', None)
                 if sj:
                     raise sa_exc.InvalidRequestError(
                         "Can't assign 'secondaryjoin' on a backref against "
                         "a non-secondary relationship."
                             )
+
             foreign_keys = kwargs.pop('foreign_keys',
                     self._user_defined_foreign_keys)
             parent = self.parent.primary_mapper()
@@ -1112,21 +1112,6 @@ class RelationshipProperty(StrategizedProperty):
             self._add_reverse_property(self.back_populates)
 
     def _post_init(self):
-        self.logger.info('%s setup primary join %s', self,
-                         self.primaryjoin)
-        self.logger.info('%s setup secondary join %s', self,
-                         self.secondaryjoin)
-        self.logger.info('%s synchronize pairs [%s]', self,
-                         ','.join('(%s => %s)' % (l, r) for (l, r) in
-                         self.synchronize_pairs))
-        self.logger.info('%s secondary synchronize pairs [%s]', self,
-                         ','.join('(%s => %s)' % (l, r) for (l, r) in
-                         self.secondary_synchronize_pairs or []))
-        self.logger.info('%s local/remote pairs [%s]', self,
-                         ','.join('(%s / %s)' % (l, r) for (l, r) in
-                         self.local_remote_pairs))
-        self.logger.info('%s relationship direction %s', self,
-                         self.direction)
         if self.uselist is None:
             self.uselist = self.direction is not MANYTOONE
         if not self.viewonly:
@@ -1141,46 +1126,6 @@ class RelationshipProperty(StrategizedProperty):
         strategy = self._get_strategy(strategies.LazyLoader)
         return strategy.use_get
 
-    def _refers_to_parent_table(self):
-        alt = self._alt_refers_to_parent_table()
-        pt = self.parent.mapped_table
-        mt = self.mapper.mapped_table
-        for c, f in self.synchronize_pairs:
-            if (
-                pt.is_derived_from(c.table) and \
-                pt.is_derived_from(f.table) and \
-                mt.is_derived_from(c.table) and \
-                mt.is_derived_from(f.table)
-            ):
-                assert alt
-                return True
-        else:
-            assert not alt
-            return False
-
-    def _alt_refers_to_parent_table(self):
-        pt = self.parent.mapped_table
-        mt = self.mapper.mapped_table
-        result = [False]
-        def visit_binary(binary):
-            c, f = binary.left, binary.right
-            if (
-                isinstance(c, expression.ColumnClause) and \
-                isinstance(f, expression.ColumnClause) and \
-                pt.is_derived_from(c.table) and \
-                pt.is_derived_from(f.table) and \
-                mt.is_derived_from(c.table) and \
-                mt.is_derived_from(f.table)
-            ):
-                result[0] = True
-
-        visitors.traverse(
-                    self.primaryjoin,
-                    {},
-                    {"binary":visit_binary}
-                )
-        return result[0]
-
     @util.memoized_property
     def _is_self_referential(self):
         return self.mapper.common_parent(self.parent)
index 723d4529563c5c81c28f9c755c90d63941c40c94..d8c2659b650b2c2c103bf4c80504024984acbbd7 100644 (file)
@@ -15,7 +15,7 @@ and `secondaryjoin` aspects of :func:`.relationship`.
 
 from sqlalchemy import sql, util, log, exc as sa_exc, schema
 from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \
-    join_condition, _shallow_annotate
+    join_condition, _shallow_annotate, visit_binary_product
 from sqlalchemy.sql import operators, expression, visitors
 from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, ONETOMANY
 
@@ -78,7 +78,34 @@ class JoinCondition(object):
         self._determine_joins()
         self._annotate_fks()
         self._annotate_remote()
+        self._annotate_local()
         self._determine_direction()
+        self._setup_pairs()
+        self._check_foreign_cols(self.primaryjoin, True)
+        if self.secondaryjoin is not None:
+            self._check_foreign_cols(self.secondaryjoin, False)
+        self._check_remote_side()
+        self._log_joins()
+
+    def _log_joins(self):
+        if self.prop is None:
+            return
+        log = self.prop.logger
+        log.info('%s setup primary join %s', self,
+                         self.primaryjoin)
+        log.info('%s setup secondary join %s', self,
+                         self.secondaryjoin)
+        log.info('%s synchronize pairs [%s]', self,
+                         ','.join('(%s => %s)' % (l, r) for (l, r) in
+                         self.synchronize_pairs))
+        log.info('%s secondary synchronize pairs [%s]', self,
+                         ','.join('(%s => %s)' % (l, r) for (l, r) in
+                         self.secondary_synchronize_pairs or []))
+        log.info('%s local/remote pairs [%s]', self,
+                         ','.join('(%s / %s)' % (l, r) for (l, r) in
+                         self.local_remote_pairs))
+        log.info('%s relationship direction %s', self,
+                         self.direction)
 
     def _determine_joins(self):
         """Determine the 'primaryjoin' and 'secondaryjoin' attributes,
@@ -128,28 +155,60 @@ class JoinCondition(object):
                     "'secondaryjoin' is needed as well."
                     % self.prop)
 
+    @util.memoized_property
+    def primaryjoin_reverse_remote(self):
+        def replace(element):
+            if "remote" in element._annotations:
+                v = element._annotations.copy()
+                del v['remote']
+                v['local'] = True
+                return element._with_annotations(v)
+            elif "local" in element._annotations:
+                v = element._annotations.copy()
+                del v['local']
+                v['remote'] = True
+                return element._with_annotations(v)
+        return visitors.replacement_traverse(self.primaryjoin, {}, replace)
+
+    def _has_annotation(self, clause, annotation):
+        for col in visitors.iterate(clause, {}):
+            if annotation in col._annotations:
+                return True
+        else:
+            return False
+
     def _annotate_fks(self):
+        if self._has_annotation(self.primaryjoin, "foreign"):
+            return
+
+        if self.consider_as_foreign_keys:
+            self._annotate_from_fk_list()
+        else:
+            self._annotate_present_fks()
+
+    def _annotate_from_fk_list(self):
+        def check_fk(col):
+            if col in self.consider_as_foreign_keys:
+                return col._annotate({"foreign":True})
+        self.primaryjoin = visitors.replacement_traverse(
+            self.primaryjoin,
+            {},
+            check_fk
+        )
+        if self.secondaryjoin is not None:
+            self.secondaryjoin = visitors.replacement_traverse(
+                self.secondaryjoin,
+                {},
+                check_fk
+            )
+
+    def _annotate_present_fks(self):
         if self.secondary is not None:
             secondarycols = util.column_set(self.secondary.c)
         else:
             secondarycols = set()
 
-        def col_is(a, b):
-            return a.compare(b)
-
         def is_foreign(a, b):
-            if self.consider_as_foreign_keys:
-                if a in self.consider_as_foreign_keys and (
-                        col_is(a, b) or 
-                        b not in self.consider_as_foreign_keys
-                    ):
-                    return a
-                elif b in self.consider_as_foreign_keys and (
-                        col_is(a, b) or 
-                        a not in self.consider_as_foreign_keys
-                    ):
-                    return b
-
             if isinstance(a, schema.Column) and \
                         isinstance(b, schema.Column):
                 if a.references(b):
@@ -163,19 +222,6 @@ class JoinCondition(object):
                 elif b in secondarycols and a not in secondarycols:
                     return b
 
-        def _annotate_fk(binary, left, right):
-            can_be_synced = self.can_be_synced_fn(left)
-            left = left._annotate({
-                #"equated":binary.operator is operators.eq,
-                "can_be_synced":can_be_synced and \
-                    binary.operator is operators.eq
-            })
-            right = right._annotate({
-                #"equated":binary.operator is operators.eq,
-                "referent":True
-            })
-            return left, right
-
         def visit_binary(binary):
             if not isinstance(binary.left, sql.ColumnElement) or \
                         not isinstance(binary.right, sql.ColumnElement):
@@ -185,20 +231,12 @@ class JoinCondition(object):
                 "foreign" not in binary.right._annotations:
                 col = is_foreign(binary.left, binary.right)
                 if col is not None:
-                    if col is binary.left:
+                    if col.compare(binary.left):
                         binary.left = binary.left._annotate(
                                             {"foreign":True})
-                    elif col is binary.right:
+                    elif col.compare(binary.right):
                         binary.right = binary.right._annotate(
                                             {"foreign":True})
-                    # TODO: when the two cols are the same.
-
-            if "foreign" in binary.left._annotations:
-                binary.left, binary.right = _annotate_fk(
-                                binary, binary.left, binary.right)
-            if "foreign" in binary.right._annotations:
-                binary.right, binary.left = _annotate_fk(
-                            binary, binary.right, binary.left)
 
         self.primaryjoin = visitors.cloned_traverse(
             self.primaryjoin,
@@ -211,11 +249,6 @@ class JoinCondition(object):
                 {},
                 {"binary":visit_binary}
             )
-        self._check_foreign_cols(
-                        self.primaryjoin, True)
-        if self.secondaryjoin is not None:
-            self._check_foreign_cols(
-                        self.secondaryjoin, False)
 
     def _refers_to_parent_table(self):
         pt = self.parent_selectable
@@ -241,18 +274,14 @@ class JoinCondition(object):
         return result[0]
 
     def _annotate_remote(self):
-        parentcols = util.column_set(self.parent_selectable.c)
+        if self._has_annotation(self.primaryjoin, "remote"):
+            return
 
-        for col in visitors.iterate(self.primaryjoin, {}):
-            if "remote" in col._annotations:
-                has_remote_annotations = True
-                break
-        else:
-            has_remote_annotations = False
+        parentcols = util.column_set(self.parent_selectable.c)
 
         def _annotate_selfref(fn):
             def visit_binary(binary):
-                equated = binary.left is binary.right
+                equated = binary.left.compare(binary.right)
                 if isinstance(binary.left, sql.ColumnElement) and \
                     isinstance(binary.right, sql.ColumnElement):
                     # assume one to many - FKs are "remote"
@@ -267,44 +296,72 @@ class JoinCondition(object):
                                     self.primaryjoin, {}, 
                                     {"binary":visit_binary})
 
-        if not has_remote_annotations:
+        if self.secondary is not None:
+            def repl(element):
+                if self.secondary.c.contains_column(element):
+                    return element._annotate({"remote":True})
+            self.primaryjoin = visitors.replacement_traverse(
+                                        self.primaryjoin, {},  repl)
+            self.secondaryjoin = visitors.replacement_traverse(
+                                        self.secondaryjoin, {}, repl)
+        elif self._local_remote_pairs or self._remote_side:
+
             if self._local_remote_pairs:
-                raise NotImplementedError()
-            elif self._remote_side:
-                if self._refers_to_parent_table():
-                    _annotate_selfref(lambda col:col in self._remote_side)
-                else:
-                    def repl(element):
-                        if element in self._remote_side:
-                            return element._annotate({"remote":True})
-                    self.primaryjoin = visitors.replacement_traverse(
-                                                self.primaryjoin, {},  repl)
-            elif self.secondary is not None:
-                def repl(element):
-                    if self.secondary.c.contains_column(element):
-                        return element._annotate({"remote":True})
-                self.primaryjoin = visitors.replacement_traverse(
-                                            self.primaryjoin, {},  repl)
-                self.secondaryjoin = visitors.replacement_traverse(
-                                            self.secondaryjoin, {}, repl)
-            elif self._refers_to_parent_table():
-                _annotate_selfref(lambda col:"foreign" in col._annotations)
+                if self._remote_side:
+                    raise sa_exc.ArgumentError(
+                            "remote_side argument is redundant "
+                            "against more detailed _local_remote_side "
+                            "argument.")
+
+                remote_side = [r for (l, r) in self._local_remote_pairs]
+            else:
+                remote_side = self._remote_side
+
+            if self._refers_to_parent_table():
+                _annotate_selfref(lambda col:col in remote_side)
             else:
                 def repl(element):
-                    if self.child_selectable.c.contains_column(element):
+                    if element in remote_side:
                         return element._annotate({"remote":True})
-
                 self.primaryjoin = visitors.replacement_traverse(
                                             self.primaryjoin, {},  repl)
+        elif self._refers_to_parent_table():
+            _annotate_selfref(lambda col:"foreign" in col._annotations)
+        else:
+            def repl(element):
+                if self.child_selectable.c.contains_column(element):
+                    return element._annotate({"remote":True})
+
+            self.primaryjoin = visitors.replacement_traverse(
+                                        self.primaryjoin, {},  repl)
+
+    def _annotate_local(self):
+        if self._has_annotation(self.primaryjoin, "local"):
+            return
+
+        parentcols = util.column_set(self.parent_selectable.c)
+
+        if self._local_remote_pairs:
+            local_side = util.column_set([l for (l, r) 
+                                in self._local_remote_pairs])
+        else:
+            local_side = util.column_set(self.parent_selectable.c)
 
         def locals_(elem):
             if "remote" not in elem._annotations and \
-                elem in parentcols:
+                elem in local_side:
                 return elem._annotate({"local":True})
         self.primaryjoin = visitors.replacement_traverse(
                 self.primaryjoin, {}, locals_
             )
 
+    def _check_remote_side(self):
+        if not self.local_remote_pairs:
+            raise sa_exc.ArgumentError('Relationship %s could '
+                    'not determine any local/remote column '
+                    'pairs from remote side argument %r'
+                    % (self.prop, self._remote_side))
+
     def _check_foreign_cols(self, join_condition, primary):
         """Check the foreign key columns collected and emit error messages."""
 
@@ -315,11 +372,10 @@ class JoinCondition(object):
 
         has_foreign = bool(foreign_cols)
 
-        if self.support_sync:
-            for col in foreign_cols:
-                if col._annotations.get("can_be_synced"):
-                    can_sync = True
-                    break
+        if primary:
+            can_sync = bool(self.synchronize_pairs)
+        else:
+            can_sync = bool(self.secondary_synchronize_pairs)
 
         if self.support_sync and can_sync or \
             (not self.support_sync and has_foreign):
@@ -407,6 +463,44 @@ class JoinCondition(object):
                         "key columns are present in neither the parent "
                         "nor the child's mapped tables" % self.prop)
 
+    def _setup_pairs(self):
+        sync_pairs = []
+        lrp = util.OrderedSet([])
+        secondary_sync_pairs = []
+
+        def go(joincond, collection):
+            def visit_binary(binary, left, right):
+                if "remote" in right._annotations and \
+                    "remote" not in left._annotations and \
+                    self.can_be_synced_fn(left):
+                    lrp.add((left, right))
+                elif "remote" in left._annotations and \
+                    "remote" not in right._annotations and \
+                    self.can_be_synced_fn(right):
+                    lrp.add((right, left))
+                if binary.operator is operators.eq:
+                    # and \
+                    #binary.left.compare(left) and \
+                    #binary.right.compare(right):
+                    if "foreign" in right._annotations:
+                        collection.append((left, right))
+                    elif "foreign" in left._annotations:
+                        collection.append((right, left))
+            visit_binary_product(visit_binary, joincond)
+
+        for joincond, collection in [
+            (self.primaryjoin, sync_pairs),
+            (self.secondaryjoin, secondary_sync_pairs)
+        ]:
+            if joincond is None:
+                continue
+            go(joincond, collection)
+
+        self.local_remote_pairs = list(lrp)
+        self.synchronize_pairs = sync_pairs
+        self.secondary_synchronize_pairs = secondary_sync_pairs
+
+
     @util.memoized_property
     def remote_columns(self):
         return self._gather_join_annotations("remote")
@@ -415,38 +509,6 @@ class JoinCondition(object):
     def local_columns(self):
         return self._gather_join_annotations("local")
 
-    @util.memoized_property
-    def synchronize_pairs(self):
-        parentcols = util.column_set(self.parent_selectable.c)
-        targetcols = util.column_set(self.child_selectable.c)
-        result = []
-        for l, r in self.local_remote_pairs:
-            if self.secondary is not None:
-                if "foreign" in r._annotations and \
-                    l in parentcols:
-                    result.append((l, r))
-            elif "foreign" in r._annotations and \
-                "can_be_synced" in r._annotations:
-                result.append((l, r))
-            elif "foreign" in l._annotations and \
-                "can_be_synced" in l._annotations:
-                result.append((r, l))
-        return result
-
-    @util.memoized_property
-    def secondary_synchronize_pairs(self):
-        parentcols = util.column_set(self.parent_selectable.c)
-        targetcols = util.column_set(self.child_selectable.c)
-        result = []
-        if self.secondary is None:
-            return result
-
-        for l, r in self.local_remote_pairs:
-            if "foreign" in l._annotations and \
-                r in targetcols:
-                result.append((l, r))
-        return result
-
     @util.memoized_property
     def foreign_key_columns(self):
         return self._gather_join_annotations("foreign")
@@ -470,24 +532,6 @@ class JoinCondition(object):
             if annotation.issubset(col._annotations)
         ])
 
-    @util.memoized_property
-    def local_remote_pairs(self):
-        lrp = util.OrderedSet()
-        def visit_binary(binary):
-            if "remote" in binary.right._annotations and \
-                "remote" not in binary.left._annotations and \
-                isinstance(binary.left, expression.ColumnClause) and \
-                self.can_be_synced_fn(binary.left):
-                lrp.add((binary.left, binary.right))
-            elif "remote" in binary.left._annotations and \
-                "remote" not in binary.right._annotations and \
-                isinstance(binary.right, expression.ColumnClause) and \
-                self.can_be_synced_fn(binary.right):
-                lrp.add((binary.right, binary.left))
-        visitors.traverse(self.primaryjoin, {}, {"binary":visit_binary})
-        if self.secondaryjoin is not None:
-            visitors.traverse(self.secondaryjoin, {}, {"binary":visit_binary})
-        return list(lrp)
 
     def join_targets(self, source_selectable, 
                             dest_selectable,
@@ -604,147 +648,6 @@ def _create_lazy_clause(cls, prop, reverse_direction=False):
     return lazywhere, bind_to_col, equated_columns
 
 
-def _determine_synchronize_pairs(self):
-    """Resolve 'primary'/foreign' column pairs from the primaryjoin
-    and secondaryjoin arguments.
-
-    """
-    if self.local_remote_pairs:
-        if not self._user_defined_foreign_keys:
-            raise sa_exc.ArgumentError(
-                    "foreign_keys argument is "
-                    "required with _local_remote_pairs argument")
-        self.synchronize_pairs = []
-        for l, r in self.local_remote_pairs:
-            if r in self._user_defined_foreign_keys:
-                self.synchronize_pairs.append((l, r))
-            elif l in self._user_defined_foreign_keys:
-                self.synchronize_pairs.append((r, l))
-    else:
-        self.synchronize_pairs = self._sync_pairs_from_join(
-                                            self.primaryjoin, 
-                                            True)
-
-    self._calculated_foreign_keys = util.column_set(
-                            r for (l, r) in
-                            self.synchronize_pairs)
-
-    if self.secondaryjoin is not None:
-        self.secondary_synchronize_pairs = self._sync_pairs_from_join(
-                                                    self.secondaryjoin, 
-                                                    False)
-        self._calculated_foreign_keys.update(
-                            r for (l, r) in
-                            self.secondary_synchronize_pairs)
-    else:
-        self.secondary_synchronize_pairs = None
-
-
-def _determine_local_remote_pairs(self):
-    """Determine pairs of columns representing "local" to 
-    "remote", where "local" columns are on the parent mapper,
-    "remote" are on the target mapper.
-
-    These pairs are used on the load side only to generate
-    lazy loading clauses.
-
-    """
-    if not self.local_remote_pairs and not self.remote_side:
-        # the most common, trivial case.   Derive 
-        # local/remote pairs from the synchronize pairs.
-        eq_pairs = util.unique_list(
-                        self.synchronize_pairs + 
-                        (self.secondary_synchronize_pairs or []))
-        if self.direction is MANYTOONE:
-            self.local_remote_pairs = [(r, l) for l, r in eq_pairs]
-        else:
-            self.local_remote_pairs = eq_pairs
-
-    # "remote_side" specified, derive from the primaryjoin
-    # plus remote_side, similarly to how synchronize_pairs
-    # were determined.
-    elif self.remote_side:
-        if self.local_remote_pairs:
-            raise sa_exc.ArgumentError('remote_side argument is '
-                'redundant against more detailed '
-                '_local_remote_side argument.')
-        if self.direction is MANYTOONE:
-            self.local_remote_pairs = [(r, l) for (l, r) in
-                    criterion_as_pairs(self.primaryjoin,
-                    consider_as_referenced_keys=self.remote_side,
-                    any_operator=True)]
-
-        else:
-            self.local_remote_pairs = \
-                criterion_as_pairs(self.primaryjoin,
-                    consider_as_foreign_keys=self.remote_side,
-                    any_operator=True)
-        if not self.local_remote_pairs:
-            raise sa_exc.ArgumentError('Relationship %s could '
-                    'not determine any local/remote column '
-                    'pairs from remote side argument %r'
-                    % (self, self.remote_side))
-    # else local_remote_pairs were sent explcitly via
-    # ._local_remote_pairs.
-
-    # create local_side/remote_side accessors
-    self.local_side = util.ordered_column_set(
-                        l for l, r in self.local_remote_pairs)
-    self.remote_side = util.ordered_column_set(
-                        r for l, r in self.local_remote_pairs)
-
-    # check that the non-foreign key column in the local/remote
-    # collection is mapped.  The foreign key
-    # which the individual mapped column references directly may
-    # itself be in a non-mapped table; see
-    # test.orm.test_relationships.ViewOnlyComplexJoin.test_basic
-    # for an example of this.
-    if self.direction is ONETOMANY:
-        for col in self.local_side:
-            if not self._columns_are_mapped(col):
-                raise sa_exc.ArgumentError(
-                        "Local column '%s' is not "
-                        "part of mapping %s.  Specify remote_side "
-                        "argument to indicate which column lazy join "
-                        "condition should compare against." % (col,
-                        self.parent))
-    elif self.direction is MANYTOONE:
-        for col in self.remote_side:
-            if not self._columns_are_mapped(col):
-                raise sa_exc.ArgumentError(
-                        "Remote column '%s' is not "
-                        "part of mapping %s. Specify remote_side "
-                        "argument to indicate which column lazy join "
-                        "condition should bind." % (col, self.mapper))
-
-    count = [0]
-    def clone(elem):
-        if set(['local', 'remote']).intersection(elem._annotations):
-            return None
-        elif elem in self.local_side and elem in self.remote_side:
-            # TODO: OK this still sucks.  this is basically,
-            # refuse, refuse, refuse the temptation to guess!
-            # but crap we really have to guess don't we.  we 
-            # might want to traverse here with cloned_traverse
-            # so we can see the binary exprs and do it at that 
-            # level....
-            if count[0] % 2 == 0:
-                elem = elem._annotate({'local':True})
-            else:
-                elem = elem._annotate({'remote':True})
-            count[0] += 1
-        elif elem in self.local_side:
-            elem = elem._annotate({'local':True})
-        elif elem in self.remote_side:
-            elem = elem._annotate({'remote':True})
-        else:
-            elem = None
-        return elem
-
-    self.primaryjoin = visitors.replacement_traverse(
-                            self.primaryjoin, {}, clone
-                        )
-
 
 def _criterion_exists(self, criterion=None, **kwargs):
     if getattr(self, '_of_type', None):
index 5f4b182d08e2f8b19fb53b32361d6b475cd28552..32023428190488a7f6d0d586ed79855a84d7bcdc 100644 (file)
@@ -785,6 +785,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
             leftmost_mapper, leftmost_prop = \
                                     subq_mapper, \
                                     subq_mapper._props[subq_path[1]]
+        # TODO: local cols might not be unique here
         leftmost_cols, remote_cols = self._local_remote_columns(leftmost_prop)
 
         leftmost_attr = [
@@ -846,6 +847,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
             # self.parent is more specific than subq_path[-2]
             parent_alias = mapperutil.AliasedClass(self.parent)
 
+        # TODO: local cols might not be unique here
         local_cols, remote_cols = \
                         self._local_remote_columns(self.parent_property)
 
@@ -885,6 +887,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
         if prop.secondary is None:
             return zip(*prop.local_remote_pairs)
         else:
+            # TODO: this isn't going to work for readonly....
             return \
                 [p[0] for p in prop.synchronize_pairs],\
                 [
@@ -930,6 +933,7 @@ class SubqueryLoader(AbstractRelationshipLoader):
         if ('subquery', reduced_path) not in context.attributes:
             return None, None, None
 
+        # TODO: local_cols might not be unique here
         local_cols, remote_cols = self._local_remote_columns(self.parent_property)
 
         q = context.attributes[('subquery', reduced_path)]
index 0cd5b059486be9c720cdd468f1aa88f689594066..f17f675f410d2d52731830aaf8304a372dfb4702 100644 (file)
@@ -366,7 +366,18 @@ def _orm_annotate(element, exclude=None):
     """
     return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude)
 
-_orm_deannotate = sql_util._deep_deannotate
+def _orm_deannotate(element):
+    """Remove annotations that link a column to a particular mapping.
+    
+    Note this doesn't affect "remote" and "foreign" annotations
+    passed by the :func:`.orm.foreign` and :func:`.orm.remote`
+    annotators.
+    
+    """
+
+    return sql_util._deep_deannotate(element, 
+                values=("_orm_adapt", "parententity")
+            )
 
 class _ORMJoin(expression.Join):
     """Extend Join to support ORM constructs as input."""
index 72099a5f5eb0361f213eddfc9f09b8b11f3a4c41..ebf4de9a2ef5e1a461f5e8712cc2ad9b2f54018b 100644 (file)
@@ -1576,18 +1576,30 @@ class ClauseElement(Visitable):
             return id(self)
 
     def _annotate(self, values):
-        """return a copy of this ClauseElement with the given annotations
-        dictionary.
+        """return a copy of this ClauseElement with annotations
+        updated by the given dictionary.
 
         """
         return sqlutil.Annotated(self, values)
 
-    def _deannotate(self):
-        """return a copy of this ClauseElement with an empty annotations
-        dictionary.
+    def _with_annotations(self, values):
+        """return a copy of this ClauseElement with annotations
+        replaced by the given dictionary.
 
         """
-        return self._clone()
+        return sqlutil.Annotated(self, values)
+
+    def _deannotate(self, values=None):
+        """return a copy of this :class:`.ClauseElement` with annotations
+        removed.
+        
+        :param values: optional tuple of individual values
+         to remove.
+
+        """
+        # since we have no annotations we return
+        # self
+        return self
 
     def unique_params(self, *optionaldict, **kwargs):
         """Return a copy with :func:`bindparam()` elments replaced.
index f0509c16f73a987535a8242ceff77a2e62c35608..9a45a577750e06bf6c661e3559d26c26f92c468a 100644 (file)
@@ -62,6 +62,61 @@ def find_join_source(clauses, join_to):
     else:
         return None, None
 
+
+def visit_binary_product(fn, expr):
+    """Produce a traversal of the given expression, delivering
+    column comparisons to the given function.
+    
+    The function is of the form::
+    
+        def my_fn(binary, left, right)
+    
+    For each binary expression located which has a 
+    comparison operator, the product of "left" and
+    "right" will be delivered to that function,
+    in terms of that binary.
+    
+    Hence an expression like::
+    
+        and_(
+            (a + b) == q + func.sum(e + f),
+            j == r
+        )
+    
+    would have the traversal::
+    
+        a <eq> q
+        a <eq> e
+        a <eq> f
+        b <eq> q
+        b <eq> e
+        b <eq> f
+        j <eq> r
+
+    That is, every combination of "left" and
+    "right" that doesn't further contain
+    a binary comparison is passed as pairs.
+    
+    """
+    stack = []
+    def visit(element):
+        if element.__visit_name__ == 'binary' and \
+            operators.is_comparison(element.operator):
+            stack.insert(0, element)
+            for l in visit(element.left):
+                for r in visit(element.right):
+                    fn(stack[0], l, r)
+            stack.pop(0)
+            for elem in element.get_children():
+                visit(elem)
+        else:
+            if isinstance(element, expression.ColumnClause):
+                yield element
+            for elem in element.get_children():
+                for e in visit(elem):
+                    yield e
+    list(visit(expr))
+
 def find_tables(clause, check_columns=False, 
                 include_aliases=False, include_joins=False, 
                 include_selects=False, include_crud=False):
@@ -357,13 +412,22 @@ class Annotated(object):
     def _annotate(self, values):
         _values = self._annotations.copy()
         _values.update(values)
+        return self._with_annotations(_values)
+
+    def _with_annotations(self, values):
         clone = self.__class__.__new__(self.__class__)
         clone.__dict__ = self.__dict__.copy()
-        clone._annotations = _values
+        clone._annotations = values
         return clone
 
-    def _deannotate(self):
-        return self.__element
+    def _deannotate(self, values=None):
+        if values is None:
+            return self.__element
+        else:
+            _values = self._annotations.copy()
+            for v in values:
+                _values.pop(v, None)
+            return self._with_annotations(_values)
 
     def _compiler_dispatch(self, visitor, **kw):
         return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
@@ -426,11 +490,11 @@ def _deep_annotate(element, annotations, exclude=None):
         element = clone(element)
     return element
 
-def _deep_deannotate(element):
-    """Deep copy the given element, removing all annotations."""
+def _deep_deannotate(element, values=None):
+    """Deep copy the given element, removing annotations."""
 
     def clone(elem):
-        elem = elem._deannotate()
+        elem = elem._deannotate(values=values)
         elem._copy_internals(clone=clone)
         return elem
 
index d3d346bbaf51589e4ef4a1d23a2c70188e25bbe2..346cb90c10dc0805c66b135cd27f5206dcb848d8 100644 (file)
@@ -6,9 +6,8 @@ from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, \
     select, ForeignKeyConstraint, exc
 from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
 
-class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
-    __dialect__ = 'default'
 
+class _JoinFixtures(object):
     @classmethod
     def setup_class(cls):
         m = MetaData()
@@ -36,6 +35,28 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
                 ['composite_selfref.id', 'composite_selfref.group_id']
             )
         )
+        cls.m2mleft = Table('m2mlft', m,
+            Column('id', Integer, primary_key=True),
+        )
+        cls.m2mright = Table('m2mrgt', m,
+            Column('id', Integer, primary_key=True),
+        )
+        cls.m2msecondary = Table('m2msecondary', m,
+            Column('lid', Integer, ForeignKey('m2mlft.id'), primary_key=True),
+            Column('rid', Integer, ForeignKey('m2mrgt.id'), primary_key=True),
+        )
+
+    def _join_fixture_m2m_selfref(self, **kw):
+        return relationships.JoinCondition(
+                    self.m2mleft, 
+                    self.m2mright, 
+                    self.m2mleft, 
+                    self.m2mright,
+                    secondary=self.m2msecondary,
+                    primaryjoin=self.m2mleft.c.id==self.m2msecondary.c.lid,
+                    secondaryjoin=self.m2mright.c.id==self.m2msecondary.c.rid,
+                    **kw
+                )
 
     def _join_fixture_o2m(self, **kw):
         return relationships.JoinCondition(
@@ -120,6 +141,7 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
             **kw
         )
 
+class ColumnCollectionsTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
     def test_determine_remote_columns_compound_1(self):
         joincond = self._join_fixture_compound_expression_1(
                                 support_sync=False)
@@ -133,7 +155,25 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
                                 support_sync=False)
         eq_(
             joincond.local_remote_pairs,
-            []
+            [
+                (self.left.c.x, self.right.c.x), 
+                (self.left.c.x, self.right.c.y), 
+                (self.left.c.y, self.right.c.x),
+                (self.left.c.y, self.right.c.y)
+            ]
+        )
+
+    def test_determine_local_remote_compound_2(self):
+        joincond = self._join_fixture_compound_expression_2(
+                                support_sync=False)
+        eq_(
+            joincond.local_remote_pairs,
+            [
+                (self.left.c.x, self.right.c.x), 
+                (self.left.c.x, self.right.c.y), 
+                (self.left.c.y, self.right.c.x),
+                (self.left.c.y, self.right.c.y)
+            ]
         )
 
     def test_err_local_remote_compound_1(self):
@@ -160,14 +200,71 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
             set([self.right.c.x, self.right.c.y])
         )
 
-    def test_determine_local_remote_compound_2(self):
-        joincond = self._join_fixture_compound_expression_2(
-                                support_sync=False)
+
+    def test_determine_remote_columns_o2m(self):
+        joincond = self._join_fixture_o2m()
+        eq_(
+            joincond.remote_columns,
+            set([self.right.c.lid])
+        )
+
+    def test_determine_remote_columns_o2m_selfref(self):
+        joincond = self._join_fixture_o2m_selfref()
+        eq_(
+            joincond.remote_columns,
+            set([self.selfref.c.sid])
+        )
+
+    def test_determine_remote_columns_o2m_composite_selfref(self):
+        joincond = self._join_fixture_o2m_composite_selfref()
+        eq_(
+            joincond.remote_columns,
+            set([self.composite_selfref.c.parent_id, 
+                self.composite_selfref.c.group_id])
+        )
+
+    def test_determine_remote_columns_m2o_composite_selfref(self):
+        joincond = self._join_fixture_m2o_composite_selfref()
+        eq_(
+            joincond.remote_columns,
+            set([self.composite_selfref.c.id, 
+                self.composite_selfref.c.group_id])
+        )
+
+    def test_determine_remote_columns_m2o(self):
+        joincond = self._join_fixture_m2o()
+        eq_(
+            joincond.remote_columns,
+            set([self.left.c.id])
+        )
+
+    def test_determine_local_remote_pairs_o2m(self):
+        joincond = self._join_fixture_o2m()
         eq_(
             joincond.local_remote_pairs,
-            []
+            [(self.left.c.id, self.right.c.lid)]
+        )
+
+    def test_determine_synchronize_pairs_m2m_selfref(self):
+        joincond = self._join_fixture_m2m_selfref()
+        eq_(
+            joincond.synchronize_pairs,
+            [(self.m2mleft.c.id, self.m2msecondary.c.lid)]
+        )
+        eq_(
+            joincond.secondary_synchronize_pairs,
+            [(self.m2mright.c.id, self.m2msecondary.c.rid)]
         )
 
+    def test_determine_remote_columns_m2o_selfref(self):
+        joincond = self._join_fixture_m2o_selfref()
+        eq_(
+            joincond.remote_columns,
+            set([self.selfref.c.id])
+        )
+
+
+class DirectionTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
     def test_determine_direction_compound_2(self):
         joincond = self._join_fixture_compound_expression_2(
                                 support_sync=False)
@@ -176,60 +273,46 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
             ONETOMANY
         )
 
-    def test_determine_join_o2m(self):
-        joincond = self._join_fixture_o2m()
-        self.assert_compile(
-                joincond.primaryjoin,
-                "lft.id = rgt.lid"
-        )
-
     def test_determine_direction_o2m(self):
         joincond = self._join_fixture_o2m()
         is_(joincond.direction, ONETOMANY)
 
-    def test_determine_remote_columns_o2m(self):
-        joincond = self._join_fixture_o2m()
-        eq_(
-            joincond.remote_columns,
-            set([self.right.c.lid])
-        )
-
-    def test_determine_join_o2m_selfref(self):
-        joincond = self._join_fixture_o2m_selfref()
-        self.assert_compile(
-                joincond.primaryjoin,
-                "selfref.id = selfref.sid"
-        )
-
     def test_determine_direction_o2m_selfref(self):
         joincond = self._join_fixture_o2m_selfref()
         is_(joincond.direction, ONETOMANY)
 
-    def test_determine_remote_columns_o2m_selfref(self):
-        joincond = self._join_fixture_o2m_selfref()
-        eq_(
-            joincond.remote_columns,
-            set([self.selfref.c.sid])
-        )
+    def test_determine_direction_m2o_selfref(self):
+        joincond = self._join_fixture_m2o_selfref()
+        is_(joincond.direction, MANYTOONE)
 
-    def test_join_targets_o2m_selfref(self):
-        joincond = self._join_fixture_o2m_selfref()
-        left = select([joincond.parent_selectable]).alias('pj')
-        pj, sj, sec, adapter = joincond.join_targets(
-                                    left, 
-                                    joincond.child_selectable, 
-                                    True)
+    def test_determine_direction_o2m_composite_selfref(self):
+        joincond = self._join_fixture_o2m_composite_selfref()
+        is_(joincond.direction, ONETOMANY)
+
+    def test_determine_direction_m2o_composite_selfref(self):
+        joincond = self._join_fixture_m2o_composite_selfref()
+        is_(joincond.direction, MANYTOONE)
+
+    def test_determine_direction_m2o(self):
+        joincond = self._join_fixture_m2o()
+        is_(joincond.direction, MANYTOONE)
+
+
+class DetermineJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = 'default'
+
+    def test_determine_join_o2m(self):
+        joincond = self._join_fixture_o2m()
         self.assert_compile(
-            pj, "pj.id = selfref.sid"
+                joincond.primaryjoin,
+                "lft.id = rgt.lid"
         )
 
-        right = select([joincond.child_selectable]).alias('pj')
-        pj, sj, sec, adapter = joincond.join_targets(
-                                    joincond.parent_selectable, 
-                                    right, 
-                                    True)
+    def test_determine_join_o2m_selfref(self):
+        joincond = self._join_fixture_o2m_selfref()
         self.assert_compile(
-            pj, "selfref.id = pj.sid"
+                joincond.primaryjoin,
+                "selfref.id = selfref.sid"
         )
 
     def test_determine_join_m2o_selfref(self):
@@ -239,17 +322,6 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
                 "selfref.id = selfref.sid"
         )
 
-    def test_determine_direction_m2o_selfref(self):
-        joincond = self._join_fixture_m2o_selfref()
-        is_(joincond.direction, MANYTOONE)
-
-    def test_determine_remote_columns_m2o_selfref(self):
-        joincond = self._join_fixture_m2o_selfref()
-        eq_(
-            joincond.remote_columns,
-            set([self.selfref.c.id])
-        )
-
     def test_determine_join_o2m_composite_selfref(self):
         joincond = self._join_fixture_o2m_composite_selfref()
         self.assert_compile(
@@ -258,18 +330,6 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
                 "AND composite_selfref.id = composite_selfref.parent_id"
         )
 
-    def test_determine_direction_o2m_composite_selfref(self):
-        joincond = self._join_fixture_o2m_composite_selfref()
-        is_(joincond.direction, ONETOMANY)
-
-    def test_determine_remote_columns_o2m_composite_selfref(self):
-        joincond = self._join_fixture_o2m_composite_selfref()
-        eq_(
-            joincond.remote_columns,
-            set([self.composite_selfref.c.parent_id, 
-                self.composite_selfref.c.group_id])
-        )
-
     def test_determine_join_m2o_composite_selfref(self):
         joincond = self._join_fixture_m2o_composite_selfref()
         self.assert_compile(
@@ -278,17 +338,7 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
                 "AND composite_selfref.id = composite_selfref.parent_id"
         )
 
-    def test_determine_direction_m2o_composite_selfref(self):
-        joincond = self._join_fixture_m2o_composite_selfref()
-        is_(joincond.direction, MANYTOONE)
 
-    def test_determine_remote_columns_m2o_composite_selfref(self):
-        joincond = self._join_fixture_m2o_composite_selfref()
-        eq_(
-            joincond.remote_columns,
-            set([self.composite_selfref.c.id, 
-                self.composite_selfref.c.group_id])
-        )
 
     def test_determine_join_m2o(self):
         joincond = self._join_fixture_m2o()
@@ -297,24 +347,30 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
                 "lft.id = rgt.lid"
         )
 
-    def test_determine_direction_m2o(self):
-        joincond = self._join_fixture_m2o()
-        is_(joincond.direction, MANYTOONE)
+class AdaptedJoinTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = 'default'
 
-    def test_determine_remote_columns_m2o(self):
-        joincond = self._join_fixture_m2o()
-        eq_(
-            joincond.remote_columns,
-            set([self.left.c.id])
+    def test_join_targets_o2m_selfref(self):
+        joincond = self._join_fixture_o2m_selfref()
+        left = select([joincond.parent_selectable]).alias('pj')
+        pj, sj, sec, adapter = joincond.join_targets(
+                                    left, 
+                                    joincond.child_selectable, 
+                                    True)
+        self.assert_compile(
+            pj, "pj.id = selfref.sid"
         )
 
-    def test_determine_local_remote_pairs_o2m(self):
-        joincond = self._join_fixture_o2m()
-        eq_(
-            joincond.local_remote_pairs,
-            [(self.left.c.id, self.right.c.lid)]
+        right = select([joincond.child_selectable]).alias('pj')
+        pj, sj, sec, adapter = joincond.join_targets(
+                                    joincond.parent_selectable, 
+                                    right, 
+                                    True)
+        self.assert_compile(
+            pj, "selfref.id = pj.sid"
         )
 
+
     def test_join_targets_o2m_plain(self):
         joincond = self._join_fixture_o2m()
         pj, sj, sec, adapter = joincond.join_targets(
@@ -347,6 +403,8 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
             pj, "lft.id = pj.lid"
         )
 
+class LazyClauseTest(_JoinFixtures, fixtures.TestBase, AssertsCompiledSQL):
+
     def _test_lazy_clause_o2m(self):
         joincond = self._join_fixture_o2m()
         self.assert_compile(
index 0a02cbf9a3aec831043255bfa262869f6eade617..d2dcbe3128a3a32731bb1b0af0a1689a5a524009 100644 (file)
@@ -7,8 +7,9 @@ from test.lib.schema import Table, Column
 from sqlalchemy.orm import mapper, relationship, relation, \
                     backref, create_session, configure_mappers, \
                     clear_mappers, sessionmaker, attributes,\
-                    Session, composite, column_property
-from test.lib.testing import eq_, startswith_, AssertsCompiledSQL
+                    Session, composite, column_property, foreign
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
+from test.lib.testing import eq_, startswith_, AssertsCompiledSQL, is_
 from test.lib import fixtures
 from test.orm import _fixtures
 
@@ -141,12 +142,12 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
         Table('company_t', metadata,
               Column('company_id', Integer, primary_key=True, 
                                 test_needs_autoincrement=True),
-              Column('name', sa.Unicode(30)))
+              Column('name', String(30)))
 
         Table('employee_t', metadata,
               Column('company_id', Integer, primary_key=True),
               Column('emp_id', Integer, primary_key=True),
-              Column('name', sa.Unicode(30)),
+              Column('name', String(30)),
               Column('reports_to_id', Integer),
               sa.ForeignKeyConstraint(
                   ['company_id'],
@@ -158,7 +159,8 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
     @classmethod
     def setup_classes(cls):
         class Company(cls.Basic):
-            pass
+            def __init__(self, name):
+                self.name = name
 
         class Employee(cls.Basic):
             def __init__(self, name, company, emp_id, reports_to=None):
@@ -248,11 +250,16 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
         self._test()
 
     def _test(self):
+        sess = Session()
+        self._setup_data(sess)
+        self._test_lazy_relations(sess)
+        self._test_join_aliasing(sess)
+
+    def _setup_data(self, sess):
         Employee, Company = self.classes.Employee, self.classes.Company
 
-        sess = create_session()
-        c1 = Company()
-        c2 = Company()
+        c1 = Company('c1')
+        c2 = Company('c2')
 
         e1 = Employee(u'emp1', c1, 1)
         e2 = Employee(u'emp2', c1, 2, e1)
@@ -263,10 +270,17 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
         e7 = Employee(u'emp7', c2, 3, e5)
 
         sess.add_all((c1, c2))
-        sess.flush()
-        sess.expunge_all()
+        sess.commit()
+        sess.close()
+
+    def _test_lazy_relations(self, sess):
+        Employee, Company = self.classes.Employee, self.classes.Company
+
+        c1 = sess.query(Company).filter_by(name='c1').one()
+        c2 = sess.query(Company).filter_by(name='c2').one()
+        e1 = sess.query(Employee).filter_by(name='emp1').one()
+        e5 = sess.query(Employee).filter_by(name='emp5').one()
 
-        test_c1 = sess.query(Company).get(c1.company_id)
         test_e1 = sess.query(Employee).get([c1.company_id, e1.emp_id])
         assert test_e1.name == 'emp1', test_e1.name
         test_e5 = sess.query(Employee).get([c2.company_id, e5.emp_id])
@@ -277,6 +291,16 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
         assert sess.query(Employee).\
                 get([c2.company_id, 3]).reports_to.name == 'emp5'
 
+    def _test_join_aliasing(self, sess):
+        Employee, Company = self.classes.Employee, self.classes.Company
+        eq_(
+            [n for n, in sess.query(Employee.name).\
+                    join(Employee.reports_to, aliased=True).\
+                    filter_by(name='emp5').\
+                    reset_joinpoint().\
+                    order_by(Employee.name)],
+            ['emp6', 'emp7']
+        )
 
 class CompositeJoinPartialFK(fixtures.MappedTest, AssertsCompiledSQL):
     __dialect__ = 'default'
@@ -839,7 +863,6 @@ class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest):
     def test_mapping(self):
         Subscriber, Address = self.classes.Subscriber, self.classes.Address
 
-        from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
         sess = create_session()
         assert Subscriber.addresses.property.direction is ONETOMANY
         assert Address.customer.property.direction is MANYTOONE
@@ -1733,21 +1756,45 @@ class ExplicitLocalRemoteTest(fixtures.MappedTest):
         class T2(cls.Comparable):
             pass
 
-    def test_onetomany_funcfk(self):
+    def test_onetomany_funcfk_oldstyle(self):
         T2, T1, t2, t1 = (self.classes.T2,
                                 self.classes.T1,
                                 self.tables.t2,
                                 self.tables.t1)
 
-        # use a function within join condition.  but specifying
-        # local_remote_pairs overrides all parsing of the join condition.
+        # old _local_remote_pairs
         mapper(T1, t1, properties={
             't2s':relationship(T2,
                            primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id),
                            _local_remote_pairs=[(t1.c.id, t2.c.t1id)],
-                           foreign_keys=[t2.c.t1id])})
+                           foreign_keys=[t2.c.t1id]
+                           )
+                          })
+        mapper(T2, t2)
+        self._test_onetomany()
+
+    def test_onetomany_funcfk_annotated(self):
+        T2, T1, t2, t1 = (self.classes.T2,
+                                self.classes.T1,
+                                self.tables.t2,
+                                self.tables.t1)
+
+        # use annotation
+        mapper(T1, t1, properties={
+            't2s':relationship(T2,
+                           primaryjoin=t1.c.id==
+                            foreign(sa.func.lower(t2.c.t1id)),
+                           )})
         mapper(T2, t2)
+        self._test_onetomany()
 
+    def _test_onetomany(self):
+        T2, T1, t2, t1 = (self.classes.T2,
+                                self.classes.T1,
+                                self.tables.t2,
+                                self.tables.t1)
+        is_(T1.t2s.property.direction, ONETOMANY)
+        eq_(T1.t2s.property.local_remote_pairs, [(t1.c.id, t2.c.t1id)])
         sess = create_session()
         a1 = T1(id='number1', data='a1')
         a2 = T1(id='number2', data='a2')
index f9333dbf5329b0092474c22716944721636b4f65..d4f324dd7eea9f1ed84d871e4f1c6b17a985a478 100644 (file)
@@ -1,5 +1,5 @@
 from sqlalchemy import *
-from sqlalchemy.sql import table, column, ClauseElement
+from sqlalchemy.sql import table, column, ClauseElement, operators
 from sqlalchemy.sql.expression import  _clone, _from_objects
 from test.lib import *
 from sqlalchemy.sql.visitors import *
@@ -166,6 +166,90 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
         s = set(ClauseVisitor().iterate(bin))
         assert set(ClauseVisitor().iterate(bin)) == set([foo, bar, bin])
 
+class BinaryEndpointTraversalTest(fixtures.TestBase):
+    """test the special binary product visit"""
+
+    def _assert_traversal(self, expr, expected):
+        canary = []
+        def visit(binary, l, r):
+            canary.append((binary.operator, l, r))
+            print binary.operator, l, r
+        sql_util.visit_binary_product(visit, expr)
+        eq_(
+            canary, expected
+        )
+
+    def test_basic(self):
+        a, b = column("a"), column("b")
+        self._assert_traversal(
+            a == b,
+            [
+                (operators.eq, a, b)
+            ]
+        )
+
+    def test_with_tuples(self):
+        a, b, c, d, b1, b1a, b1b, e, f = (
+            column("a"),
+            column("b"),
+            column("c"),
+            column("d"),
+            column("b1"),
+            column("b1a"),
+            column("b1b"),
+            column("e"),
+            column("f")
+        )
+        expr = tuple_(
+            a, b, b1==tuple_(b1a, b1b == d), c
+        ) > tuple_(
+                func.go(e + f)
+            )
+        self._assert_traversal(
+            expr,
+            [
+                (operators.gt, a, e),
+                (operators.gt, a, f),
+                (operators.gt, b, e),
+                (operators.gt, b, f),
+                (operators.eq, b1, b1a),
+                (operators.eq, b1b, d),
+                (operators.gt, c, e),
+                (operators.gt, c, f)
+            ]
+        )
+
+    def test_composed(self):
+        a, b, e, f, q, j, r = (
+            column("a"),
+            column("b"),
+            column("e"),
+            column("f"),
+            column("q"),
+            column("j"),
+            column("r"),
+        )
+        expr = and_(
+            (a + b) == q + func.sum(e + f),
+            and_(
+                j == r,
+                f == q
+            )
+        )
+        self._assert_traversal(
+            expr,
+            [
+                (operators.eq, a, q),
+                (operators.eq, a, e),
+                (operators.eq, a, f),
+                (operators.eq, b, q),
+                (operators.eq, b, e),
+                (operators.eq, b, f),
+                (operators.eq, j, r),
+                (operators.eq, f, q),
+            ]
+        )
+
 class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
     """test copy-in-place behavior of various ClauseElements."""
 
index 6d85f7c4f3e16f906cc2d0ef1b0bf2e54196ef7e..4f1f39014929e762580fc3a948c4085c8aa1ad37 100644 (file)
@@ -1151,5 +1151,7 @@ class AnnotationsTest(fixtures.TestBase):
         assert b2.left is not bin.left 
         assert b3.left is not b2.left is not bin.left
         assert b4.left is bin.left  # since column is immutable
-        assert b4.right is not bin.right is not b2.right is not b3.right
+        assert b4.right is bin.right
+        assert b2.right is not bin.right
+        assert b3.right is b4.right is bin.right