]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merged sync_simplify branch
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Apr 2008 00:21:28 +0000 (00:21 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Apr 2008 00:21:28 +0000 (00:21 +0000)
- The methodology behind "primaryjoin"/"secondaryjoin" has
been refactored.  Behavior should be slightly more
intelligent, primarily in terms of error messages which
have been pared down to be more readable.  In a slight
number of scenarios it can better resolve the correct
foreign key than before.
- moved collections unit test from relationships.py to collection.py
- PropertyLoader now has "synchronize_pairs" and "equated_pairs"
collections which allow easy access to the source/destination
parent/child relation between columns (might change names)
- factored out ClauseSynchronizer (finally)
- added many more tests for priamryjoin/secondaryjoin
error checks

12 files changed:
CHANGES
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/util.py
test/orm/collection.py
test/orm/inheritance/polymorph2.py
test/orm/relationships.py

diff --git a/CHANGES b/CHANGES
index a83db182c18de8beae3c60562822b4dac8355eca..3280389eb2a0c62494844a989972adba2a90c737 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -27,6 +27,13 @@ CHANGES
     - Added a more aggressive check for "uncompiled mappers",
       helps particularly with declarative layer [ticket:995]
 
+    - The methodology behind "primaryjoin"/"secondaryjoin" has
+      been refactored.  Behavior should be slightly more
+      intelligent, primarily in terms of error messages which
+      have been pared down to be more readable.  In a slight
+      number of scenarios it can better resolve the correct 
+      foreign key than before.
+
     - Added comparable_property(), adds query Comparator behavior
       to regular, unmanaged Python properties
 
index 8519d2260d89a4140f0200c29165160f3df39a03..c667460a71796bece5c567ab23a36b95dc0e674c 100644 (file)
@@ -11,8 +11,8 @@
 """
 
 from sqlalchemy.orm import sync
-from sqlalchemy.orm.sync import ONETOMANY,MANYTOONE,MANYTOMANY
 from sqlalchemy import sql, util, exceptions
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
 
 
 def create_dependency_processor(prop):
@@ -43,8 +43,8 @@ class DependencyProcessor(object):
         self.passive_updates = prop.passive_updates
         self.enable_typechecks = prop.enable_typechecks
         self.key = prop.key
-
-        self._compile_synchronizers()
+        if not self.prop.synchronize_pairs:
+            raise exceptions.ArgumentError("Can't build a DependencyProcessor for relation %s.  No target attributes to populate between parent and child are present" % self.prop)
 
     def _get_instrumented_attribute(self):
         """Return the ``InstrumentedAttribute`` handled by this
@@ -121,20 +121,6 @@ class DependencyProcessor(object):
 
         raise NotImplementedError()
 
-    def _compile_synchronizers(self):
-        """Assemble a list of *synchronization rules*.
-
-        These are fired to populate attributes from one side
-        of a relation to another.
-        """
-
-        self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction)
-        if self.direction == sync.MANYTOMANY:
-            self.syncrules.compile(self.prop.primaryjoin, issecondary=False, foreign_keys=self.foreign_keys)
-            self.syncrules.compile(self.prop.secondaryjoin, issecondary=True, foreign_keys=self.foreign_keys)
-        else:
-            self.syncrules.compile(self.prop.primaryjoin, foreign_keys=self.foreign_keys)
-
 
     def _conditional_post_update(self, state, uowcommit, related):
         """Execute a post_update call.
@@ -153,11 +139,11 @@ class DependencyProcessor(object):
         if state is not None and self.post_update:
             for x in related:
                 if x is not None:
-                    uowcommit.register_object(state, postupdate=True, post_update_cols=self.syncrules.dest_columns())
+                    uowcommit.register_object(state, postupdate=True, post_update_cols=[r for l, r in self.prop.synchronize_pairs])
                     break
 
     def _pks_changed(self, uowcommit, state):
-        return self.syncrules.source_changes(uowcommit, state)
+        raise NotImplementedError()
 
     def __str__(self):
         return "%s(%s)" % (self.__class__.__name__, str(self.prop))
@@ -259,7 +245,13 @@ class OneToManyDP(DependencyProcessor):
         if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
             return
         self._verify_canload(child)
-        self.syncrules.execute(source, dest, source, child, clearkeys)
+        if clearkeys:
+            sync.clear(dest, self.mapper, self.prop.synchronize_pairs)
+        else:
+            sync.populate(source, self.parent, dest, self.mapper, self.prop.synchronize_pairs)
+
+    def _pks_changed(self, uowcommit, state):
+        return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
 
 class DetectKeySwitch(DependencyProcessor):
     """a special DP that works for many-to-one relations, fires off for
@@ -298,7 +290,11 @@ class DetectKeySwitch(DependencyProcessor):
                     elem.dict[self.key]._state in switchers
                 ]:
                 uowcommit.register_object(s, listonly=self.passive_updates)
-                self.syncrules.execute(s.dict[self.key]._state, s, None, None, False)
+                sync.populate(s.dict[self.key]._state, self.mapper, s, self.parent, self.prop.synchronize_pairs)
+                #self.syncrules.execute(s.dict[self.key]._state, s, None, None, False)
+
+    def _pks_changed(self, uowcommit, state):
+        return sync.source_changes(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
 
 class ManyToOneDP(DependencyProcessor):
     def __init__(self, prop):
@@ -368,12 +364,14 @@ class ManyToOneDP(DependencyProcessor):
 
 
     def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
-        source = child
-        dest = state
-        if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
+        if state is None or (not self.post_update and uowcommit.is_deleted(state)):
             return
-        self._verify_canload(child)
-        self.syncrules.execute(source, dest, dest, child, clearkeys)
+
+        if clearkeys or child is None:
+            sync.clear(state, self.parent, self.prop.synchronize_pairs)
+        else:
+            self._verify_canload(child)
+            sync.populate(child, self.mapper, state, self.parent, self.prop.synchronize_pairs)
 
 class ManyToManyDP(DependencyProcessor):
     def register_dependencies(self, uowcommit):
@@ -433,7 +431,10 @@ class ManyToManyDP(DependencyProcessor):
                 if not self.passive_updates and unchanged and self._pks_changed(uowcommit, state):
                     for child in unchanged:
                         associationrow = {}
-                        self.syncrules.update(associationrow, state, child, "old_")
+                        sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs)
+                        sync.update(child, self.mapper, associationrow, "old_", self.prop.secondary_synchronize_pairs)
+
+                        #self.syncrules.update(associationrow, state, child, "old_")
                         secondary_update.append(associationrow)
 
         if secondary_delete:
@@ -470,7 +471,12 @@ class ManyToManyDP(DependencyProcessor):
         if associationrow is None:
             return
         self._verify_canload(child)
-        self.syncrules.execute(None, associationrow, state, child, clearkeys)
+        
+        sync.populate_dict(state, self.parent, associationrow, self.prop.synchronize_pairs)
+        sync.populate_dict(child, self.mapper, associationrow, self.prop.secondary_synchronize_pairs)
+
+    def _pks_changed(self, uowcommit, state):
+        return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
 
 class AssociationDP(OneToManyDP):
     def __init__(self, *args, **kwargs):
index f00c424213158248e66428783c81ba0215e213a7..d61ebe9603f7ab1962c0dc3355b76bb40a5eed32 100644 (file)
@@ -27,6 +27,10 @@ __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
 EXT_CONTINUE = EXT_PASS = util.symbol('EXT_CONTINUE')
 EXT_STOP = util.symbol('EXT_STOP')
 
