]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
identified more issues with inheritance. mapper inheritance is more closed-minded...
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 17 Mar 2006 21:11:59 +0000 (21:11 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 17 Mar 2006 21:11:59 +0000 (21:11 +0000)
erion as well as the sync rules in inheritance.  syncrules have been tightened up to be smarter about creating a new
SyncRule given lists of tables and a join clause.  properties also checks for relation direction against the "noninherited table" which for the moment makes it a stronger requirement that a relation to a mapper must relate to that mapper's main table, not any tables that it inherits from.

lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/mapping/sync.py
lib/sqlalchemy/util.py
test/inheritance.py

index b3378a76b2583d1eeb302472231ef77838109cc0..32271aff5cd00852466e662c975f5cffeb92d887 100644 (file)
@@ -63,10 +63,23 @@ class Mapper(object):
 
         if inherits is not None:
             self.primarytable = inherits.primarytable
-            # inherit_condition is optional since the join can figure it out
+            # inherit_condition is optional.
+            if inherit_condition is None:
+                # figure out inherit condition from our table to the immediate table
+                # of the inherited mapper, not its full table which could pull in other 
+                # stuff we dont want (allows test/inheritance.InheritTest4 to pass)
+                inherit_condition = sql.join(inherits.noninherited_table, table).onclause
             self.table = sql.join(inherits.table, table, inherit_condition)
+            #print "inherit condition", str(self.table.onclause)
+
+            # generate sync rules.  similarly to creating the on clause, specify a 
+            # stricter set of tables to create "sync rules" by,based on the immediate
+            # inherited table, rather than all inherited tables
             self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY)
-            self._synchronizer.compile(self.table.onclause, inherits.tables, TableFinder(table))
+            self._synchronizer.compile(self.table.onclause, util.HashSet([inherits.noninherited_table]), TableFinder(table))
+            # the old rule
+            #self._synchronizer.compile(self.table.onclause, inherits.tables, TableFinder(table))
+
             self.inherits = inherits
             self.noninherited_table = table
         else:
@@ -965,7 +978,8 @@ class TableFinder(sql.ClauseVisitor):
     def __init__(self, table, check_columns=False):
         self.tables = []
         self.check_columns = check_columns
-        table.accept_visitor(self)
+        if table is not None:
+            table.accept_visitor(self)
     def visit_table(self, table):
         self.tables.append(table)
     def __len__(self):
@@ -977,7 +991,7 @@ class TableFinder(sql.ClauseVisitor):
     def __contains__(self, obj):
         return obj in self.tables
     def __add__(self, obj):
-        return self.tables + obj
+        return self.tables + list(obj)
     def visit_column(self, column):
         if self.check_columns:
             column.table.accept_visitor(self)
index 34b0ae48e4376a3607b1796613ea57abd0f106dd..e5f702b578c7acc350dcf75d8a07a149c4153a8c 100644 (file)
@@ -228,11 +228,9 @@ class PropertyLoader(MapperProperty):
                 return PropertyLoader.ONETOMANY
         elif self.secondaryjoin is not None:
             return PropertyLoader.MANYTOMANY
-        elif self.foreigntable == self.target:
-        #elif self.foreigntable is self.target or self.foreigntable in self.mapper.tables:
+        elif self.foreigntable == self.mapper.noninherited_table:
             return PropertyLoader.ONETOMANY
-        elif self.foreigntable == self.parent.table:
-        #elif self.foreigntable is self.parent.table or self.foreigntable in self.parent.tables:
+        elif self.foreigntable == self.parent.noninherited_table:
             return PropertyLoader.MANYTOONE
         else:
             raise ArgumentError("Cant determine relation direction")
