]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixed up testbase coverage to get module-level stuff
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 22 Jul 2007 22:16:15 +0000 (22:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 22 Jul 2007 22:16:15 +0000 (22:16 +0000)
fixed activemapper tests

test/coverage.py
test/ext/activemapper.py
test/testbase.py

index 66e55e0c45c85b143adb27e702a307fbcf7d2121..618f962fef3ea839349fed3ba40b482433fdc55f 100644 (file)
@@ -22,7 +22,8 @@
 # interface and limitations.  See [GDR 2001-12-04b] for requirements and
 # design.
 
-r"""Usage:
+r"""\
+Usage:
 
 coverage.py -x [-p] MODULE.py [ARG1 ARG2 ...]
     Execute module, passing the given command-line arguments, collecting
@@ -54,18 +55,27 @@ coverage.py -a [-d dir] [-o dir1,dir2,...] FILE1 FILE2 ...
 Coverage data is saved in the file .coverage by default.  Set the
 COVERAGE_FILE environment variable to save it somewhere else."""
 
-__version__ = "2.6.20060823"    # see detailed history at the end of this file.
+__version__ = "2.75.20070722"    # see detailed history at the end of this file.
 
 import compiler
 import compiler.visitor
+import glob
 import os
 import re
 import string
+import symbol
 import sys
 import threading
+import token
 import types
 from socket import gethostname
 
+# Python version compatibility
+try:
+    strclass = basestring   # new to 2.3
+except:
+    strclass = str
+
 # 2. IMPLEMENTATION
 #
 # This uses the "singleton" pattern.
@@ -87,6 +97,9 @@ from socket import gethostname
 # names to increase speed.
 
 class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
+    """ A visitor for a parsed Abstract Syntax Tree which finds executable
+        statements.
+    """
     def __init__(self, statements, excluded, suite_spots):
         compiler.visitor.ASTVisitor.__init__(self)
         self.statements = statements
@@ -95,7 +108,6 @@ class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
         self.excluding_suite = 0
         
     def doRecursive(self, node):
-        self.recordNodeLine(node)
         for n in node.getChildNodes():
             self.dispatch(n)
 
@@ -131,12 +143,35 @@ class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
     def doStatement(self, node):
         self.recordLine(self.getFirstLine(node))
 
-    visitAssert = visitAssign = visitAssTuple = visitDiscard = visitPrint = \
+    visitAssert = visitAssign = visitAssTuple = visitPrint = \
         visitPrintnl = visitRaise = visitSubscript = visitDecorators = \
         doStatement
     
+    def visitPass(self, node):
+        # Pass statements have weird interactions with docstrings.  If this
+        # pass statement is part of one of those pairs, claim that the statement
+        # is on the later of the two lines.
+        l = node.lineno
+        if l:
+            lines = self.suite_spots.get(l, [l,l])
+            self.statements[lines[1]] = 1
+        
+    def visitDiscard(self, node):
+        # Discard nodes are statements that execute an expression, but then
+        # discard the results.  This includes function calls, so we can't 
+        # ignore them all.  But if the expression is a constant, the statement
+        # won't be "executed", so don't count it now.
+        if node.expr.__class__.__name__ != 'Const':
+            self.doStatement(node)
+
     def recordNodeLine(self, node):
-        return self.recordLine(node.lineno)
+        # Stmt nodes often have None, but shouldn't claim the first line of
+        # their children (because the first child might be an ignorable line
+        # like "global a").
+        if node.__class__.__name__ != 'Stmt':
+            return self.recordLine(self.getFirstLine(node))
+        else:
+            return 0
     
     def recordLine(self, lineno):
         # Returns a bool, whether the line is included or excluded.
@@ -145,7 +180,7 @@ class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
             # keyword.
             if lineno in self.suite_spots:
                 lineno = self.suite_spots[lineno][0]
-            # If we're inside an exluded suite, record that this line was
+            # If we're inside an excluded suite, record that this line was
             # excluded.
             if self.excluding_suite:
                 self.excluded[lineno] = 1
@@ -197,6 +232,8 @@ class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
         self.doSuite(node, node.body)
         self.doElse(node.body, node)
 
+    visitWhile = visitFor
+
     def visitIf(self, node):
         # The first test has to be handled separately from the rest.
         # The first test is credited to the line with the "if", but the others
@@ -206,10 +243,6 @@ class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
             self.doSuite(t, n)
         self.doElse(node.tests[-1][1], node)
 
-    def visitWhile(self, node):
-        self.doSuite(node, node.body)
-        self.doElse(node.body, node)
-
     def visitTryExcept(self, node):
         self.doSuite(node, node.body)
         for i in range(len(node.handlers)):
@@ -268,11 +301,13 @@ class coverage:
             raise CoverageException, "Only one coverage object allowed."
         self.usecache = 1
         self.cache = None
+        self.parallel_mode = False
         self.exclude_re = ''
         self.nesting = 0
         self.cstack = []
         self.xstack = []
-        self.relative_dir = os.path.normcase(os.path.abspath(os.curdir)+os.path.sep)
+        self.relative_dir = os.path.normcase(os.path.abspath(os.curdir)+os.sep)
+        self.exclude('# *pragma[: ]*[nN][oO] *[cC][oO][vV][eE][rR]')
 
     # t(f, x, y).  This method is passed to sys.settrace as a trace function.  
     # See [van Rossum 2001-07-20b, 9.2] for an explanation of sys.settrace and 
@@ -280,23 +315,24 @@ class coverage:
     # See [van Rossum 2001-07-20a, 3.2] for a description of frame and code
     # objects.
     
-    def t(self, f, w, a):                                   #pragma: no cover
+    def t(self, f, w, unused):                                   #pragma: no cover
         if w == 'line':
+            #print "Executing %s @ %d" % (f.f_code.co_filename, f.f_lineno)
             self.c[(f.f_code.co_filename, f.f_lineno)] = 1
             for c in self.cstack:
                 c[(f.f_code.co_filename, f.f_lineno)] = 1
         return self.t
     
-    def help(self, error=None):
+    def help(self, error=None):     #pragma: no cover
         if error:
             print error
             print
         print __doc__
         sys.exit(1)
 
-    def command_line(self, argv, help=None):
+    def command_line(self, argv, help_fn=None):
         import getopt
-        help = help or self.help
+        help_fn = help_fn or self.help
         settings = {}
         optmap = {
             '-a': 'annotate',
@@ -327,12 +363,12 @@ class coverage:
                 pass    # Can't get here, because getopt won't return anything unknown.
 
         if settings.get('help'):
-            help()
+            help_fn()
 
         for i in ['erase', 'execute']:
             for j in ['annotate', 'report', 'collect']:
                 if settings.get(i) and settings.get(j):
-                    help("You can't specify the '%s' and '%s' "
+                    help_fn("You can't specify the '%s' and '%s' "
                               "options at the same time." % (i, j))
 
         args_needed = (settings.get('execute')
@@ -342,18 +378,18 @@ class coverage:
                   or settings.get('collect')
                   or args_needed)
         if not action:
-            help("You must specify at least one of -e, -x, -c, -r, or -a.")
+            help_fn("You must specify at least one of -e, -x, -c, -r, or -a.")
         if not args_needed and args:
-            help("Unexpected arguments: %s" % " ".join(args))
+            help_fn("Unexpected arguments: %s" % " ".join(args))
         
-        self.get_ready(settings.get('parallel-mode'))
-        self.exclude('#pragma[: ]+[nN][oO] [cC][oO][vV][eE][rR]')
+        self.parallel_mode = settings.get('parallel-mode')
+        self.get_ready()
 
         if settings.get('erase'):
             self.erase()
         if settings.get('execute'):
             if not args:
-                help("Nothing to do.")
+                help_fn("Nothing to do.")
             sys.argv = args
             self.start()
             import __main__
@@ -387,13 +423,13 @@ class coverage:
     def get_ready(self, parallel_mode=False):
         if self.usecache and not self.cache:
             self.cache = os.environ.get(self.cache_env, self.cache_default)
-            if parallel_mode:
+            if self.parallel_mode:
                 self.cache += "." + gethostname() + "." + str(os.getpid())
             self.restore()
         self.analysis_cache = {}
         
     def start(self, parallel_mode=False):
-        self.get_ready(parallel_mode)
+        self.get_ready()
         if self.nesting == 0:                               #pragma: no cover
             sys.settrace(self.t)
             if hasattr(threading, 'settrace'):
@@ -408,12 +444,12 @@ class coverage:
                 threading.settrace(None)
 
     def erase(self):
+        self.get_ready()
         self.c = {}
         self.analysis_cache = {}
         self.cexecuted = {}
         if self.cache and os.path.exists(self.cache):
             os.remove(self.cache)
-        self.exclude_re = ""
 
     def exclude(self, re):
         if self.exclude_re:
@@ -464,11 +500,11 @@ class coverage:
 
     def collect(self):
         cache_dir, local = os.path.split(self.cache)
-        for file in os.listdir(cache_dir):
-            if not file.startswith(local):
+        for f in os.listdir(cache_dir or '.'):
+            if not f.startswith(local):
                 continue
 
-            full_path = os.path.join(cache_dir, file)
+            full_path = os.path.join(cache_dir, f)
             cexecuted = self.restore_file(full_path)
             self.merge_data(cexecuted)
 
@@ -508,6 +544,9 @@ class coverage:
 
     def canonicalize_filenames(self):
         for filename, lineno in self.c.keys():
+            if filename == '<string>':
+                # Can't do anything useful with exec'd strings, so skip them.
+                continue
             f = self.canonical_filename(filename)
             if not self.cexecuted.has_key(f):
                 self.cexecuted[f] = {}
@@ -520,17 +559,19 @@ class coverage:
         if isinstance(morf, types.ModuleType):
             if not hasattr(morf, '__file__'):
                 raise CoverageException, "Module has no __file__ attribute."
-            file = morf.__file__
+            f = morf.__file__
         else:
-            file = morf
-        return self.canonical_filename(file)
+            f = morf
+        return self.canonical_filename(f)
 
     # analyze_morf(morf).  Analyze the module or filename passed as
     # the argument.  If the source code can't be found, raise an error.
     # Otherwise, return a tuple of (1) the canonical filename of the
     # source code for the module, (2) a list of lines of statements
-    # in the source code, and (3) a list of lines of excluded statements.
-
+    # in the source code, (3) a list of lines of excluded statements,
+    # and (4), a map of line numbers to multi-line line number ranges, for
+    # statements that cross lines.
+    
     def analyze_morf(self, morf):
         if self.analysis_cache.has_key(morf):
             return self.analysis_cache[morf]
@@ -544,16 +585,53 @@ class coverage:
         elif ext != '.py':
             raise CoverageException, "File '%s' not Python source." % filename
         source = open(filename, 'r')
-        lines, excluded_lines = self.find_executable_statements(
+        lines, excluded_lines, line_map = self.find_executable_statements(
             source.read(), exclude=self.exclude_re
             )
         source.close()
-        result = filename, lines, excluded_lines
+        result = filename, lines, excluded_lines, line_map
         self.analysis_cache[morf] = result
         return result
 
+    def first_line_of_tree(self, tree):
+        while True:
+            if len(tree) == 3 and type(tree[2]) == type(1):
+                return tree[2]
+            tree = tree[1]
+    
+    def last_line_of_tree(self, tree):
+        while True:
+            if len(tree) == 3 and type(tree[2]) == type(1):
+                return tree[2]
+            tree = tree[-1]
+    
+    def find_docstring_pass_pair(self, tree, spots):
+        for i in range(1, len(tree)):
+            if self.is_string_constant(tree[i]) and self.is_pass_stmt(tree[i+1]):
+                first_line = self.first_line_of_tree(tree[i])
+                last_line = self.last_line_of_tree(tree[i+1])
+                self.record_multiline(spots, first_line, last_line)
+        
+    def is_string_constant(self, tree):
+        try:
+            return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.expr_stmt
+        except:
+            return False
+        
+    def is_pass_stmt(self, tree):
+        try:
+            return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.pass_stmt
+        except:
+            return False
+
+    def record_multiline(self, spots, i, j):
+        for l in range(i, j+1):
+            spots[l] = (i, j)
+            
     def get_suite_spots(self, tree, spots):
-        import symbol, token
+        """ Analyze a parse tree to find suite introducers which span a number
+            of lines.
+        """
         for i in range(1, len(tree)):
             if type(tree[i]) == type(()):
                 if tree[i][0] == symbol.suite:
@@ -561,7 +639,9 @@ class coverage:
                     lineno_colon = lineno_word = None
                     for j in range(i-1, 0, -1):
                         if tree[j][0] == token.COLON:
-                            lineno_colon = tree[j][2]
+                            # Colons are never executed themselves: we want the
+                            # line number of the last token before the colon.
+                            lineno_colon = self.last_line_of_tree(tree[j-1])
                         elif tree[j][0] == token.NAME:
                             if tree[j][1] == 'elif':
                                 # Find the line number of the first non-terminal
@@ -583,8 +663,18 @@ class coverage:
                     if lineno_colon and lineno_word:
                         # Found colon and keyword, mark all the lines
                         # between the two with the two line numbers.
-                        for l in range(lineno_word, lineno_colon+1):
-                            spots[l] = (lineno_word, lineno_colon)
+                        self.record_multiline(spots, lineno_word, lineno_colon)
+
+                    # "pass" statements are tricky: different versions of Python
+                    # treat them differently, especially in the common case of a
+                    # function with a doc string and a single pass statement.
+                    self.find_docstring_pass_pair(tree[i], spots)
+                    
+                elif tree[i][0] == symbol.simple_stmt:
+                    first_line = self.first_line_of_tree(tree[i])
+                    last_line = self.last_line_of_tree(tree[i])
+                    if first_line != last_line:
+                        self.record_multiline(spots, first_line, last_line)
                 self.get_suite_spots(tree[i], spots)
 
     def find_executable_statements(self, text, exclude=None):
@@ -598,10 +688,13 @@ class coverage:
                 if reExclude.search(lines[i]):
                     excluded[i+1] = 1
 
+        # Parse the code and analyze the parse tree to find out which statements
+        # are multiline, and where suites begin and end.
         import parser
         tree = parser.suite(text+'\n\n').totuple(1)
         self.get_suite_spots(tree, suite_spots)
-            
+        #print "Suite spots:", suite_spots
+        
         # Use the compiler module to parse the text and find the executable
         # statements.  We add newlines to be impervious to final partial lines.
         statements = {}
@@ -613,7 +706,7 @@ class coverage:
         lines.sort()
         excluded_lines = excluded.keys()
         excluded_lines.sort()
-        return lines, excluded_lines
+        return lines, excluded_lines, suite_spots
 
     # format_lines(statements, lines).  Format a list of line numbers
     # for printing by coalescing groups of lines as long as the lines
@@ -646,7 +739,8 @@ class coverage:
                 return "%d" % start
             else:
                 return "%d-%d" % (start, end)
-        return string.join(map(stringify, pairs), ", ")
+        ret = string.join(map(stringify, pairs), ", ")
+        return ret
 
     # Backward compatibility with version 1.
     def analysis(self, morf):
@@ -654,13 +748,17 @@ class coverage:
         return f, s, m, mf
 
     def analysis2(self, morf):
-        filename, statements, excluded = self.analyze_morf(morf)
+        filename, statements, excluded, line_map = self.analyze_morf(morf)
         self.canonicalize_filenames()
         if not self.cexecuted.has_key(filename):
             self.cexecuted[filename] = {}
         missing = []
         for line in statements:
-            if not self.cexecuted[filename].has_key(line):
+            lines = line_map.get(line, [line, line])
+            for l in range(lines[0], lines[1]+1):
+                if self.cexecuted[filename].has_key(l):
+                    break
+            else:
                 missing.append(line)
         return (filename, statements, excluded, missing,
                 self.format_lines(statements, missing))
@@ -698,6 +796,15 @@ class coverage:
     def report(self, morfs, show_missing=1, ignore_errors=0, file=None, omit_prefixes=[]):
         if not isinstance(morfs, types.ListType):
             morfs = [morfs]
+        # On windows, the shell doesn't expand wildcards.  Do it here.
+        globbed = []
+        for morf in morfs:
+            if isinstance(morf, strclass):
+                globbed.extend(glob.glob(morf))
+            else:
+                globbed.append(morf)
+        morfs = globbed
+        
         morfs = self.filter_by_prefix(morfs, omit_prefixes)
         morfs.sort(self.morf_name_compare)
 
@@ -735,8 +842,8 @@ class coverage:
                 raise
             except:
                 if not ignore_errors:
-                    type, msg = sys.exc_info()[0:2]
-                    print >>file, fmt_err % (name, type, msg)
+                    typ, msg = sys.exc_info()[0:2]
+                    print >>file, fmt_err % (name, typ, msg)
         if len(morfs) > 1:
             print >>file, "-" * len(header)
             if total_statements > 0:
@@ -816,18 +923,41 @@ class coverage:
 the_coverage = coverage()
 
 # Module functions call methods in the singleton object.
-def use_cache(*args, **kw): return the_coverage.use_cache(*args, **kw)
-def start(*args, **kw): return the_coverage.start(*args, **kw)
-def stop(*args, **kw): return the_coverage.stop(*args, **kw)
-def erase(*args, **kw): return the_coverage.erase(*args, **kw)
-def begin_recursive(*args, **kw): return the_coverage.begin_recursive(*args, **kw)
-def end_recursive(*args, **kw): return the_coverage.end_recursive(*args, **kw)
-def exclude(*args, **kw): return the_coverage.exclude(*args, **kw)
-def analysis(*args, **kw): return the_coverage.analysis(*args, **kw)
-def analysis2(*args, **kw): return the_coverage.analysis2(*args, **kw)
-def report(*args, **kw): return the_coverage.report(*args, **kw)
-def annotate(*args, **kw): return the_coverage.annotate(*args, **kw)
-def annotate_file(*args, **kw): return the_coverage.annotate_file(*args, **kw)
+def use_cache(*args, **kw): 
+    return the_coverage.use_cache(*args, **kw)
+
+def start(*args, **kw): 
+    return the_coverage.start(*args, **kw)
+
+def stop(*args, **kw): 
+    return the_coverage.stop(*args, **kw)
+
+def erase(*args, **kw): 
+    return the_coverage.erase(*args, **kw)
+
+def begin_recursive(*args, **kw): 
+    return the_coverage.begin_recursive(*args, **kw)
+
+def end_recursive(*args, **kw): 
+    return the_coverage.end_recursive(*args, **kw)
+
+def exclude(*args, **kw): 
+    return the_coverage.exclude(*args, **kw)
+
+def analysis(*args, **kw): 
+    return the_coverage.analysis(*args, **kw)
+
+def analysis2(*args, **kw): 
+    return the_coverage.analysis2(*args, **kw)
+
+def report(*args, **kw): 
+    return the_coverage.report(*args, **kw)
+
+def annotate(*args, **kw): 
+    return the_coverage.annotate(*args, **kw)
+
+def annotate_file(*args, **kw): 
+    return the_coverage.annotate_file(*args, **kw)
 
 # Save coverage data when Python exits.  (The atexit module wasn't
 # introduced until Python 2.0, so use sys.exitfunc when it's not
@@ -918,11 +1048,32 @@ if __name__ == '__main__':
 #
 # 2006-08-23 NMB Refactorings to improve testability.  Fixes to command-line
 # logic for parallel mode and collect.
+#
+# 2006-08-25 NMB "#pragma: nocover" is excluded by default.
+#
+# 2006-09-10 NMB Properly ignore docstrings and other constant expressions that
+# appear in the middle of a function, a problem reported by Tim Leslie.
+# Minor changes to avoid lint warnings.
+#
+# 2006-09-17 NMB coverage.erase() shouldn't clobber the exclude regex.
+# Change how parallel mode is invoked, and fix erase() so that it erases the
+# cache when called programmatically.
+#
+# 2007-07-21 NMB In reports, ignore code executed from strings, since we can't
+# do anything useful with it anyway.
+# Better file handling on Linux, thanks Guillaume Chazarain.
+# Better shell support on Windows, thanks Noel O'Boyle.
+# Python 2.2 support maintained, thanks Catherine Proulx.
+#
+# 2007-07-22 NMB Python 2.5 now fully supported. The method of dealing with
+# multi-line statements is now less sensitive to the exact line that Python
+# reports during execution. Pass statements are handled specially so that their
+# disappearance during execution won't throw off the measurement.
 
 # C. COPYRIGHT AND LICENCE
 #
 # Copyright 2001 Gareth Rees.  All rights reserved.
-# Copyright 2004-2006 Ned Batchelder.  All rights reserved.
+# Copyright 2004-2007 Ned Batchelder.  All rights reserved.
 #
 # Redistribution and use in source and binary forms, with or without
 # modification, are permitted provided that the following conditions are
@@ -949,4 +1100,5 @@ if __name__ == '__main__':
 # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
 # DAMAGE.
 #
-# $Id: coverage.py 47 2006-08-24 01:08:48Z Ned $
+# $Id: coverage.py 67 2007-07-21 19:51:07Z nedbat $
+
index dd967d9ec45e2b959b5c17ed2635661e68a11efd..da1726f6a6d757a5ca5a9b394a17c19c619984ec 100644 (file)
@@ -134,7 +134,7 @@ class testcase(testbase.PersistTest):
         objectstore.flush()
         objectstore.clear()
         
-        results = Person.select()
+        results = Person.query.select()
         
         self.assertEquals(len(results), 1)
         
@@ -149,24 +149,24 @@ class testcase(testbase.PersistTest):
         objectstore.flush()
         objectstore.clear()
         
-        person = Person.select()[0]
+        person = Person.query.select()[0]
         person.gender = 'F'
         objectstore.flush()
         objectstore.clear()
         self.assertEquals(person.row_version, 2)
 
-        person = Person.select()[0]
+        person = Person.query.select()[0]
         person.gender = 'M'
         objectstore.flush()
         objectstore.clear()
         self.assertEquals(person.row_version, 3)
 
         #TODO: check that a concurrent modification raises exception
-        p1 = Person.select()[0]
+        p1 = Person.query.select()[0]
         s1 = objectstore.session
         s2 = create_session()
         objectstore.context.current = s2
-        p2 = Person.select()[0]
+        p2 = Person.query.select()[0]
         p1.first_name = "jack"
         p2.first_name = "ed"
         objectstore.flush()
@@ -186,14 +186,14 @@ class testcase(testbase.PersistTest):
         objectstore.flush()
         objectstore.clear()
         
-        results = Person.select()
+        results = Person.query.select()
         self.assertEquals(len(results), 1)
         
         results[0].delete()
         objectstore.flush()
         objectstore.clear()
         
-        results = Person.select()
+        results = Person.query.select()
         self.assertEquals(len(results), 0)
     
     
@@ -205,7 +205,7 @@ class testcase(testbase.PersistTest):
         objectstore.clear()
         
         # select and make sure we get back two results
-        people = Person.select()
+        people = Person.query.select()
         self.assertEquals(len(people), 2)
                 
         # make sure that our backwards relationships work
@@ -213,7 +213,7 @@ class testcase(testbase.PersistTest):
         self.assertEquals(people[1].addresses[0].person.id, p2.id)
         
         # try a more complex select
-        results = Person.select(
+        results = Person.query.select(
             or_(
                 and_(
                     Address.c.person_id == Person.c.id,
@@ -254,12 +254,12 @@ class testcase(testbase.PersistTest):
         objectstore.flush()
         objectstore.clear()
         
-        results = Person.join('addresses').select(
+        results = Person.query.join('addresses').select(
             Address.c.postal_code.like('30075') 
         )
         self.assertEquals(len(results), 1)
 
-        self.assertEquals(Person.count(), 2)
+        self.assertEquals(Person.query.count(), 2)
 
 class testmanytomany(testbase.PersistTest):
      def setUpAll(self):
@@ -299,8 +299,8 @@ class testmanytomany(testbase.PersistTest):
          objectstore.flush()
          objectstore.clear()
 
-         foo1 = foo.get_by(name='foo1')
-         baz1 = baz.get_by(name='baz1')
+         foo1 = foo.query.get_by(name='foo1')
+         baz1 = baz.query.get_by(name='baz1')
          
          # Just checking ...
          assert (foo1.name == 'foo1')
@@ -341,15 +341,15 @@ class testselfreferential(testbase.PersistTest):
         objectstore.flush()
         objectstore.clear()
         
-        t = TreeNode.get_by(name='node1')
+        t = TreeNode.query.get_by(name='node1')
         assert (t.name == 'node1')
         assert (t.children[0].name == 'node2')
         assert (t.children[1].name == 'node3')
         assert (t.children[1].parent is t)
 
         objectstore.clear()
-        t = TreeNode.get_by(name='node3')
-        assert (t.parent is TreeNode.get_by(name='node1'))
+        t = TreeNode.query.get_by(name='node3')
+        assert (t.parent is TreeNode.query.get_by(name='node1'))
         
 if __name__ == '__main__':
     testbase.main()
index b59b1dfedd30b6ae19351e7089f9370962c444a4..c11ca1457ebc7a93c58de3a01f0be59a244bd5d2 100644 (file)
@@ -3,12 +3,10 @@ instruments SQLAlchemy dialect/engine to track SQL statements for assertion purp
 provides base test classes for common test scenarios."""
 
 import sys
+import coverage
 
 import os, unittest, StringIO, re, ConfigParser, optparse
 sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
-import sqlalchemy
-from sqlalchemy import sql, schema, engine, pool, MetaData
-from sqlalchemy.orm import clear_mappers
 
 db = None
 metadata = None
@@ -125,7 +123,12 @@ firebird=firebird://sysdba:s@localhost/tmp/test.fdb
     
     global with_coverage
     with_coverage = options.coverage
+    if with_coverage:
+        coverage.erase()
+        coverage.start()
 
+    from sqlalchemy import engine, schema
+    
     if options.serverside:
         opts['server_side_cursors'] = True
     
@@ -163,7 +166,7 @@ firebird=firebird://sysdba:s@localhost/tmp/test.fdb
     if options.log_debug is not None:
         for elem in options.log_debug:
             logging.getLogger(elem).setLevel(logging.DEBUG)
-    metadata = sqlalchemy.MetaData(db)
+    metadata = schema.MetaData(db)
     
 def unsupported(*dbs):
     """a decorator that marks a test as unsupported by one or more database implementations"""
@@ -465,14 +468,21 @@ unittest.TestLoader.suiteClass = TTestSuite
 
 parse_argv()
 
+import sqlalchemy
+from sqlalchemy import schema, MetaData, sql
+from sqlalchemy.orm import clear_mappers
                     
 def runTests(suite):
     sys.stdout = Logger()    
     runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
-    if with_coverage:
-        return cover(lambda:runner.run(suite))
-    else:
+    try:
         return runner.run(suite)
+    finally:
+        if with_coverage:
+            global echo
+            echo=True
+            coverage.stop()
+            coverage.report(list(covered_files()), show_missing=False)
 
 def covered_files():
     for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):                          
@@ -480,22 +490,6 @@ def covered_files():
             if x.endswith('.py'):
                 yield os.path.join(rec[0], x)
 
-def cover(callable_):
-    import coverage
-    coverage_client = coverage.the_coverage
-    coverage_client.get_ready()
-    coverage_client.exclude('#pragma[: ]+[nN][oO] [cC][oO][vV][eE][rR]')
-    coverage_client.erase()
-    coverage_client.start()
-    try:
-        return callable_()
-    finally:
-        global echo
-        echo=True
-        coverage_client.stop()
-        coverage_client.save()
-        coverage_client.report(list(covered_files()), show_missing=False, ignore_errors=False)
-
 def main(suite=None):
     
     if not suite:
@@ -508,4 +502,3 @@ def main(suite=None):
     sys.exit(not result.wasSuccessful())
 
 
-