+ONETOMANY = util.symbol('ONETOMANY')
+MANYTOONE = util.symbol('MANYTOONE')
+MANYTOMANY = util.symbol('MANYTOMANY')
+
 class MapperExtension(object):
     """Base implementation for customizing Mapper behavior.
 
index 12e7d03a9537e1d8db60e5062f0da591a19b37e6..22f5678d692bf26063f3a8a2903a5f74a88ca811 100644 (file)
@@ -111,7 +111,8 @@ class Mapper(object):
         self._dependency_processors = []
         self._clause_adapter = None
         self._requires_row_aliasing = False
-
+        self.__inherits_equated_pairs = None
+        
         if not issubclass(class_, object):
             raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
 
@@ -171,11 +172,11 @@ class Mapper(object):
         self.__should_log_info = logging.is_info_enabled(self.logger)
         self.__should_log_debug = logging.is_debug_enabled(self.logger)
 
-        self._compile_class()
-        self._compile_inheritance()
-        self._compile_extensions()
-        self._compile_properties()
-        self._compile_pks()
+        self.__compile_class()
+        self.__compile_inheritance()
+        self.__compile_extensions()
+        self.__compile_properties()
+        self.__compile_pks()
         global __new_mappers
         __new_mappers = True
         self.__log("constructed")
@@ -352,17 +353,17 @@ class Mapper(object):
         to execute once all mappers have been constructed.
         """
 
-        self.__log("_initialize_properties() started")
+        self.__log("__initialize_properties() started")
         l = [(key, prop) for key, prop in self.__props.iteritems()]
         for key, prop in l:
             self.__log("initialize prop " + key)
             if getattr(prop, 'key', None) is None:
                 prop.init(key, self)
-        self.__log("_initialize_properties() complete")
+        self.__log("__initialize_properties() complete")
         self.__props_init = True
 
 
-    def _compile_extensions(self):
+    def __compile_extensions(self):
         """Go through the global_extensions list as well as the list
         of ``MapperExtensions`` specified for this ``Mapper`` and
         creates a linked list of those extensions.
@@ -393,7 +394,7 @@ class Mapper(object):
         for ext in extlist:
             self.extension.append(ext)
 
-    def _compile_inheritance(self):
+    def __compile_inheritance(self):
         """Configure settings related to inherting and/or inherited mappers being present."""
 
         if self.inherits:
@@ -412,7 +413,6 @@ class Mapper(object):
                 self.single = True
             if not self.local_table is self.inherits.local_table:
                 if self.concrete:
-                    self._synchronizer = None
                     self.mapped_table = self.local_table
                     for mapper in self.iterate_to_root():
                         if mapper.polymorphic_on:
@@ -424,17 +424,10 @@ class Mapper(object):
                         # stuff we dont want (allows test/inheritance.InheritTest4 to pass)
                         self.inherit_condition = sql.join(self.inherits.local_table, self.local_table).onclause
                     self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition)
-                    # generate sync rules.  similarly to creating the on clause, specify a
-                    # stricter set of tables to create "sync rules" by,based on the immediate
-                    # inherited table, rather than all inherited tables
-                    self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY)
-                    if self.inherit_foreign_keys:
-                        fks = util.Set(self.inherit_foreign_keys)
-                    else:
-                        fks = None
-                    self._synchronizer.compile(self.mapped_table.onclause, foreign_keys=fks)
+                    
+                    fks = util.to_set(self.inherit_foreign_keys)
+                    self.__inherits_equated_pairs = sqlutil.criterion_as_pairs(self.mapped_table.onclause, consider_as_foreign_keys=fks)
             else:
-                self._synchronizer = None
                 self.mapped_table = self.local_table
             if self.polymorphic_identity is not None:
                 self.inherits.polymorphic_map[self.polymorphic_identity] = self
@@ -470,7 +463,6 @@ class Mapper(object):
         else:
             self._all_tables = util.Set()
             self.base_mapper = self
-            self._synchronizer = None
             self.mapped_table = self.local_table
             if self.polymorphic_identity:
                 if self.polymorphic_on is None:
@@ -481,7 +473,7 @@ class Mapper(object):
         if self.mapped_table is None:
             raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified.  (Are you using the return value of table.create()?  It no longer has a return value.)" % str(self))
 
-    def _compile_pks(self):
+    def __compile_pks(self):
 
         self.tables = sqlutil.find_tables(self.mapped_table)
 
@@ -634,7 +626,7 @@ class Mapper(object):
 
             return getattr(getattr(cls, clskey), key)
 
-    def _compile_properties(self):
+    def __compile_properties(self):
 
         # object attribute names mapped to MapperProperty objects
         self.__props = util.OrderedDict()
@@ -770,7 +762,7 @@ class Mapper(object):
         for mapper in self._inheriting_mappers:
             mapper._adapt_inherited_property(key, prop)
 
-    def _compile_class(self):
+    def __compile_class(self):
         """If this mapper is to be a primary mapper (i.e. the
         non_primary flag is not set), associate this Mapper with the
         given class_ and entity name.
@@ -1169,8 +1161,8 @@ class Mapper(object):
                     # TODO: this fires off more than needed, try to organize syncrules
                     # per table
                     for m in util.reversed(list(mapper.iterate_to_root())):
-                        if m._synchronizer:
-                            m._synchronizer.execute(state, state)
+                        if m.__inherits_equated_pairs:
+                            m._synchronize_inherited(state)
 
                     # testlib.pragma exempt:__hash__
                     inserted_objects.add((state, connection))
@@ -1186,6 +1178,9 @@ class Mapper(object):
                     if 'after_update' in mapper.extension.methods:
                         mapper.extension.after_update(mapper, connection, state.obj())
 
+    def _synchronize_inherited(self, state):
+        sync.populate(state, self, state, self, self.__inherits_equated_pairs)
+
     def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
         """After an ``INSERT`` or ``UPDATE``, assemble newly generated
         values on an instance.  For columns which are marked as being generated
index 9f8e852f1d646ee8c84fa5331b9a68b87c7b9c33..fb10357cfb5c7bd7b0166d7bed894d4c8b42cfed 100644 (file)
@@ -11,16 +11,15 @@ invidual ORM-mapped attributes.
 """
 
 from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql.util import ClauseAdapter
+from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, find_columns
 from sqlalchemy.sql import visitors, operators, ColumnElement
 from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm.mapper import _class_to_mapper
 from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
-from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
+from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY
 from sqlalchemy.exceptions import ArgumentError
 
-
 __all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty',
            'ComparableProperty', 'PropertyLoader', 'BackRef')
 
@@ -288,7 +287,7 @@ class PropertyLoader(StrategizedProperty):
             
         def __eq__(self, other):
             if other is None:
-                if self.prop.direction == sync.ONETOMANY:
+                if self.prop.direction == ONETOMANY:
                     return ~sql.exists([1], self.prop.primaryjoin)
                 else:
                     return self.prop._optimized_compare(None)
@@ -377,7 +376,7 @@ class PropertyLoader(StrategizedProperty):
             
         def __ne__(self, other):
             if other is None:
-                if self.prop.direction == sync.MANYTOONE:
+                if self.prop.direction == MANYTOONE:
                     return sql.or_(*[x!=None for x in self.prop.foreign_keys])
                 elif self.prop.uselist:
                     return self.any()
@@ -475,14 +474,14 @@ class PropertyLoader(StrategizedProperty):
             return self.argument.class_
 
     def do_init(self):
-        self._determine_targets()
-        self._determine_joins()
-        self._determine_fks()
-        self._determine_direction()
-        self._determine_remote_side()
+        self.__determine_targets()
+        self.__determine_joins()
+        self.__determine_fks()
+        self.__determine_direction()
+        self.__determine_remote_side()
         self._post_init()
 
-    def _determine_targets(self):
+    def __determine_targets(self):
         if isinstance(self.argument, type):
             self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False)
         elif isinstance(self.argument, mapper.Mapper):
@@ -507,10 +506,12 @@ class PropertyLoader(StrategizedProperty):
 
         if self.cascade.delete_orphan:
             if self.parent.class_ is self.mapper.class_:
