]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added experimental relation() flag to help with primaryjoins
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Apr 2008 15:49:39 +0000 (15:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Apr 2008 15:49:39 +0000 (15:49 +0000)
across functions, etc., _local_remote_pairs=[tuples].
This complements a complex primaryjoin condition allowing
you to provide the individual column pairs which comprise
the relation's local and remote sides.

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
test/orm/relationships.py

diff --git a/CHANGES b/CHANGES
index fc2c38b0ef76de096cea9bc35b94c1851e626d9b..37861592d781c9eabc93001af644e90868fc85b1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -13,6 +13,12 @@ CHANGES
     - Also re-established viewonly relation() configurations that
       join across multiple tables.
       
+    - added experimental relation() flag to help with primaryjoins
+      across functions, etc., _local_remote_pairs=[tuples].  
+      This complements a complex primaryjoin condition allowing 
+      you to provide the individual column pairs which comprise
+      the relation's local and remote sides.
+      
     - removed ancient assertion that mapped selectables require
       "alias names" - the mapper creates its own alias now if
       none is present.  Though in this case you need to use 
index 560c304da2a7c80512d1d8a22ce541cea4588967..5508b5c8e3d33d99a5a7791c307b61f92b45fcf1 100644 (file)
@@ -185,11 +185,11 @@ class Mapper(object):
 
     def __log(self, msg):
         if self.__should_log_info:
-            self.logger.info("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self.non_primary and "|non-primary" or "") + ") " + msg)
+            self.logger.info("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg)
 
     def __log_debug(self, msg):
         if self.__should_log_debug:
-            self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (not self.non_primary and "|non-primary" or "") + ") " + msg)
+            self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg)
 
     def _is_orphan(self, obj):
         o = False
