]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added new polymorph test, todos for session/cascade
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Jun 2006 18:49:10 +0000 (18:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Jun 2006 18:49:10 +0000 (18:49 +0000)
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
test/alltests.py
test/inheritance3.py [new file with mode: 0644]

index 8002998fb378e9721b6efa391120a068b53e776c..7d3b341ab0757b115b3cce94b921f3de48202480 100644 (file)
@@ -64,6 +64,7 @@ class DependencyProcessor(object):
         insert/update/delete order (topological sort)."""
         raise NotImplementedError()
 
+    # TODO: all of these preproc rules need to take dependencies into account
     def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
         """used before the flushes' topological sort to traverse through related objects and insure every 
         instance which will require save/update/delete is properly added to the UOWTransaction."""
index 42f442b095dcc87a37558b95651055063a83adf0..56d699cc6607f55ead57e271ad816da45812f33b 100644 (file)
@@ -279,6 +279,7 @@ class Session(object):
             if c is object:
                 self._save_impl(c, entity_name=entity_name)
             else:
+                # TODO: this is running the cascade rules twice
                 self.save_or_update(c, entity_name=entity_name)
 
     def update(self, object, entity_name=None):
@@ -397,7 +398,7 @@ class Session(object):
     def __contains__(self, obj):
         return self._is_attached(obj) and (obj in self.uow.new or self.uow.has_key(obj._instance_key))
     def __iter__(self):
-        return iter(self.uow.identity_map.values())
+        return iter(self.uow.new + self.uow.identity_map.values())
     def _get(self, key):
         return self.uow._get(key)
     def has_key(self, key):
index a9cc1e490264f177f8a12211e603cde658dbd901..c33f344fbea9ea15bd0c71389332ac08e167e924 100644 (file)
@@ -271,6 +271,7 @@ class UOWTransaction(object):
         self.__modified = False
         self.__is_executing = False
         
+    # TODO: shouldnt be able to register stuff here that is not in the enclosing Session
     def register_object(self, obj, isdelete = False, listonly = False, postupdate=False, **kwargs):
         """adds an object to this UOWTransaction to be updated in the database.
 
index a0ee689fc3d1343e9fca2a943e9afc93b12e57f3..183fc1492415815f379a049ceea08f1844f5c517 100644 (file)
@@ -52,6 +52,7 @@ def suite():
         'onetoone',
         'inheritance',
         'inheritance2',
+       'inheritance3',
         'polymorph',
         
         # extensions
diff --git a/test/inheritance3.py b/test/inheritance3.py
new file mode 100644 (file)
index 0000000..15c93e7
--- /dev/null
@@ -0,0 +1,202 @@
+from sqlalchemy import *
+import testbase
+
+
+class BaseObject(object):
+    def __init__(self, *args, **kwargs):
+        for key, value in kwargs.iteritems():
+            setattr(self, key, value)
+class Publication(BaseObject):
+    pass
+
+class Issue(BaseObject):
+    pass
+
+class Location(BaseObject):
+    def __repr__(self):
+        return "%s(%s, %s)" % (self.__class__.__name__, repr(self.issue_id), repr(str(self._name.name)))
+
+    def _get_name(self):
+        return self._name
+
+    def _set_name(self, name):
+        session = create_session()
+        s = session.query(LocationName).selectfirst(location_name_table.c.name==name)
+        session.clear()
+        if s is not None:
+            self._name = s
+
+            return
+
+        found = False
+
+        for i in session.new:
+            if isinstance(i, LocationName) and i.name == name:
+                self._name = i
+                found = True
+
+                break
+
+        if found == False:
+            self._name = LocationName(name=name)
+
+    name = property(_get_name, _set_name)
+
+class LocationName(BaseObject):
+    def __repr__(self):
+        return "%s()" % (self.__class__.__name__)
+
+class PageSize(BaseObject):
+    def __repr__(self):
+        return "%s(%sx%s, %s)" % (self.__class__.__name__, self.width, self.height, self.name)
+        
+class Magazine(BaseObject):
+    def __repr__(self):
+        return "%s(%s, %s)" % (self.__class__.__name__, repr(self.location), repr(self.size))
+
+class Page(BaseObject):
+    def __repr__(self):
+        return "%s(%s)" % (self.__class__.__name__, repr(self.page_no))
+
+class MagazinePage(Page):
+    def __repr__(self):
+        return "%s(%s, %s)" % (self.__class__.__name__, repr(self.page_no), repr(self.magazine))
+
+class ClassifiedPage(MagazinePage):
+    pass
+
+class InheritTest(testbase.AssertMixin):
+    """tests a large polymorphic relationship"""
+    def setUpAll(self):
+        global metadata, publication_table, issue_table, location_table, location_name_table, magazine_table, \
+        page_table, magazine_page_table, classified_page_table, page_size_table
+        
+        metadata = BoundMetaData(testbase.db)
+
+        zerodefault = {} #{'default':0}
+        publication_table = Table('publication', metadata,
+            Column('id', Integer, primary_key=True, default=None),
+            Column('name', String(45), default=''),
+        )
+        issue_table = Table('issue', metadata,
+            Column('id', Integer, primary_key=True, default=None),
+            Column('publication_id', Integer, ForeignKey('publication.id'), **zerodefault),
+            Column('issue', Integer, **zerodefault),
+        )
+        location_table = Table('location', metadata,
+            Column('id', Integer, primary_key=True, default=None),
+            Column('issue_id', Integer, ForeignKey('issue.id'), **zerodefault),
+            Column('ref', CHAR(3), default=''),
+            Column('location_name_id', Integer, ForeignKey('location_name.id'), **zerodefault),
+        )
+        location_name_table = Table('location_name', metadata,
+            Column('id', Integer, primary_key=True, default=None),
+            Column('name', String(45), default=''),
+        )
+        magazine_table = Table('magazine', metadata,
+            Column('id', Integer, primary_key=True, default=None),
+            Column('location_id', Integer, ForeignKey('location.id'), **zerodefault),
+            Column('page_size_id', Integer, ForeignKey('page_size.id'), **zerodefault),
+        )
+        page_table = Table('page', metadata,
+            Column('id', Integer, primary_key=True, default=None),
+            Column('page_no', Integer, **zerodefault),
+            Column('type', CHAR(1), default='p'),
+        )
+        magazine_page_table = Table('magazine_page', metadata,
+            Column('page_id', Integer, ForeignKey('page.id'), primary_key=True, **zerodefault),
+            Column('magazine_id', Integer, ForeignKey('magazine.id'), **zerodefault),
+            Column('orders', TEXT, default=''),
+        )
+        classified_page_table = Table('classified_page', metadata,
+            Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True, **zerodefault),
+            Column('titles', String(45), default=''),
+        )
+        page_size_table = Table('page_size', metadata,
+            Column('id', Integer, primary_key=True, default=None),
+            Column('width', Integer, **zerodefault),
+            Column('height', Integer, **zerodefault),
+            Column('name', String(45), default=''),
+        )
+
+        metadata.create_all()
+        
+        publication_mapper = mapper(Publication, publication_table)
+
+        issue_mapper = mapper(Issue, issue_table, properties = {
+            'publication': relation(Publication, backref=backref('issues', cascade="all, delete-orphan")),
+        })
+
+        location_name_mapper = mapper(LocationName, location_name_table)
+
+        location_mapper = mapper(Location, location_table, properties = {
+            'issue': relation(Issue, backref='locations'),
+            '_name': relation(LocationName),
+        })
+
+        issue_mapper.add_property('locations', relation(Location, lazy=False, private=True, backref='issue'))
+
+        page_size_mapper = mapper(PageSize, page_size_table)
+
+        page_join = polymorphic_union(
+            {
+                'm': page_table.join(magazine_page_table),
+                'c': page_table.join(magazine_page_table).join(classified_page_table),
+                'p': page_table.select(page_table.c.type=='p'),
+            }, None, 'page_join')
+
+        magazine_join = polymorphic_union(
+            {
+                'm': page_table.join(magazine_page_table),
+                'c': page_table.join(magazine_page_table).join(classified_page_table),
+            }, None, 'page_join')
+
+        magazine_mapper = mapper(Magazine, magazine_table, properties = {
+            'location': relation(Location, backref=backref('magazine', uselist=False)),
+            'size': relation(PageSize),
+        })
+
+        page_mapper = mapper(Page, page_table, select_table=page_join, polymorphic_on=page_join.c.type, polymorphic_identity='p')
+
+        magazine_page_mapper = mapper(MagazinePage, magazine_page_table, select_table=magazine_join, inherits=page_mapper, polymorphic_identity='m', properties={
+            'magazine': relation(Magazine, backref=backref('pages', order_by=magazine_join.c.page_no))
+        })
+
+        classified_page_mapper = mapper(ClassifiedPage, classified_page_table, inherits=magazine_page_mapper, polymorphic_identity='c')
+
+    def tearDownAll(self):
+        metadata.drop_all()
+        clear_mappers()
+        
+    def testone(self):
+        session = create_session()
+
+        pub = Publication(name='Test')
+        issue = Issue(issue=46,publication=pub)
+
+        location = Location(ref='ABC',name='London',issue=issue)
+
+        page_size = PageSize(name='A4',width=210,height=297)
+
+        magazine = Magazine(location=location,size=page_size)
+        page = ClassifiedPage(magazine=magazine,page_no=1)
+        page2 = MagazinePage(magazine=magazine,page_no=2)
+        page3 = ClassifiedPage(magazine=magazine,page_no=3)
+        session.save(pub)
+        print [x for x in session]
+        
+        session.flush()
+        print [x for x in session]
+        session.clear()
+
+        session.echo_uow=True
+        session.flush()
+        session.clear()
+        p = session.query(Publication).selectone_by(name='Test')
+
+        print p.issues[0].locations[0].magazine.pages
+        print [page, page2, page3]
+        assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3])
+        
+if __name__ == '__main__':
+    testbase.main()
\ No newline at end of file