]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- "non-batch" mode in mapper(), a feature which allows
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Sep 2008 21:41:37 +0000 (21:41 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Sep 2008 21:41:37 +0000 (21:41 +0000)
mapper extension methods to be called as each instance
is updated/inserted, now honors the insert order
of the objects given.
- added some tests, some commented out, involving [ticket:1171]

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
test/orm/query.py
test/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index 41e38e72bca3008e6a40106b2474e3e4669b0800..56d2ec7254c74acceb9d869e2f9313029fd457d4 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,7 +6,12 @@ CHANGES
 
 0.5.0rc2
 ========
-
+- orm
+    - "non-batch" mode in mapper(), a feature which allows
+      mapper extension methods to be called as each instance
+      is updated/inserted, now honors the insert order
+      of the objects given. 
+      
 - sql
     - column.in_(someselect) can now be used as 
       a columns-clause expression without the subquery
index 03f8b455349163adf45cef6f3af70b9cd7a12897..21cbe3f2b8d1ee1dc7cc8dc5101211559433f85b 100644 (file)
@@ -1047,7 +1047,9 @@ class Mapper(object):
 
         # if batch=false, call _save_obj separately for each object
         if not single and not self.batch:
-            for state in states:
+            def comparator(a, b):
+                return cmp(getattr(a, 'insert_order', 0), getattr(b, 'insert_order', 0))
+            for state in sorted(states, comparator):
                 self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
             return
 
index 5266a682b65b5f0e3bcd37b7bba70c4501bfef26..09512203922a737dd05886951ceefb9ca39e05b5 100644 (file)
@@ -155,6 +155,7 @@ class CompositeProperty(ColumnProperty):
             else:
                 values = other.__composite_values__()
             return sql.and_(*[a==b for a, b in zip(self.prop.columns, values)])
+            
         def __ne__(self, other):
             return sql.not_(self.__eq__(other))
 
index ffc7104c8213f8d8736a0fdca575bde151ffc315..236680c77d0cf83c249ac3c472d659356deac54e 100644 (file)
@@ -250,12 +250,30 @@ class InvalidGenerationsTest(QueryTest):
         q = s.query(User).order_by(User.name)
         self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
         
-class OperatorTest(QueryTest):
+class OperatorTest(QueryTest, AssertsCompiledSQL):
     """test sql.Comparator implementation for MapperProperties"""
 
     def _test(self, clause, expected):
-        c = str(clause.compile(dialect = default.DefaultDialect()))
-        assert c == expected, "%s != %s" % (c, expected)
+        self.assert_compile(clause, expected, dialect=default.DefaultDialect())
+
+    def define_tables(self, metadata):
+        global nodes
+        nodes = Table('nodes', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('parent_id', Integer, ForeignKey('nodes.id')),
+            Column('data', String(30)))
+        
+    def insert_data(self):
+        global Node
+
+        class Node(Base):
+            pass
+
+        mapper(Node, nodes, properties={
+            'children':relation(Node, 
+                backref=backref('parent', remote_side=[nodes.c.id])
+            )
+        })
 
     def test_arithmetic(self):
         create_session().query(User)
@@ -276,6 +294,8 @@ class OperatorTest(QueryTest):
 
     def test_comparison(self):
         create_session().query(User)
+        ualias = aliased(User)
+        
         for (py_op, fwd_op, rev_op) in ((operator.lt, '<', '>'),
                                         (operator.gt, '>', '<'),
                                         (operator.eq, '=', '='),
@@ -291,6 +311,10 @@ class OperatorTest(QueryTest):
                 (literal('a'), 'b', ':param_1', ':param_2'),
                 (literal('a'), User.id, ':param_1', 'users.id'),
                 (literal('a'), literal('b'), ':param_1', ':param_2'),
+                (ualias.id, literal('b'), 'users_1.id', ':param_1'),
+                (User.id, ualias.name, 'users.id', 'users_1.name'),
+                (User.name, ualias.name, 'users.name', 'users_1.name'),
+                (ualias.name, User.name, 'users_1.name', 'users.name'),
                 ):
 
                 # the compiled clause should match either (e.g.):
@@ -303,8 +327,51 @@ class OperatorTest(QueryTest):
                              "\n'" + compiled + "'\n does not match\n'" +
                              fwd_sql + "'\n or\n'" + rev_sql + "'")
 
