]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
this version has easy cases going well. hard cases not so much
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Feb 2012 00:49:26 +0000 (19:49 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Feb 2012 00:49:26 +0000 (19:49 -0500)
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/relationships.py
test/orm/test_rel_fn.py

index a590ad7069d3d8b8d46b3658028e4ab6c34488a2..5b883a8f5750ad8615ec9bcb9955c6917ba00d30 100644 (file)
@@ -33,9 +33,9 @@ from descriptor_props import CompositeProperty, SynonymProperty, \
 
 class ColumnProperty(StrategizedProperty):
     """Describes an object attribute that corresponds to a table column.
-    
+
     Public constructor is the :func:`.orm.column_property` function.
-    
+
     """
 
     def __init__(self, *columns, **kwargs):
@@ -176,13 +176,13 @@ log.class_logger(ColumnProperty)
 class RelationshipProperty(StrategizedProperty):
     """Describes an object property that holds a single item or list
     of items that correspond to a related database table.
-    
+
     Public constructor is the :func:`.orm.relationship` function.
-    
+
     Of note here is the :class:`.RelationshipProperty.Comparator`
     class, which implements comparison operations for scalar-
     and collection-referencing mapped attributes.
-    
+
     """
 
     strategy_wildcard_key = 'relationship:*'
@@ -292,7 +292,7 @@ class RelationshipProperty(StrategizedProperty):
         def __init__(self, prop, mapper, of_type=None, adapter=None):
             """Construction of :class:`.RelationshipProperty.Comparator`
             is internal to the ORM's attribute mechanics.
-            
+
             """
             self.prop = prop
             self.mapper = mapper
@@ -331,10 +331,10 @@ class RelationshipProperty(StrategizedProperty):
         def of_type(self, cls):
             """Produce a construct that represents a particular 'subtype' of
             attribute for the parent class.
-            
+
             Currently this is usable in conjunction with :meth:`.Query.join`
             and :meth:`.Query.outerjoin`.
-            
+
             """
             return RelationshipProperty.Comparator(
                                         self.property, 
@@ -344,7 +344,7 @@ class RelationshipProperty(StrategizedProperty):
         def in_(self, other):
             """Produce an IN clause - this is not implemented 
             for :func:`~.orm.relationship`-based attributes at this time.
-            
+
             """
             raise NotImplementedError('in_() not yet supported for '
                     'relationships.  For a simple many-to-one, use '
@@ -361,15 +361,15 @@ class RelationshipProperty(StrategizedProperty):
 
             this will typically produce a
             clause such as::
-  
+
               mytable.related_id == <some id>
-  
+
             Where ``<some id>`` is the primary key of the given 
             object.
-  
+
             The ``==`` operator provides partial functionality for non-
             many-to-one comparisons:
-  
+
             * Comparisons against collections are not supported.
               Use :meth:`~.RelationshipProperty.Comparator.contains`.
             * Compared to a scalar one-to-many, will produce a 
@@ -465,42 +465,42 @@ class RelationshipProperty(StrategizedProperty):
         def any(self, criterion=None, **kwargs):
             """Produce an expression that tests a collection against
             particular criterion, using EXISTS.
-            
+
             An expression like::
-            
+
                 session.query(MyClass).filter(
                     MyClass.somereference.any(SomeRelated.x==2)
                 )
-                
-                
+
+
             Will produce a query like::
-            
+
                 SELECT * FROM my_table WHERE
                 EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id 
                 AND related.x=2)
-                
+
             Because :meth:`~.RelationshipProperty.Comparator.any` uses
             a correlated subquery, its performance is not nearly as
             good when compared against large target tables as that of
             using a join.
-            
+
             :meth:`~.RelationshipProperty.Comparator.any` is particularly
             useful for testing for empty collections::
-            
+
                 session.query(MyClass).filter(
                     ~MyClass.somereference.any()
                 )
-            
+
             will produce::
-            
+
                 SELECT * FROM my_table WHERE
                 NOT EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id)
-                
+
             :meth:`~.RelationshipProperty.Comparator.any` is only
             valid for collections, i.e. a :func:`.relationship`
             that has ``uselist=True``.  For scalar references,
             use :meth:`~.RelationshipProperty.Comparator.has`.
-            
+
             """
             if not self.property.uselist:
                 raise sa_exc.InvalidRequestError(
@@ -515,14 +515,14 @@ class RelationshipProperty(StrategizedProperty):
             particular criterion, using EXISTS.
 
             An expression like::
-            
+
                 session.query(MyClass).filter(
                     MyClass.somereference.has(SomeRelated.x==2)
                 )
-                
-                
+
+
             Will produce a query like::
-            
+
                 SELECT * FROM my_table WHERE
                 EXISTS (SELECT 1 FROM related WHERE related.id==my_table.related_id
                 AND related.x=2)
@@ -531,12 +531,12 @@ class RelationshipProperty(StrategizedProperty):
             a correlated subquery, its performance is not nearly as
             good when compared against large target tables as that of
             using a join.
-            
+
             :meth:`~.RelationshipProperty.Comparator.has` is only
             valid for scalar references, i.e. a :func:`.relationship`
             that has ``uselist=False``.  For collection references,
             use :meth:`~.RelationshipProperty.Comparator.any`.
-            
+
             """
             if self.property.uselist:
                 raise sa_exc.InvalidRequestError(
@@ -547,44 +547,44 @@ class RelationshipProperty(StrategizedProperty):
         def contains(self, other, **kwargs):
             """Return a simple expression that tests a collection for 
             containment of a particular item.
-            
+
             :meth:`~.RelationshipProperty.Comparator.contains` is
             only valid for a collection, i.e. a
             :func:`~.orm.relationship` that implements
             one-to-many or many-to-many with ``uselist=True``.
-            
+
             When used in a simple one-to-many context, an 
             expression like::
-            
+
                 MyClass.contains(other)
-                
+
             Produces a clause like::
-            
+
                 mytable.id == <some id>
-                
+
             Where ``<some id>`` is the value of the foreign key
             attribute on ``other`` which refers to the primary
             key of its parent object. From this it follows that
             :meth:`~.RelationshipProperty.Comparator.contains` is
             very useful when used with simple one-to-many
             operations.
-            
+
             For many-to-many operations, the behavior of
             :meth:`~.RelationshipProperty.Comparator.contains`
             has more caveats. The association table will be
             rendered in the statement, producing an "implicit"
             join, that is, includes multiple tables in the FROM
             clause which are equated in the WHERE clause::
-            
+
                 query(MyClass).filter(MyClass.contains(other))
-                
+
             Produces a query like::
-            
+
                 SELECT * FROM my_table, my_association_table AS
                 my_association_table_1 WHERE
                 my_table.id = my_association_table_1.parent_id
                 AND my_association_table_1.child_id = <some id>
-                
+
             Where ``<some id>`` would be the primary key of
             ``other``. From the above, it is clear that
             :meth:`~.RelationshipProperty.Comparator.contains`
@@ -598,7 +598,7 @@ class RelationshipProperty(StrategizedProperty):
             a less-performant alternative using EXISTS, or refer
             to :meth:`.Query.outerjoin` as well as :ref:`ormtutorial_joins`
             for more details on constructing outer joins.
-            
+
             """
             if not self.property.uselist:
                 raise sa_exc.InvalidRequestError(
@@ -649,19 +649,19 @@ class RelationshipProperty(StrategizedProperty):
             """Implement the ``!=`` operator.
 
             In a many-to-one context, such as::
-  
+
               MyClass.some_prop != <some object>
-  
+
             This will typically produce a clause such as::
-  
+
               mytable.related_id != <some id>
-  
+
             Where ``<some id>`` is the primary key of the
             given object.
-  
+
             The ``!=`` operator provides partial functionality for non-
             many-to-one comparisons:
-  
+
             * Comparisons against collections are not supported.
               Use
               :meth:`~.RelationshipProperty.Comparator.contains`
@@ -682,7 +682,7 @@ class RelationshipProperty(StrategizedProperty):
               membership tests.
             * Comparisons against ``None`` given in a one-to-many
               or many-to-many context produce an EXISTS clause.
-                
+
             """
             if isinstance(other, (NoneType, expression._Null)):
                 if self.property.direction == MANYTOONE:
@@ -880,9 +880,9 @@ class RelationshipProperty(StrategizedProperty):
     def mapper(self):
         """Return the targeted :class:`.Mapper` for this 
         :class:`.RelationshipProperty`.
-        
+
         This is a lazy-initializing static attribute.
-        
+
         """
         if isinstance(self.argument, type):
             mapper_ = mapper.class_mapper(self.argument,
@@ -914,14 +914,43 @@ class RelationshipProperty(StrategizedProperty):
     def do_init(self):
         self._check_conflicts()
         self._process_dependent_arguments()
+        self._create_new_thing()
         self._determine_joins()
         self._determine_synchronize_pairs()
         self._determine_direction()
         self._determine_local_remote_pairs()
+        self._test_new_thing()
         self._post_init()
         self._generate_backref()
         super(RelationshipProperty, self).do_init()
 
+    def _create_new_thing(self):
+        import relationships
+        self.jc = relationships.JoinCondition(
+                    parent_selectable=self.parent.mapped_table,
+                    child_selectable=self.mapper.mapped_table,
+                    parent_local_selectable=self.parent.local_table,
+                    child_local_selectable=self.mapper.local_table,
+                    primaryjoin=self.primaryjoin,
+                    secondary=self.secondary,
+                    secondaryjoin=self.secondaryjoin,
+                    parent_equivalents=self.parent._equivalent_columns,
+                    child_equivalents=self.mapper._equivalent_columns,
+                    consider_as_foreign_keys=self._user_defined_foreign_keys,
+                    local_remote_pairs=self.local_remote_pairs,
+                    remote_side=self.remote_side,
+                    self_referential=self._is_self_referential,
+                    prop=self,
+                    support_sync=not self.viewonly,
+                    can_be_synced_fn=self._columns_are_mapped
+
+        )
+
+    def _test_new_thing(self):
+        assert self.jc.direction is self.direction
+        assert self.jc.remote_side == self.remote_side
+        assert self.jc.local_remote_pairs == self.local_remote_pairs
+
     def _check_conflicts(self):
         """Test that this relationship is legal, warn about 
         inheritance conflicts."""
@@ -952,9 +981,9 @@ class RelationshipProperty(StrategizedProperty):
     def _process_dependent_arguments(self):
         """Convert incoming configuration arguments to their 
         proper form.
-        
+
         Callables are resolved, ORM annotations removed.
-        
+
         """
         # accept callables for other attributes which may require
         # deferred initialization.  This technique is used
@@ -1011,10 +1040,10 @@ class RelationshipProperty(StrategizedProperty):
     def _determine_joins(self):
         """Determine the 'primaryjoin' and 'secondaryjoin' attributes,
         if not passed to the constructor already.
-        
+
         This is based on analysis of the foreign key relationships
         between the parent and target mapped selectables.
-        
+
         """
         if self.secondaryjoin is not None and self.secondary is None:
             raise sa_exc.ArgumentError("Property '" + self.key
@@ -1056,7 +1085,7 @@ class RelationshipProperty(StrategizedProperty):
     def _columns_are_mapped(self, *cols):
         """Return True if all columns in the given collection are 
         mapped by the tables referenced by this :class:`.Relationship`.
-        
+
         """
         for c in cols:
             if self.secondary is not None \
@@ -1071,11 +1100,11 @@ class RelationshipProperty(StrategizedProperty):
         """Determine a list of "source"/"destination" column pairs
         based on the given join condition, as well as the
         foreign keys argument.
-        
+
         "source" would be a column referenced by a foreign key,
         and "destination" would be the column who has a foreign key
         reference to "source".
-        
+
         """
 
         fks = self._user_defined_foreign_keys
@@ -1186,7 +1215,7 @@ class RelationshipProperty(StrategizedProperty):
     def _determine_synchronize_pairs(self):
         """Resolve 'primary'/foreign' column pairs from the primaryjoin
         and secondaryjoin arguments.
-        
+
         """
         if self.local_remote_pairs:
             if not self._user_defined_foreign_keys:
@@ -1221,10 +1250,10 @@ class RelationshipProperty(StrategizedProperty):
     def _determine_direction(self):
         """Determine if this relationship is one to many, many to one, 
         many to many.
-        
+
         This is derived from the primaryjoin, presence of "secondary",
         and in the case of self-referential the "remote side".
-        
+
         """
         if self.secondaryjoin is not None:
             self.direction = MANYTOMANY
@@ -1304,7 +1333,7 @@ class RelationshipProperty(StrategizedProperty):
         """Determine pairs of columns representing "local" to 
         "remote", where "local" columns are on the parent mapper,
         "remote" are on the target mapper.
-        
+
         These pairs are used on the load side only to generate
         lazy loading clauses.
 
index bb70f8a11b6a23597647b0231d04d6e3eb0796f3..9aebc9f8ae9703e99b1f527a8af8b14a4f770e25 100644 (file)
@@ -33,29 +33,30 @@ class JoinCondition(object):
                     consider_as_foreign_keys=None,
                     local_remote_pairs=None,
                     remote_side=None,
-                    extra_child_criterion=None,
                     self_referential=False,
                     prop=None,
-                    support_sync=True
+                    support_sync=True,
+                    can_be_synced_fn=lambda c: True
                     ):
         self.parent_selectable = parent_selectable
         self.parent_local_selectable = parent_local_selectable
         self.child_selectable = child_selectable
-        self.child_local_selecatble = child_local_selectable
+        self.child_local_selectable = child_local_selectable
         self.parent_equivalents = parent_equivalents
         self.child_equivalents = child_equivalents
         self.primaryjoin = primaryjoin
         self.secondaryjoin = secondaryjoin
         self.secondary = secondary
-        self.extra_child_criterion = extra_child_criterion
         self.consider_as_foreign_keys = consider_as_foreign_keys
-        self.local_remote_pairs = local_remote_pairs
-        self.remote_side = remote_side
+        self._local_remote_pairs = local_remote_pairs
+        self._remote_side = remote_side
         self.prop = prop
         self.self_referential = self_referential
         self.support_sync = support_sync
+        self.can_be_synced_fn = can_be_synced_fn
         self._determine_joins()
         self._parse_joins()
+        self._determine_direction()
 
     def _determine_joins(self):
         """Determine the 'primaryjoin' and 'secondaryjoin' attributes,
@@ -89,7 +90,7 @@ class JoinCondition(object):
                         join_condition(
                                 self.parent_selectable, 
                                 self.secondary, 
-                                a_subset=parent_local_selectable)
+                                a_subset=self.parent_local_selectable)
             else:
                 if self.primaryjoin is None:
                     self.primaryjoin = \
@@ -108,14 +109,32 @@ class JoinCondition(object):
     def _parse_joins(self):
         """Apply 'remote', 'local' and 'foreign' annotations
         to the primary and secondary join conditions.
-        
+
         """
         parentcols = util.column_set(self.parent_selectable.c)
         targetcols = util.column_set(self.child_selectable.c)
+        if self.secondary is not None:
+            secondarycols = util.column_set(self.secondary.c)
+        else:
+            secondarycols = set()
 
         def col_is(a, b):
             return a.compare(b)
 
+        def refers_to_parent_table(binary):
+            pt = self.parent_selectable
+            mt = self.child_selectable
+            c, f = binary.left, binary.right
+            if (
+                pt.is_derived_from(c.table) and \
+                pt.is_derived_from(f.table) and \
+                mt.is_derived_from(c.table) and \
+                mt.is_derived_from(f.table)
+            ):
+                return True
+            else:
+                return False
+
         def is_foreign(a, b):
             if self.consider_as_foreign_keys:
                 if a in self.consider_as_foreign_keys and (
@@ -134,12 +153,49 @@ class JoinCondition(object):
                     return a
                 elif b.references(a):
                     return b
+            elif secondarycols:
+                if a in secondarycols and b not in secondarycols:
+                    return a
+                elif b in secondarycols and a not in secondarycols:
+                    return b
 
-        any_operator = not self.support_sync
+        def _annotate_fk(binary, switch):
+            if switch:
+                right, left = binary.left, binary.right
+            else:
+                left, right = binary.left, binary.right
+            can_be_synced = self.can_be_synced_fn(left)
+            left = left._annotate({
+                "equated":binary.operator is operators.eq,
+                "can_be_synced":can_be_synced and \
+                    binary.operator is operators.eq
+            })
+            right = right._annotate({
+                "equated":binary.operator is operators.eq,
+                "referent":True
+            })
+            if switch:
+                binary.right, binary.left = left, right
+            else:
+                binary.left, binary.right = left, right
+
+        def _annotate_remote(binary, switch):
+            if switch:
+                right, left = binary.left, binary.right
+            else:
+                left, right = binary.left, binary.right
+            left = left._annotate(
+                                {"remote":True})
+            if right in parentcols or \
+                secondarycols and right in targetcols:
+                right = right._annotate(
+                                {"local":True})
+            if switch:
+                binary.right, binary.left = left, right
+            else:
+                binary.left, binary.right = left, right
 
         def visit_binary(binary):
-            #if not any_operator and binary.operator is not operators.eq:
-            #    return
             if not isinstance(binary.left, sql.ColumnElement) or \
                         not isinstance(binary.right, sql.ColumnElement):
                 return
@@ -156,49 +212,247 @@ class JoinCondition(object):
                                             {"foreign":True})
                     # TODO: when the two cols are the same.
 
+            has_foreign = False
+            if "foreign" in binary.left._annotations:
+                _annotate_fk(binary, False)
+                has_foreign = True
+            if "foreign" in binary.right._annotations:
+                _annotate_fk(binary, True)
+                has_foreign = True
+
             if "remote" not in binary.left._annotations and \
                 "remote" not in binary.right._annotations:
-                if self.local_remote_pairs:
+                if self._local_remote_pairs:
                     raise NotImplementedError()
-                elif self.remote_side:
-                    raise NotImplementedError()
-                elif self.self_referential:
-                    # assume one to many - FKs are "Remote"
+                elif self._remote_side:
+                    if binary.left in self._remote_side:
+                        _annotate_remote(binary, False)
+                    elif binary.right in self._remote_side:
+                        _annotate_remote(binary, True)
+                elif refers_to_parent_table(binary):
+                    # assume one to many - FKs are "remote"
                     if "foreign" in binary.left._annotations:
-                        binary.left = binary.left._annotate(
-                                            {"remote":True})
-                        if binary.right in parentcols:
-                            binary.right = binary.right._annotate(
-                                            {"local":True})
+                        _annotate_remote(binary, False)
                     elif "foreign" in binary.right._annotations:
-                        binary.right = binary.right._annotate(
-                                            {"remote":True})
-                        if binary.left in parentcols:
-                            binary.left = binary.left._annotate(
-                                            {"local":True})
+                        _annotate_remote(binary, True)
+                elif secondarycols:
+                    if binary.left in secondarycols:
+                        _annotate_remote(binary, False)
+                    elif binary.right in secondarycols:
+                        _annotate_remote(binary, True)
                 else:
-                    if binary.left in targetcols:
-                        binary.left = binary.left._annotate(
-                                            {"remote":True})
-                        if binary.right in parentcols:
-                            binary.right = binary.right._annotate(
-                                            {"local":True})
-                    elif binary.right in targetcols:
-                        binary.right = binary.right._annotate(
-                                            {"remote":True})
-                        if binary.left in parentcols:
-                            binary.left = binary.left._annotate(
-                                            {"local":True})
+                    if binary.left in targetcols and has_foreign:
+                        _annotate_remote(binary, False)
+                    elif binary.right in targetcols and has_foreign:
+                        _annotate_remote(binary, True)
 
         self.primaryjoin = visitors.cloned_traverse(
             self.primaryjoin,
             {},
             {"binary":visit_binary}
         )
+        if self.secondaryjoin is not None:
+            self.secondaryjoin = visitors.cloned_traverse(
+                self.secondaryjoin,
+                {},
+                {"binary":visit_binary}
+            )
+        self._check_foreign_cols(
+                        self.primaryjoin, True)
+        if self.secondaryjoin is not None:
+            self._check_foreign_cols(
+                        self.secondaryjoin, False)
+
+
+    def _check_foreign_cols(self, join_condition, primary):
+        """Check the foreign key columns collected and emit error messages."""
+        # TODO: don't worry, we can simplify this once we
+        # encourage configuration via direct annotation
+
+        can_sync = False
+
+        foreign_cols = self._gather_columns_with_annotation(
+                                join_condition, "foreign")
+
+        has_foreign = bool(foreign_cols)
+
+        if self.support_sync:
+            for col in foreign_cols:
+                if col._annotations.get("can_be_synced"):
+                    can_sync = True
+                    break
+
+        if self.support_sync and can_sync or \
+            (not self.support_sync and has_foreign):
+            return
+
+        # from here below is just determining the best error message
+        # to report.  Check for a join condition using any operator 
+        # (not just ==), perhaps they need to turn on "viewonly=True".
+        if self.support_sync and has_foreign and not can_sync:
+
+            err = "Could not locate any "\
+                    "foreign-key-equated, locally mapped column "\
+                    "pairs for %s "\
+                    "condition '%s' on relationship %s." % (
+                        primary and 'primaryjoin' or 'secondaryjoin', 
+                        join_condition, 
+                        self.prop
+                    )
+
+            # TODO: this needs to be changed to detect that
+            # annotations were present and whatnot.   the future
+            # foreignkey(col) annotation will cover establishing
+            # the col as foreign to it's mate
+            if not self.consider_as_foreign_keys:
+                err += "  Ensure that the "\
+                        "referencing Column objects have a "\
+                        "ForeignKey present, or are otherwise part "\
+                        "of a ForeignKeyConstraint on their parent "\
+                        "Table, or specify the foreign_keys parameter "\
+                        "to this relationship."
+
+            err += "  For more "\
+                    "relaxed rules on join conditions, the "\
+                    "relationship may be marked as viewonly=True."
+
+            raise sa_exc.ArgumentError(err)
+        else:
+            if self.consider_as_foreign_keys:
+                raise sa_exc.ArgumentError("Could not determine "
+                        "relationship direction for %s condition "
+                        "'%s', on relationship %s, using manual "
+                        "'foreign_keys' setting.  Do the columns "
+                        "in 'foreign_keys' represent all, and "
+                        "only, the 'foreign' columns in this join "
+                        "condition?  Does the %s Table already "
+                        "have adequate ForeignKey and/or "
+                        "ForeignKeyConstraint objects established "
+                        "(in which case 'foreign_keys' is usually "
+                        "unnecessary)?" 
+                        % (
+                            primary and 'primaryjoin' or 'secondaryjoin',
+                            join_condition, 
+                            self.prop,
+                            primary and 'mapped' or 'secondary'
+                        ))
+            else:
+                raise sa_exc.ArgumentError("Could not determine "
+                        "relationship direction for %s condition "
+                        "'%s', on relationship %s. Ensure that the "
+                        "referencing Column objects have a "
+                        "ForeignKey present, or are otherwise part "
+                        "of a ForeignKeyConstraint on their parent "
+                        "Table, or specify the foreign_keys parameter " 
+                        "to this relationship."
+                        % (
+                            primary and 'primaryjoin' or 'secondaryjoin', 
+                            join_condition, 
+                            self.prop
+                        ))
+
+    def _determine_direction(self):
+        """Determine if this relationship is one to many, many to one, 
+        many to many.
+
+        """
+        if self.secondaryjoin is not None:
+            self.direction = MANYTOMANY
+        else:
+            parentcols = util.column_set(self.parent_selectable.c)
+            targetcols = util.column_set(self.child_selectable.c)
+
+            # fk collection which suggests ONETOMANY.
+            onetomany_fk = targetcols.intersection(
+                            self.foreign_key_columns)
+
+            # fk collection which suggests MANYTOONE.
+
+            manytoone_fk = parentcols.intersection(
+                            self.foreign_key_columns)
+
+            if onetomany_fk and manytoone_fk:
+                # fks on both sides.  test for overlap of local/remote
+                # with foreign key
+                onetomany_local = self.remote_side.intersection(self.foreign_key_columns)
+                manytoone_local = self.local_columns.intersection(self.foreign_key_columns)
+                if onetomany_local and not manytoone_local:
+                    self.direction = ONETOMANY
+                elif manytoone_local and not onetomany_local:
+                    self.direction = MANYTOONE
+                else:
+                    raise sa_exc.ArgumentError(
+                            "Can't determine relationship"
+                            " direction for relationship '%s' - foreign "
+                            "key columns are present in both the parent "
+                            "and the child's mapped tables.  Specify "
+                            "'foreign_keys' argument." % self.prop)
+            elif onetomany_fk:
+                self.direction = ONETOMANY
+            elif manytoone_fk:
+                self.direction = MANYTOONE
+            else:
+                raise sa_exc.ArgumentError("Can't determine relationship "
+                        "direction for relationship '%s' - foreign "
+                        "key columns are present in neither the parent "
+                        "nor the child's mapped tables" % self.prop)
+
+    @util.memoized_property
+    def remote_columns(self):
+        return self._gather_join_annotations("remote")
+
+    remote_side = remote_columns
+
+    @util.memoized_property
+    def local_columns(self):
+        return self._gather_join_annotations("local")
+
+    @util.memoized_property
+    def foreign_key_columns(self):
+        return self._gather_join_annotations("foreign")
+
+    @util.memoized_property
+    def referent_columns(self):
+        return self._gather_join_annotations("referent")
+
+    def _gather_join_annotations(self, annotation):
+        s = set(
+            self._gather_columns_with_annotation(self.primaryjoin, 
+                                                    annotation)
+        )
+        if self.secondaryjoin is not None:
+            s.update(
+                self._gather_columns_with_annotation(self.secondaryjoin, 
+                                                    annotation)
+            )
+        return s
+
+    def _gather_columns_with_annotation(self, clause, *annotation):
+        annotation = set(annotation)
+        return set([
+            col for col in visitors.iterate(clause, {})
+            if annotation.issubset(col._annotations)
+        ])
+
+    @util.memoized_property
+    def local_remote_pairs(self):
+        lrp = []
+        def visit_binary(binary):
+            if "remote" in binary.right._annotations and \
+                "local" in binary.left._annotations:
+                lrp.append((binary.left, binary.right))
+            elif "remote" in binary.left._annotations and \
+                "local" in binary.right._annotations:
+                lrp.append((binary.right, binary.left))
+        visitors.traverse(self.primaryjoin, {}, {"binary":visit_binary})
+        if self.secondaryjoin is not None:
+            visitors.traverse(self.secondaryjoin, {}, {"binary":visit_binary})
+        return lrp
 
     def join_targets(self, source_selectable, 
                             dest_selectable,
-                            aliased):
+                            aliased,
+                            single_crit=None):
         """Given a source and destination selectable, create a
         join between them.
 
@@ -225,7 +479,6 @@ class JoinCondition(object):
         # this is analogous to the "_adjust_for_single_table_inheritance()"
         # method in Query.
 
-        single_crit = self.extra_child_criterion
         if single_crit is not None:
             if secondaryjoin is not None:
                 secondaryjoin = secondaryjoin & single_crit
index b642aca59f6346353a5e673a96dc7c17f674fcb2..48743a2ebd13eb024a20590b72b8974572a12e57 100644 (file)
-from test.lib.testing import assert_raises, assert_raises_message, eq_, AssertsCompiledSQL
+from test.lib.testing import assert_raises, assert_raises_message, eq_, AssertsCompiledSQL, is_
 from test.lib import fixtures
 from sqlalchemy.orm import relationships
 from sqlalchemy import MetaData, Table, Column, ForeignKey, Integer, select
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
 
 class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'
 
-    def _join_fixture_one(self):
+    @classmethod
+    def setup_class(cls):
         m = MetaData()
-        left = Table('lft', m,
+        cls.left = Table('lft', m,
             Column('id', Integer, primary_key=True),
         )
-        right = Table('rgt', m,
+        cls.right = Table('rgt', m,
             Column('id', Integer, primary_key=True),
             Column('lid', Integer, ForeignKey('lft.id'))
         )
+        cls.selfref = Table('selfref', m,
+            Column('id', Integer, primary_key=True),
+            Column('sid', Integer, ForeignKey('selfref.id'))
+        )
+
+    def _join_fixture_o2m(self, **kw):
+        return relationships.JoinCondition(
+                    self.left, 
+                    self.right, 
+                    self.left, 
+                    self.right,
+                    **kw
+                )
+
+    def _join_fixture_m2o(self, **kw):
         return relationships.JoinCondition(
-                    left, right, left, right,
+                    self.right, 
+                    self.left, 
+                    self.right,
+                    self.left,
+                    **kw
                 )
 
-    def test_determine_join(self):
-        joincond = self._join_fixture_one()
+    def _join_fixture_o2m_selfref(self, **kw):
+        return relationships.JoinCondition(
+            self.selfref,
+            self.selfref,
+            self.selfref,
+            self.selfref,
+            **kw
+        )
+
+    def _join_fixture_m2o_selfref(self, **kw):
+        return relationships.JoinCondition(
+            self.selfref,
+            self.selfref,
+            self.selfref,
+            self.selfref,
+            remote_side=set([self.selfref.c.id]),
+            **kw
+        )
+
+    def test_determine_join_o2m(self):
+        joincond = self._join_fixture_o2m()
         self.assert_compile(
                 joincond.primaryjoin,
                 "lft.id = rgt.lid"
         )
 
-    def test_join_targets_plain(self):
-        joincond = self._join_fixture_one()
+    def test_determine_direction_o2m(self):
+        joincond = self._join_fixture_o2m()
+        is_(joincond.direction, ONETOMANY)
+
+    def test_determine_remote_side_o2m(self):
+        joincond = self._join_fixture_o2m()
+        eq_(
+            joincond.remote_side,
+            set([self.right.c.lid])
+        )
+
+    def test_determine_join_o2m_selfref(self):
+        joincond = self._join_fixture_o2m_selfref()
+        self.assert_compile(
+                joincond.primaryjoin,
+                "selfref.id = selfref.sid"
+        )
+
+    def test_determine_direction_o2m_selfref(self):
+        joincond = self._join_fixture_o2m_selfref()
+        is_(joincond.direction, ONETOMANY)
+
+    def test_determine_remote_side_o2m_selfref(self):
+        joincond = self._join_fixture_o2m_selfref()
+        eq_(
+            joincond.remote_side,
+            set([self.selfref.c.sid])
+        )
+
+    def test_determine_join_m2o_selfref(self):
+        joincond = self._join_fixture_m2o_selfref()
+        self.assert_compile(
+                joincond.primaryjoin,
+                "selfref.id = selfref.sid"
+        )
+
+    def test_determine_direction_m2o_selfref(self):
+        joincond = self._join_fixture_m2o_selfref()
+        is_(joincond.direction, MANYTOONE)
+
+    def test_determine_remote_side_m2o_selfref(self):
+        joincond = self._join_fixture_m2o_selfref()
+        eq_(
+            joincond.remote_side,
+            set([self.selfref.c.id])
+        )
+
+    def test_determine_join_m2o(self):
+        joincond = self._join_fixture_m2o()
+        self.assert_compile(
+                joincond.primaryjoin,
+                "lft.id = rgt.lid"
+        )
+
+    def test_determine_direction_m2o(self):
+        joincond = self._join_fixture_m2o()
+        is_(joincond.direction, MANYTOONE)
+
+    def test_determine_remote_side_m2o(self):
+        joincond = self._join_fixture_m2o()
+        eq_(
+            joincond.remote_side,
+            set([self.left.c.id])
+        )
+
+    def test_determine_local_remote_pairs_o2m(self):
+        joincond = self._join_fixture_o2m()
+        eq_(
+            joincond.local_remote_pairs,
+            [(self.left.c.id, self.right.c.lid)]
+        )
+
+    def test_join_targets_o2m_plain(self):
+        joincond = self._join_fixture_o2m()
         pj, sj, sec, adapter = joincond.join_targets(
                                     joincond.parent_selectable, 
                                     joincond.child_selectable, 
@@ -36,8 +148,8 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
             pj, "lft.id = rgt.lid"
         )
 
-    def test_join_targets_left_aliased(self):
-        joincond = self._join_fixture_one()
+    def test_join_targets_o2m_left_aliased(self):
+        joincond = self._join_fixture_o2m()
         left = select([joincond.parent_selectable]).alias('pj')
         pj, sj, sec, adapter = joincond.join_targets(
                                     left, 
@@ -47,8 +159,8 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
             pj, "pj.id = rgt.lid"
         )
 
-    def test_join_targets_right_aliased(self):
-        joincond = self._join_fixture_one()
+    def test_join_targets_o2m_right_aliased(self):
+        joincond = self._join_fixture_o2m()
         right = select([joincond.child_selectable]).alias('pj')
         pj, sj, sec, adapter = joincond.join_targets(
                                     joincond.parent_selectable, 
@@ -59,14 +171,14 @@ class JoinCondTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def _test_lazy_clause_o2m(self):
-        joincond = self._join_fixture_one()
+        joincond = self._join_fixture_o2m()
         self.assert_compile(
             relationships.create_lazy_clause(joincond),
             ""
         )
 
     def _test_lazy_clause_o2m_reverse(self):
-        joincond = self._join_fixture_one()
+        joincond = self._join_fixture_o2m()
         self.assert_compile(
             relationships.create_lazy_clause(joincond, 
                                 reverse_direction=True),