index c4bb323de5eebf8f4a4a84e3f1fa160b260cff97..0eece6466f460c878a8ca6ad4135af66125efbaf 100644 (file)
@@ -210,7 +210,7 @@ class PropertyLoader(StrategizedProperty):
     of items that correspond to a related database table.
     """
 
-    def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, passive_updates=True, remote_side=None, enable_typechecks=True, join_depth=None, strategy_class=None):
+    def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, passive_updates=True, remote_side=None, enable_typechecks=True, join_depth=None, strategy_class=None, _local_remote_pairs=None):
         self.uselist = uselist
         self.argument = argument
         self.entity_name = entity_name
@@ -232,6 +232,7 @@ class PropertyLoader(StrategizedProperty):
         self.enable_typechecks = enable_typechecks
         self.comparator = PropertyLoader.Comparator(self)
         self.join_depth = join_depth
+        self._arg_local_remote_pairs = _local_remote_pairs
         
         if strategy_class:
             self.strategy_class = strategy_class
@@ -538,26 +539,39 @@ class PropertyLoader(StrategizedProperty):
                 self.secondary.c.contains_column(column) is not None
         
     def __determine_fks(self):
+
         if self._legacy_foreignkey and not self._refers_to_parent_table():
             self.foreign_keys = self._legacy_foreignkey
 
         arg_foreign_keys = self.foreign_keys
+
+        if self._arg_local_remote_pairs:
+            if not arg_foreign_keys:
+                raise exceptions.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument")
+            self.foreign_keys = util.OrderedSet(arg_foreign_keys)
+            self._opposite_side = util.OrderedSet()
+            for l, r in self._arg_local_remote_pairs:
+                if r in self.foreign_keys:
+                    self._opposite_side.add(l)
+                elif l in self.foreign_keys:
+                    self._opposite_side.add(r)
+            self.synchronize_pairs = zip(self._opposite_side, self.foreign_keys)
+        else:
+            eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly)
+            eq_pairs = [(l, r) for l, r in eq_pairs if (self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)) or r in arg_foreign_keys]
+
+            if not eq_pairs:
+                if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
+                    raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. "
+                        "For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.primaryjoin, self)
+                    )
+                else:
+                    raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
+                    "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self))
         
-        eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly)
-        eq_pairs = [(l, r) for l, r in eq_pairs if (self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)) or r in arg_foreign_keys]
-
-        if not eq_pairs:
-            if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
-                raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. "
-                    "For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.primaryjoin, self)
-                )
-            else:
-                raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
-                "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self))
-        
-        self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs])
-        self._opposite_side = util.OrderedSet([l for l, r in eq_pairs])
-        self.synchronize_pairs = eq_pairs
+            self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs])
+            self._opposite_side = util.OrderedSet([l for l, r in eq_pairs])
+            self.synchronize_pairs = eq_pairs
         
         if self.secondaryjoin:
             sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly)
@@ -579,7 +593,14 @@ class PropertyLoader(StrategizedProperty):
             self.secondary_synchronize_pairs = None
     
     def __determine_remote_side(self):
-        if self.remote_side:
+        if self._arg_local_remote_pairs:
+            if self.remote_side:
+                raise exceptions.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.")
+            if self.direction is MANYTOONE:
+                eq_pairs = [(r, l) for l, r in self._arg_local_remote_pairs]
+            else:
+                eq_pairs = self._arg_local_remote_pairs
+        elif self.remote_side:
             if self.direction is MANYTOONE:
                 eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True)
             else:
@@ -594,11 +615,11 @@ class PropertyLoader(StrategizedProperty):
                     eq_pairs += sq_pairs
                 eq_pairs = [(l, r) for l, r in eq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
         
-        self.local_remote_pairs = eq_pairs
         if self.direction is MANYTOONE:
             self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
         else:
             self.local_side, self.remote_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
+        self.local_remote_pairs = zip(self.local_side, self.remote_side)
         
         if self.direction is ONETOMANY:
             for l in self.local_side:
index 22b69af2aa2e0e2c2bbe0e33381aad82240adc4d..6e9a0f96e1d1188990fc180769b313db98e54fca 100644 (file)
@@ -1052,7 +1052,101 @@ class ViewOnlyTest6(ORMTest):
         })
         mapper(T3, t3)
         self.assertRaisesMessage(exceptions.ArgumentError, "Specify remote_side argument", compile_mappers)
+
+class ExplicitLocalRemoteTest(ORMTest):
+    def define_tables(self, metadata):
+        global t1, t2
+        t1 = Table('t1', metadata, 
+            Column('id', String(50), primary_key=True),
+            Column('data', String(50))
+            )
+        t2 = Table('t2', metadata,     
+            Column('id', Integer, primary_key=True),
+            Column('data', String(50)),
+            Column('t1id', String(50)),
+        )
+
+    def test_onetomany(self):
+        class T1(fixtures.Base):
+            pass
+        class T2(fixtures.Base):
+            pass
+        
+        # use a function within join condition.  but specifying
+        # local_remote_pairs overrides all parsing of the join condition.
+        mapper(T1, t1, properties={
+            't2s':relation(T2, primaryjoin=t1.c.id==func.lower(t2.c.t1id), 
+                _local_remote_pairs=[(t1.c.id, t2.c.t1id)],
+                foreign_keys=[t2.c.t1id]
+            )
+        })
+        mapper(T2, t2)
+        
+        sess = create_session()
+        a1 = T1(id='number1', data='a1')
+        b1 = T2(data='b1', t1id='NuMbEr1')
+        b2 = T2(data='b2', t1id='Number1')
+        sess.save(a1)
+        sess.save(b1)
+        sess.save(b2)
+        sess.flush()
+        sess.clear()
+        
+        self.assertEquals(sess.query(T1).first(), T1(id='number1', data='a1', t2s=[T2(data='b1', t1id='NuMbEr1'), T2(data='b2', t1id='Number1')]))
+    
+    def test_manytoone(self):
+        class T1(fixtures.Base):
+            pass
+        class T2(fixtures.Base):
+            pass
+        mapper(T1, t1)
+        mapper(T2, t2, properties={
+            't1':relation(T1, primaryjoin=t1.c.id==func.lower(t2.c.t1id),
+                _local_remote_pairs=[(t2.c.t1id, t1.c.id)],
+                foreign_keys=[t2.c.t1id]
+            )
+        })
+        sess = create_session()
+        a1 = T1(id='number1', data='a1')
+        b1 = T2(data='b1', t1id='NuMbEr1')
+        b2 = T2(data='b2', t1id='Number1')
+        sess.save(a1)
+        sess.save(b1)
+        sess.save(b2)
+        sess.flush()
+        sess.clear()
+        self.assertEquals(sess.query(T2).all(), 
+            [
+                T2(data='b1', t1=T1(id='number1', data='a1')),
+                T2(data='b2', t1=T1(id='number1', data='a1'))
+            ]
+        )
     
+    def test_escalation(self):
+        class T1(fixtures.Base):
+            pass
+        class T2(fixtures.Base):
+            pass
+        
+        mapper(T1, t1, properties={
+            't2s':relation(T2, primaryjoin=t1.c.id==func.lower(t2.c.t1id), 
+                _local_remote_pairs=[(t1.c.id, t2.c.t1id)],
+                foreign_keys=[t2.c.t1id],
+                remote_side=[t2.c.t1id]
+            )
+        })
+        mapper(T2, t2)
+        self.assertRaises(exceptions.ArgumentError, compile_mappers)
+        
+        clear_mappers()
+        mapper(T1, t1, properties={
+            't2s':relation(T2, primaryjoin=t1.c.id==func.lower(t2.c.t1id), 
+                _local_remote_pairs=[(t1.c.id, t2.c.t1id)],
+            )
+        })
+        mapper(T2, t2)
+        self.assertRaises(exceptions.ArgumentError, compile_mappers)
+        
 class InvalidRelationEscalationTest(ORMTest):
     def define_tables(self, metadata):
         global foos, bars, Foo, Bar