-                raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade rule on a self-referential relationship.  You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self)))
+                raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade "
+                            "rule on a self-referential relationship.  "
+                            "You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self)))
             self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_))
 
-    def _determine_joins(self):
+    def __determine_joins(self):
         if self.secondaryjoin is not None and self.secondary is None:
             raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument")
         # if join conditions were not specified, figure them out based on foreign keys
@@ -535,10 +536,11 @@ class PropertyLoader(StrategizedProperty):
                 if self.primaryjoin is None:
                     self.primaryjoin = _search_for_join(self.parent, self.target).onclause
         except exceptions.ArgumentError, e:
-            raise exceptions.ArgumentError("""Error determining primary and/or secondary join for relationship '%s'. If the underlying error cannot be corrected, you should specify the 'primaryjoin' (and 'secondaryjoin', if there is an association table present) keyword arguments to the relation() function (or for backrefs, by specifying the backref using the backref() function with keyword arguments) to explicitly specify the join conditions. Nested error is \"%s\"""" % (str(self), str(e)))
+            raise exceptions.ArgumentError("Could not determine join condition between parent/child tables on relation %s.  "
+                        "Specify a 'primaryjoin' expression.  If this is a many-to-many relation, 'secondaryjoin' is needed as well." % (self))
 
 
-    def _col_is_part_of_mappings(self, column):
+    def __col_is_part_of_mappings(self, column):
         if self.secondary is None:
             return self.parent.mapped_table.c.contains_column(column) or \
                 self.target.c.contains_column(column)
@@ -547,61 +549,77 @@ class PropertyLoader(StrategizedProperty):
                 self.target.c.contains_column(column) or \
                 self.secondary.c.contains_column(column) is not None
         
-    def _determine_fks(self):
+    def __determine_fks(self):
         if self._legacy_foreignkey and not self._refers_to_parent_table():
             self.foreign_keys = self._legacy_foreignkey
 
-        self._opposite_side = util.Set()
+        arg_foreign_keys = self.foreign_keys
+        
+        eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly)
+        eq_pairs = [(l, r) for l, r in eq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
+
+        if not eq_pairs:
+            if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
+                raise exceptions.ArgumentError("Could not locate any equated column pairs for primaryjoin condition '%s' on relation %s. "
+                    "If no equated pairs exist, the relation must be marked as viewonly=True." % (self.primaryjoin, self)
+                )
+            else:
+                raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
+                "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self))
+        
+        self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs])
+        self._opposite_side = util.OrderedSet([l for l, r in eq_pairs])
+        self.synchronize_pairs = eq_pairs
+        
+        if self.secondaryjoin:
+            sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys)
+            sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
+            
+            if not sq_pairs:
+                if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
+                    raise exceptions.ArgumentError("Could not locate any equated column pairs for secondaryjoin condition '%s' on relation %s. "
+                        "If no equated pairs exist, the relation must be marked as viewonly=True." % (self.secondaryjoin, self)
+                    )
+                else:
+                    raise exceptions.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. "
+                    "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.secondaryjoin, self))
 
-        if self.foreign_keys:
-            def visit_binary(binary):
-                if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
-                    return
-                if binary.left in self.foreign_keys:
-                    self._opposite_side.add(binary.right)
-                if binary.right in self.foreign_keys:
-                    self._opposite_side.add(binary.left)
+            self.foreign_keys.update([r for l, r in sq_pairs])
+            self._opposite_side.update([l for l, r in sq_pairs])
+            self.secondary_synchronize_pairs = sq_pairs
         else:
-            self.foreign_keys = util.Set()
-            def visit_binary(binary):
-                if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
-                    return
-
-                # this check is for when the user put the "view_only" flag on and has tables that have nothing
-                # to do with the relationship's parent/child mappings in the join conditions.  we dont want cols
-                # or clauses related to those external tables dealt with.  see orm.relationships.ViewOnlyTest
-                if not self._col_is_part_of_mappings(binary.left) or not self._col_is_part_of_mappings(binary.right):
-                    return
-
-                for f in binary.left.foreign_keys:
-                    if f.references(binary.right.table):
-                        self.foreign_keys.add(binary.left)
-                        self._opposite_side.add(binary.right)
-                for f in binary.right.foreign_keys:
-                    if f.references(binary.left.table):
-                        self.foreign_keys.add(binary.right)
-                        self._opposite_side.add(binary.left)
-
-        visitors.traverse(self.primaryjoin, visit_binary=visit_binary)
-
-        if not self.foreign_keys:
-            raise exceptions.ArgumentError(
-                "Can't locate any foreign key columns in primary join "
-                "condition '%s' for relationship '%s'.  Specify "
-                "'foreign_keys' argument to indicate which columns in "
-                "the join condition are foreign." %(str(self.primaryjoin), str(self)))
-
-        if self.secondaryjoin is not None:
-            visitors.traverse(self.secondaryjoin, visit_binary=visit_binary)
+            self.secondary_synchronize_pairs = None
+    
+    def equated_pairs(self):
+        return zip(self.local_side, self.remote_side)
+    equated_pairs = property(equated_pairs)
+    
+    def __determine_remote_side(self):
+        if self.remote_side:
+            if self.direction is MANYTOONE:
+                eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True)
+            else:
+                eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.remote_side, any_operator=True)
 
+            if self.secondaryjoin:
+                sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True)
+                sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
+                eq_pairs += sq_pairs
+        else:
+            eq_pairs = zip(self._opposite_side, self.foreign_keys)
 
-    def _determine_direction(self):
+        if self.direction is MANYTOONE:
+            self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
+        else:
+            self.local_side, self.remote_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
+            
+    def __determine_direction(self):
         """Determine our *direction*, i.e. do we represent one to
         many, many to many, etc.
         """
 
         if self.secondaryjoin is not None:
-            self.direction = sync.MANYTOMANY
+            self.direction = MANYTOMANY
         elif self._refers_to_parent_table():
             # for a self referential mapper, if the "foreignkey" is a single or composite primary key,
             # then we are "many to one", since the remote site of the relationship identifies a singular entity.
@@ -609,19 +627,19 @@ class PropertyLoader(StrategizedProperty):
             if self._legacy_foreignkey:
                 for f in self._legacy_foreignkey:
                     if not f.primary_key:
-                        self.direction = sync.ONETOMANY
+                        self.direction = ONETOMANY
                     else:
-                        self.direction = sync.MANYTOONE
+                        self.direction = MANYTOONE
 
             elif self.remote_side:
                 for f in self.foreign_keys:
                     if f in self.remote_side:
-                        self.direction = sync.ONETOMANY
+                        self.direction = ONETOMANY
                         return
                 else:
-                    self.direction = sync.MANYTOONE
+                    self.direction = MANYTOONE
             else:
-                self.direction = sync.ONETOMANY
+                self.direction = ONETOMANY
         else:
             for mappedtable, parenttable in [(self.mapper.mapped_table, self.parent.mapped_table), (self.mapper.local_table, self.parent.local_table)]:
                 onetomany = [c for c in self.foreign_keys if mappedtable.c.contains_column(c)]
@@ -635,10 +653,10 @@ class PropertyLoader(StrategizedProperty):
                 elif onetomany and manytoone:
                     continue
                 elif onetomany:
-                    self.direction = sync.ONETOMANY
+                    self.direction = ONETOMANY
                     break
                 elif manytoone:
-                    self.direction = sync.MANYTOONE
+                    self.direction = MANYTOONE
                     break
             else:
                 raise exceptions.ArgumentError(
@@ -647,24 +665,15 @@ class PropertyLoader(StrategizedProperty):
                     "the child's mapped tables.  Specify 'foreign_keys' "
                     "argument." % (str(self)))
 
