]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
a simplification to syncrule generation, which also allows more flexible configuration
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Oct 2006 07:57:12 +0000 (07:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Oct 2006 07:57:12 +0000 (07:57 +0000)
of which columns are to be involved in the synchronization via foreignkey property.
foreignkey param is a little more important now and should have its role clarified
particularly for self-referential mappers.

lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/sync.py
test/orm/eagertest3.py
test/orm/relationships.py

index 7bb550b4dc2eb0f02dab5b0268176b08a53c8508..898d03eb68bbd470531a693bcee0e61ad97f6e20 100644 (file)
@@ -35,6 +35,7 @@ class DependencyProcessor(object):
         self.direction = prop.direction
         self.is_backref = prop.is_backref
         self.post_update = prop.post_update
+        self.foreignkey = prop.foreignkey
         self.key = prop.key
 
         self._compile_synchronizers()
@@ -84,15 +85,12 @@ class DependencyProcessor(object):
 
         The list of rules is used within commits by the _synchronize() method when dependent 
         objects are processed."""
-        parent_tables = util.Set(self.parent.tables + [self.parent.mapped_table])
-        target_tables = util.Set(self.mapper.tables + [self.mapper.mapped_table])
-
         self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction)
         if self.direction == sync.MANYTOMANY:
-            self.syncrules.compile(self.prop.primaryjoin, parent_tables, [self.secondary], False)
-            self.syncrules.compile(self.prop.secondaryjoin, target_tables, [self.secondary], True)
+            self.syncrules.compile(self.prop.primaryjoin, issecondary=False)
+            self.syncrules.compile(self.prop.secondaryjoin, issecondary=True)
         else:
-            self.syncrules.compile(self.prop.primaryjoin, parent_tables, target_tables)
+            self.syncrules.compile(self.prop.primaryjoin, foreignkey=self.foreignkey)
         
     def get_object_dependencies(self, obj, uowcommit, passive = True):
         """returns the list of objects that are dependent on the given object, as according to the relationship
index 096e1ca33b46033bfd595c3ad7b50fa53c162859..1f4a8391db0a693853480300dba65b84b01e6c44 100644 (file)
@@ -271,7 +271,7 @@ class Mapper(object):
                     # stricter set of tables to create "sync rules" by,based on the immediate
                     # inherited table, rather than all inherited tables
                     self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY)
-                    self._synchronizer.compile(self.mapped_table.onclause, util.Set([self.inherits.local_table]), sqlutil.TableFinder(self.local_table))
+                    self._synchronizer.compile(self.mapped_table.onclause)
             else:
                 self._synchronizer = None
                 self.mapped_table = self.local_table
index aeaa5d1d24ce05828084a9d4bcb47930065b0708..5f0331e16113a6fc8ee5effce5fd82f1f8929819 100644 (file)
@@ -33,51 +33,57 @@ class ClauseSynchronizer(object):
         self.direction = direction
         self.syncrules = []
 
-    def compile(self, sqlclause, source_tables, target_tables, issecondary=None):
-        def check_for_table(binary, list1, list2):
-            #print "check for table", str(binary), [str(c) for c in l]
-            if binary.left.table in list1 and binary.right.table in list2:
-                return (binary.left, binary.right)
-            elif binary.right.table in list1 and binary.left.table in list2:
-                return (binary.right, binary.left)
-            else:
-                return (None, None)
-                
+    def compile(self, sqlclause, issecondary=None, foreignkey=None):
         def compile_binary(binary):
-            """assembles a SyncRule given a single binary condition"""
+            """assemble a SyncRule given a single binary condition"""
             if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                 return
 
-            if binary.left.table == binary.right.table:
-                # self-cyclical relation
-                if binary.left.primary_key:
-                    source = binary.left
-                    dest = binary.right
-                elif binary.right.primary_key:
-                    source = binary.right
-                    dest = binary.left
+            source_column = None
+            dest_column = None
+            if foreignkey is not None:
+                # for self-referential relationships,
+                # the best we can do right now is figure out which side
+                # is the primary key
+                # TODO: need some better way for this
+                if binary.left.table == binary.right.table:
+                    if binary.left.primary_key:
+                        source_column = binary.left
+                        dest_column = binary.right
+                    elif binary.right.primary_key:
+                        source_column = binary.right
+                        dest_column = binary.left
+                    else:
+                        raise ArgumentError("Can't locate a primary key column in self-referential equality clause '%s'" % str(binary))
+                # for other relationships we are more flexible
+                # and go off the 'foreignkey' property
+                elif binary.left in foreignkey:
+                    dest_column = binary.left
+                    source_column = binary.right
+                elif binary.right in foreignkey:
+                    dest_column = binary.right
+                    source_column = binary.left
                 else:
-                    raise ArgumentError("Cant determine direction for relationship %s = %s" % (binary.left.table.fullname, binary.right.table.fullname))
+                    return
+            else:
+                if binary.left in [f.column for f in binary.right.foreign_keys]:
+                    dest_column = binary.right
+                    source_column = binary.left
+                elif binary.right in [f.column for f in binary.left.foreign_keys]:
+                    dest_column = binary.left
+                    source_column = binary.right
+            
+            if source_column and dest_column:    
                 if self.direction == ONETOMANY:
-                    self.syncrules.append(SyncRule(self.parent_mapper, source, dest, dest_mapper=self.child_mapper))
+                    self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper))
                 elif self.direction == MANYTOONE:
-                    self.syncrules.append(SyncRule(self.child_mapper, source, dest, dest_mapper=self.parent_mapper))
+                    self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper))
                 else:
-                    raise AssertionError("assert failed")
-            else:
-                (pt, tt) = check_for_table(binary, source_tables, target_tables)
-                #print "OK", binary, [t.name for t in source_tables], [t.name for t in target_tables]
-                if pt and tt:
-                    if self.direction == ONETOMANY:
-                        self.syncrules.append(SyncRule(self.parent_mapper, pt, tt, dest_mapper=self.child_mapper))
-                    elif self.direction == MANYTOONE:
-                        self.syncrules.append(SyncRule(self.child_mapper, tt, pt, dest_mapper=self.parent_mapper))
+                    if not issecondary:
+                        self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper, issecondary=issecondary))
                     else:
-                        if not issecondary:
-                            self.syncrules.append(SyncRule(self.parent_mapper, pt, tt, dest_mapper=self.child_mapper, issecondary=issecondary))
-                        else:
-                            self.syncrules.append(SyncRule(self.child_mapper, pt, tt, dest_mapper=self.parent_mapper, issecondary=issecondary))
-                            
+                        self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary))
+
         rules_added = len(self.syncrules)
         processor = BinaryVisitor(compile_binary)
         sqlclause.accept_visitor(processor)
@@ -131,7 +137,8 @@ class SyncRule(object):
             dest[self.dest_column.key] = value
         else:
             if clearkeys and self.dest_primary_key():
-                return
+                raise exceptions.AssertionError("Dependency rule tried to blank-out a primary key column")
+                
             if logging.is_debug_enabled(self.logger):
                 self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value))
             self.dest_mapper._setattrbycolumn(dest, self.dest_column, value)
index 5cfc9eedd041d4f4b8a9cfcf4a680342dfd8fab2..4cb56d5d128b7f78e36ae71e0038841414ba7df6 100644 (file)
@@ -42,7 +42,9 @@ class EagerTest(AssertMixin):
         mapper(Test,tests,properties={
             'owner':relation(Owner,backref='tests'),
             'category':relation(Category),
-            'owner_option': relation(Option,primaryjoin=and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id),uselist=False ) 
+            'owner_option': relation(Option,primaryjoin=and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id),
+                foreignkey=[options.c.test_id, options.c.owner_id],
+            uselist=False ) 
         })
 
         s=create_session()
index ae0101108746474984993431f69bb91135b46d1b..e0bb5fecd829f8f240c677384bc4d83bff295713 100644 (file)
@@ -217,6 +217,141 @@ class RelationTest2(testbase.PersistTest):
         assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1'
         assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'
         
