]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixes to the "exists" function involving inheritance (any(), has(),
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 May 2008 19:23:58 +0000 (19:23 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 May 2008 19:23:58 +0000 (19:23 +0000)
~contains()); the full target join will be rendered into the
EXISTS clause for relations that link to subclasses.

CHANGES
lib/sqlalchemy/orm/properties.py
test/orm/inheritance/query.py

diff --git a/CHANGES b/CHANGES
index 35da79f4b071626f987111c1b9c34178f93f0b2b..989a960e491caffd4fdec527575b1ee83b4c4130 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -29,6 +29,10 @@ CHANGES
       class, not the mapped selectable, as the source of column
       attributes - so a warning is still issued.
 
+    - fixes to the "exists" function involving inheritance (any(), has(),
+      ~contains()); the full target join will be rendered into the
+      EXISTS clause for relations that link to subclasses.
+      
     - restored usage of append_result() extension method for primary 
       query rows, when the extension is present and only a single-
       entity result is being returned.
index 847e793986e3c34084f45f57bd270cdbcc35fdc4..33a0ff4326dad1ce67121c2d1fd5fc545904d8ba 100644 (file)
@@ -314,7 +314,7 @@ class PropertyLoader(StrategizedProperty):
                 to_selectable = target_mapper._with_polymorphic_selectable() #mapped_table
             else:
                 to_selectable = None
-
+            
             pj, sj, source, dest, target_adapter = self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable)
 
             for k in kwargs:
@@ -738,12 +738,12 @@ class PropertyLoader(StrategizedProperty):
             if dest_polymorphic and self.mapper.with_polymorphic:
                 dest_selectable = self.mapper._with_polymorphic_selectable()
             else:
-                dest_selectable = None
+                dest_selectable = self.mapper.mapped_table
             if self._is_self_referential():
                 if dest_selectable:
                     dest_selectable = dest_selectable.alias()
                 else:
-                    dest_selectable = self.mapper.local_table.alias()
+                    dest_selectable = self.mapper.mapped_table.alias()
                 
         primaryjoin = self.primaryjoin
         if source_selectable:
index 964cbceabbda5d80d83e6c07c681ea3f98a42614..34ead1622cfd16633e73a074c7f83087c3970e1c 100644 (file)
@@ -274,18 +274,10 @@ def make_test(select_type):
                 sess.query(Person).filter(Person.paperwork.any(Paperwork.description=="review #2")).all(), [m1]
             )
             
-            if select_type == '':
-                # this tests that a hand-rolled criterion in the any() doesn't get clobbered by
-                # aliasing, when the mapper is not set up for polymorphic joins
-                self.assertEquals(
-                    sess.query(Company).filter(Company.employees.any(and_(Engineer.primary_language=='cobol', people.c.person_id==engineers.c.person_id))).one(),
-                    c2
-                    )
-            else:
-                self.assertEquals(
-                    sess.query(Company).filter(Company.employees.any(and_(Engineer.primary_language=='cobol'))).one(),
-                    c2
-                    )
+            self.assertEquals(
+                sess.query(Company).filter(Company.employees.of_type(Engineer).any(and_(Engineer.primary_language=='cobol'))).one(),
+                c2
+                )
                 
         
         def test_expire(self):
@@ -514,5 +506,70 @@ class SelfReferentialTest(ORMTest):
             sess.query(Engineer).join('reports_to')
         self.assertRaises(exceptions.InvalidRequestError, go)
 
+class M2MFilterTest(ORMTest):
+    keep_mappers = True
+    keep_data = True
+    
+    def define_tables(self, metadata):
+        global people, engineers, Organization
+        
+        organizations = Table('organizations', metadata,
+            Column('id', Integer, Sequence('org_id_seq', optional=True), primary_key=True),
+            Column('name', String(50)),
+            )
+        engineers_to_org = Table('engineers_org', metadata,
+            Column('org_id', Integer, ForeignKey('organizations.id')),
+            Column('engineer_id', Integer, ForeignKey('engineers.person_id')),
+        )
+        
+        people = Table('people', metadata,
+           Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+           Column('name', String(50)),
+           Column('type', String(30)))
+
+        engineers = Table('engineers', metadata,
+           Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+           Column('primary_language', String(50)),
+          )
+        
+        class Organization(fixtures.Base):
+            pass
+            
+        mapper(Organization, organizations, properties={
+            'engineers':relation(Engineer, secondary=engineers_to_org, backref='organizations')
+        })
+        
+        mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
+        mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
+    
+    def insert_data(self):
+        e1 = Engineer(name='e1')
+        e2 = Engineer(name='e2')
+        e3 = Engineer(name='e3')
+        e4 = Engineer(name='e4')
+        org1 = Organization(name='org1', engineers=[e1, e2])
+        org2 = Organization(name='org2', engineers=[e3, e4])
+        
+        sess = create_session()
+        sess.save(org1)
+        sess.save(org2)
+        sess.flush()
+        
+    def test_not_contains(self):
+        sess = create_session()
+        
+        e1 = sess.query(Person).filter(Engineer.name=='e1').one()
+        
+        # this works
+        self.assertEquals(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')])
+
+        # this had a bug
+        self.assertEquals(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')])
+    
+    def test_any(self):
+        sess = create_session()
+        self.assertEquals(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')])
+        self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')])
+        
 if __name__ == "__main__":
     testenv.main()