]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- reworked the "SQL assertion" code to something more flexible and based off of Conne...
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 Dec 2008 02:16:52 +0000 (02:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 Dec 2008 02:16:52 +0000 (02:16 +0000)
will make use of the enhanced flexibility.

test/orm/cycles.py
test/orm/eager_relations.py
test/orm/mapper.py
test/orm/unitofwork.py
test/testlib/assertsql.py [new file with mode: 0644]
test/testlib/engines.py
test/testlib/testing.py

index c92228def758ee2bb789a6fa220b9ebd6747cdb0..2aafc5c54adf8944764d6013d1100fb9c0cf3771 100644 (file)
@@ -10,6 +10,7 @@ from testlib import testing
 from testlib.sa import Table, Column, Integer, String, ForeignKey
 from testlib.sa.orm import mapper, relation, backref, create_session
 from testlib.testing import eq_
+from testlib.assertsql import RegexSQL, ExactSQL, CompiledSQL
 from orm import _base
 
 
@@ -558,82 +559,31 @@ class OneToManyManyToOneTest(_base.MappedTest):
         sess.add(b)
         sess.add(p)
 
-        self.assert_sql(testing.db, sess.flush, [
-            ("INSERT INTO person (favorite_ball_id, data) "
-             "VALUES (:favorite_ball_id, :data)",
-             {'favorite_ball_id': None, 'data':'some data'}),
-
-            ("INSERT INTO ball (person_id, data) "
-             "VALUES (:person_id, :data)",
-             lambda ctx:{'person_id':p.id, 'data':'some data'}),
-
-            ("INSERT INTO ball (person_id, data) "
-             "VALUES (:person_id, :data)",
-             lambda ctx:{'person_id':p.id, 'data':'some data'}),
-
-            ("INSERT INTO ball (person_id, data) "
-             "VALUES (:person_id, :data)",
-                lambda ctx:{'person_id':p.id, 'data':'some data'}),
-
-            ("INSERT INTO ball (person_id, data) "
-             "VALUES (:person_id, :data)",
-             lambda ctx:{'person_id':p.id, 'data':'some data'}),
-
-            ("UPDATE person SET favorite_ball_id=:favorite_ball_id "
-             "WHERE person.id = :person_id",
-             lambda ctx:{'favorite_ball_id':p.favorite.id, 'person_id':p.id})
-            ],
-
-            with_sequences= [
-            ("INSERT INTO person (id, favorite_ball_id, data) "
-             "VALUES (:id, :favorite_ball_id, :data)",
-             lambda ctx: {'id':ctx.last_inserted_ids()[0],
-                          'favorite_ball_id': None,
-                          'data':'some data'}),
-
-            ("INSERT INTO ball (id, person_id, data) "
-             "VALUES (:id, :person_id, :data)",
-             lambda ctx: {'id':ctx.last_inserted_ids()[0],
-                          'person_id':p.id,
-                          'data':'some data'}),
-
-            ("INSERT INTO ball (id, person_id, data) "
-             "VALUES (:id, :person_id, :data)",
-             lambda ctx: {'id':ctx.last_inserted_ids()[0],
-                          'person_id':p.id,
-                          'data':'some data'}),
-
-            ("INSERT INTO ball (id, person_id, data) "
-             "VALUES (:id, :person_id, :data)",
-             lambda ctx: {'id':ctx.last_inserted_ids()[0],
-                          'person_id':p.id,
-                          'data':'some data'}),
-            ("INSERT INTO ball (id, person_id, data) "
-             "VALUES (:id, :person_id, :data)",
-             lambda ctx: {'id':ctx.last_inserted_ids()[0],
-                          'person_id':p.id,
-                          'data':'some data'}),
-            # heres the post update
-            ("UPDATE person SET favorite_ball_id=:favorite_ball_id "
-            "WHERE person.id = :person_id",
-            lambda ctx: {'favorite_ball_id':p.favorite.id, 'person_id':p.id})])
+        self.assert_sql_execution(
+            testing.db,
+            sess.flush,
+            RegexSQL("^INSERT INTO person", {'data':'some data'}),
+            RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}),
+            RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}),
+            RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}),
+            RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}),
+            ExactSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id "
+                        "WHERE person.id = :person_id",
+                        lambda ctx:{'favorite_ball_id':p.favorite.id, 'person_id':p.id}
+             ),
+        )
 
         sess.delete(p)
 