+    def test_relation(self):
+        self._test(User.addresses.any(Address.id==17), 
+                        "EXISTS (SELECT 1 "
+                        "FROM addresses "
+                        "WHERE users.id = addresses.user_id AND addresses.id = :id_1)"
+                    )
+
+        self._test(Address.user == User(id=7), ":param_1 = addresses.user_id")
+
+    def test_selfref_relation(self):
+
+        # auto self-referential aliasing
+        self._test(
+            Node.children.any(Node.data=='n1'), 
+                "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE "
+                "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)"
+        )
+        
+        # manual aliasing
+        nalias = aliased(Node)
+        
+        # fails
+        #self._test(
+        #        nalias.children.any(Node.data=='some data'), 
+        #        "EXISTS (SELECT 1 FROM nodes WHERE "
+        #        "nodes_1.id = nodes.parent_id AND nodes.data = :data_1)")
+        
+        # fails
+        #self._test(
+        #        Node.children.any(nalias.data=='some data'), 
+        #        "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE "
+        #        "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)"
+        #        )
+
+        self._test(
+            nalias.parent == Node(id=7), 
+            ":param_1 = nodes_1.parent_id"
+        )
+        
+        self._test(
+            nalias.children.contains(Node(id=7)), "nodes_1.id = :param_1"
+        )
+        
     def test_op(self):
-        assert str(User.name.op('ilike')('17').compile(dialect=default.DefaultDialect())) == "users.name ilike :name_1"
+        self._test(User.name.op('ilike')('17'), "users.name ilike :name_1")
 
     def test_in(self):
          self._test(User.id.in_(['a', 'b']),
@@ -314,6 +381,12 @@ class OperatorTest(QueryTest):
         self._test(User.id.between('a', 'b'),
                    "users.id BETWEEN :id_1 AND :id_2")
 
+    def test_selfref_between(self):
+        ualias = aliased(User)
+        self._test(User.id.between(ualias.id, ualias.id), "users.id BETWEEN users_1.id AND users_1.id")
+        # fails:
+        # self._test(ualias.id.between(User.id, User.id), "users_1.id BETWEEN users.id AND users.id")
+
     def test_clauses(self):
         for (expr, compare) in (
             (func.max(User.id), "max(users.id)"),
@@ -325,6 +398,7 @@ class OperatorTest(QueryTest):
             c = expr.compile(dialect=default.DefaultDialect())
             assert str(c) == compare, "%s != %s" % (str(c), compare)
 
+
 class RawSelectTest(QueryTest, AssertsCompiledSQL):
     """compare a bunch of select() tests with the equivalent Query using straight table/columns.
     
index 90134d1428752959b469a2052635fe475411a64c..05f4d88f3b316c0970980184d4155ba9bb132c00 100644 (file)
@@ -1511,9 +1511,11 @@ class SaveTest(_fixtures.FixtureTest):
     def test_batch_mode(self):
         """The 'batch=False' flag on mapper()"""
 
+        names = []
         class TestExtension(sa.orm.MapperExtension):
             def before_insert(self, mapper, connection, instance):
                 self.current_instance = instance
+                names.append(instance.name)
             def after_insert(self, mapper, connection, instance):
                 assert instance is self.current_instance
 
@@ -1524,18 +1526,25 @@ class SaveTest(_fixtures.FixtureTest):
         session = create_session()
         session.add_all((u1, u2))
         session.flush()
+        
+        u3 = User(name='user3')
+        u4 = User(name='user4')
+        u5 = User(name='user5')
+        
+        session.add_all([u4, u5, u3])
+        session.flush()
+        
+        # test insert ordering is maintained
+        assert names == ['user1', 'user2', 'user4', 'user5', 'user3']
         session.clear()
-
+        
         sa.orm.clear_mappers()
 
         m = mapper(User, users, extension=TestExtension())
         u1 = User(name='user1')
         u2 = User(name='user2')
-        try:
-            session.flush()
-            assert False
-        except AssertionError:
-            assert True
+        session.add_all((u1, u2))
+        self.assertRaises(AssertionError, session.flush)
 
 
 class ManyToOneTest(_fixtures.FixtureTest):