will make use of the enhanced flexibility.
from testlib.sa import Table, Column, Integer, String, ForeignKey
from testlib.sa.orm import mapper, relation, backref, create_session
from testlib.testing import eq_
+from testlib.assertsql import RegexSQL, ExactSQL, CompiledSQL
from orm import _base
sess.add(b)
sess.add(p)
- self.assert_sql(testing.db, sess.flush, [
- ("INSERT INTO person (favorite_ball_id, data) "
- "VALUES (:favorite_ball_id, :data)",
- {'favorite_ball_id': None, 'data':'some data'}),
-
- ("INSERT INTO ball (person_id, data) "
- "VALUES (:person_id, :data)",
- lambda ctx:{'person_id':p.id, 'data':'some data'}),
-
- ("INSERT INTO ball (person_id, data) "
- "VALUES (:person_id, :data)",
- lambda ctx:{'person_id':p.id, 'data':'some data'}),
-
- ("INSERT INTO ball (person_id, data) "
- "VALUES (:person_id, :data)",
- lambda ctx:{'person_id':p.id, 'data':'some data'}),
-
- ("INSERT INTO ball (person_id, data) "
- "VALUES (:person_id, :data)",
- lambda ctx:{'person_id':p.id, 'data':'some data'}),
-
- ("UPDATE person SET favorite_ball_id=:favorite_ball_id "
- "WHERE person.id = :person_id",
- lambda ctx:{'favorite_ball_id':p.favorite.id, 'person_id':p.id})
- ],
-
- with_sequences= [
- ("INSERT INTO person (id, favorite_ball_id, data) "
- "VALUES (:id, :favorite_ball_id, :data)",
- lambda ctx: {'id':ctx.last_inserted_ids()[0],
- 'favorite_ball_id': None,
- 'data':'some data'}),
-
- ("INSERT INTO ball (id, person_id, data) "
- "VALUES (:id, :person_id, :data)",
- lambda ctx: {'id':ctx.last_inserted_ids()[0],
- 'person_id':p.id,
- 'data':'some data'}),
-
- ("INSERT INTO ball (id, person_id, data) "
- "VALUES (:id, :person_id, :data)",
- lambda ctx: {'id':ctx.last_inserted_ids()[0],
- 'person_id':p.id,
- 'data':'some data'}),
-
- ("INSERT INTO ball (id, person_id, data) "
- "VALUES (:id, :person_id, :data)",
- lambda ctx: {'id':ctx.last_inserted_ids()[0],
- 'person_id':p.id,
- 'data':'some data'}),
- ("INSERT INTO ball (id, person_id, data) "
- "VALUES (:id, :person_id, :data)",
- lambda ctx: {'id':ctx.last_inserted_ids()[0],
- 'person_id':p.id,
- 'data':'some data'}),
- # heres the post update
- ("UPDATE person SET favorite_ball_id=:favorite_ball_id "
- "WHERE person.id = :person_id",
- lambda ctx: {'favorite_ball_id':p.favorite.id, 'person_id':p.id})])
+ self.assert_sql_execution(
+ testing.db,
+ sess.flush,
+ RegexSQL("^INSERT INTO person", {'data':'some data'}),
+ RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}),
+ RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}),
+ RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}),
+ RegexSQL("^INSERT INTO ball", lambda c: {'person_id':p.id, 'data':'some data'}),
+ ExactSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id "
+ "WHERE person.id = :person_id",
+ lambda ctx:{'favorite_ball_id':p.favorite.id, 'person_id':p.id}
+ ),
+ )
sess.delete(p)
- self.assert_sql(testing.db, sess.flush, [
- # heres the post update (which is a pre-update with deletes)
- ("UPDATE person SET favorite_ball_id=:favorite_ball_id "
- "WHERE person.id = :person_id",
- lambda ctx: {'person_id': p.id, 'favorite_ball_id': None}),
-
- ("DELETE FROM ball WHERE ball.id = :id",
- None),
- # order cant be predicted, but something like:
- #lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}]),
-
- ("DELETE FROM person WHERE person.id = :id",
- lambda ctx:[{'id': p.id}])])
-
+ self.assert_sql_execution(
+ testing.db,
+ sess.flush,
+ ExactSQL("UPDATE person SET favorite_ball_id=:favorite_ball_id "
+ "WHERE person.id = :person_id",
+ lambda ctx: {'person_id': p.id, 'favorite_ball_id': None}),
+ ExactSQL("DELETE FROM ball WHERE ball.id = :id", None), # lambda ctx:[{'id': 1L}, {'id': 4L}, {'id': 3L}, {'id': 2L}])
+ ExactSQL("DELETE FROM person WHERE person.id = :id", lambda ctx:[{'id': p.id}])
+ )
@testing.resolve_artifact_names
def testpostupdate_o2m(self):
sess = create_session()
sess.add_all((b,p,b2,b3,b4))
- self.assert_sql(testing.db, sess.flush, [
- ("INSERT INTO ball (person_id, data) "
+ self.assert_sql_execution(
+ testing.db,
+ sess.flush,
+ CompiledSQL("INSERT INTO ball (person_id, data) "
"VALUES (:person_id, :data)",
{'person_id':None, 'data':'some data'}),
- ("INSERT INTO ball (person_id, data) "
+ CompiledSQL("INSERT INTO ball (person_id, data) "
"VALUES (:person_id, :data)",
{'person_id':None, 'data':'some data'}),
- ("INSERT INTO ball (person_id, data) "
+ CompiledSQL("INSERT INTO ball (person_id, data) "
"VALUES (:person_id, :data)",
{'person_id':None, 'data':'some data'}),
- ("INSERT INTO ball (person_id, data) "
+ CompiledSQL("INSERT INTO ball (person_id, data) "
"VALUES (:person_id, :data)",
{'person_id':None, 'data':'some data'}),
- ("INSERT INTO person (favorite_ball_id, data) "
+ CompiledSQL("INSERT INTO person (favorite_ball_id, data) "
"VALUES (:favorite_ball_id, :data)",
lambda ctx:{'favorite_ball_id':b.id, 'data':'some data'}),
- # heres the post update on each one-to-many item
- ("UPDATE ball SET person_id=:person_id "
+ CompiledSQL("UPDATE ball SET person_id=:person_id "
"WHERE ball.id = :ball_id",
lambda ctx:{'person_id':p.id,'ball_id':b.id}),
- ("UPDATE ball SET person_id=:person_id "
+ CompiledSQL("UPDATE ball SET person_id=:person_id "
"WHERE ball.id = :ball_id",
lambda ctx:{'person_id':p.id,'ball_id':b2.id}),
- ("UPDATE ball SET person_id=:person_id "
+ CompiledSQL("UPDATE ball SET person_id=:person_id "
"WHERE ball.id = :ball_id",
lambda ctx:{'person_id':p.id,'ball_id':b3.id}),
- ("UPDATE ball SET person_id=:person_id "
+ CompiledSQL("UPDATE ball SET person_id=:person_id "
"WHERE ball.id = :ball_id",
- lambda ctx:{'person_id':p.id,'ball_id':b4.id})],
-
- with_sequences=[
- ("INSERT INTO ball (id, person_id, data) "
- "VALUES (:id, :person_id, :data)",
- lambda ctx: {'id':ctx.last_inserted_ids()[0],
- 'person_id':None,
- 'data':'some data'}),
- ("INSERT INTO ball (id, person_id, data) "
- "VALUES (:id, :person_id, :data)",
- lambda ctx:{'id':ctx.last_inserted_ids()[0],
- 'person_id':None,
- 'data':'some data'}),
- ("INSERT INTO ball (id, person_id, data) "
- "VALUES (:id, :person_id, :data)",
- lambda ctx:{'id':ctx.last_inserted_ids()[0],
- 'person_id':None,
- 'data':'some data'}),
- ("INSERT INTO ball (id, person_id, data) "
- "VALUES (:id, :person_id, :data)",
- lambda ctx:{'id':ctx.last_inserted_ids()[0],
- 'person_id':None,
- 'data':'some data'}),
- ("INSERT INTO person (id, favorite_ball_id, data) "
- "VALUES (:id, :favorite_ball_id, :data)",
- lambda ctx:{'id':ctx.last_inserted_ids()[0],
- 'favorite_ball_id':b.id,
- 'data':'some data'}),
- ("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id':p.id,'ball_id':b.id}),
-
- ("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id':p.id,'ball_id':b2.id}),
-
- ("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id':p.id,'ball_id':b3.id}),
-
- ("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id':p.id,'ball_id':b4.id})])
-
+ lambda ctx:{'person_id':p.id,'ball_id':b4.id}),
+ )
+
sess.delete(p)
- self.assert_sql(testing.db, sess.flush, [
- ("UPDATE ball SET person_id=:person_id "
+
+ self.assert_sql_execution(testing.db, sess.flush,
+ CompiledSQL("UPDATE ball SET person_id=:person_id "
"WHERE ball.id = :ball_id",
lambda ctx:{'person_id': None, 'ball_id': b.id}),
- ("UPDATE ball SET person_id=:person_id "
+ CompiledSQL("UPDATE ball SET person_id=:person_id "
"WHERE ball.id = :ball_id",
lambda ctx:{'person_id': None, 'ball_id': b2.id}),
- ("UPDATE ball SET person_id=:person_id "
+ CompiledSQL("UPDATE ball SET person_id=:person_id "
"WHERE ball.id = :ball_id",
lambda ctx:{'person_id': None, 'ball_id': b3.id}),
- ("UPDATE ball SET person_id=:person_id "
+ CompiledSQL("UPDATE ball SET person_id=:person_id "
"WHERE ball.id = :ball_id",
lambda ctx:{'person_id': None, 'ball_id': b4.id}),
- ("DELETE FROM person WHERE person.id = :id",
+ CompiledSQL("DELETE FROM person WHERE person.id = :id",
lambda ctx:[{'id':p.id}]),
- ("DELETE FROM ball WHERE ball.id = :id",
+ CompiledSQL("DELETE FROM ball WHERE ball.id = :id",
lambda ctx:[{'id': b.id},
{'id': b2.id},
{'id': b3.id},
- {'id': b4.id}])])
+ {'id': b4.id}])
+ )
class SelfReferentialPostUpdateTest(_base.MappedTest):
remove_child(root, cats)
# pre-trigger lazy loader on 'cats' to make the test easier
cats.children
- self.assert_sql(testing.db, lambda: session.flush(), [
- ("UPDATE node SET prev_sibling_id=:prev_sibling_id "
+ self.assert_sql_execution(
+ testing.db,
+ session.flush,
+ CompiledSQL("UPDATE node SET prev_sibling_id=:prev_sibling_id "
"WHERE node.id = :node_id",
lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id}),
- ("UPDATE node SET next_sibling_id=:next_sibling_id "
+ CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
"WHERE node.id = :node_id",
lambda ctx:{'next_sibling_id':stories.id, 'node_id':about.id}),
- ("UPDATE node SET next_sibling_id=:next_sibling_id "
+ CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
"WHERE node.id = :node_id",
lambda ctx:{'next_sibling_id':None, 'node_id':cats.id}),
- ("DELETE FROM node WHERE node.id = :id",
- lambda ctx:[{'id':cats.id}])])
+
+ CompiledSQL("DELETE FROM node WHERE node.id = :id",
+ lambda ctx:[{'id':cats.id}])
+ )
class SelfReferentialPostUpdateTest2(_base.MappedTest):
from testlib.sa import Table, Column, Integer, String, ForeignKey, and_
from testlib.sa.orm import mapper, relation, create_session, lazyload
from testlib.testing import eq_
+from testlib.assertsql import CompiledSQL
from orm import _base, _fixtures
class EagerTest(_fixtures.FixtureTest):
d = sess.query(Node).filter_by(data='n1').options(eagerload('children.children')).first()
# test that the query isn't wrapping the initial query for eager loading.
- # testing only sqlite for now since the query text is slightly different on other
- # dialects
- if testing.against('sqlite'):
- self.assert_sql(testing.db, go, [
- (
- "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes "
- "WHERE nodes.data = :data_1 ORDER BY nodes.id LIMIT 1 OFFSET 0",
- {'data_1': 'n1'}
- ),
- ])
+ self.assert_sql_execution(testing.db, go,
+ CompiledSQL(
+ "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS nodes_data FROM nodes "
+ "WHERE nodes.data = :data_1 ORDER BY nodes.id LIMIT 1 OFFSET 0",
+ {'data_1': 'n1'}
+ )
+ )
@testing.fails_on('maxdb')
@testing.resolve_artifact_names
sess.clear()
- # test that eager loading doesnt modify parent mapper
- def go():
- u = sess.query(User).filter_by(id=8).one()
- eq_(u.id, 8)
- eq_(len(u.addresses), 3)
- assert "tbl_row_count" not in self.capture_sql(testing.db, go)
+ u = sess.query(User).filter_by(id=8).one()
+ eq_(u.id, 8)
+ eq_(len(u.addresses), 3)
@testing.fails_on('maxdb')
@testing.resolve_artifact_names
sess = create_session()
q = sess.query(Order).order_by(Order.id).options(defer('user_id'))
- self.sql_eq_(q.all, [
+ def go():
+ q.all()[0].user_id
+
+ self.sql_eq_(go, [
("SELECT orders.id AS orders_id, "
"orders.address_id AS orders_address_id, "
"orders.description AS orders_description, "
"FROM orders ORDER BY orders.id", {}),
("SELECT orders.user_id AS orders_user_id "
"FROM orders WHERE orders.id = :param_1",
- {'param_1':3})])
+ {'param_1':1})])
sess.clear()
q2 = q.options(sa.orm.undefer('user_id'))
from orm import _base, _fixtures
from engine import _base as engine_base
import pickleable
-
+from testlib.assertsql import AllOf, CompiledSQL
class UnitOfWorkTest(object):
pass
objects[2].email_address = 'imnew@foo.bar'
objects[3].user = User()
objects[3].user.name = 'imnewlyadded'
- self.assert_sql(testing.db,
+ self.assert_sql_execution(testing.db,
session.flush,
- [
- ("INSERT INTO users (name) VALUES (:name)",
- {'name': 'imnewlyadded'} ),
-
- {"UPDATE addresses SET email_address=:email_address "
- "WHERE addresses.id = :addresses_id":
- lambda ctx: {'email_address': 'imnew@foo.bar',
- 'addresses_id': objects[2].id},
- "UPDATE addresses SET user_id=:user_id "
- "WHERE addresses.id = :addresses_id":
- lambda ctx: {'user_id': objects[3].user.id,
- 'addresses_id': objects[3].id}},
- ],
- with_sequences=[
- ("INSERT INTO users (id, name) VALUES (:id, :name)",
- lambda ctx:{'name': 'imnewlyadded',
- 'id':ctx.last_inserted_ids()[0]}),
- {"UPDATE addresses SET email_address=:email_address "
- "WHERE addresses.id = :addresses_id":
- lambda ctx: {'email_address': 'imnew@foo.bar',
- 'addresses_id': objects[2].id},
- ("UPDATE addresses SET user_id=:user_id "
- "WHERE addresses.id = :addresses_id"):
- lambda ctx: {'user_id': objects[3].user.id,
- 'addresses_id': objects[3].id}}])
+ CompiledSQL("INSERT INTO users (name) VALUES (:name)",
+ {'name': 'imnewlyadded'} ),
+
+ AllOf(
+ CompiledSQL("UPDATE addresses SET email_address=:email_address "
+ "WHERE addresses.id = :addresses_id",
+ lambda ctx: {'email_address': 'imnew@foo.bar',
+ 'addresses_id': objects[2].id}),
+ CompiledSQL("UPDATE addresses SET user_id=:user_id "
+ "WHERE addresses.id = :addresses_id",
+ lambda ctx: {'user_id': objects[3].user.id,
+ 'addresses_id': objects[3].id})
+ )
+ )
l = sa.select([users, addresses],
sa.and_(users.c.id==addresses.c.user_id,
k = Keyword()
k.name = 'yellow'
objects[5].keywords.append(k)
- self.assert_sql(testing.db, session.flush, [
- {"UPDATE items SET description=:description "
- "WHERE items.id = :items_id":
- {'description': 'item4updated',
- 'items_id': objects[4].id},
- "INSERT INTO keywords (name) "
- "VALUES (:name)":
- {'name': 'yellow'}},
- ("INSERT INTO item_keywords (item_id, keyword_id) "
- "VALUES (:item_id, :keyword_id)",
- lambda ctx: [{'item_id': objects[5].id,
- 'keyword_id': k.id}])],
- with_sequences = [
- {"UPDATE items SET description=:description "
- "WHERE items.id = :items_id":
- {'description': 'item4updated',
- 'items_id': objects[4].id},
- "INSERT INTO keywords (id, name) "
- "VALUES (:id, :name)":
- lambda ctx: {'name': 'yellow',
- 'id':ctx.last_inserted_ids()[0]}},
- ("INSERT INTO item_keywords (item_id, keyword_id) "
- "VALUES (:item_id, :keyword_id)",
- lambda ctx: [{'item_id': objects[5].id,
- 'keyword_id': k.id}])])
+ self.assert_sql_execution(
+ testing.db,
+ session.flush,
+ AllOf(
+ CompiledSQL("UPDATE items SET description=:description "
+ "WHERE items.id = :items_id",
+ {'description': 'item4updated',
+ 'items_id': objects[4].id},
+ ),
+ CompiledSQL("INSERT INTO keywords (name) "
+ "VALUES (:name)",
+ {'name': 'yellow'},
+ )
+ ),
+ CompiledSQL("INSERT INTO item_keywords (item_id, keyword_id) "
+ "VALUES (:item_id, :keyword_id)",
+ lambda ctx: [{'item_id': objects[5].id,
+ 'keyword_id': k.id}])
+ )
objects[2].keywords.append(k)
dkid = objects[5].keywords[1].id
del objects[5].keywords[1]
- self.assert_sql(testing.db, session.flush, [
- ("DELETE FROM item_keywords "
- "WHERE item_keywords.item_id = :item_id AND "
- "item_keywords.keyword_id = :keyword_id",
- [{'item_id': objects[5].id, 'keyword_id': dkid}]),
- ("INSERT INTO item_keywords (item_id, keyword_id) "
- "VALUES (:item_id, :keyword_id)",
- lambda ctx: [{'item_id': objects[2].id, 'keyword_id': k.id}]
- )])
+ self.assert_sql_execution(
+ testing.db,
+ session.flush,
+ CompiledSQL("DELETE FROM item_keywords "
+ "WHERE item_keywords.item_id = :item_id AND "
+ "item_keywords.keyword_id = :keyword_id",
+ [{'item_id': objects[5].id, 'keyword_id': dkid}]),
+ CompiledSQL("INSERT INTO item_keywords (item_id, keyword_id) "
+ "VALUES (:item_id, :keyword_id)",
+ lambda ctx: [{'item_id': objects[2].id, 'keyword_id': k.id}]
+ ))
session.delete(objects[3])
session.flush()
session.add_all(fixture())
- self.assert_sql(testing.db, session.flush, [
- ("INSERT INTO users (name) VALUES (:name)",
+ self.assert_sql_execution(
+ testing.db,
+ session.flush,
+ CompiledSQL("INSERT INTO users (name) VALUES (:name)",
{'name': 'u1'}),
- ("INSERT INTO users (name) VALUES (:name)",
+ CompiledSQL("INSERT INTO users (name) VALUES (:name)",
{'name': 'u2'}),
- ("INSERT INTO addresses (user_id, email_address) "
+ CompiledSQL("INSERT INTO addresses (user_id, email_address) "
"VALUES (:user_id, :email_address)",
{'user_id': 1, 'email_address': 'a1'}),
- ("INSERT INTO addresses (user_id, email_address) "
+ CompiledSQL("INSERT INTO addresses (user_id, email_address) "
"VALUES (:user_id, :email_address)",
- {'user_id': 2, 'email_address': 'a2'})],
- with_sequences = [
- ("INSERT INTO users (id, name) "
- "VALUES (:id, :name)",
- lambda ctx: {'name': 'u1', 'id':ctx.last_inserted_ids()[0]}),
- ("INSERT INTO users (id, name) "
- "VALUES (:id, :name)",
- lambda ctx: {'name': 'u2', 'id':ctx.last_inserted_ids()[0]}),
- ("INSERT INTO addresses (id, user_id, email_address) "
- "VALUES (:id, :user_id, :email_address)",
- lambda ctx:{'user_id': 1, 'email_address': 'a1',
- 'id':ctx.last_inserted_ids()[0]}),
- ("INSERT INTO addresses (id, user_id, email_address) "
- "VALUES (:id, :user_id, :email_address)",
- lambda ctx:{'user_id': 2, 'email_address': 'a2',
- 'id':ctx.last_inserted_ids()[0]})])
-
+ {'user_id': 2, 'email_address': 'a2'}),
+ )
class SaveTest3(_base.MappedTest):
def define_tables(self, metadata):
--- /dev/null
+from sqlalchemy.interfaces import ConnectionProxy
+import re
+from sqlalchemy.engine.default import DefaultDialect
+from sqlalchemy.engine.base import Connection
+import testing
+
+class AssertRule(object):
+ def process_execute(self, clauseelement, *multiparams, **params):
+ pass
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ pass
+
+ def is_consumed(self):
+ """Return True if this rule has been consumed, False if not.
+
+ Should raise an AssertionError if this rule's condition has definitely failed.
+
+ """
+ raise NotImplementedError()
+
+ def rule_passed(self):
+ """Return True if the last test of this rule passed, False if failed, None if no test was applied."""
+
+ raise NotImplementedError()
+
+ def consume_final(self):
+ """Return True if this rule has been consumed.
+
+ Should raise an AssertionError if this rule's condition has not been consumed or has failed.
+
+ """
+
+ if self._result is None:
+ assert False, "Rule has not been consumed"
+
+ return self.is_consumed()
+
+class SQLMatchRule(AssertRule):
+ def __init__(self):
+ self._result = None
+ self._errmsg = ""
+
+ def rule_passed(self):
+ return self._result
+
+ def is_consumed(self):
+ if self._result is None:
+ return False
+
+ assert self._result, self._errmsg
+
+ return True
+
+class ExactSQL(SQLMatchRule):
+ def __init__(self, sql, params=None):
+ SQLMatchRule.__init__(self)
+ self.sql = sql
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ if not context:
+ return
+
+ _received_statement = _process_engine_statement(statement, context)
+ _received_parameters = context.compiled_parameters
+
+ # TODO: remove this step once all unit tests
+ # are migrated, as ExactSQL should really be *exact* SQL
+ sql = _process_assertion_statement(self.sql, context)
+
+ equivalent = _received_statement == sql
+ if self.params:
+ if callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+
+ if not isinstance(params, list):
+ params = [params]
+ equivalent = equivalent and params == context.compiled_parameters
+ else:
+ params = {}
+
+
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = "Testing for exact statement %r exact params %r, " \
+ "received %r with params %r" % (sql, params, _received_statement, _received_parameters)
+
+
+class RegexSQL(SQLMatchRule):
+ def __init__(self, regex, params=None):
+ SQLMatchRule.__init__(self)
+ self.regex = re.compile(regex)
+ self.orig_regex = regex
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ if not context:
+ return
+
+ _received_statement = _process_engine_statement(statement, context)
+ _received_parameters = context.compiled_parameters
+
+ equivalent = bool(self.regex.match(_received_statement))
+ if self.params:
+ if callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+
+ if not isinstance(params, list):
+ params = [params]
+
+ # do a positive compare only
+ for param, received in zip(params, _received_parameters):
+ for k, v in param.iteritems():
+ if k not in received or received[k] != v:
+ equivalent = False
+ break
+ else:
+ params = {}
+
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = "Testing for regex %r partial params %r, "\
+ "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters)
+
+class CompiledSQL(SQLMatchRule):
+ def __init__(self, statement, params):
+ SQLMatchRule.__init__(self)
+ self.statement = statement
+ self.params = params
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ if not context:
+ return
+
+ _received_parameters = context.compiled_parameters
+
+ # recompile from the context, using the default dialect
+ compiled = context.compiled.statement.\
+ compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys)
+
+ _received_statement = re.sub(r'\n', '', str(compiled))
+
+ equivalent = self.statement == _received_statement
+ if self.params:
+ if callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+
+ if not isinstance(params, list):
+ params = [params]
+
+ # do a positive compare only
+ for param, received in zip(params, _received_parameters):
+ for k, v in param.iteritems():
+ if k not in received or received[k] != v:
+ equivalent = False
+ break
+ else:
+ params = {}
+
+ self._result = equivalent
+ if not self._result:
+ self._errmsg = "Testing for compiled statement %r partial params %r, " \
+ "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters)
+
+
+class CountStatements(AssertRule):
+ def __init__(self, count):
+ self.count = count
+ self._statement_count = 0
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ self._statement_count += 1
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ pass
+
+ def is_consumed(self):
+ return False
+
+ def consume_final(self):
+ assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count)
+ return True
+
+class AllOf(AssertRule):
+ def __init__(self, *rules):
+ self.rules = set(rules)
+
+ def process_execute(self, clauseelement, *multiparams, **params):
+ for rule in self.rules:
+ rule.process_execute(clauseelement, *multiparams, **params)
+
+ def process_cursor_execute(self, statement, parameters, context, executemany):
+ for rule in self.rules:
+ rule.process_cursor_execute(statement, parameters, context, executemany)
+
+ def is_consumed(self):
+ if not self.rules:
+ return True
+
+ for rule in list(self.rules):
+ if rule.rule_passed(): # a rule passed, move on
+ self.rules.remove(rule)
+ return len(self.rules) == 0
+
+ assert False, "No assertion rules were satisfied for statement"
+
+ def consume_final(self):
+ return len(self.rules) == 0
+
+def _process_engine_statement(query, context):
+ if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
+ query = query[:-25]
+
+ query = re.sub(r'\n', '', query)
+
+ return query
+
+def _process_assertion_statement(query, context):
+ paramstyle = context.dialect.paramstyle
+ if paramstyle == 'named':
+ pass
+ elif paramstyle =='pyformat':
+ query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+ else:
+ # positional params
+ repl = None
+ if paramstyle=='qmark':
+ repl = "?"
+ elif paramstyle=='format':
+ repl = r"%s"
+ elif paramstyle=='numeric':
+ repl = None
+ query = re.sub(r':([\w_]+)', repl, query)
+
+ return query
+
+class SQLAssert(ConnectionProxy):
+ rules = None
+
+ def add_rules(self, rules):
+ self.rules = list(rules)
+
+ def statement_complete(self):
+ for rule in self.rules:
+ if not rule.consume_final():
+ assert False, "All statements are complete, but pending assertion rules remain"
+
+ def clear_rules(self):
+ del self.rules
+
+ def execute(self, conn, execute, clauseelement, *multiparams, **params):
+ result = execute(clauseelement, *multiparams, **params)
+
+ if self.rules is not None:
+ if not self.rules:
+ assert False, "All rules have been exhausted, but further statements remain"
+ rule = self.rules[0]
+ rule.process_execute(clauseelement, *multiparams, **params)
+ if rule.is_consumed():
+ self.rules.pop(0)
+
+ return result
+
+ def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+ result = execute(cursor, statement, parameters, context)
+
+ if self.rules:
+ rule = self.rules[0]
+ rule.process_cursor_execute(statement, parameters, context, executemany)
+
+ return result
+
+asserter = SQLAssert()
+
"""Produce an engine configured by --options with optional overrides."""
from sqlalchemy import create_engine
- from testlib.testing import ExecutionContextWrapper
+ from testlib.assertsql import asserter
url = url or config.db_url
options = options or config.db_opts
+ options.setdefault('proxy', asserter)
+
listeners = options.setdefault('listeners', [])
listeners.append(testing_reaper)
engine = create_engine(url, **options)
- create_context = engine.dialect.create_execution_context
- def create_exec_context(*args, **kwargs):
- return ExecutionContextWrapper(create_context(*args, **kwargs))
- engine.dialect.create_execution_context = create_exec_context
return engine
def utf8_engine(url=None, options=None):
import testlib.config as config
from testlib.compat import _function_named
+from testlib import assertsql
# Delayed imports
MetaData = None
for column_values in rows])
table.append_ddl_listener('after-create', onload)
-class TestData(object):
- """Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
-
- def __init__(self):
- self.set_assert_list(None, None)
- self.sql_count = 0
- self.buffer = None
-
- def set_assert_list(self, unittest, list):
- self.unittest = unittest
- self.assert_list = list
- if list is not None:
- self.assert_list.reverse()
-
-testdata = TestData()
-
-
-class ExecutionContextWrapper(object):
- """instruments the ExecutionContext created by the Engine so that SQL expressions
- can be tracked."""
-
- def __init__(self, ctx):
- self.__dict__['ctx'] = ctx
- def __getattr__(self, key):
- return getattr(self.ctx, key)
- def __setattr__(self, key, value):
- setattr(self.ctx, key, value)
-
- trailing_underscore_pattern = re.compile(r'(\W:[\w_#]+)_\b',re.MULTILINE)
- def post_exec(self):
- ctx = self.ctx
- statement = unicode(ctx.compiled)
- statement = re.sub(r'\n', '', ctx.statement)
- if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'):
- statement = statement[:-25]
- if testdata.buffer is not None:
- testdata.buffer.write(statement + "\n")
-
- if testdata.assert_list is not None:
- assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement
- item = testdata.assert_list[-1]
- if not isinstance(item, dict):
- item = testdata.assert_list.pop()
- else:
- # asserting a dictionary of statements->parameters
- # this is to specify query assertions where the queries can be in
- # multiple orderings
- if '_converted' not in item:
- for key in item.keys():
- ckey = self.convert_statement(key)
- item[ckey] = item[key]
- if ckey != key:
- del item[key]
- item['_converted'] = True
- try:
- entry = item.pop(statement)
- if len(item) == 1:
- testdata.assert_list.pop()
- item = (statement, entry)
- except KeyError:
- assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement)
-
- (query, params) = item
- if callable(params):
- params = params(ctx)
- if params is not None and not isinstance(params, list):
- params = [params]
-
- parameters = ctx.compiled_parameters
-
- query = self.convert_statement(query)
- equivalent = ( (statement == query)
- or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
- ) \
- and \
- ( (params is None) or (params == parameters)
- or params == [dict([(k.strip('_'), v)
- for (k, v) in p.items()])
- for p in parameters]
- )
- testdata.unittest.assert_(equivalent,
- "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
- testdata.sql_count += 1
- self.ctx.post_exec()
-
- def convert_statement(self, query):
- paramstyle = self.ctx.dialect.paramstyle
- if paramstyle == 'named':
- pass
- elif paramstyle =='pyformat':
- query = re.sub(r':([\w_]+)', r"%(\1)s", query)
- else:
- # positional params
- repl = None
- if paramstyle=='qmark':
- repl = "?"
- elif paramstyle=='format':
- repl = r"%s"
- elif paramstyle=='numeric':
- repl = None
- query = re.sub(r':([\w_]+)', repl, query)
- return query
-
-
def _import_by_name(name):
submodule = name.split('.')[-1]
return __import__(name, globals(), locals(), [submodule])
cls.__name__, repr(expected_item)))
return True
- def assert_sql(self, db, callable_, list, with_sequences=None):
- global testdata
- testdata = TestData()
- if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
- testdata.set_assert_list(self, with_sequences)
- else:
- testdata.set_assert_list(self, list)
+ def assert_sql_execution(self, db, callable_, *rules):
+ assertsql.asserter.add_rules(rules)
try:
callable_()
+ assertsql.asserter.statement_complete()
finally:
- testdata.set_assert_list(None, None)
+ assertsql.asserter.clear_rules()
+
+ def assert_sql(self, db, callable_, list_, with_sequences=None):
+ if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
+ rules = with_sequences
+ else:
+ rules = list_
+
+ newrules = []
+ for rule in rules:
+ if isinstance(rule, dict):
+ newrule = assertsql.AllOf(*[
+ assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
+ ])
+ else:
+ newrule = assertsql.ExactSQL(*rule)
+ newrules.append(newrule)
+
+ self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
- global testdata
- testdata = TestData()
- callable_()
- self.assert_(testdata.sql_count == count,
- "desired statement count %d does not match %d" % (
- count, testdata.sql_count))
-
- def capture_sql(self, db, callable_):
- global testdata
- testdata = TestData()
- buffer = StringIO()
- testdata.buffer = buffer
- try:
- callable_()
- return buffer.getvalue()
- finally:
- testdata.buffer = None
+ self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
+
_otest_metadata = None
class ORMTest(TestBase, AssertsExecutionResults):