-        self.assert_sql(testing.db, sess.flush, [
-            # heres the post update (which is a pre-update with deletes)
-            ("UPDATE person SET favorite_ball_id=:favorite_ball_id "
-             "WHERE person.id = :person_id",
-             lambda ctx: {'person_id': p.id, 'favorite_ball_id': None}),
-
-            ("DELETE FROM ball WHERE ball.id = :id",
-             None),
-             # order cant be predicted, but something like:
-             #lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}]),
-
-            ("DELETE FROM person WHERE person.id = :id",
-             lambda ctx:[{'id': p.id}])])
-
+        self.assert_sql_execution(
+            testing.db, 
+            sess.flush, 
+            ExactSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id "
+                "WHERE person.id = :person_id",
+                lambda ctx: {'person_id': p.id, 'favorite_ball_id': None}),
+            ExactSQL("DELETE FROM ball WHERE ball.id = :id", None), # lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}])
+            ExactSQL("DELETE FROM person WHERE person.id = :id", lambda ctx:[{'id': p.id}])
+        )
 
     @testing.resolve_artifact_names
     def testpostupdate_o2m(self):
@@ -664,112 +614,74 @@ class OneToManyManyToOneTest(_base.MappedTest):
         sess = create_session()
         sess.add_all((b,p,b2,b3,b4))
 