-    def _determine_remote_side(self):
-        if not self.remote_side:
-            if self.direction is sync.MANYTOONE:
-                self.remote_side = util.Set(self._opposite_side)
-            elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY:
-                self.remote_side = util.Set(self.foreign_keys)
-
-        self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side)
-
     def _post_init(self):
         if logging.is_info_enabled(self.logger):
             self.logger.info(str(self) + " setup primary join " + str(self.primaryjoin))
             self.logger.info(str(self) + " setup secondary join " + str(self.secondaryjoin))
-            self.logger.info(str(self) + " foreign keys " + str([str(c) for c in self.foreign_keys]))
-            self.logger.info(str(self) + " remote columns " + str([str(c) for c in self.remote_side]))
-            self.logger.info(str(self) + " relation direction " + (self.direction is sync.ONETOMANY and "one-to-many" or (self.direction is sync.MANYTOONE and "many-to-one" or "many-to-many")))
+            self.logger.info(str(self) + " synchronize pairs " + ",".join(["(%s => %s)" % (l, r) for l, r in self.synchronize_pairs]))
+            self.logger.info(str(self) + " equated pairs " + ",".join(["(%s == %s)" % (l, r) for l, r in self.equated_pairs]))
+            self.logger.info(str(self) + " relation direction " + (self.direction is ONETOMANY and "one-to-many" or (self.direction is MANYTOONE and "many-to-one" or "many-to-many")))
 
-        if self.uselist is None and self.direction is sync.MANYTOONE:
+        if self.uselist is None and self.direction is MANYTOONE:
             self.uselist = False
 
         if self.uselist is None:
@@ -712,9 +721,9 @@ class PropertyLoader(StrategizedProperty):
             primaryjoin = self.primaryjoin
             
             if fromselectable is not frommapper.local_table:
-                if self.direction is sync.ONETOMANY:
+                if self.direction is ONETOMANY:
                     primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
-                elif self.direction is sync.MANYTOONE:
+                elif self.direction is MANYTOONE:
                     primaryjoin = ClauseAdapter(fromselectable, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
                 elif self.secondaryjoin:
                     primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
index 4028eed6a42cc3c6a07f9092d673bb851e82e694..6bd0d530f53c7ce7c9abe57143178fc1dd60fa02 100644 (file)
@@ -255,14 +255,14 @@ NoLoader.logger = logging.class_logger(NoLoader)
 class LazyLoader(AbstractRelationLoader):
     def init(self):
         super(LazyLoader, self).init()
-        (self.lazywhere, self.lazybinds, self.equated_columns) = self._create_lazy_clause(self.parent_property)
+        (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self.__create_lazy_clause(self.parent_property)
         
-        self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere))
+        self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.__lazywhere))
 
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
         #from sqlalchemy.orm import query
-        self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.lazywhere)
+        self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere)
         if self.use_get:
             self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
 
@@ -275,10 +275,9 @@ class LazyLoader(AbstractRelationLoader):
             return self._lazy_none_clause(reverse_direction)
             
         if not reverse_direction:
-            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
+            (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
         else:
-            (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
-        bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+            (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
 
         def visit_bindparam(bindparam):
             mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
@@ -291,10 +290,9 @@ class LazyLoader(AbstractRelationLoader):
     
     def _lazy_none_clause(self, reverse_direction=False):
         if not reverse_direction:
-            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
+            (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
         else:
-            (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
-        bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+            (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
 
         def visit_binary(binary):
             mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
@@ -351,24 +349,20 @@ class LazyLoader(AbstractRelationLoader):
                     instance._state.reset(self.key)
             return (new_execute, None, None)
 
-    def _create_lazy_clause(cls, prop, reverse_direction=False):
-        (primaryjoin, secondaryjoin, remote_side) = (prop.primaryjoin, prop.secondaryjoin, prop.remote_side)
-        
+    def __create_lazy_clause(cls, prop, reverse_direction=False):
         binds = {}
         equated_columns = {}
 
+        secondaryjoin = prop.secondaryjoin
+        equated = dict(prop.equated_pairs)
+        
         def should_bind(targetcol, othercol):
-            if not prop._col_is_part_of_mappings(targetcol):
-                return False
-                
             if reverse_direction and not secondaryjoin:
-                return targetcol in remote_side
+                return othercol in equated
             else:
-                return othercol in remote_side
+                return targetcol in equated
 
         def visit_binary(binary):
-            if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement):
-                return
             leftcol = binary.left
             rightcol = binary.right
 
@@ -376,31 +370,28 @@ class LazyLoader(AbstractRelationLoader):
             equated_columns[leftcol] = rightcol
 
             if should_bind(leftcol, rightcol):
-                if leftcol in binds:
-                    binary.left = binds[leftcol]
-                else:
-                    binary.left = binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
+                if leftcol not in binds:
+                    binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
+                binary.left = binds[leftcol]
+            elif should_bind(rightcol, leftcol):
+                if rightcol not in binds:
+                    binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
+                binary.right = binds[rightcol]
 
-            # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
-            # which can happen in rare cases (test/orm/relationships.py RelationTest2)
-            if leftcol is not rightcol and should_bind(rightcol, leftcol):
-                if rightcol in binds:
-                    binary.right = binds[rightcol]
-                else:
-                    binary.right = binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
-
-                
-        lazywhere = primaryjoin
+        lazywhere = prop.primaryjoin
         
-        if not secondaryjoin or not reverse_direction:
+        if not prop.secondaryjoin or not reverse_direction:
             lazywhere = visitors.traverse(lazywhere, clone=True, visit_binary=visit_binary)
         
-        if secondaryjoin is not None:
+        if prop.secondaryjoin is not None:
             if reverse_direction:
                 secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary)
             lazywhere = sql.and_(lazywhere, secondaryjoin)
-        return (lazywhere, binds, equated_columns)
-    _create_lazy_clause = classmethod(_create_lazy_clause)
+    
+        bind_to_col = dict([(binds[col].key, col) for col in binds])
+        
+        return (lazywhere, bind_to_col, equated_columns)
+    __create_lazy_clause = classmethod(__create_lazy_clause)
     
 LazyLoader.logger = logging.class_logger(LazyLoader)
 
@@ -452,7 +443,7 @@ class LoadLazyAttribute(object):
             ident = []
             allnulls = True
             for primary_key in prop.mapper.primary_key: 
-                val = instance_mapper._get_committed_attr_by_column(instance, strategy.equated_columns[primary_key])
+                val = instance_mapper._get_committed_attr_by_column(instance, strategy._equated_columns[primary_key])
                 allnulls = allnulls and val is None
                 ident.append(val)
             if allnulls:
index d95009a473e7a99e09c02edc069ab5a3762e1e82..39a7b5044c40064267f827163e75fa6f35239de0 100644 (file)
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-"""Contains the ClauseSynchronizer class, which is used to map
-attributes between two objects in a manner corresponding to a SQL
-clause that compares column values.
+"""private module containing functions used for copying data between instances
+based on join conditions.
 """
 
 from sqlalchemy import schema, exceptions, util
-from sqlalchemy.sql import visitors, operators
+from sqlalchemy.sql import visitors, operators, util as sqlutil
 from sqlalchemy import logging
 from sqlalchemy.orm import util as mapperutil
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY  # legacy
 
-ONETOMANY = 0
-MANYTOONE = 1
-MANYTOMANY = 2
-
-class ClauseSynchronizer(object):
-    """Given a SQL clause, usually a series of one or more binary
-    expressions between columns, and a set of 'source' and
-    'destination' mappers, compiles a set of SyncRules corresponding
-    to that information.
-
-    The ClauseSynchronizer can then be executed given a set of
-    parent/child objects or destination dictionary, which will iterate
-    through each of its SyncRules and execute them.  Each SyncRule
-    will copy the value of a single attribute from the parent to the
-    child, corresponding to the pair of columns in a particular binary
-    expression, using the source and destination mappers to map those
-    two columns to object attributes within parent and child.
-    """
-
-    def __init__(self, parent_mapper, child_mapper, direction):
-        self.parent_mapper = parent_mapper
-        self.child_mapper = child_mapper
-        self.direction = direction
-        self.syncrules = []
-
-    def compile(self, sqlclause, foreign_keys=None, issecondary=None):
-        def compile_binary(binary):
-            """Assemble a SyncRule given a single binary condition."""
-
-            if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
-                return
-
-            source_column = None
-            dest_column = None
-
-            if foreign_keys is None:
-                if binary.left.table == binary.right.table:
-                    raise exceptions.ArgumentError("need foreign_keys argument for self-referential sync")
-
-                if binary.left in util.Set([f.column for f in binary.right.foreign_keys]):
-                    dest_column = binary.right
-                    source_column = binary.left
-                elif binary.right in util.Set([f.column for f in binary.left.foreign_keys]):
-                    dest_column = binary.left
-                    source_column = binary.right
-            else:
-                if binary.left in foreign_keys:
-                    source_column = binary.right
-                    dest_column = binary.left
-                elif binary.right in foreign_keys:
-                    source_column = binary.left
-                    dest_column = binary.right
+def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs):
+    for l, r in synchronize_pairs:
+        try:
+            value = source_mapper._get_state_attr_by_column(source, l)
+        except exceptions.UnmappedColumnError:
+            _raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
 
-            if source_column and dest_column:
-                if self.direction == ONETOMANY:
-                    self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper))
-                elif self.direction == MANYTOONE:
-                    self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper))
-                else:
-                    if not issecondary:
-                        self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper, issecondary=issecondary))
-                    else:
-                        self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary))
+        try:
+            dest_mapper._set_state_attr_by_column(dest, r, value)
+        except exceptions.UnmappedColumnError:
+            self._raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
 
-        rules_added = len(self.syncrules)
-        visitors.traverse(sqlclause, visit_binary=compile_binary)
-        if len(self.syncrules) == rules_added:
-            raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause))
+def clear(dest, dest_mapper, synchronize_pairs):
+    for l, r in synchronize_pairs:
+        if r.primary_key:
+            raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest)))
+        try:
+            dest_mapper._set_state_attr_by_column(dest, r, None)
+        except exceptions.UnmappedColumnError:
+            _raise_col_to_prop(True, None, l, dest_mapper, r)
 
-    def dest_columns(self):
-        return [r.dest_column for r in self.syncrules if r.dest_column is not None]
+def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
+    for l, r in synchronize_pairs:
+        try:
+            oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l)
+            value = source_mapper._get_state_attr_by_column(source, l)
+        except exceptions.UnmappedColumnError:
+            self._raise_col_to_prop(False, source_mapper, l, None, r)
+        dest[r.key] = value
+        dest[old_prefix + r.key] = oldvalue
 
-    def update(self, dest, parent, child, old_prefix):
-        for rule in self.syncrules:
-            rule.update(dest, parent, child, old_prefix)
-        
-    def execute(self, source, dest, obj=None, child=None, clearkeys=None):
-        for rule in self.syncrules:
-            rule.execute(source, dest, obj, child, clearkeys)
-    
-    def source_changes(self, uowcommit, source):
-        for rule in self.syncrules:
-            if rule.source_changes(uowcommit, source):
-                return True
-        else:
-            return False
+def populate_dict(source, source_mapper, dict_, synchronize_pairs):
+    for l, r in synchronize_pairs:
+        try:
+            value = source_mapper._get_state_attr_by_column(source, l)
+        except exceptions.UnmappedColumnError:
+            _raise_col_to_prop(False, source_mapper, l, None, r)
             
-class SyncRule(object):
-    """An instruction indicating how to populate the objects on each
-    side of a relationship.
-
-    E.g. if table1 column A is joined against table2 column
-    B, and we are a one-to-many from table1 to table2, a syncrule
-    would say *take the A attribute from object1 and assign it to the
-    B attribute on object2*.
-    """
+        dict_[r.key] = value
 
-    def __init__(self, source_mapper, source_column, dest_column, dest_mapper=None, issecondary=None):
-        self.source_mapper = source_mapper
-        self.source_column = source_column
-        self.issecondary = issecondary
-        self.dest_mapper = dest_mapper
-        self.dest_column = dest_column
-        
-        #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper
-
-    def dest_primary_key(self):
-        # late-evaluating boolean since some syncs are created
-        # before the mapper has assembled pks
-        try:
-            return self._dest_primary_key
-        except AttributeError:
-            self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper._pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks
-            return self._dest_primary_key
-    
-    def _raise_col_to_prop(self, isdest):
-        if isdest:
-            raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (self.dest_column, self.dest_mapper))
-        else:
-            raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (self.source_column, self.source_mapper, self.dest_column))
-                
-    def source_changes(self, uowcommit, source):
+def source_changes(uowcommit, source, source_mapper, synchronize_pairs):
+    for l, r in synchronize_pairs:
         try:
-            prop = self.source_mapper._get_col_to_prop(self.source_column)
+            prop = source_mapper._get_col_to_prop(l)
         except exceptions.UnmappedColumnError:
-            self._raise_col_to_prop(False)
+            _raise_col_to_prop(False, source_mapper, l, None, r)
         (added, unchanged, deleted) = uowcommit.get_attribute_history(source, prop.key, passive=True)
-        return bool(added and deleted)
-    
-    def update(self, dest, parent, child, old_prefix):
-        if self.issecondary is False:
-            source = parent
-        elif self.issecondary is True:
-            source = child
+        if added and deleted:
+            return True
+    else:
+        return False
+
+def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs):
+    for l, r in synchronize_pairs:
         try:
-            oldvalue = self.source_mapper._get_committed_attr_by_column(source.obj(), self.source_column)
-            value = self.source_mapper._get_state_attr_by_column(source, self.source_column)
+            prop = dest_mapper._get_col_to_prop(r)
         except exceptions.UnmappedColumnError:
-            self._raise_col_to_prop(False)
-        dest[self.dest_column.key] = value
-        dest[old_prefix + self.dest_column.key] = oldvalue
+            _raise_col_to_prop(True, None, l, dest_mapper, r)
+        (added, unchanged, deleted) = uowcommit.get_attribute_history(dest, prop.key, passive=True)
+        if added and deleted:
+            return True
+    else:
+        return False
+
+def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column):
+    if isdest:
+        raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper))
+    else:
+        raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column))
         
-    def execute(self, source, dest, parent, child, clearkeys):
-        # TODO: break the "dictionary" case into a separate method like 'update' above,
-        # reduce conditionals
-        if source is None:
-            if self.issecondary is False:
-                source = parent
-            elif self.issecondary is True:
-                source = child
-        if clearkeys or source is None:
-            value = None
-            clearkeys = True
-        else:
-            try:
-                value = self.source_mapper._get_state_attr_by_column(source, self.source_column)
-            except exceptions.UnmappedColumnError:
-                self._raise_col_to_prop(False)
-        if isinstance(dest, dict):
-            dest[self.dest_column.key] = value
-        else:
-            if clearkeys and self.dest_primary_key():
-                raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (str(self.dest_column), mapperutil.state_str(dest)))
-
-            if logging.is_debug_enabled(self.logger):
-                self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.state_str(source), str(self.source_column), mapperutil.state_str(dest), str(self.dest_column), value))
-            try:
-                self.dest_mapper._set_state_attr_by_column(dest, self.dest_column, value)
-            except exceptions.UnmappedColumnError:
-                self._raise_col_to_prop(True)
-
-SyncRule.logger = logging.class_logger(SyncRule)
-
index 96ec5b8ded52475b75a55b6a906f89afe205a430..a4028c1efa05fab2c39ea7ed9e24510a7d10de00 100644 (file)
@@ -553,7 +553,7 @@ class Column(SchemaItem, expression._ColumnClause):
     def references(self, column):
         """Return True if this references the given column via a foreign key."""
         for fk in self.foreign_keys:
-            if fk.column is column:
+            if fk.references(column.table):
                 return True
         else:
             return False
index 8ed561e5f1308b664fb7af67b11ee0da3d923934..5b9ffd4fa7436d82ed2f79153ed2c75b78e3be2a 100644 (file)
@@ -1,10 +1,12 @@
-from sqlalchemy import exceptions, schema, topological, util
+from sqlalchemy import exceptions, schema, topological, util, sql
 from sqlalchemy.sql import expression, operators, visitors
 from itertools import chain
 
 """Utility functions that build upon SQL and Schema constructs."""
 
 def sort_tables(tables, reverse=False):
+    """sort a collection of Table objects in order of their foreign-key dependency."""
+    
     tuples = []
     class TVisitor(schema.SchemaVisitor):
         def visit_foreign_key(_self, fkey):
@@ -24,6 +26,8 @@ def sort_tables(tables, reverse=False):
         return sequence
 
 def find_tables(clause, check_columns=False, include_aliases=False):
+    """locate Table objects within the given expression."""
+    
     tables = []
     kwargs = {}
     if include_aliases:
@@ -44,6 +48,8 @@ def find_tables(clause, check_columns=False, include_aliases=False):
     return tables
 
 def find_columns(clause):
+    """locate Column objects within the given expression."""
+    
     cols = util.Set()
     def visit_column(col):
         cols.add(col)
@@ -93,6 +99,38 @@ def reduce_columns(columns, *clauses):
 
     return expression.ColumnSet(columns.difference(omit))
 
+def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_referenced_keys=None, any_operator=False):
+    """traverse an expression and locate binary criterion pairs."""
+    
+    if consider_as_foreign_keys and consider_as_referenced_keys:
+        raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
+        
+    def visit_binary(binary):
+        if not any_operator and binary.operator != operators.eq:
+            return
+        if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement):
+            return
+
+        if consider_as_foreign_keys:
+            if binary.left in consider_as_foreign_keys:
+                pairs.append((binary.right, binary.left))
+            elif binary.right in consider_as_foreign_keys:
+                pairs.append((binary.left, binary.right))
+        elif consider_as_referenced_keys:
+            if binary.left in consider_as_referenced_keys:
+                pairs.append((binary.left, binary.right))
+            elif binary.right in consider_as_referenced_keys:
+                pairs.append((binary.right, binary.left))
+        else:
+            if isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
+                if binary.left.references(binary.right):
+                    pairs.append((binary.right, binary.left))
+                elif binary.right.references(binary.left):
+                    pairs.append((binary.left, binary.right))
+    pairs = []
+    visitors.traverse(expression, visit_binary=visit_binary)
+    return pairs
+    
 class AliasedRow(object):
     
     def __init__(self, row, map):
@@ -117,7 +155,7 @@ class AliasedRow(object):
         return self.row.keys()
 
 def row_adapter(from_, equivalent_columns=None):
-    """create a row adapter against a selectable."""
+    """create a row adapter callable against a selectable."""
     
     if equivalent_columns is None:
         equivalent_columns = {}
index 4011ac2ceb215ab51cb27d240e924c3534575237..2abec607a68fd7e50e69347c55538b99e328313a 100644 (file)
@@ -1436,5 +1436,293 @@ class DictHelpersTest(ORMTest):
         collection_class = lambda: Ordered2(lambda v: (v.a, v.b))
         self._test_composite_mapped(collection_class)
 
+# TODO: are these tests redundant vs. the above tests ?
+# remove if so
+class CustomCollectionsTest(ORMTest):
+    def define_tables(self, metadata):
+        global sometable, someothertable
+        sometable = Table('sometable', metadata,
+            Column('col1',Integer, primary_key=True),
+            Column('data', String(30)))
+        someothertable = Table('someothertable', metadata,
+            Column('col1', Integer, primary_key=True),
+            Column('scol1', Integer, ForeignKey(sometable.c.col1)),
+            Column('data', String(20))
+        )
+    def test_basic(self):
+        class MyList(list):
+            pass
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar, collection_class=MyList)
+        })
+        mapper(Bar, someothertable)
+        f = Foo()
+        assert isinstance(f.bars, MyList)
+        
+    def test_lazyload(self):
+        """test that a 'set' can be used as a collection and can lazyload."""
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar, collection_class=set)
+        })
+        mapper(Bar, someothertable)
+        f = Foo()
+        f.bars.add(Bar())
+        f.bars.add(Bar())
+        sess = create_session()
+        sess.save(f)
+        sess.flush()
+        sess.clear()
+        f = sess.query(Foo).get(f.col1)
+        assert len(list(f.bars)) == 2
+        f.bars.clear()
+
+    def test_dict(self):
+        """test that a 'dict' can be used as a collection and can lazyload."""
+
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+        class AppenderDict(dict):
+            @collection.appender
+            def set(self, item):
+                self[id(item)] = item
+            @collection.remover
+            def remove(self, item):
+                if id(item) in self:
+                    del self[id(item)]
+
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar, collection_class=AppenderDict)
+        })
+        mapper(Bar, someothertable)
+        f = Foo()
+        f.bars.set(Bar())
+        f.bars.set(Bar())
+        sess = create_session()
+        sess.save(f)
+        sess.flush()
+        sess.clear()
+        f = sess.query(Foo).get(f.col1)
+        assert len(list(f.bars)) == 2
+        f.bars.clear()
+
+    def test_dict_wrapper(self):
+        """test that the supplied 'dict' wrapper can be used as a collection and can lazyload."""
+
+        class Foo(object):
+            pass
+        class Bar(object):
+            def __init__(self, data): self.data = data
+
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar,
+                collection_class=collections.column_mapped_collection(someothertable.c.data))
+        })
+        mapper(Bar, someothertable)
+
+        f = Foo()
+        col = collections.collection_adapter(f.bars)
+        col.append_with_event(Bar('a'))
+        col.append_with_event(Bar('b'))
+        sess = create_session()
+        sess.save(f)
+        sess.flush()
+        sess.clear()
+        f = sess.query(Foo).get(f.col1)
+        assert len(list(f.bars)) == 2
+
+        existing = set([id(b) for b in f.bars.values()])
+
+        col = collections.collection_adapter(f.bars)
+        col.append_with_event(Bar('b'))
+        f.bars['a'] = Bar('a')
+        sess.flush()
+        sess.clear()
+        f = sess.query(Foo).get(f.col1)
+        assert len(list(f.bars)) == 2
+
+        replaced = set([id(b) for b in f.bars.values()])
+        self.assert_(existing != replaced)
+
+    def test_list(self):
+        class Parent(object):
+            pass
+        class Child(object):
+            pass
+
+        mapper(Parent, sometable, properties={
+            'children':relation(Child, collection_class=list)
+        })
+        mapper(Child, someothertable)
+
+        control = list()
+        p = Parent()
+
+        o = Child()
+        control.append(o)
+        p.children.append(o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = [Child(), Child(), Child(), Child()]
+        control.extend(o)
+        p.children.extend(o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        assert control[0] == p.children[0]
+        assert control[-1] == p.children[-1]
+        assert control[1:3] == p.children[1:3]
+
+        del control[1]
+        del p.children[1]
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = [Child()]
+        control[1:3] = o
+        p.children[1:3] = o
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = [Child(), Child(), Child(), Child()]
+        control[1:3] = o
+        p.children[1:3] = o
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = [Child(), Child(), Child(), Child()]
+        control[-1:-2] = o
+        p.children[-1:-2] = o
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = [Child(), Child(), Child(), Child()]
+        control[4:] = o
+        p.children[4:] = o
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = Child()
+        control.insert(0, o)
+        p.children.insert(0, o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = Child()
+        control.insert(3, o)
+        p.children.insert(3, o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = Child()
+        control.insert(999, o)
+        p.children.insert(999, o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        del control[0:1]
+        del p.children[0:1]
+        assert control == p.children
+        assert control == list(p.children)
+
+        del control[1:1]
+        del p.children[1:1]
+        assert control == p.children
+        assert control == list(p.children)
+
+        del control[1:3]
+        del p.children[1:3]
+        assert control == p.children
+        assert control == list(p.children)
+
+        del control[7:]
+        del p.children[7:]
+        assert control == p.children
+        assert control == list(p.children)
+
+        assert control.pop() == p.children.pop()
+        assert control == p.children
+        assert control == list(p.children)
+
+        assert control.pop(0) == p.children.pop(0)
+        assert control == p.children
+        assert control == list(p.children)
+
+        assert control.pop(2) == p.children.pop(2)
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = Child()
+        control.insert(2, o)
+        p.children.insert(2, o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        control.remove(o)
+        p.children.remove(o)
+        assert control == p.children
+        assert control == list(p.children)
+
+    def test_custom(self):
+        class Parent(object):
+            pass
+        class Child(object):
+            pass
+
+        class MyCollection(object):
+            def __init__(self):
+                self.data = []
+            @collection.appender
+            def append(self, value):
+                self.data.append(value)
+            @collection.remover
+            def remove(self, value):
+                self.data.remove(value)
+            @collection.iterator
+            def __iter__(self):
+                return iter(self.data)
+
+        mapper(Parent, sometable, properties={
+            'children':relation(Child, collection_class=MyCollection)
+        })
+        mapper(Child, someothertable)
+
+        control = list()
+        p1 = Parent()
+
+        o = Child()
+        control.append(o)
+        p1.children.append(o)
+        assert control == list(p1.children)
+
+        o = Child()
+        control.append(o)
+        p1.children.append(o)
+        assert control == list(p1.children)
+
+        o = Child()
+        control.append(o)
+        p1.children.append(o)
+        assert control == list(p1.children)
+
+        sess = create_session()
+        sess.save(p1)
+        sess.flush()
+        sess.clear()
+
+        p2 = sess.query(Parent).get(p1.col1)
+        o = list(p2.children)
+        assert len(o) == 3
+
 if __name__ == "__main__":
     testenv.main()
index a5e361aa761968a28d51607f379a765fc21fd8df..ed003927bb64d45047ce7ae0d3ce0db4779f6c27 100644 (file)
@@ -42,27 +42,19 @@ class RelationTest1(ORMTest):
             pass
         class Manager(Person):
             pass
-
-        mapper(Person, people, properties={
-            'manager':relation(Manager, primaryjoin=people.c.manager_id==managers.c.person_id, uselist=False)
-        })
-        mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id)
-
-        self.assertRaisesMessage(exceptions.ArgumentError, 
-            r"Can't determine relation direction for relationship 'Person\.manager \(Manager\)' - foreign key columns are present in both the parent and the child's mapped tables\.  Specify 'foreign_keys' argument\.",
-            compile_mappers
-        )
-        clear_mappers()
-
+        
+        # note that up until recently (0.4.4), we had to specify "foreign_keys" here
+        # for this primary join.  
         mapper(Person, people, properties={
             'manager':relation(Manager, primaryjoin=(people.c.manager_id ==
                                                      managers.c.person_id),
-                               foreign_keys=[people.c.manager_id],
                                uselist=False, post_update=True)
         })
         mapper(Manager, managers, inherits=Person,
                inherit_condition=people.c.person_id==managers.c.person_id)
-
+        
+        self.assertEquals(class_mapper(Person).get_property('manager').foreign_keys, set([people.c.manager_id]))
+        
         session = create_session()
         p = Person(name='some person')
         m = Manager(name='some manager')
index 287ee053c6d2ed73bf51253d050a714a653292d3..6583e5584fae6bab445815479798c5000e124e5f 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy.orm import *
 from sqlalchemy.orm import collections
 from sqlalchemy.orm.collections import collection
 from testlib import *
+from testlib import fixtures
 
 class RelationTest(TestBase):
     """An extended topological sort test
@@ -757,291 +758,8 @@ class TypedAssociationTable(ORMTest):
 
         assert t3.count().scalar() == 1
 
-# TODO: move these tests to either attributes.py test or its own module
-class CustomCollectionsTest(ORMTest):
-    def define_tables(self, metadata):
-        global sometable, someothertable
-        sometable = Table('sometable', metadata,
-            Column('col1',Integer, primary_key=True),
-            Column('data', String(30)))
-        someothertable = Table('someothertable', metadata,
-            Column('col1', Integer, primary_key=True),
-            Column('scol1', Integer, ForeignKey(sometable.c.col1)),
-            Column('data', String(20))
-        )
-    def testbasic(self):
-        class MyList(list):
-            pass
-        class Foo(object):
-            pass
-        class Bar(object):
-            pass
-        mapper(Foo, sometable, properties={
-            'bars':relation(Bar, collection_class=MyList)
-        })
-        mapper(Bar, someothertable)
-        f = Foo()
-        assert isinstance(f.bars, MyList)
-    def testlazyload(self):
-        """test that a 'set' can be used as a collection and can lazyload."""
-        class Foo(object):
-            pass
-        class Bar(object):
-            pass
-        mapper(Foo, sometable, properties={
-            'bars':relation(Bar, collection_class=set)
-        })
-        mapper(Bar, someothertable)
-        f = Foo()
-        f.bars.add(Bar())
-        f.bars.add(Bar())
-        sess = create_session()
-        sess.save(f)
-        sess.flush()
-        sess.clear()
-        f = sess.query(Foo).get(f.col1)
-        assert len(list(f.bars)) == 2
-        f.bars.clear()
-
-    def testdict(self):
-        """test that a 'dict' can be used as a collection and can lazyload."""
-
-        class Foo(object):
-            pass
-        class Bar(object):
-            pass
-        class AppenderDict(dict):
-            @collection.appender
-            def set(self, item):
-                self[id(item)] = item
-            @collection.remover
-            def remove(self, item):
-                if id(item) in self:
-                    del self[id(item)]
-
-        mapper(Foo, sometable, properties={
-            'bars':relation(Bar, collection_class=AppenderDict)
-        })
-        mapper(Bar, someothertable)
-        f = Foo()
-        f.bars.set(Bar())
-        f.bars.set(Bar())
-        sess = create_session()
-        sess.save(f)
-        sess.flush()
-        sess.clear()
-        f = sess.query(Foo).get(f.col1)
-        assert len(list(f.bars)) == 2
-        f.bars.clear()
-
-    def testdictwrapper(self):
-        """test that the supplied 'dict' wrapper can be used as a collection and can lazyload."""
-
-        class Foo(object):
-            pass
-        class Bar(object):
-            def __init__(self, data): self.data = data
-
-        mapper(Foo, sometable, properties={
-            'bars':relation(Bar,
-                collection_class=collections.column_mapped_collection(someothertable.c.data))
-        })
-        mapper(Bar, someothertable)
-
-        f = Foo()
-        col = collections.collection_adapter(f.bars)
-        col.append_with_event(Bar('a'))
-        col.append_with_event(Bar('b'))
-        sess = create_session()
-        sess.save(f)
-        sess.flush()
-        sess.clear()
-        f = sess.query(Foo).get(f.col1)
-        assert len(list(f.bars)) == 2
-
-        existing = set([id(b) for b in f.bars.values()])
-
-        col = collections.collection_adapter(f.bars)
-        col.append_with_event(Bar('b'))
-        f.bars['a'] = Bar('a')
-        sess.flush()
-        sess.clear()
-        f = sess.query(Foo).get(f.col1)
-        assert len(list(f.bars)) == 2
-
-        replaced = set([id(b) for b in f.bars.values()])
-        self.assert_(existing != replaced)
-
-    def testlist(self):
-        class Parent(object):
-            pass
-        class Child(object):
-            pass
-
-        mapper(Parent, sometable, properties={
-            'children':relation(Child, collection_class=list)
-        })
-        mapper(Child, someothertable)
-
-        control = list()
-        p = Parent()
-
-        o = Child()
-        control.append(o)
-        p.children.append(o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = [Child(), Child(), Child(), Child()]
-        control.extend(o)
-        p.children.extend(o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        assert control[0] == p.children[0]
-        assert control[-1] == p.children[-1]
-        assert control[1:3] == p.children[1:3]
-
-        del control[1]
-        del p.children[1]
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = [Child()]
-        control[1:3] = o
-        p.children[1:3] = o
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = [Child(), Child(), Child(), Child()]
-        control[1:3] = o
-        p.children[1:3] = o
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = [Child(), Child(), Child(), Child()]
-        control[-1:-2] = o
-        p.children[-1:-2] = o
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = [Child(), Child(), Child(), Child()]
-        control[4:] = o
-        p.children[4:] = o
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = Child()
-        control.insert(0, o)
-        p.children.insert(0, o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = Child()
-        control.insert(3, o)
-        p.children.insert(3, o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = Child()
-        control.insert(999, o)
-        p.children.insert(999, o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        del control[0:1]
-        del p.children[0:1]
-        assert control == p.children
-        assert control == list(p.children)
-
-        del control[1:1]
-        del p.children[1:1]
-        assert control == p.children
-        assert control == list(p.children)
-
-        del control[1:3]
-        del p.children[1:3]
-        assert control == p.children
-        assert control == list(p.children)
-
-        del control[7:]
-        del p.children[7:]
-        assert control == p.children
-        assert control == list(p.children)
-
-        assert control.pop() == p.children.pop()
-        assert control == p.children
-        assert control == list(p.children)
-
-        assert control.pop(0) == p.children.pop(0)
-        assert control == p.children
-        assert control == list(p.children)
-
-        assert control.pop(2) == p.children.pop(2)
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = Child()
-        control.insert(2, o)
-        p.children.insert(2, o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        control.remove(o)
-        p.children.remove(o)
-        assert control == p.children
-        assert control == list(p.children)
-
-    def testobj(self):
-        class Parent(object):
-            pass
-        class Child(object):
-            pass
-
-        class MyCollection(object):
-            def __init__(self):
-                self.data = []
-            @collection.appender
-            def append(self, value):
-                self.data.append(value)
-            @collection.remover
-            def remove(self, value):
-                self.data.remove(value)
-            @collection.iterator
-            def __iter__(self):
-                return iter(self.data)
-
-        mapper(Parent, sometable, properties={
-            'children':relation(Child, collection_class=MyCollection)
-        })
-        mapper(Child, someothertable)
-
-        control = list()
-        p1 = Parent()
-
-        o = Child()
-        control.append(o)
-        p1.children.append(o)
-        assert control == list(p1.children)
-
-        o = Child()
-        control.append(o)
-        p1.children.append(o)
-        assert control == list(p1.children)
-
-        o = Child()
-        control.append(o)
-        p1.children.append(o)
-        assert control == list(p1.children)
-
-        sess = create_session()
-        sess.save(p1)
-        sess.flush()
-        sess.clear()
-
-        p2 = sess.query(Parent).get(p1.col1)
-        o = list(p2.children)
-        assert len(o) == 3
+        
+    
 
 class ViewOnlyTest(ORMTest):
     """test a view_only mapping where a third table is pulled into the primary join condition,
@@ -1144,6 +862,191 @@ class ViewOnlyTest2(ORMTest):
         assert set([x.t2id for x in c1.t2s]) == set([c2a.t2id, c2b.t2id])
         assert set([x.t2id for x in c1.t2_view]) == set([c2b.t2id])
 
+class ViewOnlyTest3(ORMTest):
+    def define_tables(self, metadata):
+        global foos, bars
+        foos = Table('foos', metadata, Column('id', Integer, primary_key=True))
+        bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer))
+
+    def test_viewonly_join(self):
+        class Foo(fixtures.Base):
+            pass
+        class Bar(fixtures.Base):
+            pass
+
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, foreign_keys=[bars.c.fid], viewonly=True)
+        })
+
+        mapper(Bar, bars)
+
+        sess = create_session()
+        sess.save(Foo(id=4))
+        sess.save(Foo(id=9))
+        sess.save(Bar(id=1, fid=2))
+        sess.save(Bar(id=2, fid=3))
+        sess.save(Bar(id=3, fid=6))
+        sess.save(Bar(id=4, fid=7))
+        sess.flush()
+
+        sess = create_session()
+        self.assertEquals(sess.query(Foo).filter_by(id=4).one(), Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)]))
+        self.assertEquals(sess.query(Foo).filter_by(id=9).one(), Foo(id=9, bars=[Bar(fid=2), Bar(fid=3), Bar(fid=6), Bar(fid=7)]))
+
+class InvalidRelationEscalationTest(ORMTest):
+    def define_tables(self, metadata):
+        global foos, bars, Foo, Bar
+        foos = Table('foos', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer))
+        bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer))
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+            
+    def test_no_join(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+
+    def test_no_join_self_ref(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+        
+    def test_no_equated(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+    def test_no_equated_fks(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, foreign_keys=bars.c.fid)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated column pairs for primaryjoin condition", compile_mappers)
+
+    def test_no_equated_self_ref(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+    def test_no_equated_self_ref(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, foreign_keys=[foos.c.fid])
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated column pairs for primaryjoin condition", compile_mappers)
+
+    def test_no_equated_viewonly(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, viewonly=True)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+    def test_no_equated_self_ref_viewonly(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, viewonly=True)
+        })
+
+        mapper(Bar, bars)
+
+        self.assertRaisesMessage(exceptions.ArgumentError, "Specify the foreign_keys argument to indicate which columns on the relation are foreign.", compile_mappers)
+
+    def test_no_equated_self_ref_viewonly_fks(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, viewonly=True, foreign_keys=[foos.c.fid])
+        })
+        compile_mappers()
+        self.assertEquals(Foo.foos.property.equated_pairs, [(foos.c.id, foos.c.fid)])
+
+    def test_equated(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=foos.c.id==bars.c.fid)
+        })
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+    
+    def test_equated_self_ref(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid)
+        })
+
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+    def test_equated_self_ref_wrong_fks(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid, foreign_keys=[bars.c.id])
+        })
+
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+class InvalidRelationEscalationTestM2M(ORMTest):
+    def define_tables(self, metadata):
+        global foos, bars, Foo, Bar, foobars
+        foos = Table('foos', metadata, Column('id', Integer, primary_key=True))
+        foobars = Table('foobars', metadata, Column('fid', Integer), Column('bid', Integer))
+        bars = Table('bars', metadata, Column('id', Integer, primary_key=True))
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+
+    def test_no_join(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, secondary=foobars)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+
+    def test_no_secondaryjoin(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id>foobars.c.fid)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+
+    def test_bad_primaryjoin(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id>foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+    def test_bad_secondaryjoin(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id==foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id, foreign_keys=[foobars.c.fid])
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for secondaryjoin condition", compile_mappers)
+
+    def test_no_equated_secondaryjoin(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id==foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id, foreign_keys=[foobars.c.fid, foobars.c.bid])
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated column pairs for secondaryjoin condition", compile_mappers)
+
 
 if __name__ == "__main__":
     testenv.main()