From: Mike Bayer Date: Thu, 18 Sep 2008 21:41:37 +0000 (+0000) Subject: - "non-batch" mode in mapper(), a feature which allows X-Git-Tag: rel_0_5rc2~43 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=9632a752d2661326aabbfc02185f17318d41aef2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - "non-batch" mode in mapper(), a feature which allows mapper extension methods to be called as each instance is updated/inserted, now honors the insert order of the objects given. - added some tests, some commented out, involving [ticket:1171] --- diff --git a/CHANGES b/CHANGES index 41e38e72bc..56d2ec7254 100644 --- a/CHANGES +++ b/CHANGES @@ -6,7 +6,12 @@ CHANGES 0.5.0rc2 ======== - +- orm + - "non-batch" mode in mapper(), a feature which allows + mapper extension methods to be called as each instance + is updated/inserted, now honors the insert order + of the objects given. + - sql - column.in_(someselect) can now be used as a columns-clause expression without the subquery diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 03f8b45534..21cbe3f2b8 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1047,7 +1047,9 @@ class Mapper(object): # if batch=false, call _save_obj separately for each object if not single and not self.batch: - for state in states: + def comparator(a, b): + return cmp(getattr(a, 'insert_order', 0), getattr(b, 'insert_order', 0)) + for state in sorted(states, comparator): self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) return diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 5266a682b6..0951220392 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -155,6 +155,7 @@ class CompositeProperty(ColumnProperty): else: values = other.__composite_values__() return sql.and_(*[a==b for a, b in zip(self.prop.columns, values)]) + def __ne__(self, other): return sql.not_(self.__eq__(other)) diff --git a/test/orm/query.py b/test/orm/query.py index ffc7104c82..236680c77d 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -250,12 +250,30 @@ class InvalidGenerationsTest(QueryTest): q = s.query(User).order_by(User.name) self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") -class OperatorTest(QueryTest): +class OperatorTest(QueryTest, AssertsCompiledSQL): """test sql.Comparator implementation for MapperProperties""" def _test(self, clause, expected): - c = str(clause.compile(dialect = default.DefaultDialect())) - assert c == expected, "%s != %s" % (c, expected) + self.assert_compile(clause, expected, dialect=default.DefaultDialect()) + + def define_tables(self, metadata): + global nodes + nodes = Table('nodes', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('parent_id', Integer, ForeignKey('nodes.id')), + Column('data', String(30))) + + def insert_data(self): + global Node + + class Node(Base): + pass + + mapper(Node, nodes, properties={ + 'children':relation(Node, + backref=backref('parent', remote_side=[nodes.c.id]) + ) + }) def test_arithmetic(self): create_session().query(User) @@ -276,6 +294,8 @@ class OperatorTest(QueryTest): def test_comparison(self): create_session().query(User) + ualias = aliased(User) + for (py_op, fwd_op, rev_op) in ((operator.lt, '<', '>'), (operator.gt, '>', '<'), (operator.eq, '=', '='), @@ -291,6 +311,10 @@ class OperatorTest(QueryTest): (literal('a'), 'b', ':param_1', ':param_2'), (literal('a'), User.id, ':param_1', 'users.id'), (literal('a'), literal('b'), ':param_1', ':param_2'), + (ualias.id, literal('b'), 'users_1.id', ':param_1'), + (User.id, ualias.name, 'users.id', 'users_1.name'), + (User.name, ualias.name, 'users.name', 'users_1.name'), + (ualias.name, User.name, 'users_1.name', 'users.name'), ): # the compiled clause should match either (e.g.): @@ -303,8 +327,51 @@ class OperatorTest(QueryTest): "\n'" + compiled + "'\n does not match\n'" + fwd_sql + "'\n or\n'" + rev_sql + "'") + def test_relation(self): + self._test(User.addresses.any(Address.id==17), + "EXISTS (SELECT 1 " + "FROM addresses " + "WHERE users.id = addresses.user_id AND addresses.id = :id_1)" + ) + + self._test(Address.user == User(id=7), ":param_1 = addresses.user_id") + + def test_selfref_relation(self): + + # auto self-referential aliasing + self._test( + Node.children.any(Node.data=='n1'), + "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE " + "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)" + ) + + # manual aliasing + nalias = aliased(Node) + + # fails + #self._test( + # nalias.children.any(Node.data=='some data'), + # "EXISTS (SELECT 1 FROM nodes WHERE " + # "nodes_1.id = nodes.parent_id AND nodes.data = :data_1)") + + # fails + #self._test( + # Node.children.any(nalias.data=='some data'), + # "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE " + # "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)" + # ) + + self._test( + nalias.parent == Node(id=7), + ":param_1 = nodes_1.parent_id" + ) + + self._test( + nalias.children.contains(Node(id=7)), "nodes_1.id = :param_1" + ) + def test_op(self): - assert str(User.name.op('ilike')('17').compile(dialect=default.DefaultDialect())) == "users.name ilike :name_1" + self._test(User.name.op('ilike')('17'), "users.name ilike :name_1") def test_in(self): self._test(User.id.in_(['a', 'b']), @@ -314,6 +381,12 @@ class OperatorTest(QueryTest): self._test(User.id.between('a', 'b'), "users.id BETWEEN :id_1 AND :id_2") + def test_selfref_between(self): + ualias = aliased(User) + self._test(User.id.between(ualias.id, ualias.id), "users.id BETWEEN users_1.id AND users_1.id") + # fails: + # self._test(ualias.id.between(User.id, User.id), "users_1.id BETWEEN users.id AND users.id") + def test_clauses(self): for (expr, compare) in ( (func.max(User.id), "max(users.id)"), @@ -325,6 +398,7 @@ class OperatorTest(QueryTest): c = expr.compile(dialect=default.DefaultDialect()) assert str(c) == compare, "%s != %s" % (str(c), compare) + class RawSelectTest(QueryTest, AssertsCompiledSQL): """compare a bunch of select() tests with the equivalent Query using straight table/columns. diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 90134d1428..05f4d88f3b 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -1511,9 +1511,11 @@ class SaveTest(_fixtures.FixtureTest): def test_batch_mode(self): """The 'batch=False' flag on mapper()""" + names = [] class TestExtension(sa.orm.MapperExtension): def before_insert(self, mapper, connection, instance): self.current_instance = instance + names.append(instance.name) def after_insert(self, mapper, connection, instance): assert instance is self.current_instance @@ -1524,18 +1526,25 @@ class SaveTest(_fixtures.FixtureTest): session = create_session() session.add_all((u1, u2)) session.flush() + + u3 = User(name='user3') + u4 = User(name='user4') + u5 = User(name='user5') + + session.add_all([u4, u5, u3]) + session.flush() + + # test insert ordering is maintained + assert names == ['user1', 'user2', 'user4', 'user5', 'user3'] session.clear() - + sa.orm.clear_mappers() m = mapper(User, users, extension=TestExtension()) u1 = User(name='user1') u2 = User(name='user2') - try: - session.flush() - assert False - except AssertionError: - assert True + session.add_all((u1, u2)) + self.assertRaises(AssertionError, session.flush) class ManyToOneTest(_fixtures.FixtureTest):