]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- rework the assert_sql system so that we have a context manager to work with,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 7 Dec 2014 23:54:52 +0000 (18:54 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 7 Dec 2014 23:54:52 +0000 (18:54 -0500)
use events that are local to the engine and to the run and are removed afterwards.

lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/assertsql.py
lib/sqlalchemy/testing/engines.py

index bf7c27a890681d3f4951f5139a91b618dc09eb1f..66d1f3cb0c9d9d0a2d6c1845d213b775c2ec096f 100644 (file)
@@ -405,13 +405,16 @@ class AssertsExecutionResults(object):
                         cls.__name__, repr(expected_item)))
         return True
 
+    def sql_execution_asserter(self, db=None):
+        if db is None:
+            from . import db as db
+
+        return assertsql.assert_engine(db)
+
     def assert_sql_execution(self, db, callable_, *rules):
-        assertsql.asserter.add_rules(rules)
-        try:
+        with self.sql_execution_asserter(db) as asserter:
             callable_()
-            assertsql.asserter.statement_complete()
-        finally:
-            assertsql.asserter.clear_rules()
+        asserter.assert_(*rules)
 
     def assert_sql(self, db, callable_, list_, with_sequences=None):
         if (with_sequences is not None and
index bcc999fe3024f2614b3dc92f36281f1d7683d654..2ac0605a22408b9b423ad71c6989ee6427937527 100644 (file)
@@ -8,6 +8,9 @@
 from ..engine.default import DefaultDialect
 from .. import util
 import re
+import collections
+import contextlib
+from .. import event
 
 
 class AssertRule(object):
@@ -321,39 +324,78 @@ def _process_assertion_statement(query, context):
     return query
 
 
-class SQLAssert(object):
+class SQLExecuteObserved(
+    collections.namedtuple(
+        "SQLExecuteObserved", ["clauseelement", "multiparams", "params"])
+):
+    def process(self, rules):
+        if rules is not None:
+            if not rules:
+                assert False, \
+                    'All rules have been exhausted, but further '\
+                    'statements remain'
+            rule = rules[0]
+            rule.process_execute(
+                self.clauseelement, *self.multiparams, **self.params)
+            if rule.is_consumed():
+                rules.pop(0)
 
-    rules = None
 
-    def add_rules(self, rules):
-        self.rules = list(rules)
+class SQLCursorExecuteObserved(
+    collections.namedtuple(
+        "SQLCursorExecuteObserved",
+        ["statement", "parameters", "context", "executemany"])
+):
+    def process(self, rules):
+        if rules:
+            rule = rules[0]
+            rule.process_cursor_execute(
+                self.statement, self.parameters,
+                self.context, self.executemany)
 
-    def statement_complete(self):
-        for rule in self.rules:
+
+class SQLAsserter(object):
+    def __init__(self):
+        self.accumulated = []
+
+    def _close(self):
+        # safety feature in case event.remove
+        # goes haywire
+        self._final = self.accumulated
+        del self.accumulated
+
+    def assert_(self, *rules):
+        rules = list(rules)
+        for observed in self._final:
+            observed.process(rules)
+
+        for rule in rules:
             if not rule.consume_final():
                 assert False, \
                     'All statements are complete, but pending '\
                     'assertion rules remain'
 
-    def clear_rules(self):
-        del self.rules
 
-    def execute(self, conn, clauseelement, multiparams, params, result):
-        if self.rules is not None:
-            if not self.rules:
-                assert False, \
-                    'All rules have been exhausted, but further '\
-                    'statements remain'
-            rule = self.rules[0]
-            rule.process_execute(clauseelement, *multiparams, **params)
-            if rule.is_consumed():
-                self.rules.pop(0)
+@contextlib.contextmanager
+def assert_engine(engine):
+    asserter = SQLAsserter()
 
-    def cursor_execute(self, conn, cursor, statement, parameters,
-                       context, executemany):
-        if self.rules:
-            rule = self.rules[0]
-            rule.process_cursor_execute(statement, parameters, context,
-                                        executemany)
+    @event.listens_for(engine, "after_execute")
+    def execute(conn, clauseelement, multiparams, params, result):
+        asserter.accumulated.append(
+            SQLExecuteObserved(
+                clauseelement, multiparams, params))
 
-asserter = SQLAssert()
+    @event.listens_for(engine, "after_cursor_execute")
+    def cursor_execute(conn, cursor, statement, parameters,
+                       context, executemany):
+        asserter.accumulated.append(
+            SQLCursorExecuteObserved(
+                statement, parameters, context, executemany))
+
+    try:
+        yield asserter
+    finally:
+        asserter._close()
+        event.remove(engine, "after_cursor_execute", cursor_execute)
+        event.remove(engine, "after_execute", execute)
index 0f6f59401c8e12d3d131d195d927ec231832cabb..7d73e742332080c59158ee89fb8e6d1586e53502 100644 (file)
@@ -204,7 +204,6 @@ def testing_engine(url=None, options=None):
     """Produce an engine configured by --options with optional overrides."""
 
     from sqlalchemy import create_engine
-    from .assertsql import asserter
 
     if not options:
         use_reaper = True
@@ -219,8 +218,6 @@ def testing_engine(url=None, options=None):
     if isinstance(engine.pool, pool.QueuePool):
         engine.pool._timeout = 0
         engine.pool._max_overflow = 0
-    event.listen(engine, 'after_execute', asserter.execute)
-    event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
     if use_reaper:
         event.listen(engine.pool, 'connect', testing_reaper.connect)
         event.listen(engine.pool, 'checkout', testing_reaper.checkout)