]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug regarding inherit_condition passed
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 8 Aug 2008 04:25:17 +0000 (04:25 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 8 Aug 2008 04:25:17 +0000 (04:25 +0000)
with "A=B" versus "B=A" leading to errors
[ticket:1039]

CHANGES
lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/basic.py

diff --git a/CHANGES b/CHANGES
index e31a4ba38c67ee2f0c52ac6af7566ad4c3dd1236..acba941d80317d6ec392d276d4a1f3742caecb5b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,6 +1,13 @@
 =======
 CHANGES
 =======
+0.4.8
+=====
+- orm
+    - Fixed bug regarding inherit_condition passed
+      with "A=B" versus "B=A" leading to errors
+      [ticket:1039]
+      
 0.4.7p1
 =====
 - orm 
index d39e743e017f6c15b7c009b1fbbc680cc3f33564..b86fdc07b18c575f60073f11a3a6c540c56c5744 100644 (file)
@@ -1578,12 +1578,13 @@ class Mapper(object):
         def visit_binary(binary):
             leftcol = binary.left
             rightcol = binary.right
+
             if leftcol is None or rightcol is None:
                 return
             if leftcol.table not in needs_tables:
                 binary.left = sql.bindparam(None, None, type_=binary.right.type)
                 param_names.append((leftcol, binary.left))
-            elif rightcol not in needs_tables:
+            elif rightcol.table not in needs_tables:
                 binary.right = sql.bindparam(None, None, type_=binary.right.type)
                 param_names.append((rightcol, binary.right))
 
@@ -1595,7 +1596,7 @@ class Mapper(object):
                 break
             if not mapper.single:
                 allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
-
+        
         return sql.and_(*allconds), param_names
 
 Mapper.logger = logging.class_logger(Mapper)
index cb754fed596a13d565beba05d988eeb0f8493b5c..f59973d0263e867e5afbcd68bb9cc9e278b40951 100644 (file)
@@ -275,7 +275,42 @@ class GetTest(ORMTest):
     test_get_polymorphic = create_test(True, 'test_get_polymorphic')
     test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic')
 
+class InheritConditionTest(ORMTest):
+    def define_tables(self, metadata):
+        global base, child
+        base = Table('base', metadata,
+            Column('id', Integer, primary_key = True),
+            Column('type', String(40)))
+
+        child = Table('child', metadata,
+            Column('base_id', Integer, ForeignKey(base.c.id), primary_key = True),
+            Column('title', String(64)),
+            Column('parent_id', Integer, ForeignKey('child.base_id'))
+        )
+    
+    def test_inherit_cond(self):
+        class Base(fixtures.Base):
+            pass
+        class Child(Base):
+            pass
+            
+        mapper(Base, base, polymorphic_on=base.c.type)
+        mapper(Child, child, inherits=Base, inherit_condition=child.c.base_id==base.c.id, polymorphic_identity='c')
+        
+        sess = create_session()
+        sess.save(Child(title='c1'))
+        sess.save(Child(title='c2'))
+        sess.flush()
+        sess.clear()
 
+        for inherit_cond in (child.c.base_id==base.c.id, base.c.id==child.c.base_id):
+            clear_mappers()
+            mapper(Base, base, polymorphic_on=base.c.type)
+            mapper(Child, child, inherits=Base, inherit_condition=inherit_cond, polymorphic_identity='c')
+            
+            assert sess.query(Base).order_by(Base.id).all() == [Child(title='c1'), Child(title='c2')]
+        
+        
 class ConstructionTest(ORMTest):
     def define_tables(self, metadata):
         global content_type, content, product