]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added 'inherit_foreign_keys' arg to mapper()
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Aug 2007 14:29:34 +0000 (14:29 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Aug 2007 14:29:34 +0000 (14:29 +0000)
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/alltests.py
test/orm/inheritance/selects.py [new file with mode: 0644]

index 9316fca000a373ebc32daa487dc3a1db90d6b9e3..40806bdb0a308fa85ea6bc432d903a99269a4f32 100644 (file)
@@ -423,6 +423,11 @@ def mapper(class_, local_table=None, *args, **params):
         ``ClauseElement``) which will define how the two tables are
         joined; defaults to a natural join between the two tables.
 
+      inherit_foreign_keys
+        when inherit_condition is used and the condition contains no
+        ForeignKey columns, specify the "foreign" columns of the join 
+        condition in this list.  else leave as None.
+        
       order_by
         A single ``Column`` or list of ``Columns`` for which
         selection operations should use as the default ordering for
index 1f24e8d90daf8053bb556759d5be41b5ae635c7b..014d593d0ccc3788d222a984b4affb28e1813d65 100644 (file)
@@ -49,6 +49,7 @@ class Mapper(object):
                  non_primary = False,
                  inherits = None,
                  inherit_condition = None,
+                 inherit_foreign_keys = None,
                  extension = None,
                  order_by = False,
                  allow_column_override = False,
@@ -98,6 +99,7 @@ class Mapper(object):
         self.select_table = select_table
         self.local_table = local_table
         self.inherit_condition = inherit_condition
+        self.inherit_foreign_keys = inherit_foreign_keys
         self.extension = extension
         self.properties = properties or {}
         self.allow_column_override = allow_column_override
@@ -342,7 +344,11 @@ class Mapper(object):
                     # stricter set of tables to create "sync rules" by,based on the immediate
                     # inherited table, rather than all inherited tables
                     self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY)
-                    self._synchronizer.compile(self.mapped_table.onclause)
+                    if self.inherit_foreign_keys:
+                        fks = util.Set(self.inherit_foreign_keys)
+                    else:
+                        fks = None
+                    self._synchronizer.compile(self.mapped_table.onclause, foreign_keys=fks)
             else:
                 self._synchronizer = None
                 self.mapped_table = self.local_table
index 1ab10c060701e37b5c97ba2b8ca98d91a4d22a7a..da59dd8fb72b9c83a907a3534a3d5e124cfea922 100644 (file)
@@ -13,6 +13,7 @@ def suite():
         'orm.inheritance.abc_inheritance',
         'orm.inheritance.productspec',
         'orm.inheritance.magazine',
+        'orm.inheritance.selects',
         
         )
     alltests = unittest.TestSuite()
diff --git a/test/orm/inheritance/selects.py b/test/orm/inheritance/selects.py
new file mode 100644 (file)
index 0000000..1e307ff
--- /dev/null
@@ -0,0 +1,97 @@
+import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
+# TODO: refactor "fixtures" to be part of testlib, so Base is globally available
+_recursion_stack = util.Set()
+class Base(object):
+    def __init__(self, **kwargs):
+        for k in kwargs:
+            setattr(self, k, kwargs[k])
+    
+    def __ne__(self, other):
+        return not self.__eq__(other)
+        
+    def __eq__(self, other):
+        """'passively' compare this object to another.
+        
+        only look at attributes that are present on the source object.
+        
+        """
+        if self in _recursion_stack:
+            return True
+        _recursion_stack.add(self)
+        try:
+            # use __dict__ to avoid instrumented properties
+            for attr in self.__dict__.keys():
+                if attr[0] == '_':
+                    continue
+                value = getattr(self, attr)
+                if hasattr(value, '__iter__') and not isinstance(value, basestring):
+                    try:
+                        # catch AttributeError so that lazy loaders trigger
+                        otherattr = getattr(other, attr)
+                    except AttributeError:
+                        return False
+                    if len(value) != len(getattr(other, attr)):
+                       return False
+                    for (us, them) in zip(value, getattr(other, attr)):
+                        if us != them:
+                            return False
+                    else:
+                        continue
+                else:
+                    if value is not None:
+                        print "KEY", attr, "COMPARING", value, "TO", getattr(other, attr, None)
+                        if value != getattr(other, attr, None):
+                            return False
+            else:
+                return True
+        finally:
+            _recursion_stack.remove(self)
+
+class InheritingSelectablesTest(ORMTest):
+    def define_tables(self, metadata):
+        global foo, bar, baz
+        foo = Table('foo', metadata,
+                    Column('a', String, primary_key=1),
+                    Column('b', String, nullable=0))
+
+        bar = foo.select(foo.c.b == 'bar').alias('bar')
+        baz = foo.select(foo.c.b == 'baz').alias('baz')
+
+    def test_load(self):
+        # TODO: add persistence test also
+        testbase.db.execute(foo.insert(), a='not bar', b='baz')
+        testbase.db.execute(foo.insert(), a='also not bar', b='baz')
+        testbase.db.execute(foo.insert(), a='i am bar', b='bar')
+        testbase.db.execute(foo.insert(), a='also bar', b='bar')
+
+        class Foo(Base): pass
+        class Bar(Foo): pass
+        class Baz(Foo): pass
+
+        mapper(Foo, foo, polymorphic_on=foo.c.b)
+
+        mapper(Baz, baz, 
+                    select_table=foo.join(baz, foo.c.b=='baz').alias('baz'),
+                    inherits=Foo,
+                    inherit_condition=(foo.c.a==baz.c.a),
+                    inherit_foreign_keys=[baz.c.a],
+                    polymorphic_identity='baz')
+
+        mapper(Bar, bar,
+                    select_table=foo.join(bar, foo.c.b=='bar').alias('bar'),
+                    inherits=Foo, 
+                    inherit_condition=(foo.c.a==bar.c.a),
+                    inherit_foreign_keys=[bar.c.a],
+                    polymorphic_identity='bar')
+
+        s = sessionmaker(bind=testbase.db)()
+
+        assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all()
+        assert [Bar(), Bar()] == s.query(Bar).all()
+
+if __name__ == '__main__':
+    testbase.main()