+class RelationTest3(testbase.PersistTest):
+    def setUpAll(self):
+        global jobs, pageversions, pages, metadata, Job, Page, PageVersion, PageComment
+        import datetime
+        metadata = BoundMetaData(testbase.db)  
+        jobs = Table("jobs", metadata,
+                        Column("jobno", Unicode(15), primary_key=True),
+                        Column("created", DateTime, nullable=False, default=datetime.datetime.now),
+                        Column("deleted", Boolean, nullable=False, default=False))
+        pageversions = Table("pageversions", metadata,
+                        Column("jobno", Unicode(15), primary_key=True),
+                        Column("pagename", Unicode(30), primary_key=True),
+                        Column("version", Integer, primary_key=True, default=1),
+                        Column("created", DateTime, nullable=False, default=datetime.datetime.now),
+                        Column("md5sum", String(32)),
+                        Column("width", Integer, nullable=False, default=0),
+                        Column("height", Integer, nullable=False, default=0),
+                        ForeignKeyConstraint(["jobno", "pagename"], ["pages.jobno", "pages.pagename"])
+                        )
+        pages = Table("pages", metadata,
+                        Column("jobno", Unicode(15), ForeignKey("jobs.jobno"), primary_key=True),
+                        Column("pagename", Unicode(30), primary_key=True),
+                        Column("created", DateTime, nullable=False, default=datetime.datetime.now),
+                        Column("deleted", Boolean, nullable=False, default=False),
+                        Column("current_version", Integer))
+        pagecomments = Table("pagecomments", metadata,
+            Column("jobno", Unicode(15), primary_key=True),
+            Column("pagename", Unicode(30), primary_key=True),
+            Column("comment_id", Integer, primary_key=True),
+            Column("content", Unicode),
+            ForeignKeyConstraint(["jobno", "pagename"], ["pages.jobno", "pages.pagename"])
+        )
+
+        metadata.create_all()
+        class Job(object):
+            def __init__(self, jobno=None):
+                self.jobno = jobno
+            def create_page(self, pagename, *args, **kwargs):
+                return Page(job=self, pagename=pagename, *args, **kwargs)
+        class PageVersion(object):
+            def __init__(self, page=None, version=None):
+                self.page = page
+                self.version = version
+        class Page(object):
+            def __init__(self, job=None, pagename=None):
+                self.job = job
+                self.pagename = pagename
+                self.currentversion = PageVersion(self, 1)
+            def __repr__(self):
+                return "Page jobno:%s pagename:%s %s" % (self.jobno, self.pagename, getattr(self, '_instance_key', None))
+            def add_version(self):
+                self.currentversion = PageVersion(self, self.currentversion.version+1)
+                comment = self.add_comment()
+                comment.closeable = False
+                comment.content = u'some content'
+                return self.currentversion
+            def add_comment(self):
+                nextnum = max([-1] + [c.comment_id for c in self.comments]) + 1
+                newcomment = PageComment()
+                newcomment.comment_id = nextnum
+                self.comments.append(newcomment)
+                newcomment.created_version = self.currentversion.version
+                return newcomment
+        class PageComment(object):
+            pass
+        mapper(Job, jobs)
+        mapper(PageVersion, pageversions)
+        mapper(Page, pages, properties={
+            'job': relation(Job, backref=backref('pages', cascade="all, delete-orphan", order_by=pages.c.pagename)),
+            'currentversion': relation(PageVersion,
+                            foreignkey=pages.c.current_version,
+                            primaryjoin=and_(pages.c.jobno==pageversions.c.jobno,
+                                             pages.c.pagename==pageversions.c.pagename,
+                                             pages.c.current_version==pageversions.c.version),
+                            post_update=True),
+            'versions': relation(PageVersion, cascade="all, delete-orphan",
+                            primaryjoin=and_(pages.c.jobno==pageversions.c.jobno,
+                                             pages.c.pagename==pageversions.c.pagename),
+                            order_by=pageversions.c.version,
+                            backref=backref('page', lazy=False,
+                                            primaryjoin=and_(pages.c.jobno==pageversions.c.jobno,
+                                                             pages.c.pagename==pageversions.c.pagename)))
+        })
+        mapper(PageComment, pagecomments, properties={
+            'page': relation(Page, primaryjoin=and_(pages.c.jobno==pagecomments.c.jobno,
+                                                    pages.c.pagename==pagecomments.c.pagename),
+                                backref=backref("comments", cascade="all, delete-orphan",
+                                                primaryjoin=and_(pages.c.jobno==pagecomments.c.jobno,
+                                                                 pages.c.pagename==pagecomments.c.pagename),
+                                                order_by=pagecomments.c.comment_id))
+        })
+
+
+    def tearDownAll(self):
+        clear_mappers()
+        metadata.drop_all()    
+
+    def testbasic(self):
+        """test the combination of complicated join conditions with post_update"""
+        j1 = Job('somejob')
+        j1.create_page('page1')
+        j1.create_page('page2')
+        j1.create_page('page3')
+
+        j2 = Job('somejob2')
+        j2.create_page('page1')
+        j2.create_page('page2')
+        j2.create_page('page3')
+
+        j2.pages[0].add_version()
+        j2.pages[0].add_version()
+        j2.pages[1].add_version()
+        print j2.pages
+        print j2.pages[0].versions
+        print j2.pages[1].versions
+        s = create_session()
+
+        s.save(j1)
+        s.save(j2)
+        s.flush()
+
+        s.clear()
+        j = s.query(Job).get_by(jobno='somejob')
+        oldp = list(j.pages)
+        j.pages = []
+
+        s.flush()
+
+        s.clear()
+        j = s.query(Job).get_by(jobno='somejob2')
+        j.pages[1].current_version = 12
+        s.delete(j)
+        s.flush()
+        
+        
         
         
 if __name__ == "__main__":