-from sqlalchemy import sql, schema, exceptions
+from sqlalchemy import sql, schema, exceptions, util
from sqlalchemy import logging
from sqlalchemy.orm import util as mapperutil
if binary.left.table == binary.right.table:
raise exceptions.ArgumentError("need foreign_keys argument for self-referential sync")
- if binary.left in [f.column for f in binary.right.foreign_keys]:
+ if binary.left in util.Set([f.column for f in binary.right.foreign_keys]):
dest_column = binary.right
source_column = binary.left
- elif binary.right in [f.column for f in binary.left.foreign_keys]:
+ elif binary.right in util.Set([f.column for f in binary.left.foreign_keys]):
dest_column = binary.left
source_column = binary.right
else:
assert alice1.name == alice2.name == 'alice'
assert bob.name == 'bob'
+class SyncCompileTest(testbase.ORMTest):
+ """test that syncrules compile properly on custom inherit conds"""
+ def define_tables(self, metadata):
+ global _a_table, _b_table, _c_table
+
+ _a_table = Table('a', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data1', String)
+ )
+
+ _b_table = Table('b', metadata,
+ Column('a_id', Integer, ForeignKey('a.id'), primary_key=True),
+ Column('data2', String)
+ )
+
+ _c_table = Table('c', metadata,
+ # Column('a_id', Integer, ForeignKey('b.a_id'), primary_key=True), #works
+ Column('b_a_id', Integer, ForeignKey('b.a_id'), primary_key=True),
+ Column('data3', String)
+ )
+
+ def test_joins(self):
+ for j1 in (None, _b_table.c.a_id==_a_table.c.id, _a_table.c.id==_b_table.c.a_id):
+ for j2 in (None, _b_table.c.a_id==_c_table.c.b_a_id, _c_table.c.b_a_id==_b_table.c.a_id):
+ self._do_test(j1, j2)
+ for t in _a_table.metadata.table_iterator(reverse=True):
+ t.delete().execute().close()
+
+ def _do_test(self, j1, j2):
+ class A(object):
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ class B(A):
+ pass
+
+ class C(B):
+ pass
+
+ mapper(A, _a_table)
+ mapper(B, _b_table, inherits=A,
+ inherit_condition=j1
+ )
+ mapper(C, _c_table, inherits=B,
+ inherit_condition=j2
+ )
+
+ session = create_session()
+
+ a = A(data1='a1')
+ session.save(a)
+
+ b = B(data1='b1', data2='b2')
+ session.save(b)
+
+ c = C(data1='c1', data2='c2', data3='c3')
+ session.save(c)
+
+ session.flush()
+ session.clear()
+
+ assert len(session.query(A).all()) == 3
+ assert len(session.query(B).all()) == 2
+ assert len(session.query(C).all()) == 1
if __name__ == "__main__":