From a2f90fd0038ab6586022762467ed95a84e62e224 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 10 Dec 2008 02:16:52 +0000 Subject: [PATCH] - reworked the "SQL assertion" code to something more flexible and based off of ConnectionProxy. upcoming changes to dependency.py will make use of the enhanced flexibility. --- test/orm/cycles.py | 200 ++++++++----------------- test/orm/eager_relations.py | 18 +-- test/orm/mapper.py | 16 +- test/orm/unitofwork.py | 138 +++++++----------- test/testlib/assertsql.py | 281 ++++++++++++++++++++++++++++++++++++ test/testlib/engines.py | 8 +- test/testlib/testing.py | 154 ++++---------------- 7 files changed, 438 insertions(+), 377 deletions(-) create mode 100644 test/testlib/assertsql.py diff --git a/test/orm/cycles.py b/test/orm/cycles.py index c92228def7..2aafc5c54a 100644 --- a/test/orm/cycles.py +++ b/test/orm/cycles.py @@ -10,6 +10,7 @@ from testlib import testing from testlib.sa import Table, Column, Integer, String, ForeignKey from testlib.sa.orm import mapper, relation, backref, create_session from testlib.testing import eq_ +from testlib.assertsql import RegexSQL, ExactSQL, CompiledSQL from orm import _base @@ -558,82 +559,31 @@ class OneToManyManyToOneTest(_base.MappedTest): sess.add(b) sess.add(p) - self.assert_sql(testing.db, sess.flush, [ - ("INSERT INTO person (favorite_ball_id, data) " - "VALUES (:favorite_ball_id, :data)", - {'favorite_ball_id': None, 'data':'some data'}), - - ("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - lambda ctx:{'person_id':p.id, 'data':'some data'}), - - ("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - lambda ctx:{'person_id':p.id, 'data':'some data'}), - - ("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - lambda ctx:{'person_id':p.id, 'data':'some data'}), - - ("INSERT INTO ball (person_id, data) " - "VALUES (:person_id, :data)", - lambda ctx:{'person_id':p.id, 'data':'some data'}), - - ("UPDATE person SET favorite_ball_id=:favorite_ball_id " - "WHERE person.id = :person_id", - lambda ctx:{'favorite_ball_id':p.favorite.id, 'person_id':p.id}) - ], - - with_sequences= [ - ("INSERT INTO person (id, favorite_ball_id, data) " - "VALUES (:id, :favorite_ball_id, :data)", - lambda ctx: {'id':ctx.last_inserted_ids()[0], - 'favorite_ball_id': None, - 'data':'some data'}), - - ("INSERT INTO ball (id, person_id, data) " - "VALUES (:id, :person_id, :data)", - lambda ctx: {'id':ctx.last_inserted_ids()[0], - 'person_id':p.id, - 'data':'some data'}), - - ("INSERT INTO ball (id, person_id, data) " - "VALUES (:id, :person_id, :data)", - lambda ctx: {'id':ctx.last_inserted_ids()[0], - 'person_id':p.id, - 'data':'some data'}), - - ("INSERT INTO ball (id, person_id, data) " - "VALUES (:id, :person_id, :data)", - lambda ctx: {'id':ctx.last_inserted_ids()[0], - 'person_id':p.id, - 'data':'some data'}), - ("INSERT INTO ball (id, person_id, data) " - "VALUES (:id, :person_id, :data)", - lambda ctx: {'id':ctx.last_inserted_ids()[0], - 'person_id':p.id, - 'data':'some data'}), - # heres the post update - ("UPDATE person SET favorite_ball_id=:favorite_ball_id " - "WHERE person.id = :person_id", - lambda ctx: {'favorite_ball_id':p.favorite.id, 'person_id':p.id})]) + self.assert_sql_execution( + testing.db, + sess.flush, + RegexSQL("^INSERT INTO person", {'data':'some data'}), + RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), + RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), + RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), + RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}), + ExactSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id " + "WHERE person.id = :person_id", + lambda ctx:{'favorite_ball_id':p.favorite.id, 'person_id':p.id} + ), + ) sess.delete(p) - self.assert_sql(testing.db, sess.flush, [ - # heres the post update (which is a pre-update with deletes) - ("UPDATE person SET favorite_ball_id=:favorite_ball_id " - "WHERE person.id = :person_id", - lambda ctx: {'person_id': p.id, 'favorite_ball_id': None}), - - ("DELETE FROM ball WHERE ball.id = :id", - None), - # order cant be predicted, but something like: - #lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}]), - - ("DELETE FROM person WHERE person.id = :id", - lambda ctx:[{'id': p.id}])]) - + self.assert_sql_execution( + testing.db, + sess.flush, + ExactSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id " + "WHERE person.id = :person_id", + lambda ctx: {'person_id': p.id, 'favorite_ball_id': None}), + ExactSQL("DELETE FROM ball WHERE ball.id = :id", None), # lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}]) + ExactSQL("DELETE FROM person WHERE person.id = :id", lambda ctx:[{'id': p.id}]) + ) @testing.resolve_artifact_names def testpostupdate_o2m(self): @@ -664,112 +614,74 @@ class OneToManyManyToOneTest(_base.MappedTest): sess = create_session() sess.add_all((b,p,b2,b3,b4)) - self.assert_sql(testing.db, sess.flush, [ - ("INSERT INTO ball (person_id, data) " + self.assert_sql_execution( + testing.db, + sess.flush, + CompiledSQL("INSERT INTO ball (person_id, data) " "VALUES (:person_id, :data)", {'person_id':None, 'data':'some data'}), - ("INSERT INTO ball (person_id, data) " + CompiledSQL("INSERT INTO ball (person_id, data) " "VALUES (:person_id, :data)", {'person_id':None, 'data':'some data'}), - ("INSERT INTO ball (person_id, data) " + CompiledSQL("INSERT INTO ball (person_id, data) " "VALUES (:person_id, :data)", {'person_id':None, 'data':'some data'}), - ("INSERT INTO ball (person_id, data) " + CompiledSQL("INSERT INTO ball (person_id, data) " "VALUES (:person_id, :data)", {'person_id':None, 'data':'some data'}), - ("INSERT INTO person (favorite_ball_id, data) " + CompiledSQL("INSERT INTO person (favorite_ball_id, data) " "VALUES (:favorite_ball_id, :data)", lambda ctx:{'favorite_ball_id':b.id, 'data':'some data'}), - # heres the post update on each one-to-many item - ("UPDATE ball SET person_id=:person_id " + CompiledSQL("UPDATE ball SET person_id=:person_id " "WHERE ball.id = :ball_id", lambda ctx:{'person_id':p.id,'ball_id':b.id}), - ("UPDATE ball SET person_id=:person_id " + CompiledSQL("UPDATE ball SET person_id=:person_id " "WHERE ball.id = :ball_id", lambda ctx:{'person_id':p.id,'ball_id':b2.id}), - ("UPDATE ball SET person_id=:person_id " + CompiledSQL("UPDATE ball SET person_id=:person_id " "WHERE ball.id = :ball_id", lambda ctx:{'person_id':p.id,'ball_id':b3.id}), - ("UPDATE ball SET person_id=:person_id " + CompiledSQL("UPDATE ball SET person_id=:person_id " "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b4.id})], - - with_sequences=[ - ("INSERT INTO ball (id, person_id, data) " - "VALUES (:id, :person_id, :data)", - lambda ctx: {'id':ctx.last_inserted_ids()[0], - 'person_id':None, - 'data':'some data'}), - ("INSERT INTO ball (id, person_id, data) " - "VALUES (:id, :person_id, :data)", - lambda ctx:{'id':ctx.last_inserted_ids()[0], - 'person_id':None, - 'data':'some data'}), - ("INSERT INTO ball (id, person_id, data) " - "VALUES (:id, :person_id, :data)", - lambda ctx:{'id':ctx.last_inserted_ids()[0], - 'person_id':None, - 'data':'some data'}), - ("INSERT INTO ball (id, person_id, data) " - "VALUES (:id, :person_id, :data)", - lambda ctx:{'id':ctx.last_inserted_ids()[0], - 'person_id':None, - 'data':'some data'}), - ("INSERT INTO person (id, favorite_ball_id, data) " - "VALUES (:id, :favorite_ball_id, :data)", - lambda ctx:{'id':ctx.last_inserted_ids()[0], - 'favorite_ball_id':b.id, - 'data':'some data'}), - ("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b.id}), - - ("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b2.id}), - - ("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b3.id}), - - ("UPDATE ball SET person_id=:person_id " - "WHERE ball.id = :ball_id", - lambda ctx:{'person_id':p.id,'ball_id':b4.id})]) - + lambda ctx:{'person_id':p.id,'ball_id':b4.id}), + ) + sess.delete(p) - self.assert_sql(testing.db, sess.flush, [ - ("UPDATE ball SET person_id=:person_id " + + self.assert_sql_execution(testing.db, sess.flush, + CompiledSQL("UPDATE ball SET person_id=:person_id " "WHERE ball.id = :ball_id", lambda ctx:{'person_id': None, 'ball_id': b.id}), - ("UPDATE ball SET person_id=:person_id " + CompiledSQL("UPDATE ball SET person_id=:person_id " "WHERE ball.id = :ball_id", lambda ctx:{'person_id': None, 'ball_id': b2.id}), - ("UPDATE ball SET person_id=:person_id " + CompiledSQL("UPDATE ball SET person_id=:person_id " "WHERE ball.id = :ball_id", lambda ctx:{'person_id': None, 'ball_id': b3.id}), - ("UPDATE ball SET person_id=:person_id " + CompiledSQL("UPDATE ball SET person_id=:person_id " "WHERE ball.id = :ball_id", lambda ctx:{'person_id': None, 'ball_id': b4.id}), - ("DELETE FROM person WHERE person.id = :id", + CompiledSQL("DELETE FROM person WHERE person.id = :id", lambda ctx:[{'id':p.id}]), - ("DELETE FROM ball WHERE ball.id = :id", + CompiledSQL("DELETE FROM ball WHERE ball.id = :id", lambda ctx:[{'id': b.id}, {'id': b2.id}, {'id': b3.id}, - {'id': b4.id}])]) + {'id': b4.id}]) + ) class SelfReferentialPostUpdateTest(_base.MappedTest): @@ -859,20 +771,24 @@ class SelfReferentialPostUpdateTest(_base.MappedTest): remove_child(root, cats) # pre-trigger lazy loader on 'cats' to make the test easier cats.children - self.assert_sql(testing.db, lambda: session.flush(), [ - ("UPDATE node SET prev_sibling_id=:prev_sibling_id " + self.assert_sql_execution( + testing.db, + session.flush, + CompiledSQL("UPDATE node SET prev_sibling_id=:prev_sibling_id " "WHERE node.id = :node_id", lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id}), - ("UPDATE node SET next_sibling_id=:next_sibling_id " + CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " "WHERE node.id = :node_id", lambda ctx:{'next_sibling_id':stories.id, 'node_id':about.id}), - ("UPDATE node SET next_sibling_id=:next_sibling_id " + CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id " "WHERE node.id = :node_id", lambda ctx:{'next_sibling_id':None, 'node_id':cats.id}), - ("DELETE FROM node WHERE node.id = :id", - lambda ctx:[{'id':cats.id}])]) + + CompiledSQL("DELETE FROM node WHERE node.id = :id", + lambda ctx:[{'id':cats.id}]) + ) class SelfReferentialPostUpdateTest2(_base.MappedTest): diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index 94fabd0b11..d704f99925 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import eagerload, deferred, undefer from testlib.sa import Table, Column, Integer, String, ForeignKey, and_ from testlib.sa.orm import mapper, relation, create_session, lazyload from testlib.testing import eq_ +from testlib.assertsql import CompiledSQL from orm import _base, _fixtures class EagerTest(_fixtures.FixtureTest): @@ -1020,16 +1021,13 @@ class SelfReferentialEagerTest(_base.MappedTest): d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first() # test that the query isn't wrapping the initial query for eager loading. - # testing only sqlite for now since the query text is slightly different on other - # dialects - if testing.against('sqlite'): - self.assert_sql(testing.db, go, [ - ( - "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes " - "WHERE nodes.data = :data_1 ORDER BY nodes.id LIMIT 1 OFFSET 0", - {'data_1': 'n1'} - ), - ]) + self.assert_sql_execution(testing.db, go, + CompiledSQL( + "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes " + "WHERE nodes.data = :data_1 ORDER BY nodes.id LIMIT 1 OFFSET 0", + {'data_1': 'n1'} + ) + ) @testing.fails_on('maxdb') @testing.resolve_artifact_names diff --git a/test/orm/mapper.py b/test/orm/mapper.py index b18d8ac001..7e2d3b5958 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -906,12 +906,9 @@ class OptionsTest(_fixtures.FixtureTest): sess.clear() - # test that eager loading doesnt modify parent mapper - def go(): - u = sess.query(User).filter_by(id=8).one() - eq_(u.id, 8) - eq_(len(u.addresses), 3) - assert "tbl_row_count" not in self.capture_sql(testing.db, go) + u = sess.query(User).filter_by(id=8).one() + eq_(u.id, 8) + eq_(len(u.addresses), 3) @testing.fails_on('maxdb') @testing.resolve_artifact_names @@ -1396,7 +1393,10 @@ class DeferredTest(_fixtures.FixtureTest): sess = create_session() q = sess.query(Order).order_by(Order.id).options(defer('user_id')) - self.sql_eq_(q.all, [ + def go(): + q.all()[0].user_id + + self.sql_eq_(go, [ ("SELECT orders.id AS orders_id, " "orders.address_id AS orders_address_id, " "orders.description AS orders_description, " @@ -1404,7 +1404,7 @@ class DeferredTest(_fixtures.FixtureTest): "FROM orders ORDER BY orders.id", {}), ("SELECT orders.user_id AS orders_user_id " "FROM orders WHERE orders.id = :param_1", - {'param_1':3})]) + {'param_1':1})]) sess.clear() q2 = q.options(sa.orm.undefer('user_id')) diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 975743fb0d..627e7cb99b 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -13,7 +13,7 @@ from testlib.testing import eq_, ne_ from orm import _base, _fixtures from engine import _base as engine_base import pickleable - +from testlib.assertsql import AllOf, CompiledSQL class UnitOfWorkTest(object): pass @@ -1602,33 +1602,22 @@ class ManyToOneTest(_fixtures.FixtureTest): objects[2].email_address = 'imnew@foo.bar' objects[3].user = User() objects[3].user.name = 'imnewlyadded' - self.assert_sql(testing.db, + self.assert_sql_execution(testing.db, session.flush, - [ - ("INSERT INTO users (name) VALUES (:name)", - {'name': 'imnewlyadded'} ), - - {"UPDATE addresses SET email_address=:email_address " - "WHERE addresses.id = :addresses_id": - lambda ctx: {'email_address': 'imnew@foo.bar', - 'addresses_id': objects[2].id}, - "UPDATE addresses SET user_id=:user_id " - "WHERE addresses.id = :addresses_id": - lambda ctx: {'user_id': objects[3].user.id, - 'addresses_id': objects[3].id}}, - ], - with_sequences=[ - ("INSERT INTO users (id, name) VALUES (:id, :name)", - lambda ctx:{'name': 'imnewlyadded', - 'id':ctx.last_inserted_ids()[0]}), - {"UPDATE addresses SET email_address=:email_address " - "WHERE addresses.id = :addresses_id": - lambda ctx: {'email_address': 'imnew@foo.bar', - 'addresses_id': objects[2].id}, - ("UPDATE addresses SET user_id=:user_id " - "WHERE addresses.id = :addresses_id"): - lambda ctx: {'user_id': objects[3].user.id, - 'addresses_id': objects[3].id}}]) + CompiledSQL("INSERT INTO users (name) VALUES (:name)", + {'name': 'imnewlyadded'} ), + + AllOf( + CompiledSQL("UPDATE addresses SET email_address=:email_address " + "WHERE addresses.id = :addresses_id", + lambda ctx: {'email_address': 'imnew@foo.bar', + 'addresses_id': objects[2].id}), + CompiledSQL("UPDATE addresses SET user_id=:user_id " + "WHERE addresses.id = :addresses_id", + lambda ctx: {'user_id': objects[3].user.id, + 'addresses_id': objects[3].id}) + ) + ) l = sa.select([users, addresses], sa.and_(users.c.id==addresses.c.user_id, @@ -1813,44 +1802,40 @@ class ManyToManyTest(_fixtures.FixtureTest): k = Keyword() k.name = 'yellow' objects[5].keywords.append(k) - self.assert_sql(testing.db, session.flush, [ - {"UPDATE items SET description=:description " - "WHERE items.id = :items_id": - {'description': 'item4updated', - 'items_id': objects[4].id}, - "INSERT INTO keywords (name) " - "VALUES (:name)": - {'name': 'yellow'}}, - ("INSERT INTO item_keywords (item_id, keyword_id) " - "VALUES (:item_id, :keyword_id)", - lambda ctx: [{'item_id': objects[5].id, - 'keyword_id': k.id}])], - with_sequences = [ - {"UPDATE items SET description=:description " - "WHERE items.id = :items_id": - {'description': 'item4updated', - 'items_id': objects[4].id}, - "INSERT INTO keywords (id, name) " - "VALUES (:id, :name)": - lambda ctx: {'name': 'yellow', - 'id':ctx.last_inserted_ids()[0]}}, - ("INSERT INTO item_keywords (item_id, keyword_id) " - "VALUES (:item_id, :keyword_id)", - lambda ctx: [{'item_id': objects[5].id, - 'keyword_id': k.id}])]) + self.assert_sql_execution( + testing.db, + session.flush, + AllOf( + CompiledSQL("UPDATE items SET description=:description " + "WHERE items.id = :items_id", + {'description': 'item4updated', + 'items_id': objects[4].id}, + ), + CompiledSQL("INSERT INTO keywords (name) " + "VALUES (:name)", + {'name': 'yellow'}, + ) + ), + CompiledSQL("INSERT INTO item_keywords (item_id, keyword_id) " + "VALUES (:item_id, :keyword_id)", + lambda ctx: [{'item_id': objects[5].id, + 'keyword_id': k.id}]) + ) objects[2].keywords.append(k) dkid = objects[5].keywords[1].id del objects[5].keywords[1] - self.assert_sql(testing.db, session.flush, [ - ("DELETE FROM item_keywords " - "WHERE item_keywords.item_id = :item_id AND " - "item_keywords.keyword_id = :keyword_id", - [{'item_id': objects[5].id, 'keyword_id': dkid}]), - ("INSERT INTO item_keywords (item_id, keyword_id) " - "VALUES (:item_id, :keyword_id)", - lambda ctx: [{'item_id': objects[2].id, 'keyword_id': k.id}] - )]) + self.assert_sql_execution( + testing.db, + session.flush, + CompiledSQL("DELETE FROM item_keywords " + "WHERE item_keywords.item_id = :item_id AND " + "item_keywords.keyword_id = :keyword_id", + [{'item_id': objects[5].id, 'keyword_id': dkid}]), + CompiledSQL("INSERT INTO item_keywords (item_id, keyword_id) " + "VALUES (:item_id, :keyword_id)", + lambda ctx: [{'item_id': objects[2].id, 'keyword_id': k.id}] + )) session.delete(objects[3]) session.flush() @@ -1999,33 +1984,20 @@ class SaveTest2(_fixtures.FixtureTest): session.add_all(fixture()) - self.assert_sql(testing.db, session.flush, [ - ("INSERT INTO users (name) VALUES (:name)", + self.assert_sql_execution( + testing.db, + session.flush, + CompiledSQL("INSERT INTO users (name) VALUES (:name)", {'name': 'u1'}), - ("INSERT INTO users (name) VALUES (:name)", + CompiledSQL("INSERT INTO users (name) VALUES (:name)", {'name': 'u2'}), - ("INSERT INTO addresses (user_id, email_address) " + CompiledSQL("INSERT INTO addresses (user_id, email_address) " "VALUES (:user_id, :email_address)", {'user_id': 1, 'email_address': 'a1'}), - ("INSERT INTO addresses (user_id, email_address) " + CompiledSQL("INSERT INTO addresses (user_id, email_address) " "VALUES (:user_id, :email_address)", - {'user_id': 2, 'email_address': 'a2'})], - with_sequences = [ - ("INSERT INTO users (id, name) " - "VALUES (:id, :name)", - lambda ctx: {'name': 'u1', 'id':ctx.last_inserted_ids()[0]}), - ("INSERT INTO users (id, name) " - "VALUES (:id, :name)", - lambda ctx: {'name': 'u2', 'id':ctx.last_inserted_ids()[0]}), - ("INSERT INTO addresses (id, user_id, email_address) " - "VALUES (:id, :user_id, :email_address)", - lambda ctx:{'user_id': 1, 'email_address': 'a1', - 'id':ctx.last_inserted_ids()[0]}), - ("INSERT INTO addresses (id, user_id, email_address) " - "VALUES (:id, :user_id, :email_address)", - lambda ctx:{'user_id': 2, 'email_address': 'a2', - 'id':ctx.last_inserted_ids()[0]})]) - + {'user_id': 2, 'email_address': 'a2'}), + ) class SaveTest3(_base.MappedTest): def define_tables(self, metadata): diff --git a/test/testlib/assertsql.py b/test/testlib/assertsql.py new file mode 100644 index 0000000000..33a7e5b64f --- /dev/null +++ b/test/testlib/assertsql.py @@ -0,0 +1,281 @@ +from sqlalchemy.interfaces import ConnectionProxy +import re +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.engine.base import Connection +import testing + +class AssertRule(object): + def process_execute(self, clauseelement, *multiparams, **params): + pass + + def process_cursor_execute(self, statement, parameters, context, executemany): + pass + + def is_consumed(self): + """Return True if this rule has been consumed, False if not. + + Should raise an AssertionError if this rule's condition has definitely failed. + + """ + raise NotImplementedError() + + def rule_passed(self): + """Return True if the last test of this rule passed, False if failed, None if no test was applied.""" + + raise NotImplementedError() + + def consume_final(self): + """Return True if this rule has been consumed. + + Should raise an AssertionError if this rule's condition has not been consumed or has failed. + + """ + + if self._result is None: + assert False, "Rule has not been consumed" + + return self.is_consumed() + +class SQLMatchRule(AssertRule): + def __init__(self): + self._result = None + self._errmsg = "" + + def rule_passed(self): + return self._result + + def is_consumed(self): + if self._result is None: + return False + + assert self._result, self._errmsg + + return True + +class ExactSQL(SQLMatchRule): + def __init__(self, sql, params=None): + SQLMatchRule.__init__(self) + self.sql = sql + self.params = params + + def process_cursor_execute(self, statement, parameters, context, executemany): + if not context: + return + + _received_statement = _process_engine_statement(statement, context) + _received_parameters = context.compiled_parameters + + # TODO: remove this step once all unit tests + # are migrated, as ExactSQL should really be *exact* SQL + sql = _process_assertion_statement(self.sql, context) + + equivalent = _received_statement == sql + if self.params: + if callable(self.params): + params = self.params(context) + else: + params = self.params + + if not isinstance(params, list): + params = [params] + equivalent = equivalent and params == context.compiled_parameters + else: + params = {} + + + self._result = equivalent + if not self._result: + self._errmsg = "Testing for exact statement %r exact params %r, " \ + "received %r with params %r" % (sql, params, _received_statement, _received_parameters) + + +class RegexSQL(SQLMatchRule): + def __init__(self, regex, params=None): + SQLMatchRule.__init__(self) + self.regex = re.compile(regex) + self.orig_regex = regex + self.params = params + + def process_cursor_execute(self, statement, parameters, context, executemany): + if not context: + return + + _received_statement = _process_engine_statement(statement, context) + _received_parameters = context.compiled_parameters + + equivalent = bool(self.regex.match(_received_statement)) + if self.params: + if callable(self.params): + params = self.params(context) + else: + params = self.params + + if not isinstance(params, list): + params = [params] + + # do a positive compare only + for param, received in zip(params, _received_parameters): + for k, v in param.iteritems(): + if k not in received or received[k] != v: + equivalent = False + break + else: + params = {} + + self._result = equivalent + if not self._result: + self._errmsg = "Testing for regex %r partial params %r, "\ + "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters) + +class CompiledSQL(SQLMatchRule): + def __init__(self, statement, params): + SQLMatchRule.__init__(self) + self.statement = statement + self.params = params + + def process_cursor_execute(self, statement, parameters, context, executemany): + if not context: + return + + _received_parameters = context.compiled_parameters + + # recompile from the context, using the default dialect + compiled = context.compiled.statement.\ + compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys) + + _received_statement = re.sub(r'\n', '', str(compiled)) + + equivalent = self.statement == _received_statement + if self.params: + if callable(self.params): + params = self.params(context) + else: + params = self.params + + if not isinstance(params, list): + params = [params] + + # do a positive compare only + for param, received in zip(params, _received_parameters): + for k, v in param.iteritems(): + if k not in received or received[k] != v: + equivalent = False + break + else: + params = {} + + self._result = equivalent + if not self._result: + self._errmsg = "Testing for compiled statement %r partial params %r, " \ + "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters) + + +class CountStatements(AssertRule): + def __init__(self, count): + self.count = count + self._statement_count = 0 + + def process_execute(self, clauseelement, *multiparams, **params): + self._statement_count += 1 + + def process_cursor_execute(self, statement, parameters, context, executemany): + pass + + def is_consumed(self): + return False + + def consume_final(self): + assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count) + return True + +class AllOf(AssertRule): + def __init__(self, *rules): + self.rules = set(rules) + + def process_execute(self, clauseelement, *multiparams, **params): + for rule in self.rules: + rule.process_execute(clauseelement, *multiparams, **params) + + def process_cursor_execute(self, statement, parameters, context, executemany): + for rule in self.rules: + rule.process_cursor_execute(statement, parameters, context, executemany) + + def is_consumed(self): + if not self.rules: + return True + + for rule in list(self.rules): + if rule.rule_passed(): # a rule passed, move on + self.rules.remove(rule) + return len(self.rules) == 0 + + assert False, "No assertion rules were satisfied for statement" + + def consume_final(self): + return len(self.rules) == 0 + +def _process_engine_statement(query, context): + if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'): + query = query[:-25] + + query = re.sub(r'\n', '', query) + + return query + +def _process_assertion_statement(query, context): + paramstyle = context.dialect.paramstyle + if paramstyle == 'named': + pass + elif paramstyle =='pyformat': + query = re.sub(r':([\w_]+)', r"%(\1)s", query) + else: + # positional params + repl = None + if paramstyle=='qmark': + repl = "?" + elif paramstyle=='format': + repl = r"%s" + elif paramstyle=='numeric': + repl = None + query = re.sub(r':([\w_]+)', repl, query) + + return query + +class SQLAssert(ConnectionProxy): + rules = None + + def add_rules(self, rules): + self.rules = list(rules) + + def statement_complete(self): + for rule in self.rules: + if not rule.consume_final(): + assert False, "All statements are complete, but pending assertion rules remain" + + def clear_rules(self): + del self.rules + + def execute(self, conn, execute, clauseelement, *multiparams, **params): + result = execute(clauseelement, *multiparams, **params) + + if self.rules is not None: + if not self.rules: + assert False, "All rules have been exhausted, but further statements remain" + rule = self.rules[0] + rule.process_execute(clauseelement, *multiparams, **params) + if rule.is_consumed(): + self.rules.pop(0) + + return result + + def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + result = execute(cursor, statement, parameters, context) + + if self.rules: + rule = self.rules[0] + rule.process_cursor_execute(statement, parameters, context, executemany) + + return result + +asserter = SQLAssert() + diff --git a/test/testlib/engines.py b/test/testlib/engines.py index a6e9fb8b72..000b188ce2 100644 --- a/test/testlib/engines.py +++ b/test/testlib/engines.py @@ -104,20 +104,18 @@ def testing_engine(url=None, options=None): """Produce an engine configured by --options with optional overrides.""" from sqlalchemy import create_engine - from testlib.testing import ExecutionContextWrapper + from testlib.assertsql import asserter url = url or config.db_url options = options or config.db_opts + options.setdefault('proxy', asserter) + listeners = options.setdefault('listeners', []) listeners.append(testing_reaper) engine = create_engine(url, **options) - create_context = engine.dialect.create_execution_context - def create_exec_context(*args, **kwargs): - return ExecutionContextWrapper(create_context(*args, **kwargs)) - engine.dialect.create_execution_context = create_exec_context return engine def utf8_engine(url=None, options=None): diff --git a/test/testlib/testing.py b/test/testlib/testing.py index f6a16a4f8b..5f5d323c79 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -13,6 +13,7 @@ from cStringIO import StringIO import testlib.config as config from testlib.compat import _function_named +from testlib import assertsql # Delayed imports MetaData = None @@ -459,110 +460,6 @@ def fixture(table, columns, *rows): for column_values in rows]) table.append_ddl_listener('after-create', onload) -class TestData(object): - """Tracks SQL expressions as they are executed via an instrumented ExecutionContext.""" - - def __init__(self): - self.set_assert_list(None, None) - self.sql_count = 0 - self.buffer = None - - def set_assert_list(self, unittest, list): - self.unittest = unittest - self.assert_list = list - if list is not None: - self.assert_list.reverse() - -testdata = TestData() - - -class ExecutionContextWrapper(object): - """instruments the ExecutionContext created by the Engine so that SQL expressions - can be tracked.""" - - def __init__(self, ctx): - self.__dict__['ctx'] = ctx - def __getattr__(self, key): - return getattr(self.ctx, key) - def __setattr__(self, key, value): - setattr(self.ctx, key, value) - - trailing_underscore_pattern = re.compile(r'(\W:[\w_#]+)_\b',re.MULTILINE) - def post_exec(self): - ctx = self.ctx - statement = unicode(ctx.compiled) - statement = re.sub(r'\n', '', ctx.statement) - if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'): - statement = statement[:-25] - if testdata.buffer is not None: - testdata.buffer.write(statement + "\n") - - if testdata.assert_list is not None: - assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement - item = testdata.assert_list[-1] - if not isinstance(item, dict): - item = testdata.assert_list.pop() - else: - # asserting a dictionary of statements->parameters - # this is to specify query assertions where the queries can be in - # multiple orderings - if '_converted' not in item: - for key in item.keys(): - ckey = self.convert_statement(key) - item[ckey] = item[key] - if ckey != key: - del item[key] - item['_converted'] = True - try: - entry = item.pop(statement) - if len(item) == 1: - testdata.assert_list.pop() - item = (statement, entry) - except KeyError: - assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement) - - (query, params) = item - if callable(params): - params = params(ctx) - if params is not None and not isinstance(params, list): - params = [params] - - parameters = ctx.compiled_parameters - - query = self.convert_statement(query) - equivalent = ( (statement == query) - or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) ) - ) \ - and \ - ( (params is None) or (params == parameters) - or params == [dict([(k.strip('_'), v) - for (k, v) in p.items()]) - for p in parameters] - ) - testdata.unittest.assert_(equivalent, - "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) - testdata.sql_count += 1 - self.ctx.post_exec() - - def convert_statement(self, query): - paramstyle = self.ctx.dialect.paramstyle - if paramstyle == 'named': - pass - elif paramstyle =='pyformat': - query = re.sub(r':([\w_]+)', r"%(\1)s", query) - else: - # positional params - repl = None - if paramstyle=='qmark': - repl = "?" - elif paramstyle=='format': - repl = r"%s" - elif paramstyle=='numeric': - repl = None - query = re.sub(r':([\w_]+)', repl, query) - return query - - def _import_by_name(name): submodule = name.split('.')[-1] return __import__(name, globals(), locals(), [submodule]) @@ -827,36 +724,35 @@ class AssertsExecutionResults(object): cls.__name__, repr(expected_item))) return True - def assert_sql(self, db, callable_, list, with_sequences=None): - global testdata - testdata = TestData() - if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): - testdata.set_assert_list(self, with_sequences) - else: - testdata.set_assert_list(self, list) + def assert_sql_execution(self, db, callable_, *rules): + assertsql.asserter.add_rules(rules) try: callable_() + assertsql.asserter.statement_complete() finally: - testdata.set_assert_list(None, None) + assertsql.asserter.clear_rules() + + def assert_sql(self, db, callable_, list_, with_sequences=None): + if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): + rules = with_sequences + else: + rules = list_ + + newrules = [] + for rule in rules: + if isinstance(rule, dict): + newrule = assertsql.AllOf(*[ + assertsql.ExactSQL(k, v) for k, v in rule.iteritems() + ]) + else: + newrule = assertsql.ExactSQL(*rule) + newrules.append(newrule) + + self.assert_sql_execution(db, callable_, *newrules) def assert_sql_count(self, db, callable_, count): - global testdata - testdata = TestData() - callable_() - self.assert_(testdata.sql_count == count, - "desired statement count %d does not match %d" % ( - count, testdata.sql_count)) - - def capture_sql(self, db, callable_): - global testdata - testdata = TestData() - buffer = StringIO() - testdata.buffer = buffer - try: - callable_() - return buffer.getvalue() - finally: - testdata.buffer = None + self.assert_sql_execution(db, callable_, assertsql.CountStatements(count)) + _otest_metadata = None class ORMTest(TestBase, AssertsExecutionResults): -- 2.47.3