From: Mike Bayer Date: Fri, 17 Mar 2006 21:11:59 +0000 (+0000) Subject: identified more issues with inheritance. mapper inheritance is more closed-minded... X-Git-Tag: rel_0_1_5~61 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0e599da0cfd64c0921ba31bb8957aa5d409318c0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git identified more issues with inheritance. mapper inheritance is more closed-minded about how it creates the join crit erion as well as the sync rules in inheritance. syncrules have been tightened up to be smarter about creating a new SyncRule given lists of tables and a join clause. properties also checks for relation direction against the "noninherited table" which for the moment makes it a stronger requirement that a relation to a mapper must relate to that mapper's main table, not any tables that it inherits from. --- diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index b3378a76b2..32271aff5c 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -63,10 +63,23 @@ class Mapper(object): if inherits is not None: self.primarytable = inherits.primarytable - # inherit_condition is optional since the join can figure it out + # inherit_condition is optional. + if inherit_condition is None: + # figure out inherit condition from our table to the immediate table + # of the inherited mapper, not its full table which could pull in other + # stuff we dont want (allows test/inheritance.InheritTest4 to pass) + inherit_condition = sql.join(inherits.noninherited_table, table).onclause self.table = sql.join(inherits.table, table, inherit_condition) + #print "inherit condition", str(self.table.onclause) + + # generate sync rules. similarly to creating the on clause, specify a + # stricter set of tables to create "sync rules" by,based on the immediate + # inherited table, rather than all inherited tables self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) - self._synchronizer.compile(self.table.onclause, inherits.tables, TableFinder(table)) + self._synchronizer.compile(self.table.onclause, util.HashSet([inherits.noninherited_table]), TableFinder(table)) + # the old rule + #self._synchronizer.compile(self.table.onclause, inherits.tables, TableFinder(table)) + self.inherits = inherits self.noninherited_table = table else: @@ -965,7 +978,8 @@ class TableFinder(sql.ClauseVisitor): def __init__(self, table, check_columns=False): self.tables = [] self.check_columns = check_columns - table.accept_visitor(self) + if table is not None: + table.accept_visitor(self) def visit_table(self, table): self.tables.append(table) def __len__(self): @@ -977,7 +991,7 @@ class TableFinder(sql.ClauseVisitor): def __contains__(self, obj): return obj in self.tables def __add__(self, obj): - return self.tables + obj + return self.tables + list(obj) def visit_column(self, column): if self.check_columns: column.table.accept_visitor(self) diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index 34b0ae48e4..e5f702b578 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -228,11 +228,9 @@ class PropertyLoader(MapperProperty): return PropertyLoader.ONETOMANY elif self.secondaryjoin is not None: return PropertyLoader.MANYTOMANY - elif self.foreigntable == self.target: - #elif self.foreigntable is self.target or self.foreigntable in self.mapper.tables: + elif self.foreigntable == self.mapper.noninherited_table: return PropertyLoader.ONETOMANY - elif self.foreigntable == self.parent.table: - #elif self.foreigntable is self.parent.table or self.foreigntable in self.parent.tables: + elif self.foreigntable == self.parent.noninherited_table: return PropertyLoader.MANYTOONE else: raise ArgumentError("Cant determine relation direction") @@ -529,6 +527,7 @@ class PropertyLoader(MapperProperty): The list of rules is used within commits by the _synchronize() method when dependent objects are processed.""" + parent_tables = util.HashSet(self.parent.tables + [self.parent.primarytable]) target_tables = util.HashSet(self.mapper.tables + [self.mapper.primarytable]) diff --git a/lib/sqlalchemy/mapping/sync.py b/lib/sqlalchemy/mapping/sync.py index b322780118..e690737ae9 100644 --- a/lib/sqlalchemy/mapping/sync.py +++ b/lib/sqlalchemy/mapping/sync.py @@ -24,13 +24,15 @@ class ClauseSynchronizer(object): self.syncrules = [] def compile(self, sqlclause, source_tables, target_tables, issecondary=None): - def check_for_table(binary, l): - for col in [binary.left, binary.right]: - if col.table in l: - return col + def check_for_table(binary, list1, list2): + #print "check for table", str(binary), [str(c) for c in l] + if binary.left.table in list1 and binary.right.table in list2: + return (binary.left, binary.right) + elif binary.right.table in list1 and binary.left.table in list2: + return (binary.right, binary.left) else: - return None - + return (None, None) + def compile_binary(binary): """assembles a SyncRule given a single binary condition""" if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): @@ -53,8 +55,7 @@ class ClauseSynchronizer(object): else: raise AssertionError("assert failed") else: - pt = check_for_table(binary, source_tables) - tt = check_for_table(binary, target_tables) + (pt, tt) = check_for_table(binary, source_tables, target_tables) #print "OK", binary, [t.name for t in source_tables], [t.name for t in target_tables] if pt and tt: if self.direction == ONETOMANY: @@ -94,7 +95,7 @@ class SyncRule(object): self.issecondary = issecondary self.dest_mapper = dest_mapper self.dest_column = dest_column - #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper, direction + #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper def execute(self, source, dest, obj, child, clearkeys): if source is None: diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 98d3a70f14..85bbfdc5ac 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -255,6 +255,8 @@ class HashSet(object): return self.map.has_key(item) def clear(self): self.map.clear() + def intersection(self, l): + return HashSet([x for x in l if self.contains(x)]) def empty(self): return len(self.map) == 0 def append(self, item): diff --git a/test/inheritance.py b/test/inheritance.py index 5273a62ec1..71cc78882e 100644 --- a/test/inheritance.py +++ b/test/inheritance.py @@ -14,80 +14,83 @@ class Group( Principal ): pass class InheritTest(testbase.AssertMixin): - def setUpAll(self): - global principals - global users - global groups - global user_group_map - principals = Table( - 'principals', - testbase.db, - Column('principal_id', Integer, Sequence('principal_id_seq', optional=False), primary_key=True), - Column('name', String(50), nullable=False), - ) - - users = Table( - 'prin_users', - testbase.db, - Column('principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True), - Column('password', String(50), nullable=False), - Column('email', String(50), nullable=False), - Column('login_id', String(50), nullable=False), - - ) - - groups = Table( - 'prin_groups', - testbase.db, - Column( 'principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True), - - ) - - user_group_map = Table( - 'prin_user_group_map', - testbase.db, - Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), primary_key=True ), - Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), primary_key=True ), - #Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), ), - #Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), ), - - ) - - principals.create() - users.create() - groups.create() - user_group_map.create() - def tearDownAll(self): - user_group_map.drop() - groups.drop() - users.drop() - principals.drop() - testbase.db.tables.clear() - def setUp(self): - objectstore.clear() - clear_mappers() - - def testbasic(self): - assign_mapper( Principal, principals ) - assign_mapper( - User, - users, - inherits=Principal.mapper - ) - - assign_mapper( - Group, - groups, - inherits=Principal.mapper, - properties=dict( users = relation(User.mapper, user_group_map, lazy=True, backref="groups") ) - ) - - g = Group(name="group1") - g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1")) - - objectstore.commit() + """deals with inheritance and many-to-many relationships""" + def setUpAll(self): + global principals + global users + global groups + global user_group_map + principals = Table( + 'principals', + testbase.db, + Column('principal_id', Integer, Sequence('principal_id_seq', optional=False), primary_key=True), + Column('name', String(50), nullable=False), + ) + + users = Table( + 'prin_users', + testbase.db, + Column('principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True), + Column('password', String(50), nullable=False), + Column('email', String(50), nullable=False), + Column('login_id', String(50), nullable=False), + + ) + + groups = Table( + 'prin_groups', + testbase.db, + Column( 'principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True), + + ) + + user_group_map = Table( + 'prin_user_group_map', + testbase.db, + Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), primary_key=True ), + Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), primary_key=True ), + #Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), ), + #Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), ), + + ) + + principals.create() + users.create() + groups.create() + user_group_map.create() + def tearDownAll(self): + user_group_map.drop() + groups.drop() + users.drop() + principals.drop() + testbase.db.tables.clear() + def setUp(self): + objectstore.clear() + clear_mappers() + + def testbasic(self): + assign_mapper( Principal, principals ) + assign_mapper( + User, + users, + inherits=Principal.mapper + ) + + assign_mapper( + Group, + groups, + inherits=Principal.mapper, + properties=dict( users = relation(User.mapper, user_group_map, lazy=True, backref="groups") ) + ) + g = Group(name="group1") + g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1")) + + objectstore.commit() + # TODO: put an assertion + class InheritTest2(testbase.AssertMixin): + """deals with inheritance and many-to-many relationships""" def setUpAll(self): engine = testbase.db global foo, bar, foo_bar @@ -155,6 +158,7 @@ class InheritTest2(testbase.AssertMixin): ) class InheritTest3(testbase.AssertMixin): + """deals with inheritance and many-to-many relationships""" def setUpAll(self): engine = testbase.db global foo, bar, blub, bar_foo, blub_bar, blub_foo,tables @@ -217,9 +221,11 @@ class InheritTest3(testbase.AssertMixin): b.foos.append(Foo("foo #1")) b.foos.append(Foo("foo #2")) objectstore.commit() + compare = repr(b) + repr(b.foos) objectstore.clear() l = Bar.mapper.select() - print l[0], l[0].foos + self.echo(repr(l[0]) + repr(l[0].foos)) + self.assert_(repr(l[0]) + repr(l[0].foos) == compare) def testadvanced(self): class Foo(object): @@ -274,7 +280,76 @@ class InheritTest3(testbase.AssertMixin): self.echo(x) self.assert_(repr(x) == compare) +class InheritTest4(testbase.AssertMixin): + """deals with inheritance and one-to-many relationships""" + def setUpAll(self): + engine = testbase.db + global foo, bar, blub, tables + engine.engine.echo = 'debug' + # the 'data' columns are to appease SQLite which cant handle a blank INSERT + foo = Table('foo', engine, + Column('id', Integer, Sequence('foo_seq'), primary_key=True), + Column('data', String(20))) + + bar = Table('bar', engine, + Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('data', String(20))) + + blub = Table('blub', engine, + Column('id', Integer, ForeignKey('bar.id'), primary_key=True), + Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False), + Column('data', String(20))) + + tables = [foo, bar, blub] + for table in tables: + table.create() + def tearDownAll(self): + for table in reversed(tables): + table.drop() + testbase.db.tables.clear() + + def tearDown(self): + for table in reversed(tables): + table.delete().execute() + + def testbasic(self): + class Foo(object): + def __init__(self, data=None): + self.data = data + def __repr__(self): + return "Foo id %d, data %s" % (self.id, self.data) + Foo.mapper = mapper(Foo, foo) + + class Bar(Foo): + def __repr__(self): + return "Bar id %d, data %s" % (self.id, self.data) + + Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper) + + class Blub(Bar): + def __repr__(self): + return "Blub id %d, data %s" % (self.id, self.data) + Blub.mapper = mapper(Blub, blub, inherits=Bar.mapper, properties={ + # bug was raised specifically based on the order of cols in the join.... +# 'parent_foo':relation(Foo.mapper, primaryjoin=blub.c.foo_id==foo.c.id) +# 'parent_foo':relation(Foo.mapper, primaryjoin=foo.c.id==blub.c.foo_id) + 'parent_foo':relation(Foo.mapper) + }) + + b1 = Blub("blub #1") + b2 = Blub("blub #2") + f = Foo("foo #1") + b1.parent_foo = f + b2.parent_foo = f + objectstore.commit() + compare = repr(b1) + repr(b2) + repr(b1.parent_foo) + repr(b2.parent_foo) + objectstore.clear() + l = Blub.mapper.select() + result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo) + self.echo(result) + self.assert_(compare == result) + self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') if __name__ == "__main__": testbase.main()