-        self.assert_sql(testing.db, sess.flush, [
-            ("INSERT INTO ball (person_id, data) "
+        self.assert_sql_execution(
+            testing.db,
+            sess.flush,
+            CompiledSQL("INSERT INTO ball (person_id, data) "
              "VALUES (:person_id, :data)",
              {'person_id':None, 'data':'some data'}),
 
-            ("INSERT INTO ball (person_id, data) "
+            CompiledSQL("INSERT INTO ball (person_id, data) "
              "VALUES (:person_id, :data)",
              {'person_id':None, 'data':'some data'}),
 
-            ("INSERT INTO ball (person_id, data) "
+            CompiledSQL("INSERT INTO ball (person_id, data) "
              "VALUES (:person_id, :data)",
              {'person_id':None, 'data':'some data'}),
 
-            ("INSERT INTO ball (person_id, data) "
+            CompiledSQL("INSERT INTO ball (person_id, data) "
              "VALUES (:person_id, :data)",
              {'person_id':None, 'data':'some data'}),
 
-            ("INSERT INTO person (favorite_ball_id, data) "
+            CompiledSQL("INSERT INTO person (favorite_ball_id, data) "
              "VALUES (:favorite_ball_id, :data)",
              lambda ctx:{'favorite_ball_id':b.id, 'data':'some data'}),
 
-            # heres the post update on each one-to-many item
-            ("UPDATE ball SET person_id=:person_id "
+            CompiledSQL("UPDATE ball SET person_id=:person_id "
              "WHERE ball.id = :ball_id",
              lambda ctx:{'person_id':p.id,'ball_id':b.id}),
 
-            ("UPDATE ball SET person_id=:person_id "
+            CompiledSQL("UPDATE ball SET person_id=:person_id "
              "WHERE ball.id = :ball_id",
              lambda ctx:{'person_id':p.id,'ball_id':b2.id}),
 
-            ("UPDATE ball SET person_id=:person_id "
+            CompiledSQL("UPDATE ball SET person_id=:person_id "
              "WHERE ball.id = :ball_id",
              lambda ctx:{'person_id':p.id,'ball_id':b3.id}),
 
-            ("UPDATE ball SET person_id=:person_id "
+            CompiledSQL("UPDATE ball SET person_id=:person_id "
              "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id':p.id,'ball_id':b4.id})],
-
-            with_sequences=[
-            ("INSERT INTO ball (id, person_id, data) "
-             "VALUES (:id, :person_id, :data)",
-             lambda ctx: {'id':ctx.last_inserted_ids()[0],
-                          'person_id':None,
-                          'data':'some data'}),
-            ("INSERT INTO ball (id, person_id, data) "
-             "VALUES (:id, :person_id, :data)",
-             lambda ctx:{'id':ctx.last_inserted_ids()[0],
-                         'person_id':None,
-                         'data':'some data'}),
-            ("INSERT INTO ball (id, person_id, data) "
-             "VALUES (:id, :person_id, :data)",
-             lambda ctx:{'id':ctx.last_inserted_ids()[0],
-                         'person_id':None,
-                         'data':'some data'}),
-            ("INSERT INTO ball (id, person_id, data) "
-             "VALUES (:id, :person_id, :data)",
-             lambda ctx:{'id':ctx.last_inserted_ids()[0],
-                         'person_id':None,
-                         'data':'some data'}),
-            ("INSERT INTO person (id, favorite_ball_id, data) "
-             "VALUES (:id, :favorite_ball_id, :data)",
-             lambda ctx:{'id':ctx.last_inserted_ids()[0],
-                         'favorite_ball_id':b.id,
-                         'data':'some data'}),
-            ("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id':p.id,'ball_id':b.id}),
-
-            ("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id':p.id,'ball_id':b2.id}),
-
-            ("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id':p.id,'ball_id':b3.id}),
-
-            ("UPDATE ball SET person_id=:person_id "
-             "WHERE ball.id = :ball_id",
-             lambda ctx:{'person_id':p.id,'ball_id':b4.id})])
-
+             lambda ctx:{'person_id':p.id,'ball_id':b4.id}),
+        )
+        
         sess.delete(p)
-        self.assert_sql(testing.db, sess.flush, [
-            ("UPDATE ball SET person_id=:person_id "
+        
+        self.assert_sql_execution(testing.db, sess.flush, 
+            CompiledSQL("UPDATE ball SET person_id=:person_id "
              "WHERE ball.id = :ball_id",
              lambda ctx:{'person_id': None, 'ball_id': b.id}),
 
-            ("UPDATE ball SET person_id=:person_id "
+            CompiledSQL("UPDATE ball SET person_id=:person_id "
              "WHERE ball.id = :ball_id",
              lambda ctx:{'person_id': None, 'ball_id': b2.id}),
 
-            ("UPDATE ball SET person_id=:person_id "
+            CompiledSQL("UPDATE ball SET person_id=:person_id "
              "WHERE ball.id = :ball_id",
              lambda ctx:{'person_id': None, 'ball_id': b3.id}),
 
-            ("UPDATE ball SET person_id=:person_id "
+            CompiledSQL("UPDATE ball SET person_id=:person_id "
              "WHERE ball.id = :ball_id",
              lambda ctx:{'person_id': None, 'ball_id': b4.id}),
 
-            ("DELETE FROM person WHERE person.id = :id",
+            CompiledSQL("DELETE FROM person WHERE person.id = :id",
              lambda ctx:[{'id':p.id}]),
 
-            ("DELETE FROM ball WHERE ball.id = :id",
+            CompiledSQL("DELETE FROM ball WHERE ball.id = :id",
              lambda ctx:[{'id': b.id},
                          {'id': b2.id},
                          {'id': b3.id},
-                         {'id': b4.id}])])
+                         {'id': b4.id}])
+        )
 
 
 class SelfReferentialPostUpdateTest(_base.MappedTest):
@@ -859,20 +771,24 @@ class SelfReferentialPostUpdateTest(_base.MappedTest):
         remove_child(root, cats)
         # pre-trigger lazy loader on 'cats' to make the test easier
         cats.children
-        self.assert_sql(testing.db, lambda: session.flush(), [
-            ("UPDATE node SET prev_sibling_id=:prev_sibling_id "
+        self.assert_sql_execution(
+            testing.db, 
+            session.flush,
+            CompiledSQL("UPDATE node SET prev_sibling_id=:prev_sibling_id "
              "WHERE node.id = :node_id",
              lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id}),
 
-            ("UPDATE node SET next_sibling_id=:next_sibling_id "
+            CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
              "WHERE node.id = :node_id",
              lambda ctx:{'next_sibling_id':stories.id, 'node_id':about.id}),
 
-            ("UPDATE node SET next_sibling_id=:next_sibling_id "
+            CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
              "WHERE node.id = :node_id",
              lambda ctx:{'next_sibling_id':None, 'node_id':cats.id}),
-            ("DELETE FROM node WHERE node.id = :id",
-             lambda ctx:[{'id':cats.id}])])
+             
+            CompiledSQL("DELETE FROM node WHERE node.id = :id",
+             lambda ctx:[{'id':cats.id}])
+        )
 
 
 class SelfReferentialPostUpdateTest2(_base.MappedTest):
index 94fabd0b11bb17e5a98fcdad982cf3fe7835f0bb..d704f9992516bc8f7b75742ba22603996550d934 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy.orm import eagerload, deferred, undefer
 from testlib.sa import Table, Column, Integer, String, ForeignKey, and_
 from testlib.sa.orm import mapper, relation, create_session, lazyload
 from testlib.testing import eq_
+from testlib.assertsql import CompiledSQL
 from orm import _base, _fixtures
 
 class EagerTest(_fixtures.FixtureTest):
@@ -1020,16 +1021,13 @@ class SelfReferentialEagerTest(_base.MappedTest):
             d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first()
 
         # test that the query isn't wrapping the initial query for eager loading.
-        # testing only sqlite for now since the query text is slightly different on other
-        # dialects
-        if testing.against('sqlite'):
-            self.assert_sql(testing.db, go, [
-                (
-                    "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes "
-                    "WHERE nodes.data = :data_1 ORDER BY nodes.id  LIMIT 1 OFFSET 0",
-                    {'data_1': 'n1'}
-                ),
-            ])
+        self.assert_sql_execution(testing.db, go, 
+            CompiledSQL(
+                "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes "
+                "WHERE nodes.data = :data_1 ORDER BY nodes.id  LIMIT 1 OFFSET 0",
+                {'data_1': 'n1'}
+            )
+        )
 
     @testing.fails_on('maxdb')
     @testing.resolve_artifact_names
index b18d8ac0015f1f963d9bed084dbe9eeaa0da3d4c..7e2d3b5958a0f1c7f3f790a234f44eb97e8a3ab1 100644 (file)
@@ -906,12 +906,9 @@ class OptionsTest(_fixtures.FixtureTest):
 
         sess.clear()
 
-        # test that eager loading doesnt modify parent mapper
-        def go():
-            u = sess.query(User).filter_by(id=8).one()
-            eq_(u.id, 8)
-            eq_(len(u.addresses), 3)
-        assert "tbl_row_count" not in self.capture_sql(testing.db, go)
+        u = sess.query(User).filter_by(id=8).one()
+        eq_(u.id, 8)
+        eq_(len(u.addresses), 3)
 
     @testing.fails_on('maxdb')
     @testing.resolve_artifact_names
@@ -1396,7 +1393,10 @@ class DeferredTest(_fixtures.FixtureTest):
         sess = create_session()
         q = sess.query(Order).order_by(Order.id).options(defer('user_id'))
 
-        self.sql_eq_(q.all, [
+        def go():
+            q.all()[0].user_id
+                
+        self.sql_eq_(go, [
             ("SELECT orders.id AS orders_id, "
              "orders.address_id AS orders_address_id, "
              "orders.description AS orders_description, "
@@ -1404,7 +1404,7 @@ class DeferredTest(_fixtures.FixtureTest):
              "FROM orders ORDER BY orders.id", {}),
             ("SELECT orders.user_id AS orders_user_id "
              "FROM orders WHERE orders.id = :param_1",
-             {'param_1':3})])
+             {'param_1':1})])
         sess.clear()
 
         q2 = q.options(sa.orm.undefer('user_id'))
index 975743fb0da0044b4230b58a01a3d03973dff51a..627e7cb99b56829f6d9f777c78fb707b7f1f57a4 100644 (file)
@@ -13,7 +13,7 @@ from testlib.testing import eq_, ne_
 from orm import _base, _fixtures
 from engine import _base as engine_base
 import pickleable
-
+from testlib.assertsql import AllOf, CompiledSQL
 
 class UnitOfWorkTest(object):
     pass
@@ -1602,33 +1602,22 @@ class ManyToOneTest(_fixtures.FixtureTest):
         objects[2].email_address = 'imnew@foo.bar'
         objects[3].user = User()
         objects[3].user.name = 'imnewlyadded'
-        self.assert_sql(testing.db,
+        self.assert_sql_execution(testing.db,
                         session.flush,
-                        [
-            ("INSERT INTO users (name) VALUES (:name)",
-             {'name': 'imnewlyadded'} ),
-
-            {"UPDATE addresses SET email_address=:email_address "
-             "WHERE addresses.id = :addresses_id":
-             lambda ctx: {'email_address': 'imnew@foo.bar',
-                          'addresses_id': objects[2].id},
-             "UPDATE addresses SET user_id=:user_id "
-             "WHERE addresses.id = :addresses_id":
-             lambda ctx: {'user_id': objects[3].user.id,
-                          'addresses_id': objects[3].id}},
-                        ],
-                        with_sequences=[
-            ("INSERT INTO users (id, name) VALUES (:id, :name)",
-             lambda ctx:{'name': 'imnewlyadded',
-                         'id':ctx.last_inserted_ids()[0]}),
-            {"UPDATE addresses SET email_address=:email_address "
-             "WHERE addresses.id = :addresses_id":
-             lambda ctx: {'email_address': 'imnew@foo.bar',
-                          'addresses_id': objects[2].id},
-             ("UPDATE addresses SET user_id=:user_id "
-              "WHERE addresses.id = :addresses_id"):
-             lambda ctx: {'user_id': objects[3].user.id,
-                          'addresses_id': objects[3].id}}])
+                        CompiledSQL("INSERT INTO users (name) VALUES (:name)",
+                         {'name': 'imnewlyadded'} ),
+
+                         AllOf(
+                            CompiledSQL("UPDATE addresses SET email_address=:email_address "
+                                        "WHERE addresses.id = :addresses_id",
+                                        lambda ctx: {'email_address': 'imnew@foo.bar',
+                                          'addresses_id': objects[2].id}),
+                            CompiledSQL("UPDATE addresses SET user_id=:user_id "
+                                          "WHERE addresses.id = :addresses_id",
+                                          lambda ctx: {'user_id': objects[3].user.id,
+                                                       'addresses_id': objects[3].id})
+                        )
+                    )
 
         l = sa.select([users, addresses],
                       sa.and_(users.c.id==addresses.c.user_id,
@@ -1813,44 +1802,40 @@ class ManyToManyTest(_fixtures.FixtureTest):
         k = Keyword()
         k.name = 'yellow'
         objects[5].keywords.append(k)
-        self.assert_sql(testing.db, session.flush, [
-            {"UPDATE items SET description=:description "
-             "WHERE items.id = :items_id":
-             {'description': 'item4updated',
-              'items_id': objects[4].id},
-             "INSERT INTO keywords (name) "
-             "VALUES (:name)":
-             {'name': 'yellow'}},
-            ("INSERT INTO item_keywords (item_id, keyword_id) "
-             "VALUES (:item_id, :keyword_id)",
-             lambda ctx: [{'item_id': objects[5].id,
-                           'keyword_id': k.id}])],
-                        with_sequences = [
-              {"UPDATE items SET description=:description "
-               "WHERE items.id = :items_id":
-               {'description': 'item4updated',
-                'items_id': objects[4].id},
-               "INSERT INTO keywords (id, name) "
-               "VALUES (:id, :name)":
-               lambda ctx: {'name': 'yellow',
-                            'id':ctx.last_inserted_ids()[0]}},
-              ("INSERT INTO item_keywords (item_id, keyword_id) "
-               "VALUES (:item_id, :keyword_id)",
-               lambda ctx: [{'item_id': objects[5].id,
-                             'keyword_id': k.id}])])
+        self.assert_sql_execution(
+            testing.db, 
+            session.flush, 
+            AllOf(
+                CompiledSQL("UPDATE items SET description=:description "
+                 "WHERE items.id = :items_id",
+                     {'description': 'item4updated',
+                      'items_id': objects[4].id},
+                ),
+                CompiledSQL("INSERT INTO keywords (name) "
+                    "VALUES (:name)",
+                    {'name': 'yellow'},
+                )
+            ),
+            CompiledSQL("INSERT INTO item_keywords (item_id, keyword_id) "
+                    "VALUES (:item_id, :keyword_id)",
+                     lambda ctx: [{'item_id': objects[5].id,
+                                   'keyword_id': k.id}])
+            )
 
         objects[2].keywords.append(k)
         dkid = objects[5].keywords[1].id
         del objects[5].keywords[1]
-        self.assert_sql(testing.db, session.flush, [
-            ("DELETE FROM item_keywords "
-             "WHERE item_keywords.item_id = :item_id AND "
-             "item_keywords.keyword_id = :keyword_id",
-             [{'item_id': objects[5].id, 'keyword_id': dkid}]),
-            ("INSERT INTO item_keywords (item_id, keyword_id) "
-             "VALUES (:item_id, :keyword_id)",
-             lambda ctx: [{'item_id': objects[2].id, 'keyword_id': k.id}]
-             )])
+        self.assert_sql_execution(
+            testing.db, 
+            session.flush, 
+            CompiledSQL("DELETE FROM item_keywords "
+                     "WHERE item_keywords.item_id = :item_id AND "
+                     "item_keywords.keyword_id = :keyword_id",
+                     [{'item_id': objects[5].id, 'keyword_id': dkid}]),
+            CompiledSQL("INSERT INTO item_keywords (item_id, keyword_id) "
+                    "VALUES (:item_id, :keyword_id)",
+                    lambda ctx: [{'item_id': objects[2].id, 'keyword_id': k.id}]
+             ))
 
         session.delete(objects[3])
         session.flush()
@@ -1999,33 +1984,20 @@ class SaveTest2(_fixtures.FixtureTest):
 
         session.add_all(fixture())
 
-        self.assert_sql(testing.db, session.flush, [
-            ("INSERT INTO users (name) VALUES (:name)",
+        self.assert_sql_execution(
+            testing.db, 
+            session.flush, 
+            CompiledSQL("INSERT INTO users (name) VALUES (:name)",
              {'name': 'u1'}),
-            ("INSERT INTO users (name) VALUES (:name)",
+            CompiledSQL("INSERT INTO users (name) VALUES (:name)",
              {'name': 'u2'}),
-            ("INSERT INTO addresses (user_id, email_address) "
+            CompiledSQL("INSERT INTO addresses (user_id, email_address) "
              "VALUES (:user_id, :email_address)",
              {'user_id': 1, 'email_address': 'a1'}),
-            ("INSERT INTO addresses (user_id, email_address) "
+            CompiledSQL("INSERT INTO addresses (user_id, email_address) "
              "VALUES (:user_id, :email_address)",
-             {'user_id': 2, 'email_address': 'a2'})],
-            with_sequences = [
-            ("INSERT INTO users (id, name) "
-             "VALUES (:id, :name)",
-             lambda ctx: {'name': 'u1', 'id':ctx.last_inserted_ids()[0]}),
-            ("INSERT INTO users (id, name) "
-             "VALUES (:id, :name)",
-             lambda ctx: {'name': 'u2', 'id':ctx.last_inserted_ids()[0]}),
-            ("INSERT INTO addresses (id, user_id, email_address) "
-             "VALUES (:id, :user_id, :email_address)",
-             lambda ctx:{'user_id': 1, 'email_address': 'a1',
-                         'id':ctx.last_inserted_ids()[0]}),
-            ("INSERT INTO addresses (id, user_id, email_address) "
-             "VALUES (:id, :user_id, :email_address)",
-             lambda ctx:{'user_id': 2, 'email_address': 'a2',
-                         'id':ctx.last_inserted_ids()[0]})])
-
+             {'user_id': 2, 'email_address': 'a2'}),
+        )
 
 class SaveTest3(_base.MappedTest):
     def define_tables(self, metadata):
diff --git a/test/testlib/assertsql.py b/test/testlib/assertsql.py
new file mode 100644 (file)
index 0000000..33a7e5b
--- /dev/null
@@ -0,0 +1,281 @@
+from sqlalchemy.interfaces import ConnectionProxy
+import re
+from sqlalchemy.engine.default import DefaultDialect
+from sqlalchemy.engine.base import Connection
+import testing
+
+class AssertRule(object):
+    def process_execute(self, clauseelement, *multiparams, **params):
+        pass
+
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        pass
+        
+    def is_consumed(self):
+        """Return True if this rule has been consumed, False if not.
+        
+        Should raise an AssertionError if this rule's condition has definitely failed.
+        
+        """
+        raise NotImplementedError()
+    
+    def rule_passed(self):
+        """Return True if the last test of this rule passed, False if failed, None if no test was applied."""
+        
+        raise NotImplementedError()
+        
+    def consume_final(self):
+        """Return True if this rule has been consumed.
+        
+        Should raise an AssertionError if this rule's condition has not been consumed or has failed.
+        
+        """
+        
+        if self._result is None:
+            assert False, "Rule has not been consumed"
+            
+        return self.is_consumed()
+
+class SQLMatchRule(AssertRule):
+    def __init__(self):
+        self._result = None
+        self._errmsg = ""
+    
+    def rule_passed(self):
+        return self._result
+        
+    def is_consumed(self):
+        if self._result is None:
+            return False
+            
+        assert self._result, self._errmsg
+        
+        return True
+    
+class ExactSQL(SQLMatchRule):
+    def __init__(self, sql, params=None):
+        SQLMatchRule.__init__(self)
+        self.sql = sql
+        self.params = params
+    
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        if not context:
+            return
+            
+        _received_statement = _process_engine_statement(statement, context)
+        _received_parameters = context.compiled_parameters
+        
+        # TODO: remove this step once all unit tests
+        # are migrated, as ExactSQL should really be *exact* SQL 
+        sql = _process_assertion_statement(self.sql, context)
+        
+        equivalent = _received_statement == sql
+        if self.params:
+            if callable(self.params):
+                params = self.params(context)
+            else:
+                params = self.params
+
+            if not isinstance(params, list):
+                params = [params]
+            equivalent = equivalent and params == context.compiled_parameters
+        else:
+            params = {}
+        
+        
+        self._result = equivalent
+        if not self._result:
+            self._errmsg = "Testing for exact statement %r exact params %r, " \
+                "received %r with params %r" % (sql, params, _received_statement, _received_parameters)
+    
+
+class RegexSQL(SQLMatchRule):
+    def __init__(self, regex, params=None):
+        SQLMatchRule.__init__(self)
+        self.regex = re.compile(regex)
+        self.orig_regex = regex
+        self.params = params
+
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        if not context:
+            return
+
+        _received_statement = _process_engine_statement(statement, context)
+        _received_parameters = context.compiled_parameters
+
+        equivalent = bool(self.regex.match(_received_statement))
+        if self.params:
+            if callable(self.params):
+                params = self.params(context)
+            else:
+                params = self.params
+
+            if not isinstance(params, list):
+                params = [params]
+            
+            # do a positive compare only
+            for param, received in zip(params, _received_parameters):
+                for k, v in param.iteritems():
+                    if k not in received or received[k] != v:
+                        equivalent = False
+                        break
+        else:
+            params = {}
+
+        self._result = equivalent
+        if not self._result:
+            self._errmsg = "Testing for regex %r partial params %r, "\
+                "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters)
+
+class CompiledSQL(SQLMatchRule):
+    def __init__(self, statement, params):
+        SQLMatchRule.__init__(self)
+        self.statement = statement
+        self.params = params
+
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        if not context:
+            return
+
+        _received_parameters = context.compiled_parameters
+        
+        # recompile from the context, using the default dialect
+        compiled = context.compiled.statement.\
+                compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys)
+                
+        _received_statement = re.sub(r'\n', '', str(compiled))
+        
+        equivalent = self.statement == _received_statement
+        if self.params:
+            if callable(self.params):
+                params = self.params(context)
+            else:
+                params = self.params
+
+            if not isinstance(params, list):
+                params = [params]
+            
+            # do a positive compare only
+            for param, received in zip(params, _received_parameters):
+                for k, v in param.iteritems():
+                    if k not in received or received[k] != v:
+                        equivalent = False
+                        break
+        else:
+            params = {}
+
+        self._result = equivalent
+        if not self._result:
+            self._errmsg = "Testing for compiled statement %r partial params %r, " \
+                    "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters)
+    
+        
+class CountStatements(AssertRule):
+    def __init__(self, count):
+        self.count = count
+        self._statement_count = 0
+        
+    def process_execute(self, clauseelement, *multiparams, **params):
+        self._statement_count += 1
+
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        pass
+
+    def is_consumed(self):
+        return False
+    
+    def consume_final(self):
+        assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count)
+        return True
+        
+class AllOf(AssertRule):
+    def __init__(self, *rules):
+        self.rules = set(rules)
+        
+    def process_execute(self, clauseelement, *multiparams, **params):
+        for rule in self.rules:
+            rule.process_execute(clauseelement, *multiparams, **params)
+
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        for rule in self.rules:
+            rule.process_cursor_execute(statement, parameters, context, executemany)
+
+    def is_consumed(self):
+        if not self.rules:
+            return True
+        
+        for rule in list(self.rules):
+            if rule.rule_passed(): # a rule passed, move on
+                self.rules.remove(rule)
+                return len(self.rules) == 0
+
+        assert False, "No assertion rules were satisfied for statement"
+    
+    def consume_final(self):
+        return len(self.rules) == 0
+        
+def _process_engine_statement(query, context):
+    if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
+        query = query[:-25]
+    
+    query = re.sub(r'\n', '', query)
+    
+    return query
+    
+def _process_assertion_statement(query, context):
+    paramstyle = context.dialect.paramstyle
+    if paramstyle == 'named':
+        pass
+    elif paramstyle =='pyformat':
+        query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+    else:
+        # positional params
+        repl = None
+        if paramstyle=='qmark':
+            repl = "?"
+        elif paramstyle=='format':
+            repl = r"%s"
+        elif paramstyle=='numeric':
+            repl = None
+        query = re.sub(r':([\w_]+)', repl, query)
+
+    return query
+
+class SQLAssert(ConnectionProxy):
+    rules = None
+    
+    def add_rules(self, rules):
+        self.rules = list(rules)
+    
+    def statement_complete(self):
+        for rule in self.rules:
+            if not rule.consume_final():
+                assert False, "All statements are complete, but pending assertion rules remain"
+
+    def clear_rules(self):
+        del self.rules
+        
+    def execute(self, conn, execute, clauseelement, *multiparams, **params):
+        result = execute(clauseelement, *multiparams, **params)
+
+        if self.rules is not None:
+            if not self.rules:
+                assert False, "All rules have been exhausted, but further statements remain"
+            rule = self.rules[0]
+            rule.process_execute(clauseelement, *multiparams, **params)
+            if rule.is_consumed():
+                self.rules.pop(0)
+            
+        return result
+        
+    def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+        result = execute(cursor, statement, parameters, context)
+        
+        if self.rules:
+            rule = self.rules[0]
+            rule.process_cursor_execute(statement, parameters, context, executemany)
+
+        return result
+
+asserter = SQLAssert()
+    
index a6e9fb8b72d1e6c98857ea6e3bb024bc49e23b23..000b188ce26ff94f14ed5710d92f50f5536f6388 100644 (file)
@@ -104,20 +104,18 @@ def testing_engine(url=None, options=None):
     """Produce an engine configured by --options with optional overrides."""
 
     from sqlalchemy import create_engine
-    from testlib.testing import ExecutionContextWrapper
+    from testlib.assertsql import asserter
 
     url = url or config.db_url
     options = options or config.db_opts
 
+    options.setdefault('proxy', asserter)
+    
     listeners = options.setdefault('listeners', [])
     listeners.append(testing_reaper)
 
     engine = create_engine(url, **options)
 
-    create_context = engine.dialect.create_execution_context
-    def create_exec_context(*args, **kwargs):
-        return ExecutionContextWrapper(create_context(*args, **kwargs))
-    engine.dialect.create_execution_context = create_exec_context
     return engine
 
 def utf8_engine(url=None, options=None):
index f6a16a4f8b479d8db0b324dbb7ec0c04c744e482..5f5d323c79a2b6a87765de6191afa773176be5ed 100644 (file)
@@ -13,6 +13,7 @@ from cStringIO import StringIO
 
 import testlib.config as config
 from testlib.compat import _function_named
+from testlib import assertsql
 
 # Delayed imports
 MetaData = None
@@ -459,110 +460,6 @@ def fixture(table, columns, *rows):
                                     for column_values in rows])
     table.append_ddl_listener('after-create', onload)
 
-class TestData(object):
-    """Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
-
-    def __init__(self):
-        self.set_assert_list(None, None)
-        self.sql_count = 0
-        self.buffer = None
-
-    def set_assert_list(self, unittest, list):
-        self.unittest = unittest
-        self.assert_list = list
-        if list is not None:
-            self.assert_list.reverse()
-
-testdata = TestData()
-
-
-class ExecutionContextWrapper(object):
-    """instruments the ExecutionContext created by the Engine so that SQL expressions
-    can be tracked."""
-
-    def __init__(self, ctx):
-        self.__dict__['ctx'] = ctx
-    def __getattr__(self, key):
-        return getattr(self.ctx, key)
-    def __setattr__(self, key, value):
-        setattr(self.ctx, key, value)
-
-    trailing_underscore_pattern = re.compile(r'(\W:[\w_#]+)_\b',re.MULTILINE)
-    def post_exec(self):
-        ctx = self.ctx
-        statement = unicode(ctx.compiled)
-        statement = re.sub(r'\n', '', ctx.statement)
-        if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'):
-            statement = statement[:-25]
-        if testdata.buffer is not None:
-            testdata.buffer.write(statement + "\n")
-
-        if testdata.assert_list is not None:
-            assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement
-            item = testdata.assert_list[-1]
-            if not isinstance(item, dict):
-                item = testdata.assert_list.pop()
-            else:
-                # asserting a dictionary of statements->parameters
-                # this is to specify query assertions where the queries can be in
-                # multiple orderings
-                if '_converted' not in item:
-                    for key in item.keys():
-                        ckey = self.convert_statement(key)
-                        item[ckey] = item[key]
-                        if ckey != key:
-                            del item[key]
-                    item['_converted'] = True
-                try:
-                    entry = item.pop(statement)
-                    if len(item) == 1:
-                        testdata.assert_list.pop()
-                    item = (statement, entry)
-                except KeyError:
-                    assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)
-
-            (query, params) = item
-            if callable(params):
-                params = params(ctx)
-            if params is not None and not isinstance(params, list):
-                params = [params]
-
-            parameters = ctx.compiled_parameters
-
-            query = self.convert_statement(query)
-            equivalent = ( (statement == query)
-                           or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
-                         ) \
-                         and \
-                         ( (params is None) or (params == parameters)
-                           or params == [dict([(k.strip('_'), v)
-                                               for (k, v) in p.items()])
-                                         for p in parameters]
-                         )
-            testdata.unittest.assert_(equivalent,
-                    "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
-        testdata.sql_count += 1
-        self.ctx.post_exec()
-
-    def convert_statement(self, query):
-        paramstyle = self.ctx.dialect.paramstyle
-        if paramstyle == 'named':
-            pass
-        elif paramstyle =='pyformat':
-            query = re.sub(r':([\w_]+)', r"%(\1)s", query)
-        else:
-            # positional params
-            repl = None
-            if paramstyle=='qmark':
-                repl = "?"
-            elif paramstyle=='format':
-                repl = r"%s"
-            elif paramstyle=='numeric':
-                repl = None
-            query = re.sub(r':([\w_]+)', repl, query)
-        return query
-
-
 def _import_by_name(name):
     submodule = name.split('.')[-1]
     return __import__(name, globals(), locals(), [submodule])
@@ -827,36 +724,35 @@ class AssertsExecutionResults(object):
                     cls.__name__, repr(expected_item)))
         return True
 
-    def assert_sql(self, db, callable_, list, with_sequences=None):
-        global testdata
-        testdata = TestData()
-        if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
-            testdata.set_assert_list(self, with_sequences)
-        else:
-            testdata.set_assert_list(self, list)
+    def assert_sql_execution(self, db, callable_, *rules):
+        assertsql.asserter.add_rules(rules)
         try:
             callable_()
+            assertsql.asserter.statement_complete()
         finally:
-            testdata.set_assert_list(None, None)
+            assertsql.asserter.clear_rules()
+            
+    def assert_sql(self, db, callable_, list_, with_sequences=None):
+        if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
+            rules = with_sequences
+        else:
+            rules = list_
+        
+        newrules = []
+        for rule in rules:
+            if isinstance(rule, dict):
+                newrule = assertsql.AllOf(*[
+                    assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
+                ])
+            else:
+                newrule = assertsql.ExactSQL(*rule)
+            newrules.append(newrule)
+            
+        self.assert_sql_execution(db, callable_, *newrules)
 
     def assert_sql_count(self, db, callable_, count):
-        global testdata
-        testdata = TestData()
-        callable_()
-        self.assert_(testdata.sql_count == count,
-                     "desired statement count %d does not match %d" % (
-                       count, testdata.sql_count))
-
-    def capture_sql(self, db, callable_):
-        global testdata
-        testdata = TestData()
-        buffer = StringIO()
-        testdata.buffer = buffer
-        try:
-            callable_()
-            return buffer.getvalue()
-        finally:
-            testdata.buffer = None
+        self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
+
 
 _otest_metadata = None
 class ORMTest(TestBase, AssertsExecutionResults):