From 354f899be62411b15f509e3f82afc249cc5ca146 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 16 Aug 2007 14:29:34 +0000 Subject: [PATCH] added 'inherit_foreign_keys' arg to mapper() --- lib/sqlalchemy/orm/__init__.py | 5 ++ lib/sqlalchemy/orm/mapper.py | 8 ++- test/orm/inheritance/alltests.py | 1 + test/orm/inheritance/selects.py | 97 ++++++++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 test/orm/inheritance/selects.py diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 9316fca000..40806bdb0a 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -423,6 +423,11 @@ def mapper(class_, local_table=None, *args, **params): ``ClauseElement``) which will define how the two tables are joined; defaults to a natural join between the two tables. + inherit_foreign_keys + when inherit_condition is used and the condition contains no + ForeignKey columns, specify the "foreign" columns of the join + condition in this list. else leave as None. + order_by A single ``Column`` or list of ``Columns`` for which selection operations should use as the default ordering for diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 1f24e8d90d..014d593d0c 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -49,6 +49,7 @@ class Mapper(object): non_primary = False, inherits = None, inherit_condition = None, + inherit_foreign_keys = None, extension = None, order_by = False, allow_column_override = False, @@ -98,6 +99,7 @@ class Mapper(object): self.select_table = select_table self.local_table = local_table self.inherit_condition = inherit_condition + self.inherit_foreign_keys = inherit_foreign_keys self.extension = extension self.properties = properties or {} self.allow_column_override = allow_column_override @@ -342,7 +344,11 @@ class Mapper(object): # stricter set of tables to create "sync rules" by,based on the immediate # inherited table, rather than all inherited tables self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) - self._synchronizer.compile(self.mapped_table.onclause) + if self.inherit_foreign_keys: + fks = util.Set(self.inherit_foreign_keys) + else: + fks = None + self._synchronizer.compile(self.mapped_table.onclause, foreign_keys=fks) else: self._synchronizer = None self.mapped_table = self.local_table diff --git a/test/orm/inheritance/alltests.py b/test/orm/inheritance/alltests.py index 1ab10c0607..da59dd8fb7 100644 --- a/test/orm/inheritance/alltests.py +++ b/test/orm/inheritance/alltests.py @@ -13,6 +13,7 @@ def suite(): 'orm.inheritance.abc_inheritance', 'orm.inheritance.productspec', 'orm.inheritance.magazine', + 'orm.inheritance.selects', ) alltests = unittest.TestSuite() diff --git a/test/orm/inheritance/selects.py b/test/orm/inheritance/selects.py new file mode 100644 index 0000000000..1e307ffe5c --- /dev/null +++ b/test/orm/inheritance/selects.py @@ -0,0 +1,97 @@ +import testbase +from sqlalchemy import * +from sqlalchemy.orm import * +from testlib import * + +# TODO: refactor "fixtures" to be part of testlib, so Base is globally available +_recursion_stack = util.Set() +class Base(object): + def __init__(self, **kwargs): + for k in kwargs: + setattr(self, k, kwargs[k]) + + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + """'passively' compare this object to another. + + only look at attributes that are present on the source object. + + """ + if self in _recursion_stack: + return True + _recursion_stack.add(self) + try: + # use __dict__ to avoid instrumented properties + for attr in self.__dict__.keys(): + if attr[0] == '_': + continue + value = getattr(self, attr) + if hasattr(value, '__iter__') and not isinstance(value, basestring): + try: + # catch AttributeError so that lazy loaders trigger + otherattr = getattr(other, attr) + except AttributeError: + return False + if len(value) != len(getattr(other, attr)): + return False + for (us, them) in zip(value, getattr(other, attr)): + if us != them: + return False + else: + continue + else: + if value is not None: + print "KEY", attr, "COMPARING", value, "TO", getattr(other, attr, None) + if value != getattr(other, attr, None): + return False + else: + return True + finally: + _recursion_stack.remove(self) + +class InheritingSelectablesTest(ORMTest): + def define_tables(self, metadata): + global foo, bar, baz + foo = Table('foo', metadata, + Column('a', String, primary_key=1), + Column('b', String, nullable=0)) + + bar = foo.select(foo.c.b == 'bar').alias('bar') + baz = foo.select(foo.c.b == 'baz').alias('baz') + + def test_load(self): + # TODO: add persistence test also + testbase.db.execute(foo.insert(), a='not bar', b='baz') + testbase.db.execute(foo.insert(), a='also not bar', b='baz') + testbase.db.execute(foo.insert(), a='i am bar', b='bar') + testbase.db.execute(foo.insert(), a='also bar', b='bar') + + class Foo(Base): pass + class Bar(Foo): pass + class Baz(Foo): pass + + mapper(Foo, foo, polymorphic_on=foo.c.b) + + mapper(Baz, baz, + select_table=foo.join(baz, foo.c.b=='baz').alias('baz'), + inherits=Foo, + inherit_condition=(foo.c.a==baz.c.a), + inherit_foreign_keys=[baz.c.a], + polymorphic_identity='baz') + + mapper(Bar, bar, + select_table=foo.join(bar, foo.c.b=='bar').alias('bar'), + inherits=Foo, + inherit_condition=(foo.c.a==bar.c.a), + inherit_foreign_keys=[bar.c.a], + polymorphic_identity='bar') + + s = sessionmaker(bind=testbase.db)() + + assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all() + assert [Bar(), Bar()] == s.query(Bar).all() + +if __name__ == '__main__': + testbase.main() -- 2.47.3