@@ -529,6 +527,7 @@ class PropertyLoader(MapperProperty):
         The list of rules is used within commits by the _synchronize() method when dependent 
         objects are processed."""
 
+
         parent_tables = util.HashSet(self.parent.tables + [self.parent.primarytable])
         target_tables = util.HashSet(self.mapper.tables + [self.mapper.primarytable])
 
index b3227801181779599fa337f99b52e34113e55e8a..e690737ae930e5315044458774e4648b4e546f49 100644 (file)
@@ -24,13 +24,15 @@ class ClauseSynchronizer(object):
         self.syncrules = []
 
     def compile(self, sqlclause, source_tables, target_tables, issecondary=None):
-        def check_for_table(binary, l):
-            for col in [binary.left, binary.right]:
-                if col.table in l:
-                    return col
+        def check_for_table(binary, list1, list2):
+            #print "check for table", str(binary), [str(c) for c in l]
+            if binary.left.table in list1 and binary.right.table in list2:
+                return (binary.left, binary.right)
+            elif binary.right.table in list1 and binary.left.table in list2:
+                return (binary.right, binary.left)
             else:
-                return None
-        
+                return (None, None)
+                
         def compile_binary(binary):
             """assembles a SyncRule given a single binary condition"""
             if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
@@ -53,8 +55,7 @@ class ClauseSynchronizer(object):
                 else:
                     raise AssertionError("assert failed")
             else:
-                pt = check_for_table(binary, source_tables)
-                tt = check_for_table(binary, target_tables)
+                (pt, tt) = check_for_table(binary, source_tables, target_tables)
                 #print "OK", binary, [t.name for t in source_tables], [t.name for t in target_tables]
                 if pt and tt:
                     if self.direction == ONETOMANY:
@@ -94,7 +95,7 @@ class SyncRule(object):
         self.issecondary = issecondary
         self.dest_mapper = dest_mapper
         self.dest_column = dest_column
-        #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper, direction
+        #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper
 
     def execute(self, source, dest, obj, child, clearkeys):
         if source is None:
index 98d3a70f146c5db333d5530a5906dd919bd8e950..85bbfdc5ac5998fc541d8aea2884a1c5703b07ce 100644 (file)
@@ -255,6 +255,8 @@ class HashSet(object):
         return self.map.has_key(item)
     def clear(self):
         self.map.clear()
+    def intersection(self, l):
+        return HashSet([x for x in l if self.contains(x)])
     def empty(self):
         return len(self.map) == 0
     def append(self, item):
index 5273a62ec177a9454c31542293c8a7177cf10400..71cc78882ee6aeabf0dbc80049b775f3362dc916 100644 (file)
@@ -14,80 +14,83 @@ class Group( Principal ):
     pass
 
 class InheritTest(testbase.AssertMixin):
-        def setUpAll(self):
-            global principals
-            global users
-            global groups
-            global user_group_map
-            principals = Table(
-                'principals',
-                testbase.db,
-                Column('principal_id', Integer, Sequence('principal_id_seq', optional=False), primary_key=True),
-                Column('name', String(50), nullable=False),    
-                )
-
-            users = Table(
-                'prin_users',
-                testbase.db,
-                Column('principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True),
-                Column('password', String(50), nullable=False),
-                Column('email', String(50), nullable=False),
-                Column('login_id', String(50), nullable=False),
-
-                )
-
-            groups = Table(
-                'prin_groups',
-                testbase.db,
-                Column( 'principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True),
-
-                )
-
-            user_group_map = Table(
-                'prin_user_group_map',
-                testbase.db,
-                Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), primary_key=True ),
-                Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), primary_key=True ),
-                #Column('user_id', Integer, ForeignKey( "prin_users.principal_id"),  ),
-                #Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"),  ),    
-
-                )
-
-            principals.create()
-            users.create()
-            groups.create()
-            user_group_map.create()
-        def tearDownAll(self):
-            user_group_map.drop()
-            groups.drop()
-            users.drop()
-            principals.drop()
-            testbase.db.tables.clear()
-        def setUp(self):
-            objectstore.clear()
-            clear_mappers()
-            
-        def testbasic(self):
-            assign_mapper( Principal, principals )
-            assign_mapper( 
-                User, 
-                users,
-                inherits=Principal.mapper
-                )
-
-            assign_mapper( 
-                Group,
-                groups,
-                inherits=Principal.mapper,
-                properties=dict( users = relation(User.mapper, user_group_map, lazy=True, backref="groups") )
-                )
-
-            g = Group(name="group1")
-            g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1"))
-            
-            objectstore.commit()
+    """deals with inheritance and many-to-many relationships"""
+    def setUpAll(self):
+        global principals
+        global users
+        global groups
+        global user_group_map
+        principals = Table(
+            'principals',
+            testbase.db,
+            Column('principal_id', Integer, Sequence('principal_id_seq', optional=False), primary_key=True),
+            Column('name', String(50), nullable=False),    
+            )
+
+        users = Table(
+            'prin_users',
+            testbase.db,
+            Column('principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True),
+            Column('password', String(50), nullable=False),
+            Column('email', String(50), nullable=False),
+            Column('login_id', String(50), nullable=False),
+
+            )
+
+        groups = Table(
+            'prin_groups',
+            testbase.db,
+            Column( 'principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True),
+
+            )
+
+        user_group_map = Table(
+            'prin_user_group_map',
+            testbase.db,
+            Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), primary_key=True ),
+            Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), primary_key=True ),
+            #Column('user_id', Integer, ForeignKey( "prin_users.principal_id"),  ),
+            #Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"),  ),    
+
+            )
+
+        principals.create()
+        users.create()
+        groups.create()
+        user_group_map.create()
+    def tearDownAll(self):
+        user_group_map.drop()
+        groups.drop()
+        users.drop()
+        principals.drop()
+        testbase.db.tables.clear()
+    def setUp(self):
+        objectstore.clear()
+        clear_mappers()
+        
+    def testbasic(self):
+        assign_mapper( Principal, principals )
+        assign_mapper( 
+            User, 
+            users,
+            inherits=Principal.mapper
+            )
+
+        assign_mapper( 
+            Group,
+            groups,
+            inherits=Principal.mapper,
+            properties=dict( users = relation(User.mapper, user_group_map, lazy=True, backref="groups") )
+            )
 
+        g = Group(name="group1")
+        g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1"))
+        
+        objectstore.commit()
+        # TODO: put an assertion
+        
 class InheritTest2(testbase.AssertMixin):
+    """deals with inheritance and many-to-many relationships"""
     def setUpAll(self):
         engine = testbase.db
         global foo, bar, foo_bar
@@ -155,6 +158,7 @@ class InheritTest2(testbase.AssertMixin):
             )
 
 class InheritTest3(testbase.AssertMixin):
+    """deals with inheritance and many-to-many relationships"""
     def setUpAll(self):
         engine = testbase.db
         global foo, bar, blub, bar_foo, blub_bar, blub_foo,tables
@@ -217,9 +221,11 @@ class InheritTest3(testbase.AssertMixin):
         b.foos.append(Foo("foo #1"))
         b.foos.append(Foo("foo #2"))
         objectstore.commit()
+        compare = repr(b) + repr(b.foos)
         objectstore.clear()
         l = Bar.mapper.select()
-        print l[0], l[0].foos
+        self.echo(repr(l[0]) + repr(l[0].foos))
+        self.assert_(repr(l[0]) + repr(l[0].foos) == compare)
     
     def testadvanced(self):    
         class Foo(object):
@@ -274,7 +280,76 @@ class InheritTest3(testbase.AssertMixin):
         self.echo(x)
         self.assert_(repr(x) == compare)
         
+class InheritTest4(testbase.AssertMixin):
+    """deals with inheritance and one-to-many relationships"""
+    def setUpAll(self):
+        engine = testbase.db
+        global foo, bar, blub, tables
+        engine.engine.echo = 'debug'
+        # the 'data' columns are to appease SQLite which cant handle a blank INSERT
+        foo = Table('foo', engine,
+            Column('id', Integer, Sequence('foo_seq'), primary_key=True),
+            Column('data', String(20)))
+
+        bar = Table('bar', engine,
+            Column('id', Integer, ForeignKey('foo.id'), primary_key=True),
+            Column('data', String(20)))
+
+        blub = Table('blub', engine,
+            Column('id', Integer, ForeignKey('bar.id'), primary_key=True),
+            Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False),
+            Column('data', String(20)))
+
+        tables = [foo, bar, blub]
+        for table in tables:
+            table.create()
+    def tearDownAll(self):
+        for table in reversed(tables):
+            table.drop()
+        testbase.db.tables.clear()
+
+    def tearDown(self):
+        for table in reversed(tables):
+            table.delete().execute()
+
+    def testbasic(self):
+        class Foo(object):
+            def __init__(self, data=None):
+                self.data = data
+            def __repr__(self):
+                return "Foo id %d, data %s" % (self.id, self.data)
+        Foo.mapper = mapper(Foo, foo)
+
+        class Bar(Foo):
+            def __repr__(self):
+                return "Bar id %d, data %s" % (self.id, self.data)
+
+        Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper)
+        
+        class Blub(Bar):
+            def __repr__(self):
+                return "Blub id %d, data %s" % (self.id, self.data)
 
+        Blub.mapper = mapper(Blub, blub, inherits=Bar.mapper, properties={
+            # bug was raised specifically based on the order of cols in the join....
+#            'parent_foo':relation(Foo.mapper, primaryjoin=blub.c.foo_id==foo.c.id)
+#            'parent_foo':relation(Foo.mapper, primaryjoin=foo.c.id==blub.c.foo_id)
+            'parent_foo':relation(Foo.mapper)
+        })
+
+        b1 = Blub("blub #1")
+        b2 = Blub("blub #2")
+        f = Foo("foo #1")
+        b1.parent_foo = f
+        b2.parent_foo = f
+        objectstore.commit()
+        compare = repr(b1) + repr(b2) + repr(b1.parent_foo) + repr(b2.parent_foo)
+        objectstore.clear()
+        l = Blub.mapper.select()
+        result = repr(l[0]) + repr(l[1]) + repr(l[0].parent_foo) + repr(l[1].parent_foo)
+        self.echo(result)
+        self.assert_(compare == result)
+        self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
 
 if __name__ == "__main__":    
     testbase.main()