]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- unit tests have been migrated from unittest to nose.
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 Jun 2009 21:18:24 +0000 (21:18 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 Jun 2009 21:18:24 +0000 (21:18 +0000)
See README.unittests for information on how to run
the tests.  [ticket:970]

134 files changed:
CHANGES
README.unittests
convert.py [new file with mode: 0644]
lib/sqlalchemy/test/__init__.py [new file with mode: 0644]
lib/sqlalchemy/test/assertsql.py [moved from test/testlib/assertsql.py with 100% similarity]
lib/sqlalchemy/test/config.py [new file with mode: 0644]
lib/sqlalchemy/test/engines.py [moved from test/testlib/engines.py with 96% similarity]
lib/sqlalchemy/test/noseplugin.py [new file with mode: 0644]
lib/sqlalchemy/test/orm.py [moved from test/testlib/orm.py with 96% similarity]
lib/sqlalchemy/test/pickleable.py [moved from test/pickleable.py with 90% similarity]
lib/sqlalchemy/test/profiling.py [moved from test/testlib/profiling.py with 89% similarity]
lib/sqlalchemy/test/requires.py [moved from test/testlib/requires.py with 99% similarity]
lib/sqlalchemy/test/schema.py [moved from test/testlib/schema.py with 92% similarity]
lib/sqlalchemy/test/testing.py [moved from test/testlib/testing.py with 71% similarity]
setup.cfg
setup.py
test/aaa_profiling/__init__.py [moved from test/profiling/__init__.py with 100% similarity]
test/aaa_profiling/test_compiler.py [moved from test/profiling/compiler.py with 84% similarity]
test/aaa_profiling/test_memusage.py [moved from test/profiling/memusage.py with 96% similarity]
test/aaa_profiling/test_pool.py [moved from test/profiling/pool.py with 87% similarity]
test/aaa_profiling/test_zoomark.py [moved from test/profiling/zoomark.py with 99% similarity]
test/aaa_profiling/test_zoomark_orm.py [moved from test/profiling/zoomark_orm.py with 99% similarity]
test/alltests.py [deleted file]
test/base/alltests.py [deleted file]
test/base/test_dependency.py [moved from test/base/dependency.py with 97% similarity]
test/base/test_except.py [moved from test/base/except.py with 96% similarity]
test/base/test_utils.py [moved from test/base/utils.py with 94% similarity]
test/dialect/alltests.py [deleted file]
test/dialect/test_access.py [moved from test/dialect/access.py with 86% similarity]
test/dialect/test_firebird.py [moved from test/dialect/firebird.py with 84% similarity]
test/dialect/test_informix.py [moved from test/dialect/informix.py with 88% similarity]
test/dialect/test_maxdb.py [moved from test/dialect/maxdb.py with 96% similarity]
test/dialect/test_mssql.py [moved from test/dialect/mssql.py with 92% similarity, mode: 0644]
test/dialect/test_mysql.py [moved from test/dialect/mysql.py with 97% similarity]
test/dialect/test_oracle.py [moved from test/dialect/oracle.py with 97% similarity]
test/dialect/test_postgres.py [moved from test/dialect/postgres.py with 88% similarity]
test/dialect/test_sqlite.py [moved from test/dialect/sqlite.py with 93% similarity]
test/dialect/test_sybase.py [moved from test/dialect/sybase.py with 85% similarity]
test/engine/_base.py
test/engine/alltests.py [deleted file]
test/engine/test_bind.py [moved from test/engine/bind.py with 96% similarity]
test/engine/test_ddlevents.py [moved from test/engine/ddlevents.py with 92% similarity]
test/engine/test_execute.py [moved from test/engine/execute.py with 94% similarity]
test/engine/test_metadata.py [moved from test/engine/metadata.py with 93% similarity]
test/engine/test_parseconnect.py [moved from test/engine/parseconnect.py with 98% similarity]
test/engine/test_pool.py [moved from test/engine/pool.py with 99% similarity]
test/engine/test_reconnect.py [moved from test/engine/reconnect.py with 88% similarity]
test/engine/test_reflection.py [moved from test/engine/reflection.py with 96% similarity]
test/engine/test_transaction.py [moved from test/engine/transaction.py with 96% similarity]
test/ext/alltests.py [deleted file]
test/ext/test_associationproxy.py [moved from test/ext/associationproxy.py with 97% similarity]
test/ext/test_compiler.py [moved from test/ext/compiler.py with 97% similarity]
test/ext/test_declarative.py [moved from test/ext/declarative.py with 96% similarity]
test/ext/test_orderinglist.py [moved from test/ext/orderinglist.py with 98% similarity]
test/ext/test_serializer.py [moved from test/ext/serializer.py with 89% similarity]
test/orm/_base.py
test/orm/_fixtures.py
test/orm/alltests.py [deleted file]
test/orm/inheritance/alltests.py [deleted file]
test/orm/inheritance/test_abc_inheritance.py [moved from test/orm/inheritance/abc_inheritance.py with 95% similarity]
test/orm/inheritance/test_abc_polymorphic.py [moved from test/orm/inheritance/abc_polymorphic.py with 92% similarity]
test/orm/inheritance/test_basic.py [moved from test/orm/inheritance/basic.py with 95% similarity]
test/orm/inheritance/test_concrete.py [moved from test/orm/inheritance/concrete.py with 95% similarity]
test/orm/inheritance/test_magazine.py [moved from test/orm/inheritance/magazine.py with 97% similarity]
test/orm/inheritance/test_manytomany.py [moved from test/orm/inheritance/manytomany.py with 96% similarity]
test/orm/inheritance/test_poly_linked_list.py [moved from test/orm/inheritance/poly_linked_list.py with 93% similarity]
test/orm/inheritance/test_polymorph.py [moved from test/orm/inheritance/polymorph.py with 90% similarity]
test/orm/inheritance/test_polymorph2.py [moved from test/orm/inheritance/polymorph2.py with 96% similarity]
test/orm/inheritance/test_productspec.py [moved from test/orm/inheritance/productspec.py with 98% similarity]
test/orm/inheritance/test_query.py [moved from test/orm/inheritance/query.py with 79% similarity]
test/orm/inheritance/test_selects.py [moved from test/orm/inheritance/selects.py with 88% similarity]
test/orm/inheritance/test_single.py [moved from test/orm/inheritance/single.py with 85% similarity]
test/orm/sharding/alltests.py [deleted file]
test/orm/sharding/test_shard.py [moved from test/orm/sharding/shard.py with 94% similarity]
test/orm/test_association.py [moved from test/orm/association.py with 89% similarity]
test/orm/test_assorted_eager.py [moved from test/orm/assorted_eager.py with 93% similarity]
test/orm/test_attributes.py [moved from test/orm/attributes.py with 99% similarity]
test/orm/test_bind.py [moved from test/orm/bind.py with 71% similarity]
test/orm/test_cascade.py [moved from test/orm/cascade.py with 95% similarity]
test/orm/test_collection.py [moved from test/orm/collection.py with 98% similarity]
test/orm/test_compile.py [moved from test/orm/compile.py with 97% similarity]
test/orm/test_cycles.py [moved from test/orm/cycles.py with 94% similarity]
test/orm/test_defaults.py [moved from test/orm/defaults.py with 88% similarity]
test/orm/test_deprecations.py [moved from test/orm/deprecations.py with 96% similarity]
test/orm/test_dynamic.py [moved from test/orm/dynamic.py with 96% similarity]
test/orm/test_eager_relations.py [moved from test/orm/eager_relations.py with 97% similarity]
test/orm/test_evaluator.py [moved from test/orm/evaluator.py with 86% similarity]
test/orm/test_expire.py [moved from test/orm/expire.py with 95% similarity]
test/orm/test_extendedattr.py [moved from test/orm/extendedattr.py with 87% similarity]
test/orm/test_generative.py [moved from test/orm/generative.py with 91% similarity]
test/orm/test_instrumentation.py [moved from test/orm/instrumentation.py with 94% similarity]
test/orm/test_lazy_relations.py [moved from test/orm/lazy_relations.py with 96% similarity]
test/orm/test_lazytest1.py [moved from test/orm/lazytest1.py with 88% similarity]
test/orm/test_manytomany.py [moved from test/orm/manytomany.py with 93% similarity]
test/orm/test_mapper.py [moved from test/orm/mapper.py with 97% similarity]
test/orm/test_merge.py [moved from test/orm/merge.py with 98% similarity]
test/orm/test_naturalpks.py [moved from test/orm/naturalpks.py with 84% similarity]
test/orm/test_onetoone.py [moved from test/orm/onetoone.py with 81% similarity]
test/orm/test_pickled.py [moved from test/orm/pickled.py with 80% similarity]
test/orm/test_query.py [moved from test/orm/query.py with 88% similarity]
test/orm/test_relationships.py [moved from test/orm/relationships.py with 94% similarity]
test/orm/test_scoping.py [moved from test/orm/scoping.py with 83% similarity]
test/orm/test_selectable.py [moved from test/orm/selectable.py with 71% similarity]
test/orm/test_session.py [moved from test/orm/session.py with 93% similarity]
test/orm/test_transaction.py [moved from test/orm/transaction.py with 86% similarity]
test/orm/test_unitofwork.py [moved from test/orm/unitofwork.py with 95% similarity]
test/orm/test_utils.py [moved from test/orm/utils.py with 95% similarity]
test/profiling/alltests.py [deleted file]
test/sql/_base.py
test/sql/alltests.py [deleted file]
test/sql/test_case_statement.py [moved from test/sql/case_statement.py with 94% similarity]
test/sql/test_columns.py [moved from test/sql/columns.py with 79% similarity]
test/sql/test_constraints.py [moved from test/sql/constraints.py with 94% similarity]
test/sql/test_defaults.py [moved from test/sql/defaults.py with 96% similarity]
test/sql/test_functions.py [moved from test/sql/functions.py with 97% similarity]
test/sql/test_generative.py [moved from test/sql/generative.py with 99% similarity]
test/sql/test_labels.py [moved from test/sql/labels.py with 95% similarity]
test/sql/test_query.py [moved from test/sql/query.py with 94% similarity]
test/sql/test_quote.py [moved from test/sql/quote.py with 97% similarity]
test/sql/test_rowcount.py [moved from test/sql/rowcount.py with 91% similarity]
test/sql/test_select.py [moved from test/sql/select.py with 98% similarity]
test/sql/test_selectable.py [moved from test/sql/selectable.py with 98% similarity, mode: 0644]
test/sql/test_types.py [moved from test/sql/testtypes.py with 94% similarity]
test/sql/test_unicode.py [moved from test/sql/unicode.py with 95% similarity]
test/testenv.py [deleted file]
test/testlib/__init__.py [deleted file]
test/testlib/compat.py [deleted file]
test/testlib/config.py [deleted file]
test/testlib/coverage.py [deleted file]
test/testlib/sa_unittest.py [deleted file]
test/zblog/alltests.py [deleted file]
test/zblog/mappers.py
test/zblog/tables.py
test/zblog/test_zblog.py [moved from test/zblog/tests.py with 79% similarity]

diff --git a/CHANGES b/CHANGES
index 0653bd68dae9cb27d2f9afa746391e92ee49a03c..9a4596c72fcae0e7528358c8085c02735db1b1f8 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,6 +6,10 @@ CHANGES
 
 0.5.5
 =======
+- general
+    - unit tests have been migrated from unittest to nose.
+      See README.unittests for information on how to run
+      the tests.  [ticket:970]
 - orm
     - Fixed bug introduced in 0.5.4 whereby Composite types
       fail when default-holding columns are flushed.
index f70f6ab17751a02b0b2dc9fc6eb403ef1e8315e1..bfc31e28fbc7231a5bf29aee8d127943d3c5d987 100644 (file)
@@ -2,65 +2,63 @@
 SQLALCHEMY UNIT TESTS
 =====================
 
-SETUP
------
-SQLite support is required.  These instructions assume standard Python 2.4 or
-higher. 
-
-The 'test' directory must be on the PYTHONPATH.
+SQLAlchemy unit tests by default run using Python's built-in sqlite3 
+module.  If running on Python 2.4, pysqlite must be installed.
 
-cd into the SQLAlchemy distribution directory
+As of 0.5.5, unit tests are run using nose.  Documentation and
+downloads for nose are available at:
 
-In bash:
+http://somethingaboutorange.com/mrl/projects/nose/0.11.1/index.html
 
-    $ export PYTHONPATH=./test/
 
-On windows:
+SQLAlchemy implements a nose plugin that must be present when tests are run.
+This plugin is available when SQLAlchemy is installed via setuptools.
 
-    C:\sa\> set PYTHONPATH=test\
+SETUP
+-----
 
-    Adjust any other use Unix-style paths in this README as needed.
+All that's required is for SQLAlchemy to be installed via setuptools.
+For example, to create a local install in a source distribution directory:
 
-The unittest framework will automatically prepend the lib/ directory to
-sys.path.  This forces the local version of SQLAlchemy to be used, bypassing
-any setuptools-installed installations (setuptools places .egg files ahead of
-plain directories, even if on PYTHONPATH, unfortunately).
+    $ export PYTHONPATH=.
+    $ python setup.py develop -d .
 
+The above will create a setuptools "development" distribution in the local
+path, which allows the Nose plugin to be available when nosetests is run.
+The plugin is enabled using the "with-sqlalchemy=True" configuration
+in setup.cfg.
 
 RUNNING ALL TESTS
 -----------------
 To run all tests:
 
-    $ python test/alltests.py
+    $ nosetests
+
+Assuming all tests pass, this is a very unexciting output.  To make it more 
+intersesting:
 
+    $ nosetests -v
 
 RUNNING INDIVIDUAL TESTS
 -------------------------
-Any unittest module can be run directly from the module file:
+Any test module can be run directly by specifying its module name:
 
-    python test/orm/mapper.py
+    $ nosetests test.orm.test_mapper
 
-To run a specific test within the module, specify it as ClassName.methodname:
+To run a specific test within the module, specify it as module:ClassName.methodname:
 
-    python test/orm/mapper.py MapperTest.testget
+    $ nosetests test.orm.test_mapper:MapperTest.test_utils
 
 
 COMMAND LINE OPTIONS
 --------------------
-Help is available via --help
+Help is available via --help:
 
-    $ python test/alltests.py --help
+    $ nosetests --help
 
-    usage: alltests.py [options] [tests...]
-
-    Options:
-      -h, --help            show this help message and exit
-      --verbose             enable stdout echoing/printing
-      --quiet               suppress output
-    [...]
-
-Command line options can applied to alltests.py or any individual test module.
-Many are available.  The most commonly used are '--db' and '--dburi'.
+The --help screen is a combination of common nose options and options which 
+the SQLAlchemy nose plugin adds.  The most commonly SQLAlchemy-specific 
+options used are '--db' and '--dburi'.
 
 
 DATABASE TARGETS
@@ -78,7 +76,7 @@ preexisting tables will interfere with the tests
 If you'll be running the tests frequently, database aliases can save a lot of
 typing.  The --dbs option lists the built-in aliases and their matching URLs:
 
-    $ python test/alltests.py --dbs
+    $ nosetests --dbs
     Available --db options (use --dburi to override)
                mysql    mysql://scott:tiger@127.0.0.1:3306/test
               oracle    oracle://scott:tiger@127.0.0.1:1521
@@ -87,7 +85,7 @@ typing.  The --dbs option lists the built-in aliases and their matching URLs:
 
 To run tests against an aliased database:
 
-    $ python test/alltests.py --db=postgres
+    $ nosetests --db=postgres
 
 To customize the URLs with your own users or hostnames, make a simple .ini
 file called `test.cfg` at the top level of the SQLAlchemy source distribution
@@ -106,7 +104,7 @@ SQLAlchemy logs its activity and debugging through Python's logging package.
 Any log target can be directed to the console with command line options, such
 as:
 
-    $ python test/orm/unitofwork.py --log-info=sqlalchemy.orm.mapper \
+    $ nosetests test.orm.unitofwork --log-info=sqlalchemy.orm.mapper \
       --log-debug=sqlalchemy.pool --log-info=sqlalchemy.engine
 
 This would log mapper configuration, connection pool checkouts, and SQL
@@ -115,21 +113,10 @@ statement execution.
 
 BUILT-IN COVERAGE REPORTING
 ------------------------------
-Coverage is tracked with coverage.py module, included in the './test/'
-directory.  Running the test suite with the --coverage switch will generate a
-local file ".coverage" containing coverage details, and a report will be
-printed to standard output with an overview of the coverage gathered from the
-last unittest run (the file is deleted between runs).
-
-After the suite has been run with --coverage, an annotated version of any
-source file can be generated, marking statements that are executed with > and
-statements that are missed with !, by running the coverage.py utility with the
-"-a" (annotate) option, such as:
-
-    $ python ./test/testlib/coverage.py -a ./lib/sqlalchemy/sql.py
+Coverage is tracked using Nose's coverage plugin.   See the nose 
+documentation for details.  Basic usage is:
 
-This will create a new annotated file ./lib/sqlalchemy/sql.py,cover. Pretty
-cool!
+    $ nosetests test.sql.test_query --with-coverage
 
 BIG COVERAGE TIP !!!  There is an issue where existing .pyc files may
 store the incorrect filepaths, which will break the coverage system.  If
diff --git a/convert.py b/convert.py
new file mode 100644 (file)
index 0000000..b574c27
--- /dev/null
@@ -0,0 +1,228 @@
+import os
+import subprocess
+import re
+
+def walk():
+    for root, dirs, files in os.walk("./test/"):
+        if root.endswith("/perf"):
+            continue
+        
+        for fname in files:
+            if not fname.endswith(".py"):
+                continue
+            if fname == "alltests.py":
+                subprocess.call(["svn", "remove", os.path.join(root, fname)])
+            elif fname.startswith("_") or fname == "__init__.py" or fname == "pickleable.py":
+                convert(os.path.join(root, fname))
+            elif not fname.startswith("test_"):
+                if os.path.exists(os.path.join(root, "test_" + fname)):
+                    os.unlink(os.path.join(root, "test_" + fname))
+                subprocess.call(["svn", "rename", os.path.join(root, fname), os.path.join(root, "test_" + fname)])
+                convert(os.path.join(root, "test_" + fname))
+
+
+def convert(fname):
+    lines = list(file(fname))
+    replaced = []
+    flags = {}
+    
+    while lines:
+        for reg, handler in handlers:
+            m = reg.match(lines[0])
+            if m:
+                handler(lines, replaced, flags)
+                break
+    
+    post_handler(lines, replaced, flags)
+    f = file(fname, 'w')
+    f.write("".join(replaced))
+    f.close()
+
+handlers = []
+
+
+def post_handler(lines, replaced, flags):
+    imports = []
+    if "needs_eq" in flags:
+        imports.append("eq_")
+    if "needs_assert_raises" in flags:
+        imports += ["assert_raises", "assert_raises_message"]
+    if imports:
+        for i, line in enumerate(replaced):
+            if "import" in line:
+                replaced.insert(i, "from sqlalchemy.test.testing import %s\n" % ", ".join(imports))
+                break
+    
+def remove_line(lines, replaced, flags):
+    lines.pop(0)
+    
+handlers.append((re.compile(r"import testenv; testenv\.configure_for_tests"), remove_line))
+handlers.append((re.compile(r"(.*\s)?import sa_unittest"), remove_line))
+
+
+def import_testlib_sa(lines, replaced, flags):
+    line = lines.pop(0)
+    line = line.replace("import testlib.sa", "import sqlalchemy")
+    replaced.append(line)
+handlers.append((re.compile("import testlib\.sa"), import_testlib_sa))
+
+def from_testlib_sa(lines, replaced, flags):
+    line = lines.pop(0)
+    while True:
+        if line.endswith("\\\n"):
+            line = line[0:-2] + lines.pop(0)
+        else:
+            break
+    
+    components = re.compile(r'from testlib\.sa import (.*)').match(line)
+    if components:
+        components = re.split(r"\s*,\s*", components.group(1))
+        line = "from sqlalchemy import %s\n" % (", ".join(c for c in components if c not in ("Table", "Column")))
+        replaced.append(line)
+        if "Table" in components:
+            replaced.append("from sqlalchemy.test.schema import Table\n")
+        if "Column" in components:
+            replaced.append("from sqlalchemy.test.schema import Column\n")
+        return
+        
+    line = line.replace("testlib.sa", "sqlalchemy")
+    replaced.append(line)
+handlers.append((re.compile("from testlib\.sa.*import"), from_testlib_sa))
+
+def from_testlib(lines, replaced, flags):
+    line = lines.pop(0)
+    
+    components = re.compile(r'from testlib import (.*)').match(line)
+    if components:
+        components = re.split(r"\s*,\s*", components.group(1))
+        if "sa" in components:
+            replaced.append("import sqlalchemy as sa\n")
+            replaced.append("from sqlalchemy.test import %s\n" % (", ".join(c for c in components if c != "sa" and c != "sa as tsa")))
+            return
+        elif "sa as tsa" in components:
+            replaced.append("import sqlalchemy as tsa\n")
+            replaced.append("from sqlalchemy.test import %s\n" % (", ".join(c for c in components if c != "sa" and c != "sa as tsa")))
+            return
+    
+    line = line.replace("testlib", "sqlalchemy.test")
+    replaced.append(line)
+handlers.append((re.compile(r"from testlib"), from_testlib))
+
+def from_orm(lines, replaced, flags):
+    line = lines.pop(0)
+    line = line.replace("from orm import", "from test.orm import")
+    line = line.replace("from orm.", "from test.orm.")
+    replaced.append(line)
+handlers.append((re.compile(r'from orm( import|\.)'), from_orm))
+    
+def assert_equals(lines, replaced, flags):
+    line = lines.pop(0)
+    line = line.replace("self.assertEquals", "eq_")
+    line = line.replace("self.assertEqual", "eq_")
+    replaced.append(line)
+    flags["needs_eq"] = True
+handlers.append((re.compile(r"\s*self\.assertEqual(s)?"), assert_equals))
+
+def assert_raises(lines, replaced, flags):
+    line = lines.pop(0)
+    line = line.replace("self.assertRaisesMessage", "assert_raises_message")
+    line = line.replace("self.assertRaises", "assert_raises")
+    replaced.append(line)
+    flags["needs_assert_raises"] = True
+handlers.append((re.compile(r"\s*self\.assertRaises(Message)?"), assert_raises))
+
+def setup_all(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def setUpAll\(self\)\:").match(line).group(1)
+    replaced.append("%s@classmethod\n" % whitespace)
+    replaced.append("%sdef setup_class(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setUpAll\(self\)"), setup_all))
+
+def teardown_all(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def tearDownAll\(self\)\:").match(line).group(1)
+    replaced.append("%s@classmethod\n" % whitespace)
+    replaced.append("%sdef teardown_class(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def tearDownAll\(self\)"), teardown_all))
+
+def setup(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def setUp\(self\)\:").match(line).group(1)
+    replaced.append("%sdef setup(self):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setUp\(self\)"), setup))
+
+def teardown(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def tearDown\(self\)\:").match(line).group(1)
+    replaced.append("%sdef teardown(self):\n" % whitespace)
+handlers.append((re.compile(r"\s*def tearDown\(self\)"), teardown))
+    
+def define_tables(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def define_tables").match(line).group(1)
+    replaced.append("%s@classmethod\n" % whitespace)
+    replaced.append("%sdef define_tables(cls, metadata):\n" % whitespace)
+handlers.append((re.compile(r"\s*def define_tables\(self, metadata\)"), define_tables))
+
+def setup_mappers(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def setup_mappers").match(line).group(1)
+    
+    i = -1
+    while re.match("\s*@testing", replaced[i]):
+        i -= 1
+        
+    replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+    replaced.append("%sdef setup_mappers(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setup_mappers\(self\)"), setup_mappers))
+
+def setup_classes(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def setup_classes").match(line).group(1)
+    
+    i = -1
+    while re.match("\s*@testing", replaced[i]):
+        i -= 1
+        
+    replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+    replaced.append("%sdef setup_classes(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setup_classes\(self\)"), setup_classes))
+
+def insert_data(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def insert_data").match(line).group(1)
+    
+    i = -1
+    while re.match("\s*@testing", replaced[i]):
+        i -= 1
+        
+    replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+    replaced.append("%sdef insert_data(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def insert_data\(self\)"), insert_data))
+
+def fixtures(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def fixtures").match(line).group(1)
+    
+    i = -1
+    while re.match("\s*@testing", replaced[i]):
+        i -= 1
+        
+    replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+    replaced.append("%sdef fixtures(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def fixtures\(self\)"), fixtures))
+    
+    
+def call_main(lines, replaced, flags):
+    replaced.pop(-1)
+    lines.pop(0)
+handlers.append((re.compile(r"\s+testenv\.main\(\)"), call_main))
+
+def default(lines, replaced, flags):
+    replaced.append(lines.pop(0))
+handlers.append((re.compile(r".*"), default))
+
+
+if __name__ == '__main__':
+    convert("test/orm/inheritance/abc_inheritance.py")
+#    walk()
diff --git a/lib/sqlalchemy/test/__init__.py b/lib/sqlalchemy/test/__init__.py
new file mode 100644 (file)
index 0000000..d69cede
--- /dev/null
@@ -0,0 +1,26 @@
+"""Testing environment and utilities.
+
+This package contains base classes and routines used by 
+the unit tests.   Tests are based on Nose and bootstrapped
+by noseplugin.NoseSQLAlchemy.
+
+"""
+
+from sqlalchemy.test import testing, engines, requires, profiling, pickleable, config
+from sqlalchemy.test.schema import Column, Table
+from sqlalchemy.test.testing import \
+     AssertsCompiledSQL, \
+     AssertsExecutionResults, \
+     ComparesTables, \
+     TestBase, \
+     rowset
+
+
+__all__ = ('testing',
+            'Column', 'Table',
+           'rowset',
+           'TestBase', 'AssertsExecutionResults',
+           'AssertsCompiledSQL', 'ComparesTables',
+           'engines', 'profiling', 'pickleable')
+
+
diff --git a/lib/sqlalchemy/test/config.py b/lib/sqlalchemy/test/config.py
new file mode 100644 (file)
index 0000000..6ea5667
--- /dev/null
@@ -0,0 +1,177 @@
+import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
+logging = None
+
+__all__ = 'parser', 'configure', 'options',
+
+db = None
+db_label, db_url, db_opts = None, None, {}
+
+options = None
+file_config = None
+
+base_config = """
+[db]
+sqlite=sqlite:///:memory:
+sqlite_file=sqlite:///querytest.db
+postgres=postgres://scott:tiger@127.0.0.1:5432/test
+mysql=mysql://scott:tiger@127.0.0.1:3306/test
+oracle=oracle://scott:tiger@127.0.0.1:1521
+oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
+mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
+firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb
+maxdb=maxdb://MONA:RED@/maxdb1
+"""
+
+def _log(option, opt_str, value, parser):
+    global logging
+    if not logging:
+        import logging
+        logging.basicConfig()
+
+    if opt_str.endswith('-info'):
+        logging.getLogger(value).setLevel(logging.INFO)
+    elif opt_str.endswith('-debug'):
+        logging.getLogger(value).setLevel(logging.DEBUG)
+
+
+def _list_dbs(*args):
+    print "Available --db options (use --dburi to override)"
+    for macro in sorted(file_config.options('db')):
+        print "%20s\t%s" % (macro, file_config.get('db', macro))
+    sys.exit(0)
+
+def _server_side_cursors(options, opt_str, value, parser):
+    db_opts['server_side_cursors'] = True
+
+def _engine_strategy(options, opt_str, value, parser):
+    if value:
+        db_opts['strategy'] = value
+
+class _ordered_map(object):
+    def __init__(self):
+        self._keys = list()
+        self._data = dict()
+
+    def __setitem__(self, key, value):
+        if key not in self._keys:
+            self._keys.append(key)
+        self._data[key] = value
+
+    def __iter__(self):
+        for key in self._keys:
+            yield self._data[key]
+
+# at one point in refactoring, modules were injecting into the config
+# process.  this could probably just become a list now.
+post_configure = _ordered_map()
+
+def _engine_uri(options, file_config):
+    global db_label, db_url
+    db_label = 'sqlite'
+    if options.dburi:
+        db_url = options.dburi
+        db_label = db_url[:db_url.index(':')]
+    elif options.db:
+        db_label = options.db
+        db_url = None
+
+    if db_url is None:
+        if db_label not in file_config.options('db'):
+            raise RuntimeError(
+                "Unknown engine.  Specify --dbs for known engines.")
+        db_url = file_config.get('db', db_label)
+post_configure['engine_uri'] = _engine_uri
+
+def _require(options, file_config):
+    if not(options.require or
+           (file_config.has_section('require') and
+            file_config.items('require'))):
+        return
+
+    try:
+        import pkg_resources
+    except ImportError:
+        raise RuntimeError("setuptools is required for version requirements")
+
+    cmdline = []
+    for requirement in options.require:
+        pkg_resources.require(requirement)
+        cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
+
+    if file_config.has_section('require'):
+        for label, requirement in file_config.items('require'):
+            if not label == db_label or label.startswith('%s.' % db_label):
+                continue
+            seen = [c for c in cmdline if requirement.startswith(c)]
+            if seen:
+                continue
+            pkg_resources.require(requirement)
+post_configure['require'] = _require
+
+def _engine_pool(options, file_config):
+    if options.mockpool:
+        from sqlalchemy import pool
+        db_opts['poolclass'] = pool.AssertionPool
+post_configure['engine_pool'] = _engine_pool
+
+def _create_testing_engine(options, file_config):
+    from sqlalchemy.test import engines, testing
+    global db
+    db = engines.testing_engine(db_url, db_opts)
+    testing.db = db
+post_configure['create_engine'] = _create_testing_engine
+
+def _prep_testing_database(options, file_config):
+    from sqlalchemy.test import engines
+    from sqlalchemy import schema
+
+    try:
+        # also create alt schemas etc. here?
+        if options.dropfirst:
+            e = engines.utf8_engine()
+            existing = e.table_names()
+            if existing:
+                print "Dropping existing tables in database: " + db_url
+                try:
+                    print "Tables: %s" % ', '.join(existing)
+                except:
+                    pass
+                print "Abort within 5 seconds..."
+                time.sleep(5)
+                md = schema.MetaData(e, reflect=True)
+                md.drop_all()
+            e.dispose()
+    except (KeyboardInterrupt, SystemExit):
+        raise
+    except Exception, e:
+        warnings.warn(RuntimeWarning(
+            "Error checking for existing tables in testing "
+            "database: %s" % e))
+post_configure['prep_db'] = _prep_testing_database
+
+def _set_table_options(options, file_config):
+    from sqlalchemy.test import schema
+
+    table_options = schema.table_options
+    for spec in options.tableopts:
+        key, value = spec.split('=')
+        table_options[key] = value
+
+    if options.mysql_engine:
+        table_options['mysql_engine'] = options.mysql_engine
+post_configure['table_options'] = _set_table_options
+
+def _reverse_topological(options, file_config):
+    if options.reversetop:
+        from sqlalchemy.orm import unitofwork
+        from sqlalchemy import topological
+        class RevQueueDepSort(topological.QueueDependencySorter):
+            def __init__(self, tuples, allitems):
+                self.tuples = list(tuples)
+                self.allitems = list(allitems)
+                self.tuples.reverse()
+                self.allitems.reverse()
+        topological.QueueDependencySorter = RevQueueDepSort
+        unitofwork.DependencySorter = RevQueueDepSort
+post_configure['topological'] = _reverse_topological
+
similarity index 96%
rename from test/testlib/engines.py
rename to lib/sqlalchemy/test/engines.py
index 4068f43d0a9f5a877a196acb9cc05e9baf8984d9..f0001978bf4214428fe592fc2d773bf2435ff833 100644 (file)
@@ -1,7 +1,7 @@
 import sys, types, weakref
 from collections import deque
-from testlib import config
-from testlib.compat import _function_named, callable
+import config
+from sqlalchemy.util import function_named, callable
 
 class ConnectionKiller(object):
     def __init__(self):
@@ -44,7 +44,7 @@ def assert_conns_closed(fn):
             fn(*args, **kw)
         finally:
             testing_reaper.assert_all_closed()
-    return _function_named(decorated, fn.__name__)
+    return function_named(decorated, fn.__name__)
 
 def rollback_open_connections(fn):
     """Decorator that rolls back all open connections after fn execution."""
@@ -54,7 +54,7 @@ def rollback_open_connections(fn):
             fn(*args, **kw)
         finally:
             testing_reaper.rollback_all()
-    return _function_named(decorated, fn.__name__)
+    return function_named(decorated, fn.__name__)
 
 def close_open_connections(fn):
     """Decorator that closes all connections after fn execution."""
@@ -64,7 +64,7 @@ def close_open_connections(fn):
             fn(*args, **kw)
         finally:
             testing_reaper.close_all()
-    return _function_named(decorated, fn.__name__)
+    return function_named(decorated, fn.__name__)
 
 def all_dialects():
     import sqlalchemy.databases as d
@@ -104,7 +104,7 @@ def testing_engine(url=None, options=None):
     """Produce an engine configured by --options with optional overrides."""
 
     from sqlalchemy import create_engine
-    from testlib.assertsql import asserter
+    from sqlalchemy.test.assertsql import asserter
 
     url = url or config.db_url
     options = options or config.db_opts
diff --git a/lib/sqlalchemy/test/noseplugin.py b/lib/sqlalchemy/test/noseplugin.py
new file mode 100644 (file)
index 0000000..263d2d7
--- /dev/null
@@ -0,0 +1,156 @@
+import logging
+import os
+import re
+import sys
+import time
+import warnings
+import ConfigParser
+import StringIO
+from config import db, db_label, db_url, file_config, base_config, \
+                           post_configure, \
+                           _list_dbs, _server_side_cursors, _engine_strategy, \
+                           _engine_uri, _require, _engine_pool, \
+                           _create_testing_engine, _prep_testing_database, \
+                           _set_table_options, _reverse_topological, _log
+from sqlalchemy.test import testing, config, requires
+from nose.plugins import Plugin
+from nose.util import tolist
+import nose.case
+
+log = logging.getLogger('nose.plugins.sqlalchemy')
+
+class NoseSQLAlchemy(Plugin):
+    """
+    Handles the setup and extra properties required for testing SQLAlchemy
+    """
+    enabled = True
+    name = 'sqlalchemy'
+    score = 100
+
+    def options(self, parser, env=os.environ):
+        Plugin.options(self, parser, env)
+        opt = parser.add_option
+        #opt("--verbose", action="store_true", dest="verbose",
+            #help="enable stdout echoing/printing")
+        #opt("--quiet", action="store_true", dest="quiet", help="suppress output")
+        opt("--log-info", action="callback", type="string", callback=_log,
+            help="turn on info logging for <LOG> (multiple OK)")
+        opt("--log-debug", action="callback", type="string", callback=_log,
+            help="turn on debug logging for <LOG> (multiple OK)")
+        opt("--require", action="append", dest="require", default=[],
+            help="require a particular driver or module version (multiple OK)")
+        opt("--db", action="store", dest="db", default="sqlite",
+            help="Use prefab database uri")
+        opt('--dbs', action='callback', callback=_list_dbs,
+            help="List available prefab dbs")
+        opt("--dburi", action="store", dest="dburi",
+            help="Database uri (overrides --db)")
+        opt("--dropfirst", action="store_true", dest="dropfirst",
+            help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)")
+        opt("--mockpool", action="store_true", dest="mockpool",
+            help="Use mock pool (asserts only one connection used)")
+        opt("--enginestrategy", action="callback", type="string",
+            callback=_engine_strategy,
+            help="Engine strategy (plain or threadlocal, defaults to plain)")
+        opt("--reversetop", action="store_true", dest="reversetop", default=False,
+            help="Reverse the collection ordering for topological sorts (helps "
+                  "reveal dependency issues)")
+        opt("--unhashable", action="store_true", dest="unhashable", default=False,
+            help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
+        opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
+            help="Disallow SQLAlchemy from performing == on mapped test objects.")
+        opt("--truthless", action="store_true", dest="truthless", default=False,
+            help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
+        opt("--serverside", action="callback", callback=_server_side_cursors,
+            help="Turn on server side cursors for PG")
+        opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
+            help="Use the specified MySQL storage engine for all tables, default is "
+                 "a db-default/InnoDB combo.")
+        opt("--table-option", action="append", dest="tableopts", default=[],
+            help="Add a dialect-specific table option, key=value")
+
+        global file_config
+        file_config = ConfigParser.ConfigParser()
+        file_config.readfp(StringIO.StringIO(base_config))
+        file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
+        config.file_config = file_config
+        
+    def configure(self, options, conf):
+        Plugin.configure(self, options, conf)
+
+        import testing, requires
+        testing.db = db
+        testing.requires = requires
+
+        # Lazy setup of other options (post coverage)
+        for fn in post_configure:
+            fn(options, file_config)
+        
+    def describeTest(self, test):
+        return ""
+        
+    def wantClass(self, cls):
+        """Return true if you want the main test selector to collect
+        tests from this class, false if you don't, and None if you don't
+        care.
+
+        :Parameters:
+           cls : class
+             The class being examined by the selector
+
+        """
+
+        if not issubclass(cls, testing.TestBase):
+            return False
+        else:
+            if (hasattr(cls, '__whitelist__') and
+                testing.db.name in cls.__whitelist__):
+                return True
+            else:
+                return not self.__should_skip_for(cls)
+    
+    def __should_skip_for(self, cls):
+        if hasattr(cls, '__requires__'):
+            def test_suite(): return 'ok'
+            for requirement in cls.__requires__:
+                check = getattr(requires, requirement)
+                if check(test_suite)() != 'ok':
+                    # The requirement will perform messaging.
+                    return True
+        if (hasattr(cls, '__unsupported_on__') and
+            testing.db.name in cls.__unsupported_on__):
+            print "'%s' unsupported on DB implementation '%s'" % (
+                cls.__class__.__name__, testing.db.name)
+            return True
+        if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)):
+            print "'%s' unsupported on DB implementation '%s'" % (
+                cls.__class__.__name__, testing.db.name)
+            return True
+        if (getattr(cls, '__skip_if__', False)):
+            for c in getattr(cls, '__skip_if__'):
+                if c():
+                    print "'%s' skipped by %s" % (
+                        cls.__class__.__name__, c.__name__)
+                    return True
+        for rule in getattr(cls, '__excluded_on__', ()):
+            if testing._is_excluded(*rule):
+                print "'%s' unsupported on DB %s version %s" % (
+                    cls.__class__.__name__, testing.db.name,
+                    _server_version())
+                return True
+        return False
+
+    #def begin(self):
+        #pass
+
+    def beforeTest(self, test):
+        testing.resetwarnings()
+
+    def afterTest(self, test):
+        testing.resetwarnings()
+        
+    #def handleError(self, test, err):
+        #pass
+
+    #def finalize(self, result=None):
+        #pass
similarity index 96%
rename from test/testlib/orm.py
rename to lib/sqlalchemy/test/orm.py
index 22d62460114452fd98f0aebdee59227166569486..7ec13c55599deeb1a70e5bf2f87e7091479d28dc 100644 (file)
@@ -1,8 +1,6 @@
 import inspect, re
-from testlib import config, testing
-
-sa = None
-orm = None
+import config, testing
+from sqlalchemy import orm
 
 __all__ = 'mapper',
 
@@ -93,10 +91,6 @@ def _make_blocker(method_name, fallback):
     return method
 
 def mapper(type_, *args, **kw):
-    global orm
-    if orm is None:
-        from sqlalchemy import orm
-
     forbidden = [
         ('__hash__', 'unhashable', lambda s: id(s)),
         ('__eq__', 'noncomparable', lambda s, o: s is o),
similarity index 90%
rename from test/pickleable.py
rename to lib/sqlalchemy/test/pickleable.py
index ffb22f3a2478935942e29829cce7964865da7810..9794e424db4e783948110cd9671da002b609ab68 100644 (file)
@@ -1,5 +1,9 @@
-"""since the cPickle module as of py2.4 uses erroneous relative imports, define the various
-picklable classes here so we can test PickleType stuff without issue."""
+"""
+
+some objects used for pickle tests, declared in their own module so that they
+are easily pickleable.
+
+"""
 
 
 class Foo(object):
similarity index 89%
rename from test/testlib/profiling.py
rename to lib/sqlalchemy/test/profiling.py
index 89db3301110d5c73d33967aadd95ecea7df6a298..ca4b31cbd8c1744973f3b190cc5552d7402fde3c 100644 (file)
@@ -1,8 +1,13 @@
-"""Profiling support for unit and performance tests."""
+"""Profiling support for unit and performance tests.
+
+These are special purpose profiling methods which operate
+in a more fine-grained way than nose's profiling plugin.
+
+"""
 
 import os, sys
-from testlib.compat import _function_named
-import testlib.config
+from sqlalchemy.util import function_named
+import config
 
 __all__ = 'profiled', 'function_call_count', 'conditional_call_count'
 
@@ -43,12 +48,8 @@ def profiled(target=None, **target_opts):
             elapsed, load_stats, result = _profile(
                 filename, fn, *args, **kw)
 
-            if not testlib.config.options.quiet:
-                print "Profiled target '%s', wall time: %.2f seconds" % (
-                    target, elapsed)
-
             report = target_opts.get('report', profile_config['report'])
-            if report and testlib.config.options.verbose:
+            if report:
                 sort_ = target_opts.get('sort', profile_config['sort'])
                 limit = target_opts.get('limit', profile_config['limit'])
                 print "Profile report for target '%s' (%s)" % (
@@ -63,7 +64,7 @@ def profiled(target=None, **target_opts):
                 #stats.print_callers()
             os.unlink(filename)
             return result
-        return _function_named(profiled, fn.__name__)
+        return function_named(profiled, fn.__name__)
     return decorator
 
 def function_call_count(count=None, versions={}, variance=0.05):
@@ -113,10 +114,9 @@ def function_call_count(count=None, versions={}, variance=0.05):
                 stats = stat_loader()
                 calls = stats.total_calls
 
-                if testlib.config.options.verbose:
-                    stats.sort_stats('calls', 'cumulative')
-                    stats.print_stats()
-                    #stats.print_callers()
+                stats.sort_stats('calls', 'cumulative')
+                stats.print_stats()
+                #stats.print_callers()
                 deviance = int(count * variance)
                 if (calls < (count - deviance) or
                     calls > (count + deviance)):
@@ -129,7 +129,7 @@ def function_call_count(count=None, versions={}, variance=0.05):
             finally:
                 if os.path.exists(filename):
                     os.unlink(filename)
-        return _function_named(counted, fn.__name__)
+        return function_named(counted, fn.__name__)
     return decorator
 
 def conditional_call_count(discriminator, categories):
@@ -155,7 +155,7 @@ def conditional_call_count(discriminator, categories):
 
             rewrapped = function_call_count(*criteria)(fn)
             return rewrapped(*args, **kw)
-        return _function_named(at_runtime, fn.__name__)
+        return function_named(at_runtime, fn.__name__)
     return decorator
 
 
similarity index 99%
rename from test/testlib/requires.py
rename to lib/sqlalchemy/test/requires.py
index b20929a83b419d1acead2fbbecd6a214cae75220..b23b8620da054d762016ba9453aa6d2d8c3f7a14 100644 (file)
@@ -5,7 +5,7 @@ target database.
 
 """
 
-from testlib.testing import \
+from testing import \
      _block_unconditionally as no_support, \
      _chain_decorators_on, \
      exclude, \
similarity index 92%
rename from test/testlib/schema.py
rename to lib/sqlalchemy/test/schema.py
index 7009fd65d8d220eb13ed5b85a9798bd2c4845590..f96805fe4947f9f57b750d169aecec14aff9ce35 100644 (file)
@@ -1,6 +1,9 @@
-from testlib import testing
+"""Enhanced versions of schema.Table and schema.Column which establish
+desired state for different backends.
+"""
 
-schema = None
+from sqlalchemy.test import testing
+from sqlalchemy import schema
 
 __all__ = 'Table', 'Column',
 
@@ -9,10 +12,6 @@ table_options = {}
 def Table(*args, **kw):
     """A schema.Table wrapper/hook for dialect-specific tweaks."""
 
-    global schema
-    if schema is None:
-        from sqlalchemy import schema
-
     test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
                       if k.startswith('test_')])
 
@@ -65,10 +64,6 @@ def Table(*args, **kw):
 def Column(*args, **kw):
     """A schema.Column wrapper/hook for dialect-specific tweaks."""
 
-    global schema
-    if schema is None:
-        from sqlalchemy import schema
-
     test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
                       if k.startswith('test_')])
 
similarity index 71%
rename from test/testlib/testing.py
rename to lib/sqlalchemy/test/testing.py
index 408dda79f1eab74ad1fea68d16b03a8ab4a873d4..36c7d340a3bc8f80a3fe98cf1e6ce1442928294e 100644 (file)
@@ -1,28 +1,17 @@
 """TestCase and TestSuite artifacts and testing decorators."""
 
-# monkeypatches unittest.TestLoader.suiteClass at import time
-
 import itertools
 import operator
 import re
 import sys
 import types
-from testlib import sa_unittest as unittest
 import warnings
 from cStringIO import StringIO
 
-import testlib.config as config
-from testlib.compat import _function_named, callable
-
-# Delayed imports
-MetaData = None
-Session = None
-clear_mappers = None
-sa_exc = None
-schema = None
-sqltypes = None
-util = None
+from sqlalchemy.test import config, assertsql
+from sqlalchemy.util import function_named
 
+from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema
 
 _ops = { '<': operator.lt,
          '>': operator.gt,
@@ -68,7 +57,7 @@ def fails_if(callable_):
                     raise AssertionError(
                         "Unexpected success for '%s' (condition: %s)" %
                         (fn_name, description))
-        return _function_named(maybe, fn_name)
+        return function_named(maybe, fn_name)
     return decorate
 
 
@@ -89,7 +78,7 @@ def future(fn):
         else:
             raise AssertionError(
                 "Unexpected success for future test '%s'" % fn_name)
-    return _function_named(decorated, fn_name)
+    return function_named(decorated, fn_name)
 
 def fails_on(dbs, reason):
     """Mark a test as expected to fail on the specified database 
@@ -118,7 +107,7 @@ def fails_on(dbs, reason):
                     raise AssertionError(
                         "Unexpected success for '%s' on DB implementation '%s'" %
                         (fn_name, config.db.name))
-        return _function_named(maybe, fn_name)
+        return function_named(maybe, fn_name)
     return decorate
 
 def fails_on_everything_except(*dbs):
@@ -145,7 +134,7 @@ def fails_on_everything_except(*dbs):
                     raise AssertionError(
                         "Unexpected success for '%s' on DB implementation '%s'" %
                         (fn_name, config.db.name))
-        return _function_named(maybe, fn_name)
+        return function_named(maybe, fn_name)
     return decorate
 
 def crashes(db, reason):
@@ -168,7 +157,7 @@ def crashes(db, reason):
                 return True
             else:
                 return fn(*args, **kw)
-        return _function_named(maybe, fn_name)
+        return function_named(maybe, fn_name)
     return decorate
 
 def _block_unconditionally(db, reason):
@@ -192,7 +181,7 @@ def _block_unconditionally(db, reason):
                 return True
             else:
                 return fn(*args, **kw)
-        return _function_named(maybe, fn_name)
+        return function_named(maybe, fn_name)
     return decorate
 
 
@@ -221,7 +210,7 @@ def exclude(db, op, spec, reason):
                 return True
             else:
                 return fn(*args, **kw)
-        return _function_named(maybe, fn_name)
+        return function_named(maybe, fn_name)
     return decorate
 
 def _should_carp_about_exclusion(reason):
@@ -281,7 +270,7 @@ def skip_if(predicate, reason=None):
                 return True
             else:
                 return fn(*args, **kw)
-        return _function_named(maybe, fn_name)
+        return function_named(maybe, fn_name)
     return decorate
 
 def emits_warning(*messages):
@@ -299,10 +288,6 @@ def emits_warning(*messages):
     # - update: jython looks ok, it uses cpython's module
     def decorate(fn):
         def safe(*args, **kw):
-            global sa_exc
-            if sa_exc is None:
-                import sqlalchemy.exc as sa_exc
-
             # todo: should probably be strict about this, too
             filters = [dict(action='ignore',
                             category=sa_exc.SAPendingDeprecationWarning)]
@@ -320,7 +305,7 @@ def emits_warning(*messages):
                 return fn(*args, **kw)
             finally:
                 resetwarnings()
-        return _function_named(safe, fn.__name__)
+        return function_named(safe, fn.__name__)
     return decorate
 
 def emits_warning_on(db, *warnings):
@@ -344,7 +329,7 @@ def emits_warning_on(db, *warnings):
                 else:
                     wrapped = emits_warning(*warnings)(fn)
                     return wrapped(*args, **kw)
-        return _function_named(maybe, fn.__name__)
+        return function_named(maybe, fn.__name__)
     return decorate
 
 def uses_deprecated(*messages):
@@ -361,10 +346,6 @@ def uses_deprecated(*messages):
 
     def decorate(fn):
         def safe(*args, **kw):
-            global sa_exc
-            if sa_exc is None:
-                import sqlalchemy.exc as sa_exc
-
             # todo: should probably be strict about this, too
             filters = [dict(action='ignore',
                             category=sa_exc.SAPendingDeprecationWarning)]
@@ -387,16 +368,12 @@ def uses_deprecated(*messages):
                 return fn(*args, **kw)
             finally:
                 resetwarnings()
-        return _function_named(safe, fn.__name__)
+        return function_named(safe, fn.__name__)
     return decorate
 
 def resetwarnings():
     """Reset warning behavior to testing defaults."""
 
-    global sa_exc
-    if sa_exc is None:
-        import sqlalchemy.exc as sa_exc
-
     warnings.filterwarnings('ignore',
                             category=sa_exc.SAPendingDeprecationWarning) 
     warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
@@ -474,7 +451,23 @@ def startswith_(a, fragment, msg=None):
     assert a.startswith(fragment), msg or "%r does not start with %r" % (
         a, fragment)
 
+def assert_raises(except_cls, callable_, *args, **kw):
+    try:
+        callable_(*args, **kw)
+        assert False, "Callable did not raise an exception"
+    except except_cls, e:
+        pass
 
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+    try:
+        callable_(*args, **kwargs)
+        assert False, "Callable did not raise an exception"
+    except except_cls, e:
+        assert re.search(msg, str(e)), "%r !~ %s" % (msg, e)
+
+def fail(msg):
+    assert False, msg
+    
 def fixture(table, columns, *rows):
     """Insert data into table after creation."""
     def onload(event, schema_item, connection):
@@ -484,43 +477,6 @@ def fixture(table, columns, *rows):
                                     for column_values in rows])
     table.append_ddl_listener('after-create', onload)
 
-def _import_by_name(name):
-    submodule = name.split('.')[-1]
-    return __import__(name, globals(), locals(), [submodule])
-
-class CompositeModule(types.ModuleType):
-    """Merged attribute access for multiple modules."""
-
-    # break the habit
-    __all__ = ()
-
-    def __init__(self, name, *modules, **overrides):
-        """Construct a new lazy composite of modules.
-
-        Modules may be string names or module-like instances.  Individual
-        attribute overrides may be specified as keyword arguments for
-        convenience.
-
-        The constructed module will resolve attribute access in reverse order:
-        overrides, then each member of reversed(modules).  Modules specified
-        by name will be loaded lazily when encountered in attribute
-        resolution.
-
-        """
-        types.ModuleType.__init__(self, name)
-        self.__modules = list(reversed(modules))
-        for key, value in overrides.iteritems():
-            setattr(self, key, value)
-
-    def __getattr__(self, key):
-        for idx, mod in enumerate(self.__modules):
-            if isinstance(mod, basestring):
-                self.__modules[idx] = mod = _import_by_name(mod)
-            if hasattr(mod, key):
-                return getattr(mod, key)
-        raise AttributeError(key)
-
-
 def resolve_artifact_names(fn):
     """Decorator, augment function globals with tables and classes.
 
@@ -546,7 +502,7 @@ def resolve_artifact_names(fn):
             fn.func_code, context, fn.func_name, fn.func_defaults,
             fn.func_closure)
         return rebound(*args, **kwargs)
-    return _function_named(resolved, fn.func_name)
+    return function_named(resolved, fn.func_name)
 
 class adict(dict):
     """Dict keys available as attributes.  Shadows."""
@@ -560,7 +516,7 @@ class adict(dict):
         return tuple([self[key] for key in keys])
 
 
-class TestBase(unittest.TestCase):
+class TestBase(object):
     # A sequence of database names to always run, regardless of the
     # constraints below.
     __whitelist__ = ()
@@ -579,37 +535,11 @@ class TestBase(unittest.TestCase):
     # skipped.
     __skip_if__ = None
 
-
     _artifact_registries = ()
 
-    _sa_first_test = False
-    _sa_last_test = False
-
-    def __init__(self, *args, **params):
-        unittest.TestCase.__init__(self, *args, **params)
-
-    def setUpAll(self):
-        pass
-
-    def tearDownAll(self):
-        pass
-
-    def shortDescription(self):
-        """overridden to not return docstrings"""
-        return None
-
-    def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs):
-        try:
-            callable_(*args, **kwargs)
-            assert False, "Callable did not raise an exception"
-        except except_cls, e:
-            assert re.search(msg, str(e)), "%r !~ %s" % (msg, e)
-
-    if not hasattr(unittest.TestCase, 'assertTrue'):
-        assertTrue = unittest.TestCase.failUnless
-    if not hasattr(unittest.TestCase, 'assertFalse'):
-        assertFalse = unittest.TestCase.failIf
-
+    def assert_(self, val, msg=None):
+        assert val, msg
+        
 class AssertsCompiledSQL(object):
     def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None):
         if dialect is None:
@@ -626,25 +556,20 @@ class AssertsCompiledSQL(object):
 
         cc = re.sub(r'\n', '', str(c))
 
-        self.assertEquals(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
+        eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
 
         if checkparams is not None:
-            self.assertEquals(c.construct_params(params), checkparams)
+            eq_(c.construct_params(params), checkparams)
 
 class ComparesTables(object):
     def assert_tables_equal(self, table, reflected_table):
-        global sqltypes, schema
-        if sqltypes is None:
-            import sqlalchemy.types as sqltypes
-        if schema is None:
-            import sqlalchemy.schema as schema
         base_mro = sqltypes.TypeEngine.__mro__
         assert len(table.c) == len(reflected_table.c)
         for c, reflected_c in zip(table.c, reflected_table.c):
-            self.assertEquals(c.name, reflected_c.name)
+            eq_(c.name, reflected_c.name)
             assert reflected_c is reflected_table.c[c.name]
-            self.assertEquals(c.primary_key, reflected_c.primary_key)
-            self.assertEquals(c.nullable, reflected_c.nullable)
+            eq_(c.primary_key, reflected_c.primary_key)
+            eq_(c.nullable, reflected_c.nullable)
             assert len(
                 set(type(reflected_c.type).__mro__).difference(base_mro).intersection(
                 set(type(c.type).__mro__).difference(base_mro)
@@ -652,9 +577,9 @@ class ComparesTables(object):
             ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
 
             if isinstance(c.type, sqltypes.String):
-                self.assertEquals(c.type.length, reflected_c.type.length)
+                eq_(c.type.length, reflected_c.type.length)
 
-            self.assertEquals(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
+            eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
             if c.default:
                 assert isinstance(reflected_c.server_default,
                                   schema.FetchedValue)
@@ -704,10 +629,6 @@ class AssertsExecutionResults(object):
         numbers of rows that the test suite manipulates.
         """
 
-        global util
-        if util is None:
-            from sqlalchemy import util
-
         class frozendict(dict):
             def __hash__(self):
                 return id(self)
@@ -716,11 +637,11 @@ class AssertsExecutionResults(object):
         expected = set([frozendict(e) for e in expected])
 
         for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
-            self.fail('Unexpected type "%s", expected "%s"' % (
+            fail('Unexpected type "%s", expected "%s"' % (
                 type(wrong).__name__, cls.__name__))
 
         if len(found) != len(expected):
-            self.fail('Unexpected object count "%s", expected "%s"' % (
+            fail('Unexpected object count "%s", expected "%s"' % (
                 len(found), len(expected)))
 
         NOVALUE = object()
@@ -743,13 +664,12 @@ class AssertsExecutionResults(object):
                     found.remove(found_item)
                     break
             else:
-                self.fail(
+                fail(
                     "Expected %s instance with attributes %s not found." % (
                     cls.__name__, repr(expected_item)))
         return True
 
     def assert_sql_execution(self, db, callable_, *rules):
-        from testlib import assertsql
         assertsql.asserter.add_rules(rules)
         try:
             callable_()
@@ -758,8 +678,6 @@ class AssertsExecutionResults(object):
             assertsql.asserter.clear_rules()
             
     def assert_sql(self, db, callable_, list_, with_sequences=None):
-        from testlib import assertsql
-
         if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
             rules = with_sequences
         else:
@@ -778,142 +696,6 @@ class AssertsExecutionResults(object):
         self.assert_sql_execution(db, callable_, *newrules)
 
     def assert_sql_count(self, db, callable_, count):
-        from testlib import assertsql
         self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
 
 
-
-class TTestSuite(unittest.TestSuite):
-    """A TestSuite with once per TestCase setUpAll() and tearDownAll()"""
-
-    def __init__(self, tests=()):
-        if len(tests) > 0 and isinstance(tests[0], TestBase):
-            self._initTest = tests[0]
-        else:
-            self._initTest = None
-
-        for t in tests:
-            if isinstance(t, TestBase):
-                t._sa_first_test = True
-                break
-        for t in reversed(tests):
-            if isinstance(t, TestBase):
-                t._sa_last_test = True
-                break
-        unittest.TestSuite.__init__(self, tests)
-
-    def run(self, result):
-        init = getattr(self, '_initTest', None)
-        if init is not None:
-            if (hasattr(init, '__whitelist__') and
-                config.db.name in init.__whitelist__):
-                pass
-            else:
-                if self.__should_skip_for(init):
-                    return True
-            try:
-                resetwarnings()
-                init.setUpAll()
-            except:
-                # skip tests if global setup fails
-                ex = self.__exc_info()
-                for test in self._tests:
-                    result.addError(test, ex)
-                return False
-        try:
-            resetwarnings()
-            for test in self._tests:
-                if result.shouldStop:
-                    break
-                test(result)
-            return result
-        finally:
-            try:
-                resetwarnings()
-                if init is not None:
-                    init.tearDownAll()
-            except:
-                result.addError(init, self.__exc_info())
-                pass
-
-    def __should_skip_for(self, cls):
-        if hasattr(cls, '__requires__'):
-            global requires
-            if requires is None:
-                from testing import requires
-            def test_suite(): return 'ok'
-            for requirement in cls.__requires__:
-                check = getattr(requires, requirement)
-                if check(test_suite)() != 'ok':
-                    # The requirement will perform messaging.
-                    return True
-        if (hasattr(cls, '__unsupported_on__') and
-            config.db.name in cls.__unsupported_on__):
-            print "'%s' unsupported on DB implementation '%s'" % (
-                cls.__class__.__name__, config.db.name)
-            return True
-        if (getattr(cls, '__only_on__', None) not in (None,config.db.name)):
-            print "'%s' unsupported on DB implementation '%s'" % (
-                cls.__class__.__name__, config.db.name)
-            return True
-        if (getattr(cls, '__skip_if__', False)):
-            for c in getattr(cls, '__skip_if__'):
-                if c():
-                    print "'%s' skipped by %s" % (
-                        cls.__class__.__name__, c.__name__)
-                    return True
-        for rule in getattr(cls, '__excluded_on__', ()):
-            if _is_excluded(*rule):
-                print "'%s' unsupported on DB %s version %s" % (
-                    cls.__class__.__name__, config.db.name,
-                    _server_version())
-                return True
-        return False
-
-
-    def __exc_info(self):
-        """Return a version of sys.exc_info() with the traceback frame
-           minimised; usually the top level of the traceback frame is not
-           needed.
-           ripped off out of unittest module since its double __
-        """
-        exctype, excvalue, tb = sys.exc_info()
-        if sys.platform[:4] == 'java': ## tracebacks look different in Jython
-            return (exctype, excvalue, tb)
-        return (exctype, excvalue, tb)
-
-# monkeypatch
-unittest.TestLoader.suiteClass = TTestSuite
-
-
-class DevNullWriter(object):
-    def write(self, msg):
-        pass
-    def flush(self):
-        pass
-
-def runTests(suite):
-    verbose = config.options.verbose
-    quiet = config.options.quiet
-    orig_stdout = sys.stdout
-
-    try:
-        if not verbose or quiet:
-            sys.stdout = DevNullWriter()
-        runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
-        return runner.run(suite)
-    finally:
-        if not verbose or quiet:
-            sys.stdout = orig_stdout
-
-def main(suite=None):
-    if not suite:
-        if sys.argv[1:]:
-            suite =unittest.TestLoader().loadTestsFromNames(
-                sys.argv[1:], __import__('__main__'))
-        else:
-            suite = unittest.TestLoader().loadTestsFromModule(
-                __import__('__main__'))
-
-    result = runTests(suite)
-    sys.exit(not result.wasSuccessful())
index 01bb954499e4eebf51604784cd75a64aa82b7b5a..25ee974dba0446e4e15604bb1b2d3857c6d0ff69 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,3 +1,6 @@
 [egg_info]
 tag_build = dev
 tag_svn_revision = true
+
+[nosetests]
+with-sqlalchemy = true
\ No newline at end of file
index 6a24677652b6ca40f12982f30380b663b92a626b..3d65f022e0694778e369a875aff5e61e47bb2ee7 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -31,6 +31,13 @@ setup(name = "SQLAlchemy",
       packages = find_packages('lib'),
       package_dir = {'':'lib'},
       license = "MIT License",
+
+      entry_points = {
+          'nose.plugins.0.10': [
+              'sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy',
+              ]
+          },
+      
       long_description = """\
 SQLAlchemy is:
 
similarity index 84%
rename from test/profiling/compiler.py
rename to test/aaa_profiling/test_compiler.py
index 26260068a622948edd147c6557cf576444d50d17..3e4274d47dd3c686992d71ed4560f728168c4b47 100644 (file)
@@ -1,10 +1,10 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
 
 
 class CompileTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global t1, t2, metadata
         metadata = MetaData()
         t1 = Table('t1', metadata,
@@ -28,5 +28,3 @@ class CompileTest(TestBase, AssertsExecutionResults):
         s = select([t1], t1.c.c2==t2.c.c1)
         s.compile()
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 96%
rename from test/profiling/memusage.py
rename to test/aaa_profiling/test_memusage.py
index ccafc7bd7e99862c8fb97a3f60e4d1220b510945..70a3cf8cd68ffbc4d5e86c8351ba43f307874e6c 100644 (file)
@@ -1,14 +1,16 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 import gc
 from sqlalchemy.orm import mapper, relation, create_session, clear_mappers, sessionmaker
 from sqlalchemy.orm.mapper import _mapper_registry
 from sqlalchemy.orm.session import _sessions
 import operator
-from testlib import testing
-from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, PickleType
+from sqlalchemy.test import testing
+from sqlalchemy import MetaData, Integer, String, ForeignKey, PickleType
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
 import sqlalchemy as sa
 from sqlalchemy.sql import column
-from orm import _base
+from test.orm import _base
 
 
 class A(_base.ComparableEntity):
@@ -52,7 +54,7 @@ def assert_no_mappers():
     assert len(_mapper_registry) == 0
 
 class EnsureZeroed(_base.ORMTest):
-    def setUp(self):
+    def setup(self):
         _sessions.clear()
         _mapper_registry.clear()
 
@@ -106,7 +108,7 @@ class MemUsageTest(EnsureZeroed):
             sess.expunge_all()
 
             alist = sess.query(A).all()
-            self.assertEquals(
+            eq_(
                 [
                     A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]),
                     A(col2="a2", bs=[]),
@@ -157,7 +159,7 @@ class MemUsageTest(EnsureZeroed):
             sess.expunge_all()
 
             alist = sess.query(A).order_by(A.col1).all()
-            self.assertEquals(
+            eq_(
                 [
                     A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]),
                     A(col2="a2", bs=[]),
@@ -217,7 +219,7 @@ class MemUsageTest(EnsureZeroed):
             sess.expunge_all()
 
             alist = sess.query(A).order_by(A.col1).all()
-            self.assertEquals(
+            eq_(
                 [
                     A(), A(), B(col3='b1'), B(col3='b2')
                 ],
@@ -281,7 +283,7 @@ class MemUsageTest(EnsureZeroed):
             sess.expunge_all()
 
             alist = sess.query(A).order_by(A.col1).all()
-            self.assertEquals(
+            eq_(
                 [
                     A(bs=[B(col2='b1')]), A(bs=[B(col2='b2')])
                 ],
@@ -398,5 +400,3 @@ class MemUsageTest(EnsureZeroed):
             cast.compile(dialect=dialect)
         go()
         
-if __name__ == '__main__':
-    testenv.main()
similarity index 87%
rename from test/profiling/pool.py
rename to test/aaa_profiling/test_pool.py
index f3f69222c0eeb7efbe76e931bd06195cf42bdcd1..7bb61deb28d0151c88b0e659029df4afe9c5685d 100644 (file)
@@ -1,6 +1,5 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
 from sqlalchemy.pool import QueuePool
 
 
@@ -9,7 +8,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults):
         def close(self):
             pass
 
-    def setUp(self):
+    def setup(self):
         global pool
         pool = QueuePool(creator=self.Connection,
                          pool_size=3, max_overflow=-1,
@@ -39,5 +38,3 @@ class QueuePoolTest(TestBase, AssertsExecutionResults):
         c2 = go()
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 99%
rename from test/profiling/zoomark.py
rename to test/aaa_profiling/test_zoomark.py
index c9f3d9df80d6bb87291d4159d70fab7a680422c8..be29318964a8613bce44d8e72ffb6246b9f696a6 100644 (file)
@@ -6,9 +6,8 @@ An adaptation of Robert Brewers' ZooMark speed tests.
 import datetime
 import sys
 import time
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
 
 ITERATIONS = 1
 
@@ -356,5 +355,3 @@ class ZooMarkTest(TestBase):
         self.test_baseline_8_drop()
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 99%
rename from test/profiling/zoomark_orm.py
rename to test/aaa_profiling/test_zoomark_orm.py
index 5d7192261d61867845f362eca3786ac0f85eb9ec..57e1e24049c37e2efa92a80a63049324d0213ca8 100644 (file)
@@ -6,10 +6,9 @@ An adaptation of Robert Brewers' ZooMark speed tests.
 import datetime
 import sys
 import time
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testlib import *
+from sqlalchemy.test import *
 
 ITERATIONS = 1
 
@@ -318,5 +317,3 @@ class ZooMarkTest(TestBase):
         self.test_baseline_7_drop()
 
 
-if __name__ == '__main__':
-    testenv.main()
diff --git a/test/alltests.py b/test/alltests.py
deleted file mode 100644 (file)
index b014bc9..0000000
+++ /dev/null
@@ -1,24 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-import orm.alltests as orm
-import base.alltests as base
-import sql.alltests as sql
-import engine.alltests as engine
-import dialect.alltests as dialect
-import ext.alltests as ext
-import zblog.alltests as zblog
-import profiling.alltests as profiling
-
-# The profiling tests are sensitive to foibles of CPython VM state, so
-# run them first.  Ideally, each should be run in a fresh interpreter.
-
-def suite():
-    alltests = unittest.TestSuite()
-    for suite in (profiling, base, engine, sql, dialect, orm, ext, zblog):
-        alltests.addTest(suite.suite())
-    return alltests
-
-
-if __name__ == '__main__':
-    testenv.main(suite())
diff --git a/test/base/alltests.py b/test/base/alltests.py
deleted file mode 100644 (file)
index 3fef623..0000000
+++ /dev/null
@@ -1,21 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-def suite():
-    modules_to_test = (
-        # core utilities
-        'base.dependency',
-        'base.utils',
-        'base.except',
-        )
-    alltests = unittest.TestSuite()
-    for name in modules_to_test:
-        mod = __import__(name)
-        for token in name.split('.')[1:]:
-            mod = getattr(mod, token)
-        alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    return alltests
-
-
-if __name__ == '__main__':
-    testenv.main(suite())
similarity index 97%
rename from test/base/dependency.py
rename to test/base/test_dependency.py
index 8fcd093b25679611bd87e5fdd14589e7696e38b5..0457d552a4eef7d22f6535b1aea38e9b1d3a2053 100644 (file)
@@ -1,6 +1,5 @@
-import testenv; testenv.configure_for_tests()
 import sqlalchemy.topological as topological
-from testlib import TestBase
+from sqlalchemy.test import TestBase
 
 
 class DependencySortTest(TestBase):
@@ -185,5 +184,3 @@ class DependencySortTest(TestBase):
         head = topological.sort_as_tree(tuples, [])
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 96%
rename from test/base/except.py
rename to test/base/test_except.py
index 3f4d654771e7ef96ae47b304747cc73f9b14222c..efb18a153c980d1d862e01c74f3e8eef3bb46c6e 100644 (file)
@@ -1,8 +1,7 @@
 """Tests exceptions and DB-API exception wrapping."""
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
 import exceptions as stdlib_exceptions
 from sqlalchemy import exc as sa_exceptions
+from sqlalchemy.test import TestBase
 
 
 class Error(stdlib_exceptions.StandardError):
@@ -18,7 +17,7 @@ class OutOfSpec(DatabaseError):
     pass
 
 
-class WrapTest(unittest.TestCase):
+class WrapTest(TestBase):
     def test_db_error_normal(self):
         try:
             raise sa_exceptions.DBAPIError.instance(
@@ -118,5 +117,3 @@ class WrapTest(unittest.TestCase):
             self.assert_(True)
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 94%
rename from test/base/utils.py
rename to test/base/test_utils.py
index bc3fc028384a84ef97ed6d444f814da2ec3f08da..39561e9682eae3d3b9b07fd77871bc2133f1a76d 100644 (file)
@@ -1,8 +1,8 @@
-import testenv; testenv.configure_for_tests()
-import copy, threading, unittest
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import copy, threading
 from sqlalchemy import util, sql, exc
-from testlib import TestBase
-from testlib.testing import eq_, is_, ne_
+from sqlalchemy.test import TestBase
+from sqlalchemy.test.testing import eq_, is_, ne_
 
 class OrderedDictTest(TestBase):
     def test_odict(self):
@@ -159,7 +159,7 @@ class HashEqOverride(object):
             return True
 
 
-class IdentitySetTest(unittest.TestCase):
+class IdentitySetTest(TestBase):
     def assert_eq(self, identityset, expected_iterable):
         expected = sorted([id(o) for o in expected_iterable])
         found = sorted([id(o) for o in identityset])
@@ -205,7 +205,7 @@ class IdentitySetTest(unittest.TestCase):
         ids.discard(o1)
         ids.add(o1)
         ids.remove(o1)
-        self.assertRaises(KeyError, ids.remove, o1)
+        assert_raises(KeyError, ids.remove, o1)
 
         eq_(ids.copy(), ids)
 
@@ -260,8 +260,8 @@ class IdentitySetTest(unittest.TestCase):
         except TypeError:
             assert True
 
-        self.assertRaises(TypeError, cmp, ids)
-        self.assertRaises(TypeError, hash, ids)
+        assert_raises(TypeError, cmp, ids)
+        assert_raises(TypeError, hash, ids)
 
     def test_difference(self):
         os1 = util.IdentitySet([1,2,3])
@@ -271,12 +271,12 @@ class IdentitySetTest(unittest.TestCase):
 
         eq_(os1 - os2, util.IdentitySet([1, 2]))
         eq_(os2 - os1, util.IdentitySet([4, 5]))
-        self.assertRaises(TypeError, lambda: os1 - s2)
-        self.assertRaises(TypeError, lambda: os1 - [3, 4, 5])
-        self.assertRaises(TypeError, lambda: s1 - os2)
-        self.assertRaises(TypeError, lambda: s1 - [3, 4, 5])
+        assert_raises(TypeError, lambda: os1 - s2)
+        assert_raises(TypeError, lambda: os1 - [3, 4, 5])
+        assert_raises(TypeError, lambda: s1 - os2)
+        assert_raises(TypeError, lambda: s1 - [3, 4, 5])
 
-class OrderedIdentitySetTest(unittest.TestCase):
+class OrderedIdentitySetTest(TestBase):
     
     def assert_eq(self, identityset, expected_iterable):
         expected = [id(o) for o in expected_iterable]
@@ -303,7 +303,7 @@ class OrderedIdentitySetTest(unittest.TestCase):
         eq_(s1.union(s2).intersection(s3), [a, d, f])
 
 
-class DictlikeIteritemsTest(unittest.TestCase):
+class DictlikeIteritemsTest(TestBase):
     baseline = set([('a', 1), ('b', 2), ('c', 3)])
 
     def _ok(self, instance):
@@ -311,7 +311,7 @@ class DictlikeIteritemsTest(unittest.TestCase):
         eq_(set(iterator), self.baseline)
 
     def _notok(self, instance):
-        self.assertRaises(TypeError,
+        assert_raises(TypeError,
                           util.dictlike_iteritems,
                           instance)
 
@@ -638,7 +638,7 @@ class WeakIdentityMappingTest(TestBase):
 
     def test_update(self):
         data, wim = self._fixture()
-        self.assertRaises(NotImplementedError, wim.update)
+        assert_raises(NotImplementedError, wim.update)
 
     def test_weak_clear(self):
         data, wim = self._fixture()
@@ -838,13 +838,13 @@ class AsInterfaceTest(TestBase):
 
     def test_instance(self):
         obj = object()
-        self.assertRaises(TypeError, util.as_interface, obj,
+        assert_raises(TypeError, util.as_interface, obj,
                           cls=self.Something)
 
-        self.assertRaises(TypeError, util.as_interface, obj,
+        assert_raises(TypeError, util.as_interface, obj,
                           methods=('foo'))
 
-        self.assertRaises(TypeError, util.as_interface, obj,
+        assert_raises(TypeError, util.as_interface, obj,
                           cls=self.Something, required=('foo'))
 
         obj = self.Something()
@@ -860,25 +860,25 @@ class AsInterfaceTest(TestBase):
 
         for obj in partial, slotted:
             eq_(obj, util.as_interface(obj, cls=self.Something))
-            self.assertRaises(TypeError, util.as_interface, obj,
+            assert_raises(TypeError, util.as_interface, obj,
                               methods=('foo'))
             eq_(obj, util.as_interface(obj, methods=('bar',)))
             eq_(obj, util.as_interface(obj, cls=self.Something,
                                        required=('bar',)))
-            self.assertRaises(TypeError, util.as_interface, obj,
+            assert_raises(TypeError, util.as_interface, obj,
                               cls=self.Something, required=('foo',))
 
-            self.assertRaises(TypeError, util.as_interface, obj,
+            assert_raises(TypeError, util.as_interface, obj,
                               cls=self.Something, required=self.Something)
 
     def test_dict(self):
         obj = {}
 
-        self.assertRaises(TypeError, util.as_interface, obj,
+        assert_raises(TypeError, util.as_interface, obj,
                           cls=self.Something)
-        self.assertRaises(TypeError, util.as_interface, obj,
+        assert_raises(TypeError, util.as_interface, obj,
                           methods=('foo'))
-        self.assertRaises(TypeError, util.as_interface, obj,
+        assert_raises(TypeError, util.as_interface, obj,
                           cls=self.Something, required=('foo'))
 
         def assertAdapted(obj, *methods):
@@ -911,13 +911,13 @@ class AsInterfaceTest(TestBase):
         res = util.as_interface(obj, methods=('foo', 'bar'), required=('foo',))
         assertAdapted(res, 'foo', 'bar')
 
-        self.assertRaises(TypeError, util.as_interface, obj, methods=('foo',))
+        assert_raises(TypeError, util.as_interface, obj, methods=('foo',))
 
-        self.assertRaises(TypeError, util.as_interface, obj,
+        assert_raises(TypeError, util.as_interface, obj,
                           methods=('foo', 'bar', 'baz'), required=('baz',))
 
         obj = {'foo': 123}
-        self.assertRaises(TypeError, util.as_interface, obj, cls=self.Something)
+        assert_raises(TypeError, util.as_interface, obj, cls=self.Something)
 
 
 class TestClassHierarchy(TestBase):
@@ -955,5 +955,3 @@ class TestClassHierarchy(TestBase):
         eq_(set(util.class_hierarchy(A)), set((A, B, object)))
 
         
-if __name__ == "__main__":
-    testenv.main()
diff --git a/test/dialect/alltests.py b/test/dialect/alltests.py
deleted file mode 100644 (file)
index 0defb6a..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-
-def suite():
-    modules_to_test = (
-        'dialect.access',
-        'dialect.firebird',
-        'dialect.informix',
-        'dialect.maxdb',
-        'dialect.mssql',
-        'dialect.mysql',
-        'dialect.oracle',
-        'dialect.postgres',
-        'dialect.sqlite',
-        'dialect.sybase',
-        )
-    alltests = unittest.TestSuite()
-    for name in modules_to_test:
-        mod = __import__(name)
-        for token in name.split('.')[1:]:
-            mod = getattr(mod, token)
-        alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    return alltests
-
-
-if __name__ == '__main__':
-    testenv.main(suite())
similarity index 86%
rename from test/dialect/access.py
rename to test/dialect/test_access.py
index 57af45a9d6ac1642ff4cef5c50e32e542d0318f6..0ea8d9a61ad41689d90797954ad7ed8613637532 100644 (file)
@@ -1,8 +1,7 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy import sql
 from sqlalchemy.databases import access
-from testlib import *
+from sqlalchemy.test import *
 
 
 class CompileTest(TestBase, AssertsCompiledSQL):
@@ -30,5 +29,3 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % subst)
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 84%
rename from test/dialect/firebird.py
rename to test/dialect/test_firebird.py
index 5a0109dcc4258053fd7361d458f893f79390c022..fa608c9a18e5c0761cf582b352ab38dc3b780b22 100644 (file)
@@ -1,9 +1,9 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 from sqlalchemy import *
 from sqlalchemy.databases import firebird
 from sqlalchemy.exc import ProgrammingError
 from sqlalchemy.sql import table, column
-from testlib import *
+from sqlalchemy.test import *
 
 
 class DomainReflectionTest(TestBase, AssertsExecutionResults):
@@ -11,7 +11,8 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
 
     __only_on__ = 'firebird'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         con = testing.db.connect()
         try:
             con.execute('CREATE DOMAIN int_domain AS INTEGER DEFAULT 42 NOT NULL')
@@ -38,7 +39,8 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
                            NEW.question = gen_id(gen_testtable_id, 1);
                        END''')
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         con = testing.db.connect()
         con.execute('DROP TABLE testtable')
         con.execute('DROP DOMAIN int_domain')
@@ -50,22 +52,22 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
     def test_table_is_reflected(self):
         metadata = MetaData(testing.db)
         table = Table('testtable', metadata, autoload=True)
-        self.assertEquals(set(table.columns.keys()),
+        eq_(set(table.columns.keys()),
                           set(['question', 'answer', 'remark', 'photo', 'd', 't', 'dt']),
                           "Columns of reflected table didn't equal expected columns")
-        self.assertEquals(table.c.question.primary_key, True)
-        self.assertEquals(table.c.question.sequence.name, 'gen_testtable_id')
-        self.assertEquals(table.c.question.type.__class__, firebird.FBInteger)
-        self.assertEquals(table.c.question.server_default.arg.text, "42")
-        self.assertEquals(table.c.answer.type.__class__, firebird.FBString)
-        self.assertEquals(table.c.answer.server_default.arg.text, "'no answer'")
-        self.assertEquals(table.c.remark.type.__class__, firebird.FBText)
-        self.assertEquals(table.c.remark.server_default.arg.text, "''")
-        self.assertEquals(table.c.photo.type.__class__, firebird.FBBinary)
+        eq_(table.c.question.primary_key, True)
+        eq_(table.c.question.sequence.name, 'gen_testtable_id')
+        eq_(table.c.question.type.__class__, firebird.FBInteger)
+        eq_(table.c.question.server_default.arg.text, "42")
+        eq_(table.c.answer.type.__class__, firebird.FBString)
+        eq_(table.c.answer.server_default.arg.text, "'no answer'")
+        eq_(table.c.remark.type.__class__, firebird.FBText)
+        eq_(table.c.remark.server_default.arg.text, "''")
+        eq_(table.c.photo.type.__class__, firebird.FBBinary)
         # The following assume a Dialect 3 database
-        self.assertEquals(table.c.d.type.__class__, firebird.FBDate)
-        self.assertEquals(table.c.t.type.__class__, firebird.FBTime)
-        self.assertEquals(table.c.dt.type.__class__, firebird.FBDateTime)
+        eq_(table.c.d.type.__class__, firebird.FBDate)
+        eq_(table.c.t.type.__class__, firebird.FBTime)
+        eq_(table.c.dt.type.__class__, firebird.FBDateTime)
 
 
 class CompileTest(TestBase, AssertsCompiledSQL):
@@ -140,10 +142,10 @@ class ReturningTest(TestBase, AssertsExecutionResults):
             table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
 
             result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute()
-            self.assertEqual(result.fetchall(), [(1,)])
+            eq_(result.fetchall(), [(1,)])
 
             result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
-            self.assertEqual(result2.fetchall(), [(1,True),(2,False)])
+            eq_(result2.fetchall(), [(1,True),(2,False)])
         finally:
             table.drop()
 
@@ -159,19 +161,19 @@ class ReturningTest(TestBase, AssertsExecutionResults):
         try:
             result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False})
 
-            self.assertEqual(result.fetchall(), [(1,)])
+            eq_(result.fetchall(), [(1,)])
 
             # Multiple inserts only return the last row
             result2 = table.insert(firebird_returning=[table]).execute(
                  [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
 
-            self.assertEqual(result2.fetchall(), [(3,3,True)])
+            eq_(result2.fetchall(), [(3,3,True)])
 
             result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False})
-            self.assertEqual([dict(row) for row in result3], [{'ID':4}])
+            eq_([dict(row) for row in result3], [{'ID':4}])
 
             result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons')
-            self.assertEqual([dict(row) for row in result4], [{'PERSONS': 10}])
+            eq_([dict(row) for row in result4], [{'PERSONS': 10}])
         finally:
             table.drop()
 
@@ -188,10 +190,10 @@ class ReturningTest(TestBase, AssertsExecutionResults):
             table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
 
             result = table.delete(table.c.persons > 4, firebird_returning=[table.c.id]).execute()
-            self.assertEqual(result.fetchall(), [(1,)])
+            eq_(result.fetchall(), [(1,)])
 
             result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
-            self.assertEqual(result2.fetchall(), [(2,False),])
+            eq_(result2.fetchall(), [(2,False),])
         finally:
             table.drop()
 
@@ -224,5 +226,3 @@ class MiscFBTests(TestBase):
         assert len(version) == 3, "Got strange version info: %s" % repr(version)
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 88%
rename from test/dialect/informix.py
rename to test/dialect/test_informix.py
index 1fbbaa0cb485dc9c4d6351d70aeb04c06ab11279..86a4e751d41ab75d8752c76e16d2345380152a2e 100644 (file)
@@ -1,7 +1,6 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.databases import informix
-from testlib import *
+from sqlalchemy.test import *
 
 
 class CompileTest(TestBase, AssertsCompiledSQL):
@@ -19,5 +18,3 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(t1.update().values({t1.c.col1 : t1.c.col1 + 1}), 'UPDATE t1 SET col1=(t1.col1 + ?)')
         
         
-if __name__ == "__main__":
-    testenv.main()
similarity index 96%
rename from test/dialect/maxdb.py
rename to test/dialect/test_maxdb.py
index c2daf8959aeecb979b677ac87f06d315227f45d3..033a05533f1aaf0aebf27b80b85d0afb6c1cd8c4 100644 (file)
@@ -1,12 +1,12 @@
 """MaxDB-specific tests."""
 
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 import StringIO, sys
 from sqlalchemy import *
 from sqlalchemy import exc, sql
 from decimal import Decimal
 from sqlalchemy.databases import maxdb
-from testlib import *
+from sqlalchemy.test import *
 
 
 # TODO
@@ -43,12 +43,12 @@ class ReflectionTest(TestBase, AssertsExecutionResults):
             cols = ['d1','d2','n1','i1']
             t.insert().execute(dict(zip(cols,vals)))
             roundtrip = list(t.select().execute())
-            self.assertEquals(roundtrip, [tuple([1] + vals)])
+            eq_(roundtrip, [tuple([1] + vals)])
 
             t.insert().execute(dict(zip(['id'] + cols,
                                         [2] + list(roundtrip[0][1:]))))
             roundtrip2 = list(t.select(order_by=t.c.id).execute())
-            self.assertEquals(roundtrip2, [tuple([1] + vals),
+            eq_(roundtrip2, [tuple([1] + vals),
                                            tuple([2] + vals)])
         finally:
             try:
@@ -233,8 +233,6 @@ class DBAPITest(TestBase, AssertsExecutionResults):
 
     def test_modulo_operator(self):
         st = str(select([sql.column('col') % 5]).compile(testing.db))
-        self.assertEquals(st, 'SELECT mod(col, ?) FROM DUAL')
+        eq_(st, 'SELECT mod(col, ?) FROM DUAL')
 
 
-if __name__ == "__main__":
-    testenv.main()
old mode 100755 (executable)
new mode 100644 (file)
similarity index 92%
rename from test/dialect/mssql.py
rename to test/dialect/test_mssql.py
index 50f9594..5e2c9a6
@@ -1,14 +1,14 @@
 # -*- encoding: utf-8
-import testenv; testenv.configure_for_tests()
-import datetime, os, pickleable, re
+from sqlalchemy.test.testing import eq_
+import datetime, os, re
 from sqlalchemy import *
 from sqlalchemy import types, exc
 from sqlalchemy.orm import *
 from sqlalchemy.sql import table, column
 from sqlalchemy.databases import mssql
 import sqlalchemy.engine.url as url
-from testlib import *
-from testlib.testing import eq_
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
 
 
 class CompileTest(TestBase, AssertsCompiledSQL):
@@ -162,7 +162,8 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'mssql'
     __dialect__ = mssql.MSSQLDialect()
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, cattable
         metadata = MetaData(testing.db)
 
@@ -172,10 +173,10 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL):
             PrimaryKeyConstraint('id', name='PK_cattable'),
         )
 
-    def setUp(self):
+    def setup(self):
         metadata.create_all()
 
-    def tearDown(self):
+    def teardown(self):
         metadata.drop_all()
 
     def test_compiled(self):
@@ -185,12 +186,12 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL):
         cattable.insert().values(id=9, description='Python').execute()
 
         cats = cattable.select().order_by(cattable.c.id).execute()
-        self.assertEqual([(9, 'Python')], list(cats))
+        eq_([(9, 'Python')], list(cats))
 
         result = cattable.insert().values(description='PHP').execute()
-        self.assertEqual([10], result.last_inserted_ids())
+        eq_([10], result.last_inserted_ids())
         lastcat = cattable.select().order_by(desc(cattable.c.id)).execute()
-        self.assertEqual((10, 'PHP'), lastcat.fetchone())
+        eq_((10, 'PHP'), lastcat.fetchone())
 
     def test_executemany(self):
         cattable.insert().execute([
@@ -201,7 +202,7 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL):
         ])
 
         cats = cattable.select().order_by(cattable.c.id).execute()
-        self.assertEqual([(1, 'Java'), (3, 'Perl'), (8, 'Ruby'), (89, 'Python')], list(cats))
+        eq_([(1, 'Java'), (3, 'Perl'), (8, 'Ruby'), (89, 'Python')], list(cats))
 
         cattable.insert().execute([
             {'description': 'PHP'},
@@ -209,7 +210,7 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL):
         ])
 
         lastcats = cattable.select().order_by(desc(cattable.c.id)).limit(2).execute()
-        self.assertEqual([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats))
+        eq_([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats))
 
 
 class ReflectionTest(TestBase):
@@ -330,7 +331,8 @@ class Foo(object):
 class GenerativeQueryTest(TestBase):
     __only_on__ = 'mssql'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global foo, metadata
         metadata = MetaData(testing.db)
         foo = Table('foo', metadata,
@@ -347,7 +349,8 @@ class GenerativeQueryTest(TestBase):
             sess.add(Foo(bar=i, range=i%10))
         sess.flush()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
         clear_mappers()
 
@@ -361,7 +364,7 @@ class GenerativeQueryTest(TestBase):
 
 class SchemaTest(TestBase):
 
-    def setUp(self):
+    def setup(self):
         t = Table('sometable', MetaData(),
             Column('pk_column', Integer),
             Column('test_column', String)
@@ -418,7 +421,8 @@ class MatchTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'mssql'
     __skip_if__ = (full_text_search_missing, )
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, cattable, matchtable
         metadata = MetaData(testing.db)
         
@@ -456,7 +460,8 @@ class MatchTest(TestBase, AssertsCompiledSQL):
         ])
         DDL("WAITFOR DELAY '00:00:05'").execute(bind=engines.testing_engine())
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
         connection = testing.db.connect()
         connection.execute("DROP FULLTEXT CATALOG Catalog")
@@ -467,44 +472,44 @@ class MatchTest(TestBase, AssertsCompiledSQL):
 
     def test_simple_match(self):
         results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([2, 5], [r.id for r in results])
+        eq_([2, 5], [r.id for r in results])
 
     def test_simple_match_with_apostrophe(self):
         results = matchtable.select().where(matchtable.c.title.match('"Matz''s"')).execute().fetchall()
-        self.assertEquals([3], [r.id for r in results])
+        eq_([3], [r.id for r in results])
 
     def test_simple_prefix_match(self):
         results = matchtable.select().where(matchtable.c.title.match('"nut*"')).execute().fetchall()
-        self.assertEquals([5], [r.id for r in results])
+        eq_([5], [r.id for r in results])
 
     def test_simple_inflectional_match(self):
         results = matchtable.select().where(matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")')).execute().fetchall()
-        self.assertEquals([2], [r.id for r in results])
+        eq_([2], [r.id for r in results])
 
     def test_or_match(self):
         results1 = matchtable.select().where(or_(matchtable.c.title.match('nutshell'), 
                                                  matchtable.c.title.match('ruby'))
                                             ).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([3, 5], [r.id for r in results1])
+        eq_([3, 5], [r.id for r in results1])
         results2 = matchtable.select().where(matchtable.c.title.match('nutshell OR ruby'), 
                                             ).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([3, 5], [r.id for r in results2])    
+        eq_([3, 5], [r.id for r in results2])    
 
     def test_and_match(self):
         results1 = matchtable.select().where(and_(matchtable.c.title.match('python'), 
                                                   matchtable.c.title.match('nutshell'))
                                             ).execute().fetchall()
-        self.assertEquals([5], [r.id for r in results1])
+        eq_([5], [r.id for r in results1])
         results2 = matchtable.select().where(matchtable.c.title.match('python AND nutshell'), 
                                             ).execute().fetchall()
-        self.assertEquals([5], [r.id for r in results2])
+        eq_([5], [r.id for r in results2])
 
     def test_match_across_joins(self):
         results = matchtable.select().where(and_(cattable.c.id==matchtable.c.category_id, 
                                             or_(cattable.c.description.match('Ruby'), 
                                                 matchtable.c.title.match('nutshell')))
                                            ).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([1, 3, 5], [r.id for r in results])
+        eq_([1, 3, 5], [r.id for r in results])
 
 
 class ParseConnectTest(TestBase, AssertsCompiledSQL):
@@ -514,77 +519,78 @@ class ParseConnectTest(TestBase, AssertsCompiledSQL):
         u = url.make_url('mssql://mydsn')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
+        eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
 
     def test_pyodbc_connect_old_style_dsn_trusted(self):
         u = url.make_url('mssql:///?dsn=mydsn')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
+        eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
 
     def test_pyodbc_connect_dsn_non_trusted(self):
         u = url.make_url('mssql://username:password@mydsn')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['dsn=mydsn;UID=username;PWD=password'], {}], connection)
+        eq_([['dsn=mydsn;UID=username;PWD=password'], {}], connection)
 
     def test_pyodbc_connect_dsn_extra(self):
         u = url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_english&foo=bar')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection)
+        eq_([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection)
 
     def test_pyodbc_connect(self):
         u = url.make_url('mssql://username:password@hostspec/database')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
+        eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
 
     def test_pyodbc_connect_comma_port(self):
         u = url.make_url('mssql://username:password@hostspec:12345/database')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection)
+        eq_([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection)
 
     def test_pyodbc_connect_config_port(self):
         u = url.make_url('mssql://username:password@hostspec/database?port=12345')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection)
+        eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection)
 
     def test_pyodbc_extra_connect(self):
         u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection)
+        eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection)
 
     def test_pyodbc_odbc_connect(self):
         u = url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
+        eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
 
     def test_pyodbc_odbc_connect_with_dsn(self):
         u = url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection)
+        eq_([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection)
 
     def test_pyodbc_odbc_connect_ignores_other_values(self):
         u = url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
-        self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
+        eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
 
 
 class TypesTest(TestBase):
     __only_on__ = 'mssql'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global numeric_table, metadata
         metadata = MetaData(testing.db)
 
-    def tearDown(self):
+    def teardown(self):
         metadata.drop_all()
 
     def test_decimal_notation(self):
@@ -611,7 +617,7 @@ class TypesTest(TestBase):
                 numeric_table.insert().execute(numericcol=value)
 
             for value in select([numeric_table.c.numericcol]).execute():
-                self.assertTrue(value[0] in test_items, "%s not in test_items" % value[0])
+                assert value[0] in test_items, "%s not in test_items" % value[0]
 
         except Exception, e:
             raise e
@@ -763,7 +769,7 @@ class TypesTest2(TestBase, AssertsExecutionResults):
 
             t.insert().execute(adate=d1, adatetime=d2, atime=t1)
 
-            self.assertEquals(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
+            eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
 
         finally:
             t.drop(checkfirst=True)
@@ -1072,7 +1078,8 @@ def colspec(c):
 
 class BinaryTest(TestBase, AssertsExecutionResults):
     """Test the Binary and VarBinary types"""
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global binary_table, MyPickleType
 
         class MyPickleType(types.TypeDecorator):
@@ -1102,10 +1109,11 @@ class BinaryTest(TestBase, AssertsExecutionResults):
         )
         binary_table.create()
 
-    def tearDown(self):
+    def teardown(self):
         binary_table.delete().execute()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         binary_table.drop()
 
     def test_binary(self):
@@ -1124,23 +1132,21 @@ class BinaryTest(TestBase, AssertsExecutionResults):
             text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db)
         ):
             l = stmt.execute().fetchall()
-            self.assertEquals(list(stream1), list(l[0]['data']))
+            eq_(list(stream1), list(l[0]['data']))
 
             paddedstream = list(stream1[0:100])
             paddedstream.extend(['\x00'] * (100 - len(paddedstream)))
-            self.assertEquals(paddedstream, list(l[0]['data_slice']))
+            eq_(paddedstream, list(l[0]['data_slice']))
 
-            self.assertEquals(list(stream2), list(l[1]['data']))
-            self.assertEquals(list(stream2), list(l[1]['data_image']))
-            self.assertEquals(testobj1, l[0]['pickled'])
-            self.assertEquals(testobj2, l[1]['pickled'])
-            self.assertEquals(testobj3.moredata, l[0]['mypickle'].moredata)
-            self.assertEquals(l[0]['mypickle'].stuff, 'this is the right stuff')
+            eq_(list(stream2), list(l[1]['data']))
+            eq_(list(stream2), list(l[1]['data_image']))
+            eq_(testobj1, l[0]['pickled'])
+            eq_(testobj2, l[1]['pickled'])
+            eq_(testobj3.moredata, l[0]['mypickle'].moredata)
+            eq_(l[0]['mypickle'].stuff, 'this is the right stuff')
 
     def load_stream(self, name, len=3000):
-        f = os.path.join(os.path.dirname(testenv.__file__), name)
+        f = os.path.join(os.path.dirname(__file__), "..", name)
         return file(f).read(len)
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 97%
rename from test/dialect/mysql.py
rename to test/dialect/test_mysql.py
index fa8a85ec453d312c20b8928af5e2ea204efd62b8..8adb2d71c53c20036cfb7dab7cdc8a1082ef7e3a 100644 (file)
@@ -1,10 +1,10 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 import sets
 from sqlalchemy import *
 from sqlalchemy import sql, exc
 from sqlalchemy.databases import mysql
-from testlib.testing import eq_
-from testlib import *
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test import *
 
 
 class TypesTest(TestBase, AssertsExecutionResults):
@@ -522,7 +522,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
                 [set_table.c.s3],
                 set_table.c.s3.in_([set(['5']), set(['5', '7'])])).execute())
             found = set([frozenset(row[0]) for row in rows])
-            self.assertEquals(found,
+            eq_(found,
                               set([frozenset(['5']), frozenset(['5', '7'])]))
         finally:
             meta.drop_all()
@@ -812,7 +812,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
         if got != wanted:
             print "Expected %s" % wanted
             print "Found %s" % got
-        self.assertEqual(got, wanted)
+        eq_(got, wanted)
 
 
 class SQLTest(TestBase, AssertsCompiledSQL):
@@ -831,24 +831,24 @@ class SQLTest(TestBase, AssertsCompiledSQL):
                 kw['prefixes'] = prefixes
             return str(select(['q'], **kw).compile(dialect=dialect))
 
-        self.assertEqual(gen(None), 'SELECT q')
-        self.assertEqual(gen(True), 'SELECT DISTINCT q')
-        self.assertEqual(gen(1), 'SELECT DISTINCT q')
-        self.assertEqual(gen('diSTInct'), 'SELECT DISTINCT q')
-        self.assertEqual(gen('DISTINCT'), 'SELECT DISTINCT q')
+        eq_(gen(None), 'SELECT q')
+        eq_(gen(True), 'SELECT DISTINCT q')
+        eq_(gen(1), 'SELECT DISTINCT q')
+        eq_(gen('diSTInct'), 'SELECT DISTINCT q')
+        eq_(gen('DISTINCT'), 'SELECT DISTINCT q')
 
         # Standard SQL
-        self.assertEqual(gen('all'), 'SELECT ALL q')
-        self.assertEqual(gen('distinctrow'), 'SELECT DISTINCTROW q')
+        eq_(gen('all'), 'SELECT ALL q')
+        eq_(gen('distinctrow'), 'SELECT DISTINCTROW q')
 
         # Interaction with MySQL prefix extensions
-        self.assertEqual(
+        eq_(
             gen(None, ['straight_join']),
             'SELECT straight_join q')
-        self.assertEqual(
+        eq_(
             gen('all', ['HIGH_PRIORITY SQL_SMALL_RESULT']),
             'SELECT HIGH_PRIORITY SQL_SMALL_RESULT ALL q')
-        self.assertEqual(
+        eq_(
             gen(True, ['high_priority', sql.text('sql_cache')]),
             'SELECT high_priority sql_cache DISTINCT q')
 
@@ -997,7 +997,7 @@ class SQLTest(TestBase, AssertsCompiledSQL):
 
 
 class RawReflectionTest(TestBase):
-    def setUp(self):
+    def setup(self):
         self.dialect = mysql.dialect()
         self.reflector = mysql.MySQLSchemaReflector(
             self.dialect.identifier_preparer)
@@ -1059,7 +1059,8 @@ class ExecutionTest(TestBase):
 class MatchTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'mysql'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, cattable, matchtable
         metadata = MetaData(testing.db)
 
@@ -1096,7 +1097,8 @@ class MatchTest(TestBase, AssertsCompiledSQL):
              'category_id': 1}
         ])
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     def test_expression(self):
@@ -1110,14 +1112,14 @@ class MatchTest(TestBase, AssertsCompiledSQL):
                    order_by(matchtable.c.id).
                    execute().
                    fetchall())
-        self.assertEquals([2, 5], [r.id for r in results])
+        eq_([2, 5], [r.id for r in results])
 
     def test_simple_match_with_apostrophe(self):
         results = (matchtable.select().
                    where(matchtable.c.title.match('"Matz''s"')).
                    execute().
                    fetchall())
-        self.assertEquals([3], [r.id for r in results])
+        eq_([3], [r.id for r in results])
 
     def test_or_match(self):
         results1 = (matchtable.select().
@@ -1126,13 +1128,13 @@ class MatchTest(TestBase, AssertsCompiledSQL):
                     order_by(matchtable.c.id).
                     execute().
                     fetchall())
-        self.assertEquals([3, 5], [r.id for r in results1])
+        eq_([3, 5], [r.id for r in results1])
         results2 = (matchtable.select().
                     where(matchtable.c.title.match('nutshell ruby')).
                     order_by(matchtable.c.id).
                     execute().
                     fetchall())
-        self.assertEquals([3, 5], [r.id for r in results2])
+        eq_([3, 5], [r.id for r in results2])
 
 
     def test_and_match(self):
@@ -1141,12 +1143,12 @@ class MatchTest(TestBase, AssertsCompiledSQL):
                                matchtable.c.title.match('nutshell'))).
                     execute().
                     fetchall())
-        self.assertEquals([5], [r.id for r in results1])
+        eq_([5], [r.id for r in results1])
         results2 = (matchtable.select().
                     where(matchtable.c.title.match('+python +nutshell')).
                     execute().
                     fetchall())
-        self.assertEquals([5], [r.id for r in results2])
+        eq_([5], [r.id for r in results2])
 
     def test_match_across_joins(self):
         results = (matchtable.select().
@@ -1156,12 +1158,10 @@ class MatchTest(TestBase, AssertsCompiledSQL):
                    order_by(matchtable.c.id).
                    execute().
                    fetchall())
-        self.assertEquals([1, 3, 5], [r.id for r in results])
+        eq_([1, 3, 5], [r.id for r in results])
 
 
 def colspec(c):
     return testing.db.dialect.schemagenerator(testing.db.dialect,
         testing.db, None, None).get_column_specification(c)
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 97%
rename from test/dialect/oracle.py
rename to test/dialect/test_oracle.py
index 2186f22595b30739e0f5b0429dd9cd30b0671d50..16175c85121d44c93683ed07f882aa5c1b99c935 100644 (file)
@@ -1,19 +1,20 @@
 # coding: utf-8
 
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 from sqlalchemy import *
 from sqlalchemy.sql import table, column
 from sqlalchemy.databases import oracle
-from testlib import *
-from testlib.testing import eq_
-from testlib.engines import testing_engine
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.engines import testing_engine
 import os
 
 
 class OutParamTest(TestBase, AssertsExecutionResults):
     __only_on__ = 'oracle'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         testing.db.execute("""
 create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT number, z_out OUT varchar) IS
   retval number;
@@ -29,7 +30,8 @@ create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT numb
         result = testing.db.execute(text("begin foo(:x_in, :x_out, :y_out, :z_out); end;", bindparams=[bindparam('x_in', Numeric), outparam('x_out', Numeric), outparam('y_out', Numeric), outparam('z_out', String)]), x_in=5)
         assert result.out_parameters == {'x_out':10, 'y_out':75, 'z_out':None}, result.out_parameters
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
          testing.db.execute("DROP PROCEDURE foo")
 
 
@@ -200,7 +202,7 @@ class MultiSchemaTest(TestBase, AssertsCompiledSQL):
         try:
             parent.insert().execute({'pid':1})
             child.insert().execute({'cid':1, 'pid':1})
-            self.assertEquals(child.select().execute().fetchall(), [(1, 1)])
+            eq_(child.select().execute().fetchall(), [(1, 1)])
         finally:
             meta.drop_all()
 
@@ -217,7 +219,7 @@ class MultiSchemaTest(TestBase, AssertsCompiledSQL):
         try:
             parent.insert().execute({'pid':1})
             child.insert().execute({'cid':1, 'pid':1})
-            self.assertEquals(child.select().execute().fetchall(), [(1, 1)])
+            eq_(child.select().execute().fetchall(), [(1, 1)])
         finally:
             meta.drop_all()
 
@@ -346,7 +348,8 @@ class TypesTest(TestBase, AssertsCompiledSQL):
 class BufferedColumnTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'oracle'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global binary_table, stream, meta
         meta = MetaData(testing.db)
         binary_table = Table('binary_table', meta, 
@@ -360,18 +363,19 @@ class BufferedColumnTest(TestBase, AssertsCompiledSQL):
         for i in range(1, 11):
             binary_table.insert().execute(id=i, data=stream)
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         meta.drop_all()
 
     def test_fetch(self):
-        self.assertEquals(
+        eq_(
             binary_table.select().execute().fetchall() ,
             [(i, stream) for i in range(1, 11)], 
         )
 
     def test_fetch_single_arraysize(self):
         eng = testing_engine(options={'arraysize':1})
-        self.assertEquals(
+        eq_(
             eng.execute(binary_table.select()).fetchall(),
             [(i, stream) for i in range(1, 11)], 
         )
@@ -393,5 +397,3 @@ class ExecuteTest(TestBase):
     def test_basic(self):
         assert testing.db.execute("/*+ this is a comment */ SELECT 1 FROM DUAL").fetchall() == [(1,)]
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 88%
rename from test/dialect/postgres.py
rename to test/dialect/test_postgres.py
index 2dfbe018ccf99af0e1dbf97cfab92d0e911b03a4..8ca714badc79033c9441c59562a87a369f46a271 100644 (file)
@@ -1,11 +1,11 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy import exc
 from sqlalchemy.databases import postgres
 from sqlalchemy.engine.strategies import MockEngineStrategy
-from testlib import *
+from sqlalchemy.test import *
 from sqlalchemy.sql import table, column
 
 
@@ -86,10 +86,10 @@ class ReturningTest(TestBase, AssertsExecutionResults):
             table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
 
             result = table.update(table.c.persons > 4, dict(full=True), postgres_returning=[table.c.id]).execute()
-            self.assertEqual(result.fetchall(), [(1,)])
+            eq_(result.fetchall(), [(1,)])
 
             result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
-            self.assertEqual(result2.fetchall(), [(1,True),(2,False)])
+            eq_(result2.fetchall(), [(1,True),(2,False)])
         finally:
             table.drop()
 
@@ -105,22 +105,22 @@ class ReturningTest(TestBase, AssertsExecutionResults):
         try:
             result = table.insert(postgres_returning=[table.c.id]).execute({'persons': 1, 'full': False})
 
-            self.assertEqual(result.fetchall(), [(1,)])
+            eq_(result.fetchall(), [(1,)])
 
             @testing.fails_on('postgres', 'Known limitation of psycopg2')
             def test_executemany():
                 # return value is documented as failing with psycopg2/executemany
                 result2 = table.insert(postgres_returning=[table]).execute(
                      [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
-                self.assertEqual(result2.fetchall(), [(2, 2, False), (3,3,True)])
+                eq_(result2.fetchall(), [(2, 2, False), (3,3,True)])
             
             test_executemany()
             
             result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False})
-            self.assertEqual([dict(row) for row in result3], [{'double_id':8}])
+            eq_([dict(row) for row in result3], [{'double_id':8}])
 
             result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons')
-            self.assertEqual([dict(row) for row in result4], [{'persons': 10}])
+            eq_([dict(row) for row in result4], [{'persons': 10}])
         finally:
             table.drop()
 
@@ -128,11 +128,12 @@ class ReturningTest(TestBase, AssertsExecutionResults):
 class InsertTest(TestBase, AssertsExecutionResults):
     __only_on__ = 'postgres'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata
         metadata = MetaData(testing.db)
 
-    def tearDown(self):
+    def teardown(self):
         metadata.drop_all()
         metadata.tables.clear()
 
@@ -397,7 +398,8 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
 
     __only_on__ = 'postgres'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         con = testing.db.connect()
         for ddl in ('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42',
                     'CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0'):
@@ -410,7 +412,8 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
         con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)')
         con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)')
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         con = testing.db.connect()
         con.execute('DROP TABLE testtable')
         con.execute('DROP TABLE alt_schema.testtable')
@@ -421,32 +424,32 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
     def test_table_is_reflected(self):
         metadata = MetaData(testing.db)
         table = Table('testtable', metadata, autoload=True)
-        self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
-        self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger)
+        eq_(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
+        eq_(table.c.answer.type.__class__, postgres.PGInteger)
 
     def test_domain_is_reflected(self):
         metadata = MetaData(testing.db)
         table = Table('testtable', metadata, autoload=True)
-        self.assertEquals(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value")
-        self.assertFalse(table.columns.answer.nullable, "Expected reflected column to not be nullable.")
+        eq_(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value")
+        assert not table.columns.answer.nullable, "Expected reflected column to not be nullable."
 
     def test_table_is_reflected_alt_schema(self):
         metadata = MetaData(testing.db)
         table = Table('testtable', metadata, autoload=True, schema='alt_schema')
-        self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
-        self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger)
+        eq_(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
+        eq_(table.c.anything.type.__class__, postgres.PGInteger)
 
     def test_schema_domain_is_reflected(self):
         metadata = MetaData(testing.db)
         table = Table('testtable', metadata, autoload=True, schema='alt_schema')
-        self.assertEquals(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
-        self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
+        eq_(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
+        assert table.columns.answer.nullable, "Expected reflected column to be nullable."
 
     def test_crosschema_domain_is_reflected(self):
         metadata = MetaData(testing.db)
         table = Table('crosschema', metadata, autoload=True)
-        self.assertEquals(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
-        self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
+        eq_(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
+        assert table.columns.answer.nullable, "Expected reflected column to be nullable."
 
     def test_unknown_types(self):
         from sqlalchemy.databases import postgres
@@ -455,7 +458,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
         postgres.ischema_names = {}
         try:
             m2 = MetaData(testing.db)
-            self.assertRaises(exc.SAWarning, Table, "testtable", m2, autoload=True)
+            assert_raises(exc.SAWarning, Table, "testtable", m2, autoload=True)
 
             @testing.emits_warning('Did not recognize type')
             def warns():
@@ -519,15 +522,15 @@ class MiscTest(TestBase, AssertsExecutionResults):
         t = Table('mytable', MetaData(testing.db),
                   Column('id', Integer, primary_key=True),
                   Column('a', String(8)))
-        self.assertEquals(
+        eq_(
             str(t.select(distinct=t.c.a)),
             'SELECT DISTINCT ON (mytable.a) mytable.id, mytable.a \n'
             'FROM mytable')
-        self.assertEquals(
+        eq_(
             str(t.select(distinct=['id','a'])),
             'SELECT DISTINCT ON (id, a) mytable.id, mytable.a \n'
             'FROM mytable')
-        self.assertEquals(
+        eq_(
             str(t.select(distinct=[t.c.id, t.c.a])),
             'SELECT DISTINCT ON (mytable.id, mytable.a) mytable.id, mytable.a \n'
             'FROM mytable')
@@ -616,14 +619,14 @@ class MiscTest(TestBase, AssertsExecutionResults):
             users.insert().execute(id=3, name='name3')
             users.insert().execute(id=4, name='name4')
 
-            self.assertEquals(users.select().where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')])
-            self.assertEquals(users.select(use_labels=True).where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')])
+            eq_(users.select().where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')])
+            eq_(users.select(use_labels=True).where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')])
 
             users.delete().where(users.c.id==3).execute()
-            self.assertEquals(users.select().where(users.c.name=='name3').execute().fetchall(), [])
+            eq_(users.select().where(users.c.name=='name3').execute().fetchall(), [])
 
             users.update().where(users.c.name=='name4').execute(name='newname')
-            self.assertEquals(users.select(use_labels=True).where(users.c.id==4).execute().fetchall(), [(4, 'newname')])
+            eq_(users.select(use_labels=True).where(users.c.id==4).execute().fetchall(), [(4, 'newname')])
 
         finally:
             users.drop()
@@ -733,7 +736,8 @@ class TimezoneTest(TestBase, AssertsExecutionResults):
 
     __only_on__ = 'postgres'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global tztable, notztable, metadata
         metadata = MetaData(testing.db)
 
@@ -749,7 +753,8 @@ class TimezoneTest(TestBase, AssertsExecutionResults):
             Column("name", String(20)),
         )
         metadata.create_all()
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     def test_with_timezone(self):
@@ -769,7 +774,8 @@ class TimezoneTest(TestBase, AssertsExecutionResults):
 class ArrayTest(TestBase, AssertsExecutionResults):
     __only_on__ = 'postgres'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, arrtable
         metadata = MetaData(testing.db)
 
@@ -779,47 +785,48 @@ class ArrayTest(TestBase, AssertsExecutionResults):
             Column('strarr', postgres.PGArray(String(convert_unicode=True)), nullable=False)
         )
         metadata.create_all()
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     def test_reflect_array_column(self):
         metadata2 = MetaData(testing.db)
         tbl = Table('arrtable', metadata2, autoload=True)
-        self.assertTrue(isinstance(tbl.c.intarr.type, postgres.PGArray))
-        self.assertTrue(isinstance(tbl.c.strarr.type, postgres.PGArray))
-        self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer))
-        self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String))
+        assert isinstance(tbl.c.intarr.type, postgres.PGArray)
+        assert isinstance(tbl.c.strarr.type, postgres.PGArray)
+        assert isinstance(tbl.c.intarr.type.item_type, Integer)
+        assert isinstance(tbl.c.strarr.type.item_type, String)
 
     def test_insert_array(self):
         arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
         results = arrtable.select().execute().fetchall()
-        self.assertEquals(len(results), 1)
-        self.assertEquals(results[0]['intarr'], [1,2,3])
-        self.assertEquals(results[0]['strarr'], ['abc','def'])
+        eq_(len(results), 1)
+        eq_(results[0]['intarr'], [1,2,3])
+        eq_(results[0]['strarr'], ['abc','def'])
         arrtable.delete().execute()
 
     def test_array_where(self):
         arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
         arrtable.insert().execute(intarr=[4,5,6], strarr='ABC')
         results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall()
-        self.assertEquals(len(results), 1)
-        self.assertEquals(results[0]['intarr'], [1,2,3])
+        eq_(len(results), 1)
+        eq_(results[0]['intarr'], [1,2,3])
         arrtable.delete().execute()
 
     def test_array_concat(self):
         arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
         results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall()
-        self.assertEquals(len(results), 1)
-        self.assertEquals(results[0][0], [1,2,3,4,5,6])
+        eq_(len(results), 1)
+        eq_(results[0][0], [1,2,3,4,5,6])
         arrtable.delete().execute()
 
     def test_array_subtype_resultprocessor(self):
         arrtable.insert().execute(intarr=[4,5,6], strarr=[[u'm\xe4\xe4'], [u'm\xf6\xf6']])
         arrtable.insert().execute(intarr=[1,2,3], strarr=[u'm\xe4\xe4', u'm\xf6\xf6'])
         results = arrtable.select(order_by=[arrtable.c.intarr]).execute().fetchall()
-        self.assertEquals(len(results), 2)
-        self.assertEquals(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6'])
-        self.assertEquals(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']])
+        eq_(len(results), 2)
+        eq_(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6'])
+        eq_(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']])
         arrtable.delete().execute()
 
     def test_array_mutability(self):
@@ -839,23 +846,23 @@ class ArrayTest(TestBase, AssertsExecutionResults):
         sess.flush()
         sess.expunge_all()
         foo = sess.query(Foo).get(1)
-        self.assertEquals(foo.intarr, [1,2,3])
+        eq_(foo.intarr, [1,2,3])
 
         foo.intarr.append(4)
         sess.flush()
         sess.expunge_all()
         foo = sess.query(Foo).get(1)
-        self.assertEquals(foo.intarr, [1,2,3,4])
+        eq_(foo.intarr, [1,2,3,4])
 
         foo.intarr = []
         sess.flush()
         sess.expunge_all()
-        self.assertEquals(foo.intarr, [])
+        eq_(foo.intarr, [])
 
         foo.intarr = None
         sess.flush()
         sess.expunge_all()
-        self.assertEquals(foo.intarr, None)
+        eq_(foo.intarr, None)
 
         # Errors in r4217:
         foo = Foo()
@@ -872,16 +879,18 @@ class TimeStampTest(TestBase, AssertsExecutionResults):
         connection = engine.connect()
         s = select([func.TIMESTAMP("12/25/07").label("ts")])
         result = connection.execute(s).fetchone()
-        self.assertEqual(result[0], datetime.datetime(2007, 12, 25, 0, 0))
+        eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0))
 
 class ServerSideCursorsTest(TestBase, AssertsExecutionResults):
     __only_on__ = 'postgres'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global ss_engine
         ss_engine = engines.testing_engine(options={'server_side_cursors':True})
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         ss_engine.dispose()
 
     def test_uses_ss(self):
@@ -906,12 +915,12 @@ class ServerSideCursorsTest(TestBase, AssertsExecutionResults):
             nextid = ss_engine.execute(Sequence('test_table_id_seq'))
             test_table.insert().execute(id=nextid, data='data2')
 
-            self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2')])
+            eq_(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2')])
 
             test_table.update().where(test_table.c.id==2).values(data=test_table.c.data + ' updated').execute()
-            self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2 updated')])
+            eq_(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2 updated')])
             test_table.delete().execute()
-            self.assertEquals(test_table.count().scalar(), 0)
+            eq_(test_table.count().scalar(), 0)
         finally:
             test_table.drop(checkfirst=True)
 
@@ -921,7 +930,8 @@ class SpecialTypesTest(TestBase, ComparesTables):
     __only_on__ = 'postgres'
     __excluded_on__ = (('postgres', '<', (8, 3, 0)),)
     
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, table
         metadata = MetaData(testing.db)
         
@@ -935,7 +945,8 @@ class SpecialTypesTest(TestBase, ComparesTables):
         
         metadata.create_all()
     
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
     
     def test_reflection(self):
@@ -949,7 +960,8 @@ class MatchTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'postgres'
     __excluded_on__ = (('postgres', '<', (8, 3, 0)),)
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, cattable, matchtable
         metadata = MetaData(testing.db)
 
@@ -976,7 +988,8 @@ class MatchTest(TestBase, AssertsCompiledSQL):
             {'id': 5, 'title': 'Python in a Nutshell', 'category_id': 1}
         ])
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     def test_expression(self):
@@ -984,42 +997,40 @@ class MatchTest(TestBase, AssertsCompiledSQL):
 
     def test_simple_match(self):
         results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([2, 5], [r.id for r in results])
+        eq_([2, 5], [r.id for r in results])
 
     def test_simple_match_with_apostrophe(self):
         results = matchtable.select().where(matchtable.c.title.match("Matz''s")).execute().fetchall()
-        self.assertEquals([3], [r.id for r in results])
+        eq_([3], [r.id for r in results])
 
     def test_simple_derivative_match(self):
         results = matchtable.select().where(matchtable.c.title.match('nutshells')).execute().fetchall()
-        self.assertEquals([5], [r.id for r in results])
+        eq_([5], [r.id for r in results])
 
     def test_or_match(self):
         results1 = matchtable.select().where(or_(matchtable.c.title.match('nutshells'), 
                                                  matchtable.c.title.match('rubies'))
                                             ).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([3, 5], [r.id for r in results1])
+        eq_([3, 5], [r.id for r in results1])
         results2 = matchtable.select().where(matchtable.c.title.match('nutshells | rubies'), 
                                             ).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([3, 5], [r.id for r in results2])
+        eq_([3, 5], [r.id for r in results2])
         
 
     def test_and_match(self):
         results1 = matchtable.select().where(and_(matchtable.c.title.match('python'), 
                                                   matchtable.c.title.match('nutshells'))
                                             ).execute().fetchall()
-        self.assertEquals([5], [r.id for r in results1])
+        eq_([5], [r.id for r in results1])
         results2 = matchtable.select().where(matchtable.c.title.match('python & nutshells'), 
                                             ).execute().fetchall()
-        self.assertEquals([5], [r.id for r in results2])
+        eq_([5], [r.id for r in results2])
 
     def test_match_across_joins(self):
         results = matchtable.select().where(and_(cattable.c.id==matchtable.c.category_id, 
                                             or_(cattable.c.description.match('Ruby'), 
                                                 matchtable.c.title.match('nutshells')))
                                            ).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([1, 3, 5], [r.id for r in results])
+        eq_([1, 3, 5], [r.id for r in results])
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 93%
rename from test/dialect/sqlite.py
rename to test/dialect/test_sqlite.py
index d01be3521d4b786ee60c652ccb5bbaa233d5cb56..eb4581e20fcca7aa1249183ae3e757668d535038 100644 (file)
@@ -1,11 +1,11 @@
 """SQLite-specific tests."""
 
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import datetime
 from sqlalchemy import *
 from sqlalchemy import exc, sql
 from sqlalchemy.databases import sqlite
-from testlib import *
+from sqlalchemy.test import *
 
 
 class TestTypes(TestBase, AssertsExecutionResults):
@@ -34,22 +34,22 @@ class TestTypes(TestBase, AssertsExecutionResults):
             meta.drop_all()
 
     def test_string_dates_raise(self):
-        self.assertRaises(TypeError, testing.db.execute, select([1]).where(bindparam("date", type_=Date)), date=str(datetime.date(2007, 10, 30)))
+        assert_raises(TypeError, testing.db.execute, select([1]).where(bindparam("date", type_=Date)), date=str(datetime.date(2007, 10, 30)))
     
     def test_time_microseconds(self):
         dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125)  # 125 usec
-        self.assertEquals(str(dt), '2008-06-27 12:00:00.000125')
+        eq_(str(dt), '2008-06-27 12:00:00.000125')
         sldt = sqlite.SLDateTime()
         bp = sldt.bind_processor(None)
-        self.assertEquals(bp(dt), '2008-06-27 12:00:00.000125')
+        eq_(bp(dt), '2008-06-27 12:00:00.000125')
         
         rp = sldt.result_processor(None)
-        self.assertEquals(rp(bp(dt)), dt)
+        eq_(rp(bp(dt)), dt)
         
         sldt.__legacy_microseconds__ = True
         bp = sldt.bind_processor(None)
-        self.assertEquals(bp(dt), '2008-06-27 12:00:00.125')
-        self.assertEquals(rp(bp(dt)), dt)
+        eq_(bp(dt), '2008-06-27 12:00:00.125')
+        eq_(rp(bp(dt)), dt)
 
     def test_no_convert_unicode(self):
         """test no utf-8 encoding occurs"""
@@ -163,7 +163,7 @@ class TestDefaults(TestBase, AssertsExecutionResults):
             rt = Table('t_defaults', m2, autoload=True)
             expected = [c[1] for c in specs]
             for i, reflected in enumerate(rt.c):
-                self.assertEquals(reflected.server_default.arg.text, expected[i])
+                eq_(reflected.server_default.arg.text, expected[i])
         finally:
             m.drop_all()
 
@@ -184,7 +184,7 @@ class TestDefaults(TestBase, AssertsExecutionResults):
 
             rt = Table('r_defaults', m, autoload=True)
             for i, reflected in enumerate(rt.c):
-                self.assertEquals(reflected.server_default.arg.text, expected[i])
+                eq_(reflected.server_default.arg.text, expected[i])
         finally:
             db.execute("DROP TABLE r_defaults")
 
@@ -258,7 +258,7 @@ class DialectTest(TestBase, AssertsExecutionResults):
                                schema='alt_schema')
             meta.create_all(cx)
 
-            self.assertEquals(dialect.table_names(cx, 'alt_schema'),
+            eq_(dialect.table_names(cx, 'alt_schema'),
                               ['created'])
             assert len(alt_master.c) > 0
 
@@ -350,7 +350,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
                 table.insert().execute()
 
                 rows = table.select().execute().fetchall()
-                self.assertEquals(len(rows), wanted)
+                eq_(len(rows), wanted)
         finally:
             table.drop()
 
@@ -362,7 +362,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
 
     @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support')
     def test_empty_insert_pk2(self):
-        self.assertRaises(
+        assert_raises(
             exc.DBAPIError,
             self._test_empty_insert,
             Table('b', MetaData(testing.db),
@@ -371,7 +371,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
 
     @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support')
     def test_empty_insert_pk3(self):
-        self.assertRaises(
+        assert_raises(
             exc.DBAPIError,
             self._test_empty_insert,
             Table('c', MetaData(testing.db),
@@ -429,7 +429,8 @@ class MatchTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'sqlite'
     __skip_if__ = (full_text_search_missing, )
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, cattable, matchtable
         metadata = MetaData(testing.db)
         
@@ -465,7 +466,8 @@ class MatchTest(TestBase, AssertsCompiledSQL):
             {'id': 5, 'title': 'Python in a Nutshell', 'category_id': 1}
         ])
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     def test_expression(self):
@@ -473,29 +475,27 @@ class MatchTest(TestBase, AssertsCompiledSQL):
 
     def test_simple_match(self):
         results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([2, 5], [r.id for r in results])
+        eq_([2, 5], [r.id for r in results])
 
     def test_simple_prefix_match(self):
         results = matchtable.select().where(matchtable.c.title.match('nut*')).execute().fetchall()
-        self.assertEquals([5], [r.id for r in results])
+        eq_([5], [r.id for r in results])
 
     def test_or_match(self):
         results2 = matchtable.select().where(matchtable.c.title.match('nutshell OR ruby'), 
                                             ).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([3, 5], [r.id for r in results2])
+        eq_([3, 5], [r.id for r in results2])
         
 
     def test_and_match(self):
         results2 = matchtable.select().where(matchtable.c.title.match('python nutshell'), 
                                             ).execute().fetchall()
-        self.assertEquals([5], [r.id for r in results2])
+        eq_([5], [r.id for r in results2])
 
     def test_match_across_joins(self):
         results = matchtable.select().where(and_(cattable.c.id==matchtable.c.category_id, 
                                             cattable.c.description.match('Ruby'))
                                            ).order_by(matchtable.c.id).execute().fetchall()
-        self.assertEquals([1, 3], [r.id for r in results])
+        eq_([1, 3], [r.id for r in results])
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 85%
rename from test/dialect/sybase.py
rename to test/dialect/test_sybase.py
index 32b9904d8a5d6bf22dd49deb41b6c29bd3a0b48b..37de91d1c4f64e12ee8f0fa7159ad1d941ed5ff5 100644 (file)
@@ -1,8 +1,7 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy import sql
 from sqlalchemy.databases import sybase
-from testlib import *
+from sqlalchemy.test import *
 
 
 class CompileTest(TestBase, AssertsCompiledSQL):
@@ -27,5 +26,3 @@ class CompileTest(TestBase, AssertsCompiledSQL):
 
 
 
-if __name__ == "__main__":
-    testenv.main()
index 3c31d378ad1aa0644e190f91bd1b177b0aefc538..ec91243d24d775da407ceb0a60995b3042d1ec89 100644 (file)
@@ -1,5 +1,6 @@
-from testlib import sa, testing
-from testlib.testing import adict
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy.test.testing import adict
 
 
 class TablesTest(testing.TestBase):
@@ -27,41 +28,38 @@ class TablesTest(testing.TestBase):
     tables = None
     other_artifacts = None
 
-    def setUpAll(self):
-        if self.run_setup_bind is None:
-            assert self.bind is not None
-        assert self.run_deletes in (None, 'each')
-        if self.run_inserts == 'once':
-            assert self.run_deletes is None
+    @classmethod
+    def setup_class(cls):
+        if cls.run_setup_bind is None:
+            assert cls.bind is not None
+        assert cls.run_deletes in (None, 'each')
+        if cls.run_inserts == 'once':
+            assert cls.run_deletes is None
 
-        cls = self.__class__
         if cls.tables is None:
             cls.tables = adict()
         if cls.other_artifacts is None:
             cls.other_artifacts = adict()
 
-        if self.bind is None:
-            setattr(type(self), 'bind', self.setup_bind())
-
-        if self.metadata is None:
-            setattr(type(self), 'metadata', sa.MetaData())
+        if cls.bind is None:
+            setattr(cls, 'bind', cls.setup_bind())
 
-        if self.metadata.bind is None:
-            self.metadata.bind = self.bind
+        if cls.metadata is None:
+            setattr(cls, 'metadata', sa.MetaData())
 
-        if self.run_define_tables:
-            self.define_tables(self.metadata)
-            self.metadata.create_all()
-            self.tables.update(self.metadata.tables)
+        if cls.metadata.bind is None:
+            cls.metadata.bind = cls.bind
 
-        if self.run_inserts:
-            self._load_fixtures()
-            self.insert_data()
+        if cls.run_define_tables == 'once':
+            cls.define_tables(cls.metadata)
+            cls.metadata.create_all()
+            cls.tables.update(cls.metadata.tables)
 
-    def setUp(self):
-        if self._sa_first_test:
-            return
+        if cls.run_inserts == 'once':
+            cls._load_fixtures()
+            cls.insert_data()
 
+    def setup(self):
         cls = self.__class__
 
         if self.setup_bind == 'each':
@@ -79,7 +77,7 @@ class TablesTest(testing.TestBase):
             self._load_fixtures()
             self.insert_data()
 
-    def tearDown(self):
+    def teardown(self):
         # no need to run deletes if tables are recreated on setup
         if self.run_define_tables != 'each' and self.run_deletes:
             for table in reversed(self.metadata.sorted_tables):
@@ -92,33 +90,39 @@ class TablesTest(testing.TestBase):
         if self.run_dispose_bind == 'each':
             self.dispose_bind(self.bind)
 
-    def tearDownAll(self):
-        self.metadata.drop_all()
+    @classmethod
+    def teardown_class(cls):
+        cls.metadata.drop_all()
 
-        if self.dispose_bind:
-            self.dispose_bind(self.bind)
+        if cls.dispose_bind:
+            cls.dispose_bind(cls.bind)
 
-        self.metadata.bind = None
+        cls.metadata.bind = None
 
-        if self.run_setup_bind is not None:
-            self.bind = None
+        if cls.run_setup_bind is not None:
+            cls.bind = None
 
-    def setup_bind(self):
+    @classmethod
+    def setup_bind(cls):
         return testing.db
 
-    def dispose_bind(self, bind):
+    @classmethod
+    def dispose_bind(cls, bind):
         if hasattr(bind, 'dispose'):
             bind.dispose()
         elif hasattr(bind, 'close'):
             bind.close()
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         raise NotImplementedError()
 
-    def fixtures(self):
+    @classmethod
+    def fixtures(cls):
         return {}
 
-    def insert_data(self):
+    @classmethod
+    def insert_data(cls):
         pass
 
     def sql_count_(self, count, fn):
@@ -147,14 +151,17 @@ class TablesTest(testing.TestBase):
 class AltEngineTest(testing.TestBase):
     engine = None
 
-    def setUpAll(self):
-        type(self).engine = self.create_engine()
-        testing.TestBase.setUpAll(self)
-
-    def tearDownAll(self):
-        testing.TestBase.tearDownAll(self)
-        self.engine.dispose()
-        type(self).engine = None
-
-    def create_engine(self):
+    @classmethod
+    def setup_class(cls):
+        cls.engine = cls.create_engine()
+        super(AltEngineTest, cls).setup_class()
+        
+    @classmethod
+    def teardown_class(cls):
+        cls.engine.dispose()
+        cls.engine = None
+        super(AltEngineTest, cls).teardown_class()
+        
+    @classmethod
+    def create_engine(cls):
         raise NotImplementedError
diff --git a/test/engine/alltests.py b/test/engine/alltests.py
deleted file mode 100644 (file)
index ed722aa..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-
-def suite():
-    modules_to_test = (
-        # connectivity, execution
-        'engine.parseconnect',
-        'engine.pool',
-        'engine.bind',
-        'engine.reconnect',
-        'engine.execute',
-        'engine.metadata',
-        'engine.transaction',
-
-        # schema/tables
-        'engine.reflection',
-        'engine.ddlevents',
-
-        )
-    alltests = unittest.TestSuite()
-    for name in modules_to_test:
-        mod = __import__(name)
-        for token in name.split('.')[1:]:
-            mod = getattr(mod, token)
-        alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    return alltests
-
-
-if __name__ == '__main__':
-    testenv.main(suite())
similarity index 96%
rename from test/engine/bind.py
rename to test/engine/test_bind.py
index 5b8605aada8f34462db6c9520e22865acd12c8f8..7fd3009bca383f5055f82c1472c49ed3ec94f258 100644 (file)
@@ -1,11 +1,14 @@
 """tests the "bind" attribute/argument across schema and SQL,
 including the deprecated versions of these arguments"""
 
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 from sqlalchemy import engine, exc
 from sqlalchemy import MetaData, ThreadLocalMetaData
-from testlib.sa import Table, Column, Integer, text
-from testlib import sa, testing
+from sqlalchemy import Integer, text
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as sa
+from sqlalchemy.test import testing
 
 
 class BindTest(testing.TestBase):
@@ -43,7 +46,7 @@ class BindTest(testing.TestBase):
                 meth()
                 assert False
             except exc.UnboundExecutionError, e:
-                self.assertEquals(
+                eq_(
                     str(e),
                     "The MetaData "
                     "is not bound to an Engine or Connection.  "
@@ -61,7 +64,7 @@ class BindTest(testing.TestBase):
                 meth()
                 assert False
             except exc.UnboundExecutionError, e:
-                self.assertEquals(
+                eq_(
                     str(e),
                     "The Table 'test_table' "
                     "is not bound to an Engine or Connection.  "
@@ -85,7 +88,7 @@ class BindTest(testing.TestBase):
                 meth()
                 assert False
             except exc.UnboundExecutionError, e:
-                self.assertEquals(
+                eq_(
                     str(e),
                     "The Table 'test_table' "
                     "is not bound to an Engine or Connection.  "
@@ -219,5 +222,3 @@ class BindTest(testing.TestBase):
             metadata.drop_all(bind=testing.db)
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 92%
rename from test/engine/ddlevents.py
rename to test/engine/test_ddlevents.py
index 8274c63476ea01ea1e45de32e749bba074163429..5716006d93c54359ffa620e35e15d60bb4b25ffb 100644 (file)
@@ -1,9 +1,11 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy.schema import DDL
 from sqlalchemy import create_engine
-from testlib.sa import MetaData, Table, Column, Integer, String
-import testlib.sa as tsa
-from testlib import TestBase, testing, engines
+from sqlalchemy import MetaData, Integer, String
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase, testing, engines
 
 
 class DDLEventTest(TestBase):
@@ -37,7 +39,7 @@ class DDLEventTest(TestBase):
             assert bind is self.bind
             self.state = action
 
-    def setUp(self):
+    def setup(self):
         self.bind = engines.mock_engine()
         self.metadata = MetaData()
         self.table = Table('t', self.metadata, Column('id', Integer))
@@ -174,14 +176,14 @@ class DDLEventTest(TestBase):
         fn = lambda *a: None
 
         table.append_ddl_listener('before-create', fn)
-        self.assertRaises(LookupError, table.append_ddl_listener, 'blah', fn)
+        assert_raises(LookupError, table.append_ddl_listener, 'blah', fn)
 
         metadata.append_ddl_listener('before-create', fn)
-        self.assertRaises(LookupError, metadata.append_ddl_listener, 'blah', fn)
+        assert_raises(LookupError, metadata.append_ddl_listener, 'blah', fn)
 
 
 class DDLExecutionTest(TestBase):
-    def setUp(self):
+    def setup(self):
         self.engine = engines.mock_engine()
         self.metadata = MetaData(self.engine)
         self.users = Table('users', self.metadata,
@@ -303,19 +305,19 @@ class DDLTest(TestBase):
 
         ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
 
-        self.assertEquals(ddl._expand(sane_alone, bind), '-t-t')
-        self.assertEquals(ddl._expand(sane_schema, bind), 's-t-s.t')
-        self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
-        self.assertEquals(ddl._expand(insane_schema, bind),
+        eq_(ddl._expand(sane_alone, bind), '-t-t')
+        eq_(ddl._expand(sane_schema, bind), 's-t-s.t')
+        eq_(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
+        eq_(ddl._expand(insane_schema, bind),
                           '"s s"-"t t"-"s s"."t t"')
 
         # overrides are used piece-meal and verbatim.
         ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s',
                   context={'schema':'S S', 'table': 'T T', 'bonus': 'b'})
-        self.assertEquals(ddl._expand(sane_alone, bind), 'S S-T T-t-b')
-        self.assertEquals(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b')
-        self.assertEquals(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b')
-        self.assertEquals(ddl._expand(insane_schema, bind),
+        eq_(ddl._expand(sane_alone, bind), 'S S-T T-t-b')
+        eq_(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b')
+        eq_(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b')
+        eq_(ddl._expand(insane_schema, bind),
                           'S S-T T-"s s"."t t"-b')
     def test_filter(self):
         cx = self.mock_engine()
@@ -338,5 +340,3 @@ class DDLTest(TestBase):
         assert repr(DDL('s', on='engine', context={'a':1}))
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 94%
rename from test/engine/execute.py
rename to test/engine/test_execute.py
index 515c99d309bd9d44d55ec44459fb0dcac7cd8d1d..08bf80fe2f26c8758137c820e5f0eb72f6ed0011 100644 (file)
@@ -1,15 +1,17 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 import re
 from sqlalchemy.interfaces import ConnectionProxy
-from testlib.sa import MetaData, Table, Column, Integer, String, INT, \
-     VARCHAR, func, bindparam
-import testlib.sa as tsa
-from testlib import TestBase, testing, engines
+from sqlalchemy import MetaData, Integer, String, INT, VARCHAR, func, bindparam
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase, testing, engines
 
 
 users, metadata = None, None
 class ExecuteTest(TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global users, metadata
         metadata = MetaData(testing.db)
         users = Table('users', metadata,
@@ -18,9 +20,10 @@ class ExecuteTest(TestBase):
         )
         metadata.create_all()
 
-    def tearDown(self):
+    def teardown(self):
         testing.db.connect().execute(users.delete())
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite')
@@ -82,7 +85,7 @@ class ExecuteTest(TestBase):
     def test_empty_insert(self):
         """test that execute() interprets [] as a list with no params"""
         result = testing.db.execute(users.insert().values(user_name=bindparam('name')), [])
-        self.assertEquals(result.rowcount, 1)
+        eq_(result.rowcount, 1)
 
 class ProxyConnectionTest(TestBase):
     @testing.fails_on('firebird', 'Data type unknown')
@@ -162,5 +165,3 @@ class ProxyConnectionTest(TestBase):
             assert_stmts(cursor, cursor_stmts)
     
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 93%
rename from test/engine/metadata.py
rename to test/engine/test_metadata.py
index c8fc6f7e0fdf4b925f052a0cc4ffde5ec0f2ba08..024d1b854f88bdd84c1ebaa9ec151e4e3a0544bc 100644 (file)
@@ -1,11 +1,12 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
 import pickle
 from sqlalchemy import MetaData
-from testlib.sa import Table, Column, Integer, String, UniqueConstraint, \
-     CheckConstraint, ForeignKey
-import testlib.sa as tsa
-from testlib import TestBase, ComparesTables, testing, engines
-from testlib.testing import eq_
+from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase, ComparesTables, testing, engines
+from sqlalchemy.test.testing import eq_
 
 class MetaDataTest(TestBase, ComparesTables):
     def test_metadata_connect(self):
@@ -137,13 +138,13 @@ class MetaDataTest(TestBase, ComparesTables):
         
         
     def test_nonexistent(self):
-        self.assertRaises(tsa.exc.NoSuchTableError, Table,
+        assert_raises(tsa.exc.NoSuchTableError, Table,
                           'fake_table',
                           MetaData(testing.db), autoload=True)
 
 
 class TableOptionsTest(TestBase):
-    def setUp(self):
+    def setup(self):
         self.engine = engines.mock_engine()
         self.metadata = MetaData(self.engine)
 
@@ -160,5 +161,3 @@ class TableOptionsTest(TestBase):
         table2.create()
         assert [str(x) for x in self.engine.mock if 'CREATE VIRTUAL TABLE' in str(x)]
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 98%
rename from test/engine/parseconnect.py
rename to test/engine/test_parseconnect.py
index c82ca6d58d38568a468523dd46232e21127465cd..6b7ac37b20f7d14ae4fe46c36679268380be3698 100644 (file)
@@ -1,9 +1,8 @@
-import testenv; testenv.configure_for_tests()
 import ConfigParser, StringIO
 import sqlalchemy.engine.url as url
 from sqlalchemy import create_engine, engine_from_config
-import testlib.sa as tsa
-from testlib import TestBase
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase
 
 
 class ParseConnectTest(TestBase):
@@ -229,5 +228,3 @@ class MockCursor(object):
         pass
 mock_dbapi = MockDBAPI()
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 99%
rename from test/engine/pool.py
rename to test/engine/test_pool.py
index b712e24128d9fdff472b45b59dcbda3571e29000..43a0fc38b790fc5893ef15497874437559820322 100644 (file)
@@ -1,8 +1,7 @@
-import testenv; testenv.configure_for_tests()
 import threading, time, gc
 from sqlalchemy import pool, interfaces
-import testlib.sa as tsa
-from testlib import TestBase
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase
 
 
 mcid = 1
@@ -37,10 +36,11 @@ mock_dbapi = MockDBAPI()
 
 
 class PoolTestBase(TestBase):    
-    def setUp(self):
+    def setup(self):
         pool.clear_managers()
         
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
        pool.clear_managers()
 
 class PoolTest(PoolTestBase):
@@ -662,5 +662,3 @@ class NullPoolTest(PoolTestBase):
     
         
     
-if __name__ == "__main__":
-    testenv.main()
similarity index 88%
rename from test/engine/reconnect.py
rename to test/engine/test_reconnect.py
index 4f383d2dde6e5f88c7e51c522e82ce194978082e..3a525c2a702e1f835c160f0e3a4d30add230b7c4 100644 (file)
@@ -1,8 +1,10 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 import weakref
-from testlib.sa import select, MetaData, Table, Column, Integer, String, pool
-import testlib.sa as tsa
-from testlib import TestBase, testing, engines
+from sqlalchemy import select, MetaData, Integer, String, pool
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase, testing, engines
 import time
 import gc
 
@@ -47,7 +49,7 @@ class MockCursor(object):
 
 db, dbapi = None, None
 class MockReconnectTest(TestBase):
-    def setUp(self):
+    def setup(self):
         global db, dbapi
         dbapi = MockDBAPI()
 
@@ -176,17 +178,17 @@ class MockReconnectTest(TestBase):
 
 engine = None
 class RealReconnectTest(TestBase):
-    def setUp(self):
+    def setup(self):
         global engine
         engine = engines.reconnecting_engine()
 
-    def tearDown(self):
+    def teardown(self):
         engine.dispose()
 
     def test_reconnect(self):
         conn = engine.connect()
 
-        self.assertEquals(conn.execute(select([1])).scalar(), 1)
+        eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.closed
 
         engine.test_shutdown()
@@ -202,7 +204,7 @@ class RealReconnectTest(TestBase):
         assert conn.invalidated
 
         assert conn.invalidated
-        self.assertEquals(conn.execute(select([1])).scalar(), 1)
+        eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.invalidated
 
         # one more time
@@ -214,7 +216,7 @@ class RealReconnectTest(TestBase):
             if not e.connection_invalidated:
                 raise
         assert conn.invalidated
-        self.assertEquals(conn.execute(select([1])).scalar(), 1)
+        eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.invalidated
 
         conn.close()
@@ -222,7 +224,7 @@ class RealReconnectTest(TestBase):
     def test_null_pool(self):
         engine = engines.reconnecting_engine(options=dict(poolclass=pool.NullPool))
         conn = engine.connect()
-        self.assertEquals(conn.execute(select([1])).scalar(), 1)
+        eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.closed
         engine.test_shutdown()
         try:
@@ -233,12 +235,12 @@ class RealReconnectTest(TestBase):
                 raise
         assert not conn.closed
         assert conn.invalidated
-        self.assertEquals(conn.execute(select([1])).scalar(), 1)
+        eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.invalidated
         
     def test_close(self):
         conn = engine.connect()
-        self.assertEquals(conn.execute(select([1])).scalar(), 1)
+        eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.closed
 
         engine.test_shutdown()
@@ -252,14 +254,14 @@ class RealReconnectTest(TestBase):
 
         conn.close()
         conn = engine.connect()
-        self.assertEquals(conn.execute(select([1])).scalar(), 1)
+        eq_(conn.execute(select([1])).scalar(), 1)
 
     def test_with_transaction(self):
         conn = engine.connect()
 
         trans = conn.begin()
 
-        self.assertEquals(conn.execute(select([1])).scalar(), 1)
+        eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.closed
 
         engine.test_shutdown()
@@ -295,7 +297,7 @@ class RealReconnectTest(TestBase):
         assert not trans.is_active
 
         assert conn.invalidated
-        self.assertEquals(conn.execute(select([1])).scalar(), 1)
+        eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.invalidated
 
 class RecycleTest(TestBase):
@@ -304,19 +306,19 @@ class RecycleTest(TestBase):
             engine = engines.reconnecting_engine(options={'pool_recycle':1, 'pool_threadlocal':threadlocal})
         
             conn = engine.contextual_connect()
-            self.assertEquals(conn.execute(select([1])).scalar(), 1)
+            eq_(conn.execute(select([1])).scalar(), 1)
             conn.close()
 
             engine.test_shutdown()
             time.sleep(2)
     
             conn = engine.contextual_connect()
-            self.assertEquals(conn.execute(select([1])).scalar(), 1)
+            eq_(conn.execute(select([1])).scalar(), 1)
             conn.close()
     
 meta, table, engine = None, None, None
 class InvalidateDuringResultTest(TestBase):
-    def setUp(self):
+    def setup(self):
         global meta, table, engine
         engine = engines.reconnecting_engine()
         meta = MetaData(engine)
@@ -328,7 +330,7 @@ class InvalidateDuringResultTest(TestBase):
             [{'id':i, 'name':'row %d' % i} for i in range(1, 100)]
         )
 
-    def tearDown(self):
+    def teardown(self):
         meta.drop_all()
         engine.dispose()
 
@@ -350,5 +352,3 @@ class InvalidateDuringResultTest(TestBase):
 
         assert conn.invalidated
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 96%
rename from test/engine/reflection.py
rename to test/engine/test_reflection.py
index d8412237fbdc1b53c1d110ef8010ebc2b30a7b8b..ea80776a6a1bb4217fb7542a3738a9c722468537 100644 (file)
@@ -1,8 +1,11 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import StringIO, unicodedata
 import sqlalchemy as sa
-from testlib.sa import MetaData, Table, Column
-from testlib import TestBase, ComparesTables, testing, engines, sa as tsa
+from sqlalchemy import MetaData
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase, ComparesTables, testing, engines
 
 
 metadata, users = None, None
@@ -62,7 +65,7 @@ class ReflectionTest(TestBase, ComparesTables):
             foo = Table('foo', meta2, autoload=True,
                         include_columns=['b', 'f', 'e'])
             # test that cols come back in original order
-            self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f'])
+            eq_([c.name for c in foo.c], ['b', 'e', 'f'])
             for c in ('b', 'f', 'e'):
                 assert c in foo.c
             for c in ('a', 'c', 'd'):
@@ -73,7 +76,7 @@ class ReflectionTest(TestBase, ComparesTables):
             foo = Table('foo', meta3, autoload=True)
             foo = Table('foo', meta3, include_columns=['b', 'f', 'e'],
                         useexisting=True)
-            self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f'])
+            eq_([c.name for c in foo.c], ['b', 'e', 'f'])
             for c in ('b', 'f', 'e'):
                 assert c in foo.c
             for c in ('a', 'c', 'd'):
@@ -103,7 +106,7 @@ class ReflectionTest(TestBase, ComparesTables):
         dialect_module.ischema_names = {}
         try:
             m2 = MetaData(testing.db)
-            self.assertRaises(tsa.exc.SAWarning, Table, "test", m2, autoload=True)
+            assert_raises(tsa.exc.SAWarning, Table, "test", m2, autoload=True)
 
             @testing.emits_warning('Did not recognize type')
             def warns():
@@ -282,7 +285,7 @@ class ReflectionTest(TestBase, ComparesTables):
             a2 = Table('a', m2, include_columns=['z'], autoload=True)
             b2 = Table('b', m2, autoload=True)
             
-            self.assertRaises(tsa.exc.NoReferencedColumnError, a2.join, b2)
+            assert_raises(tsa.exc.NoReferencedColumnError, a2.join, b2)
         finally:
             meta.drop_all()
         
@@ -405,7 +408,7 @@ class ReflectionTest(TestBase, ComparesTables):
             Column('slot', sa.String(128)),
             )
             
-        self.assertRaisesMessage(tsa.exc.InvalidRequestError, "Could not find table 'pkgs' with which to generate a foreign key", metadata.create_all)
+        assert_raises_message(tsa.exc.InvalidRequestError, "Could not find table 'pkgs' with which to generate a foreign key", metadata.create_all)
 
     def test_composite_pks(self):
         """test reflection of a composite primary key"""
@@ -608,7 +611,8 @@ class ReflectionTest(TestBase, ComparesTables):
             m1.drop_all()
 
 class CreateDropTest(TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, users
         metadata = MetaData()
         users = Table('users', metadata,
@@ -656,16 +660,16 @@ class CreateDropTest(TestBase):
 
     def test_createdrop(self):
         metadata.create_all(bind=testing.db)
-        self.assertEqual( testing.db.has_table('items'), True )
-        self.assertEqual( testing.db.has_table('email_addresses'), True )
+        eq_( testing.db.has_table('items'), True )
+        eq_( testing.db.has_table('email_addresses'), True )
         metadata.create_all(bind=testing.db)
-        self.assertEqual( testing.db.has_table('items'), True )
+        eq_( testing.db.has_table('items'), True )
 
         metadata.drop_all(bind=testing.db)
-        self.assertEqual( testing.db.has_table('items'), False )
-        self.assertEqual( testing.db.has_table('email_addresses'), False )
+        eq_( testing.db.has_table('items'), False )
+        eq_( testing.db.has_table('email_addresses'), False )
         metadata.drop_all(bind=testing.db)
-        self.assertEqual( testing.db.has_table('items'), False )
+        eq_( testing.db.has_table('items'), False )
 
     def test_tablenames(self):
         metadata.create_all(bind=testing.db)
@@ -800,7 +804,8 @@ class SchemaTest(TestBase):
 
 
 class HasSequenceTest(TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, users
         metadata = MetaData()
         users = Table('users', metadata,
@@ -811,10 +816,8 @@ class HasSequenceTest(TestBase):
     @testing.requires.sequences
     def test_hassequence(self):
         metadata.create_all(bind=testing.db)
-        self.assertEqual(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), True)
+        eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), True)
         metadata.drop_all(bind=testing.db)
-        self.assertEqual(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False)
+        eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False)
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 96%
rename from test/engine/transaction.py
rename to test/engine/test_transaction.py
index 1fa3856108524d5f01e41b60a183c9a56fb5c0e1..7d40adf6d0b0a64a8ceefdbac8ce9f6939066806 100644 (file)
@@ -1,13 +1,15 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import sys, time, threading
-from testlib.sa import create_engine, MetaData, Table, Column, INT, VARCHAR, \
-     Sequence, select, Integer, String, func, text
-from testlib import TestBase, testing
+from sqlalchemy import create_engine, MetaData, INT, VARCHAR, Sequence, select, Integer, String, func, text
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.test import TestBase, testing
 
 
 users, metadata = None, None
 class TransactionTest(TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global users, metadata
         metadata = MetaData()
         users = Table('query_users', metadata,
@@ -17,9 +19,10 @@ class TransactionTest(TestBase):
         )
         users.create(testing.db)
 
-    def tearDown(self):
+    def teardown(self):
         testing.db.connect().execute(users.delete())
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         users.drop(testing.db)
 
     def test_commits(self):
@@ -166,7 +169,7 @@ class TransactionTest(TestBase):
         connection.execute(users.insert(), user_id=3, user_name='user3')
         transaction.commit()
 
-        self.assertEquals(
+        eq_(
             connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
             [(1,),(3,)]
         )
@@ -183,7 +186,7 @@ class TransactionTest(TestBase):
         connection.execute(users.insert(), user_id=3, user_name='user3')
         transaction.commit()
 
-        self.assertEquals(
+        eq_(
             connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
             [(1,),(2,),(3,)]
         )
@@ -202,7 +205,7 @@ class TransactionTest(TestBase):
         connection.execute(users.insert(), user_id=4, user_name='user4')
         transaction.commit()
 
-        self.assertEquals(
+        eq_(
             connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
             [(1,),(4,)]
         )
@@ -230,7 +233,7 @@ class TransactionTest(TestBase):
         transaction.prepare()
         transaction.rollback()
 
-        self.assertEquals(
+        eq_(
             connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
             [(1,),(2,)]
         )
@@ -264,7 +267,7 @@ class TransactionTest(TestBase):
 
         transaction.commit()
 
-        self.assertEquals(
+        eq_(
             connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
             [(1,),(2,),(5,)]
         )
@@ -285,19 +288,17 @@ class TransactionTest(TestBase):
         connection.close()
         connection2 = testing.db.connect()
 
-        self.assertEquals(
+        eq_(
             connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
             []
         )
 
         recoverables = connection2.recover_twophase()
-        self.assertTrue(
-            transaction.xid in recoverables
-        )
+        assert transaction.xid in recoverables
 
         connection2.commit_prepared(transaction.xid, recover=True)
 
-        self.assertEquals(
+        eq_(
             connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
             [(1,)]
         )
@@ -327,16 +328,18 @@ class TransactionTest(TestBase):
         xa.commit()
 
         result = conn.execute(select([users.c.user_name]).order_by(users.c.user_id))
-        self.assertEqual(result.fetchall(), [('user1',),('user4',)])
+        eq_(result.fetchall(), [('user1',),('user4',)])
 
         conn.close()
 
 class AutoRollbackTest(TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata
         metadata = MetaData()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all(testing.db)
 
     def test_rollback_deadlock(self):
@@ -368,17 +371,19 @@ class ExplicitAutoCommitTest(TestBase):
 
     __only_on__ = 'postgres'
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, foo
         metadata = MetaData(testing.db)
         foo = Table('foo', metadata, Column('id', Integer, primary_key=True), Column('data', String(100)))
         metadata.create_all()
         testing.db.execute("create function insert_foo(varchar) returns integer as 'insert into foo(data) values ($1);select 1;' language sql")
 
-    def tearDown(self):
+    def teardown(self):
         foo.delete().execute()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         testing.db.execute("drop function insert_foo(varchar)")
         metadata.drop_all()
 
@@ -437,7 +442,8 @@ class ExplicitAutoCommitTest(TestBase):
 
 tlengine = None
 class TLTransactionTest(TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global users, metadata, tlengine
         tlengine = create_engine(testing.db.url, strategy='threadlocal')
         metadata = MetaData()
@@ -447,15 +453,16 @@ class TLTransactionTest(TestBase):
             test_needs_acid=True,
         )
         users.create(tlengine)
-    def tearDown(self):
+    def teardown(self):
         tlengine.execute(users.delete())
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         users.drop(tlengine)
         tlengine.dispose()
 
     def test_nested_unsupported(self):
-        self.assertRaises(NotImplementedError, tlengine.contextual_connect().begin_nested)
-        self.assertRaises(NotImplementedError, tlengine.begin_nested)
+        assert_raises(NotImplementedError, tlengine.contextual_connect().begin_nested)
+        assert_raises(NotImplementedError, tlengine.begin_nested)
         
     def test_connection_close(self):
         """test that when connections are closed for real, transactions are rolled back and disposed."""
@@ -688,14 +695,15 @@ class TLTransactionTest(TestBase):
         tlengine.prepare()
         tlengine.rollback()
 
-        self.assertEquals(
+        eq_(
             tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
             [(1,),(2,)]
         )
 
 counters = None
 class ForUpdateTest(TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global counters, metadata
         metadata = MetaData()
         counters = Table('forupdate_counters', metadata,
@@ -704,9 +712,10 @@ class ForUpdateTest(TestBase):
             test_needs_acid=True,
         )
         counters.create(testing.db)
-    def tearDown(self):
+    def teardown(self):
         testing.db.connect().execute(counters.delete())
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         counters.drop(testing.db)
 
     def increment(self, count, errors, update_style=True, delay=0.005):
@@ -829,5 +838,3 @@ class ForUpdateTest(TestBase):
         self.assert_(len(errors) != 0)
 
 
-if __name__ == "__main__":
-    testenv.main()
diff --git a/test/ext/alltests.py b/test/ext/alltests.py
deleted file mode 100644 (file)
index 9f5353e..0000000
+++ /dev/null
@@ -1,36 +0,0 @@
-import testenv; testenv.configure_for_tests()
-import doctest, sys
-
-from testlib import sa_unittest as unittest
-
-
-def suite():
-    unittest_modules = (
-        'ext.declarative',
-        'ext.orderinglist',
-        'ext.associationproxy',
-        'ext.serializer',
-        'ext.compiler',
-        )
-
-    if sys.version_info < (2, 4):
-        doctest_modules = ()
-    else:
-        doctest_modules = (
-            ('sqlalchemy.ext.orderinglist', {'optionflags': doctest.ELLIPSIS}),
-            ('sqlalchemy.ext.sqlsoup', {})
-            )
-
-    alltests = unittest.TestSuite()
-    for name in unittest_modules:
-        mod = __import__(name)
-        for token in name.split('.')[1:]:
-            mod = getattr(mod, token)
-        alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    for name, opts in doctest_modules:
-        alltests.addTest(doctest.DocTestSuite(name, **opts))
-    return alltests
-
-
-if __name__ == '__main__':
-    testenv.main(suite())
similarity index 97%
rename from test/ext/associationproxy.py
rename to test/ext/test_associationproxy.py
index 821ed90721f51d3bc2f612cbbe53a2abb00e6b4f..742f98baf870431f4b504a5b9d006f837cbfdb5b 100644 (file)
@@ -1,10 +1,10 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import gc
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.orm.collections import collection
 from sqlalchemy.ext.associationproxy import *
-from testlib import *
+from sqlalchemy.test import *
 
 
 class DictCollection(dict):
@@ -34,7 +34,7 @@ class ObjectCollection(object):
         return iter(self.values)
 
 class _CollectionOperations(TestBase):
-    def setUp(self):
+    def setup(self):
         collection_class = self.collection_class
 
         metadata = MetaData(testing.db)
@@ -77,7 +77,7 @@ class _CollectionOperations(TestBase):
         self.session = create_session()
         self.Parent, self.Child = Parent, Child
 
-    def tearDown(self):
+    def teardown(self):
         self.metadata.drop_all()
 
     def roundtrip(self, obj):
@@ -189,7 +189,7 @@ class _CollectionOperations(TestBase):
         self.assert_(p1.children == after)
         self.assert_([c.name for c in p1._children] == after)
 
-        self.assertRaises(TypeError, set, [p1.children])
+        assert_raises(TypeError, set, [p1.children])
 
         p1.children *= 0
         after = []
@@ -342,7 +342,7 @@ class CustomDictTest(DictTest):
         except TypeError:
             self.assert_(True)
 
-        self.assertRaises(TypeError, set, [p1.children])
+        assert_raises(TypeError, set, [p1.children])
 
 
 class SetTest(_CollectionOperations):
@@ -458,7 +458,7 @@ class SetTest(_CollectionOperations):
         except TypeError:
             self.assert_(True)
 
-        self.assertRaises(TypeError, set, [p1.children])
+        assert_raises(TypeError, set, [p1.children])
 
 
     def test_set_comparisons(self):
@@ -473,19 +473,19 @@ class SetTest(_CollectionOperations):
                       set(['c','d']), set(['e', 'f', 'g']),
                       set()):
 
-            self.assertEqual(p1.children.union(other),
+            eq_(p1.children.union(other),
                              control.union(other))
-            self.assertEqual(p1.children.difference(other),
+            eq_(p1.children.difference(other),
                              control.difference(other))
-            self.assertEqual((p1.children - other),
+            eq_((p1.children - other),
                              (control - other))
-            self.assertEqual(p1.children.intersection(other),
+            eq_(p1.children.intersection(other),
                              control.intersection(other))
-            self.assertEqual(p1.children.symmetric_difference(other),
+            eq_(p1.children.symmetric_difference(other),
                              control.symmetric_difference(other))
-            self.assertEqual(p1.children.issubset(other),
+            eq_(p1.children.issubset(other),
                              control.issubset(other))
-            self.assertEqual(p1.children.issuperset(other),
+            eq_(p1.children.issuperset(other),
                              control.issuperset(other))
 
             self.assert_((p1.children == other)  ==  (control == other))
@@ -714,7 +714,7 @@ class ScalarTest(TestBase):
 
 
 class LazyLoadTest(TestBase):
-    def setUp(self):
+    def setup(self):
         metadata = MetaData(testing.db)
 
         parents_table = Table('Parent', metadata,
@@ -748,7 +748,7 @@ class LazyLoadTest(TestBase):
         self.Parent, self.Child = Parent, Child
         self.table = parents_table
 
-    def tearDown(self):
+    def teardown(self):
         self.metadata.drop_all()
 
     def roundtrip(self, obj):
@@ -823,7 +823,7 @@ class LazyLoadTest(TestBase):
 
 
 class ReconstitutionTest(TestBase):
-    def setUp(self):
+    def setup(self):
         metadata = MetaData(testing.db)
         parents = Table('parents', metadata,
                         Column('id', Integer, primary_key=True,
@@ -852,7 +852,7 @@ class ReconstitutionTest(TestBase):
         self.metadata = metadata
         self.Parent = Parent
 
-    def tearDown(self):
+    def teardown(self):
         self.metadata.drop_all()
 
     def test_weak_identity_map(self):
@@ -883,5 +883,3 @@ class ReconstitutionTest(TestBase):
         assert set(p_copy.kids) == set(['c1', 'c2']), p.kids
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 97%
rename from test/ext/compiler.py
rename to test/ext/test_compiler.py
index 370ea62ab0f4cac434cdbf195a14bca65af20f4b..ce2549099822d26eaedc5a26019da5729d71ca7f 100644 (file)
@@ -1,9 +1,8 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.sql.expression import ClauseElement, ColumnClause
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql import table, column
-from testlib import *
+from sqlalchemy.test import *
 
 class UserDefinedTest(TestBase, AssertsCompiledSQL):
 
@@ -122,5 +121,3 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
             "DROP THINGY",
         )
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 96%
rename from test/ext/declarative.py
rename to test/ext/test_declarative.py
index f5130b2153ace2d7450e6d64ec1571bc55779744..c49c00cec0d267c8aca0c70016c1b02042a56cf9 100644 (file)
@@ -1,21 +1,24 @@
-import testenv; testenv.configure_for_tests()
 
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy.ext import declarative as decl
 from sqlalchemy import exc
-from testlib import sa, testing
-from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index
-from testlib.sa.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred
-from testlib.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import MetaData, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred
+from sqlalchemy.test.testing import eq_
 
 
-from orm._base import ComparableEntity, MappedTest
+from test.orm._base import ComparableEntity, MappedTest
 
 class DeclarativeTestBase(testing.TestBase, testing.AssertsExecutionResults):
-    def setUp(self):
+    def setup(self):
         global Base
         Base = decl.declarative_base(testing.db)
 
-    def tearDown(self):
+    def teardown(self):
         clear_mappers()
         Base.metadata.drop_all()
     
@@ -64,7 +67,7 @@ class DeclarativeTest(DeclarativeTestBase):
         def go():
             class User(Base):
                 id = Column('id', Integer, primary_key=True)
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "does not have a __table__", go)
+        assert_raises_message(sa.exc.InvalidRequestError, "does not have a __table__", go)
 
     def test_cant_add_columns(self):
         t = Table('t', Base.metadata, Column('id', Integer, primary_key=True), Column('data', String))
@@ -73,7 +76,7 @@ class DeclarativeTest(DeclarativeTestBase):
                 __table__ = t
                 foo = Column(Integer, primary_key=True)
         # can't specify new columns not already in the table
-        self.assertRaisesMessage(sa.exc.ArgumentError, "Can't add additional column 'foo' when specifying __table__", go)
+        assert_raises_message(sa.exc.ArgumentError, "Can't add additional column 'foo' when specifying __table__", go)
 
         # regular re-mapping works tho
         class Bar(Base):
@@ -144,7 +147,7 @@ class DeclarativeTest(DeclarativeTestBase):
         sess.add(u1)
         sess.flush()
         sess.expunge_all()
-        self.assertEquals(sess.query(User).filter(User.name == 'ed').one(),
+        eq_(sess.query(User).filter(User.name == 'ed').one(),
             User(name='ed', addresses=[Address(email='xyz'), Address(email='def'), Address(email='abc')])
         )
         
@@ -152,7 +155,7 @@ class DeclarativeTest(DeclarativeTestBase):
             __tablename__ = 'foo'
             id = Column(Integer, primary_key=True)
             rel = relation("User", primaryjoin="User.addresses==Foo.id")
-        self.assertRaisesMessage(exc.InvalidRequestError, "'addresses' is not an instance of ColumnProperty", compile_mappers)
+        assert_raises_message(exc.InvalidRequestError, "'addresses' is not an instance of ColumnProperty", compile_mappers)
 
     def test_string_dependency_resolution_in_backref(self):
         class User(Base, ComparableEntity):
@@ -206,7 +209,7 @@ class DeclarativeTest(DeclarativeTestBase):
         sess.add(u1)
         sess.flush()
         sess.expunge_all()
-        self.assertEquals(sess.query(User).filter(User.name == 'ed').one(),
+        eq_(sess.query(User).filter(User.name == 'ed').one(),
             User(name='ed', addresses=[Address(email='abc'), Address(email='def'), Address(email='xyz')])
         )
             
@@ -224,7 +227,7 @@ class DeclarativeTest(DeclarativeTestBase):
 
         # this used to raise an error when accessing User.id but that's no longer the case
         # since we got rid of _CompileOnAttr.
-        self.assertRaises(sa.exc.ArgumentError, compile_mappers)
+        assert_raises(sa.exc.ArgumentError, compile_mappers)
         
     def test_nice_dependency_error_works_with_hasattr(self):
         class User(Base):
@@ -235,7 +238,7 @@ class DeclarativeTest(DeclarativeTestBase):
         # hasattr() on a compile-loaded attribute
         hasattr(User.addresses, 'property')
         # the exeption is preserved
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers)
+        assert_raises_message(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers)
 
     def test_custom_base(self):
         class MyBase(object):
@@ -255,7 +258,7 @@ class DeclarativeTest(DeclarativeTestBase):
         i = Index('my_index', User.name)
         
         # compile fails due to the nonexistent Addresses relation
-        self.assertRaises(sa.exc.InvalidRequestError, compile_mappers)
+        assert_raises(sa.exc.InvalidRequestError, compile_mappers)
         
         # index configured
         assert i in User.__table__.indexes
@@ -440,7 +443,7 @@ class DeclarativeTest(DeclarativeTestBase):
                 id = Column('id', Integer, primary_key=True),
                 name = Column('name', String(50))
             assert False
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Mapper Mapper|User|users could not assemble any primary key",
             define)
@@ -1204,7 +1207,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
                 __mapper_args__ = {'polymorphic_identity':'engineer'}
                 primary_language = Column('primary_language', String(50))
                 foo_bar = Column(Integer, primary_key=True)
-        self.assertRaisesMessage(sa.exc.ArgumentError, "place primary key", go)
+        assert_raises_message(sa.exc.ArgumentError, "place primary key", go)
         
     def test_single_no_table_args(self):
         class Person(Base, ComparableEntity):
@@ -1219,7 +1222,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
                 __mapper_args__ = {'polymorphic_identity':'engineer'}
                 primary_language = Column('primary_language', String(50))
                 __table_args__ = ()
-        self.assertRaisesMessage(sa.exc.ArgumentError, "place __table_args__", go)
+        assert_raises_message(sa.exc.ArgumentError, "place __table_args__", go)
         
     def test_concrete(self):
         engineers = Table('engineers', Base.metadata,
@@ -1270,10 +1273,11 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
         )
         
         
-def produce_test(inline, stringbased):
+def _produce_test(inline, stringbased):
     class ExplicitJoinTest(MappedTest):
     
-        def define_tables(self, metadata):
+        @classmethod
+        def define_tables(cls, metadata):
             global User, Address
             Base = decl.declarative_base(metadata=metadata)
 
@@ -1300,7 +1304,8 @@ def produce_test(inline, stringbased):
                 else:
                     Address.user = relation(User, primaryjoin=User.id==Address.user_id, backref="addresses")
 
-        def insert_data(self):
+        @classmethod
+        def insert_data(cls):
             params = [dict(zip(('id', 'name'), column_values)) for column_values in 
                 [(7, 'jack'),
                 (8, 'ed'),
@@ -1337,12 +1342,13 @@ def produce_test(inline, stringbased):
 
 for inline in (True, False):
     for stringbased in (True, False):
-        testclass = produce_test(inline, stringbased)
+        testclass = _produce_test(inline, stringbased)
         exec("%s = testclass" % testclass.__name__)
         del testclass
         
 class DeclarativeReflectionTest(testing.TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global reflection_metadata
         reflection_metadata = MetaData(testing.db)
 
@@ -1364,15 +1370,16 @@ class DeclarativeReflectionTest(testing.TestBase):
 
         reflection_metadata.create_all()
 
-    def setUp(self):
+    def setup(self):
         global Base
         Base = decl.declarative_base(testing.db)
 
-    def tearDown(self):
+    def teardown(self):
         for t in reversed(reflection_metadata.sorted_tables):
             t.delete().execute()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         reflection_metadata.drop_all()
 
     def test_basic(self):
@@ -1436,7 +1443,7 @@ class DeclarativeReflectionTest(testing.TestBase):
         eq_(a1, Address(email='two'))
         eq_(a1.user, User(nom='u1'))
 
-        self.assertRaises(TypeError, User, name='u3')
+        assert_raises(TypeError, User, name='u3')
 
     def test_supplied_fk(self):
         meta = MetaData(testing.db)
similarity index 98%
rename from test/ext/orderinglist.py
rename to test/ext/test_orderinglist.py
index c111a02de6463fbe6ad20ace39e946c151d5de39..4adc779606efc83232743f82415b4566f6728324 100644 (file)
@@ -1,9 +1,8 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.ext.orderinglist import *
-from testlib.testing import eq_
-from testlib import *
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test import *
 
 
 metadata = None
@@ -39,7 +38,7 @@ def alpha_ordering(index, collection):
     return s
 
 class OrderingListTest(TestBase):
-    def setUp(self):
+    def setup(self):
         global metadata, slides_table, bullets_table, Slide, Bullet
         slides_table, bullets_table = None, None
         Slide, Bullet = None, None
@@ -87,7 +86,7 @@ class OrderingListTest(TestBase):
 
         metadata.create_all()
 
-    def tearDown(self):
+    def teardown(self):
         metadata.drop_all()
 
     def test_append_no_reorder(self):
@@ -399,5 +398,3 @@ class OrderingListTest(TestBase):
             self.assert_(alpha[li].position == pos)
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 89%
rename from test/ext/serializer.py
rename to test/ext/test_serializer.py
index 048eccdfd1f04f95e982d78bd3f37544683fbd36..b8a8e3fef9d27dba92dfca42abfc8e823ada895b 100644 (file)
@@ -1,13 +1,15 @@
-import testenv; testenv.configure_for_tests()
 
 from sqlalchemy.ext import serializer
 from sqlalchemy import exc
-from testlib import sa, testing
-from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, select, desc, func, util
-from testlib.sa.orm import relation, sessionmaker, scoped_session, class_mapper, mapper, eagerload, compile_mappers, aliased
-from testlib.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import MetaData, Integer, String, ForeignKey, select, desc, func, util
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import relation, sessionmaker, scoped_session, class_mapper, mapper, eagerload, compile_mappers, aliased
+from sqlalchemy.test.testing import eq_
 
-from orm._base import ComparableEntity, MappedTest
+from test.orm._base import ComparableEntity, MappedTest
 
 
 class User(ComparableEntity):
@@ -21,7 +23,8 @@ class SerializeTest(MappedTest):
     run_inserts = 'once'
     run_deletes = None
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global users, addresses
         users = Table('users', metadata, 
             Column('id', Integer, primary_key=True),
@@ -33,7 +36,8 @@ class SerializeTest(MappedTest):
             Column('user_id', Integer, ForeignKey('users.id')),
         )
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         global Session
         Session = scoped_session(sessionmaker())
 
@@ -44,7 +48,8 @@ class SerializeTest(MappedTest):
 
         compile_mappers()
         
-    def insert_data(self):
+    @classmethod
+    def insert_data(cls):
         params = [dict(zip(('id', 'name'), column_values)) for column_values in 
             [(7, 'jack'),
             (8, 'ed'),
index 9e599a6f16a3312e72f20bb0556d374899cd58fb..8d695e912b99b505eba8504ac62f8d594080886e 100644 (file)
@@ -2,9 +2,10 @@ import gc
 import inspect
 import sys
 import types
-from testlib import config, sa, testing
-from testlib.testing import resolve_artifact_names, adict
-from testlib.compat import _function_named
+import sqlalchemy as sa
+from sqlalchemy.test import config, testing
+from sqlalchemy.test.testing import resolve_artifact_names, adict
+from sqlalchemy.util import function_named
 
 
 _repr_stack = set()
@@ -95,7 +96,8 @@ class ComparableEntity(BasicEntity):
 class ORMTest(testing.TestBase, testing.AssertsExecutionResults):
     __requires__ = ('subqueries',)
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         sa.orm.session.Session.close_all()
         sa.orm.clear_mappers()
         # TODO: ensure mapper registry is empty
@@ -124,18 +126,18 @@ class MappedTest(ORMTest):
     classes = None
     other_artifacts = None
 
-    def setUpAll(self):
-        if self.run_setup_classes == 'each':
-            assert self.run_setup_mappers != 'once'
+    @classmethod
+    def setup_class(cls):
+        if cls.run_setup_classes == 'each':
+            assert cls.run_setup_mappers != 'once'
 
-        assert self.run_deletes in (None, 'each')
-        if self.run_inserts == 'once':
-            assert self.run_deletes is None
+        assert cls.run_deletes in (None, 'each')
+        if cls.run_inserts == 'once':
+            assert cls.run_deletes is None
 
-        assert not hasattr(self, 'keep_mappers')
-        assert not hasattr(self, 'keep_data')
+        assert not hasattr(cls, 'keep_mappers')
+        assert not hasattr(cls, 'keep_data')
 
-        cls = self.__class__
         if cls.tables is None:
             cls.tables = adict()
         if cls.classes is None:
@@ -143,35 +145,32 @@ class MappedTest(ORMTest):
         if cls.other_artifacts is None:
             cls.other_artifacts = adict()
 
-        if self.metadata is None:
-            setattr(type(self), 'metadata', sa.MetaData())
+        if cls.metadata is None:
+            setattr(cls, 'metadata', sa.MetaData())
 
-        if self.metadata.bind is None:
-            self.metadata.bind = getattr(self, 'engine', config.db)
+        if cls.metadata.bind is None:
+            cls.metadata.bind = getattr(cls, 'engine', config.db)
 
-        if self.run_define_tables:
-            self.define_tables(self.metadata)
-            self.metadata.create_all()
-            self.tables.update(self.metadata.tables)
+        if cls.run_define_tables == 'once':
+            cls.define_tables(cls.metadata)
+            cls.metadata.create_all()
+            cls.tables.update(cls.metadata.tables)
 
-        if self.run_setup_classes:
+        if cls.run_setup_classes == 'once':
             baseline = subclasses(BasicEntity)
-            self.setup_classes()
-            self._register_new_class_artifacts(baseline)
+            cls.setup_classes()
+            cls._register_new_class_artifacts(baseline)
 
-        if self.run_setup_mappers:
+        if cls.run_setup_mappers == 'once':
             baseline = subclasses(BasicEntity)
-            self.setup_mappers()
-            self._register_new_class_artifacts(baseline)
+            cls.setup_mappers()
+            cls._register_new_class_artifacts(baseline)
 
-        if self.run_inserts:
-            self._load_fixtures()
-            self.insert_data()
-
-    def setUp(self):
-        if self._sa_first_test:
-            return
+        if cls.run_inserts == 'once':
+            cls._load_fixtures()
+            cls.insert_data()
 
+    def setup(self):
         if self.run_define_tables == 'each':
             self.tables.clear()
             self.metadata.drop_all()
@@ -195,7 +194,7 @@ class MappedTest(ORMTest):
             self._load_fixtures()
             self.insert_data()
 
-    def tearDown(self):
+    def teardown(self):
         sa.orm.session.Session.close_all()
 
         # some tests create mappers in the test bodies
@@ -213,26 +212,32 @@ class MappedTest(ORMTest):
                     print >> sys.stderr, "Error emptying table %s: %r" % (
                         table, ex)
 
-    def tearDownAll(self):
-        for cls in self.classes.values():
-            self.unregister_class(cls)
-        ORMTest.tearDownAll(self)
-        self.metadata.drop_all()
-        self.metadata.bind = None
+    @classmethod
+    def teardown_class(cls):
+        for cl in cls.classes.values():
+            cls.unregister_class(cl)
+        ORMTest.teardown_class()
+        cls.metadata.drop_all()
+        cls.metadata.bind = None
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         raise NotImplementedError()
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         pass
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         pass
 
-    def fixtures(self):
+    @classmethod
+    def fixtures(cls):
         return {}
 
-    def insert_data(self):
+    @classmethod
+    def insert_data(cls):
         pass
 
     def sql_count_(self, count, fn):
@@ -260,15 +265,16 @@ class MappedTest(ORMTest):
         if name[0].isupper:
             delattr(cls, name)
         del cls.classes[name]
-
-    def _load_fixtures(self):
+    
+    @classmethod
+    def _load_fixtures(cls):
         headers, rows = {}, {}
-        for table, data in self.fixtures().iteritems():
+        for table, data in cls.fixtures().iteritems():
             if isinstance(table, basestring):
-                table = self.tables[table]
+                table = cls.tables[table]
             headers[table] = data[0]
             rows[table] = data[1:]
-        for table in self.metadata.sorted_tables:
+        for table in cls.metadata.sorted_tables:
             if table not in headers:
                 continue
             table.bind.execute(
index f036b92b2a5b825a1f0fb243f588aacc765210d8..14709ec43328709570a3273d112b9e725f0dbe51 100644 (file)
@@ -1,7 +1,9 @@
-from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import attributes
-from testlib.testing import fixture
-from orm import _base
+from sqlalchemy import MetaData, Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import attributes
+from sqlalchemy.test.testing import fixture
+from test.orm import _base
 
 __all__ = ()
 
@@ -227,34 +229,21 @@ class FixtureTest(_base.MappedTest):
                            Address=Address,
                            Dingaling=Dingaling)
 
-    def setUpAll(self):
-        assert not hasattr(self, 'refresh_data')
-        assert not hasattr(self, 'only_tables')
-        #refresh_data = False
-        #only_tables = False
-
-        #if type(self) is not FixtureTest:
-        #    setattr(type(self), 'classes', _base.adict(self.classes))
-
-        #if self.run_setup_classes:
-        #    for cls in self.classes.values():
-        #        self.register_class(cls)
-        super(FixtureTest, self).setUpAll()
-
-        #if not self.only_tables and self.keep_data:
-        #    _registry.load()
-
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         pass
 
-    def setup_classes(self):
-        for cls in self.fixture_classes.values():
-            self.register_class(cls)
+    @classmethod
+    def setup_classes(cls):
+        for cl in cls.fixture_classes.values():
+            cls.register_class(cl)
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         pass
 
-    def insert_data(self):
+    @classmethod
+    def insert_data(cls):
         _load_fixtures()
 
 
diff --git a/test/orm/alltests.py b/test/orm/alltests.py
deleted file mode 100644 (file)
index 9458ca5..0000000
+++ /dev/null
@@ -1,60 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-import inheritance.alltests as inheritance
-import sharding.alltests as sharding
-
-def suite():
-    modules_to_test = (
-        'orm.attributes',
-        'orm.bind',
-        'orm.extendedattr',
-        'orm.instrumentation',
-        'orm.query',
-        'orm.lazy_relations',
-        'orm.eager_relations',
-        'orm.mapper',
-        'orm.expire',
-        'orm.selectable',
-        'orm.collection',
-        'orm.generative',
-        'orm.lazytest1',
-        'orm.assorted_eager',
-
-        'orm.naturalpks',
-        'orm.defaults',
-        'orm.unitofwork',
-        'orm.session',
-        'orm.transaction',
-        'orm.scoping',
-        'orm.cascade',
-        'orm.relationships',
-        'orm.association',
-        'orm.merge',
-        'orm.pickled',
-        'orm.utils',
-
-        'orm.cycles',
-
-        'orm.compile',
-        'orm.manytomany',
-        'orm.onetoone',
-        'orm.dynamic',
-
-        'orm.evaluator',
-
-        'orm.deprecations',
-        )
-    alltests = unittest.TestSuite()
-    for name in modules_to_test:
-        mod = __import__(name)
-        for token in name.split('.')[1:]:
-            mod = getattr(mod, token)
-        alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    alltests.addTest(inheritance.suite())
-    alltests.addTest(sharding.suite())
-    return alltests
-
-
-if __name__ == '__main__':
-    testenv.main(suite())
diff --git a/test/orm/inheritance/alltests.py b/test/orm/inheritance/alltests.py
deleted file mode 100644 (file)
index 41f0521..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-def suite():
-    modules_to_test = (
-        'orm.inheritance.basic',
-        'orm.inheritance.query',
-        'orm.inheritance.manytomany',
-        'orm.inheritance.single',
-        'orm.inheritance.concrete',
-        'orm.inheritance.polymorph',
-        'orm.inheritance.polymorph2',
-        'orm.inheritance.poly_linked_list',
-        'orm.inheritance.abc_polymorphic',
-        'orm.inheritance.abc_inheritance',
-        'orm.inheritance.productspec',
-        'orm.inheritance.magazine',
-        'orm.inheritance.selects',
-
-        )
-    alltests = unittest.TestSuite()
-    for name in modules_to_test:
-        mod = __import__(name)
-        for token in name.split('.')[1:]:
-            mod = getattr(mod, token)
-        alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    return alltests
-
-
-if __name__ == '__main__':
-    testenv.main(suite())
similarity index 95%
rename from test/orm/inheritance/abc_inheritance.py
rename to test/orm/inheritance/test_abc_inheritance.py
index ee324e381101978c2a640085a43ff59c780c66c6..4e55cf70eaa906f81bb71e3e831693c3d5a93a20 100644 (file)
@@ -1,10 +1,9 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
 
-from testlib import testing
-from orm import _base
+from sqlalchemy.test import testing
+from test.orm import _base
 
 
 def produce_test(parent, child, direction):
@@ -12,9 +11,10 @@ def produce_test(parent, child, direction):
     relationship between two of the classes, using either one-to-many or
     many-to-one."""
     class ABCTest(_base.MappedTest):
-        def define_tables(self, meta):
+        @classmethod
+        def define_tables(cls, metadata):
             global ta, tb, tc
-            ta = ["a", meta]
+            ta = ["a", metadata]
             ta.append(Column('id', Integer, primary_key=True)),
             ta.append(Column('a_data', String(30)))
             if "a"== parent and direction == MANYTOONE:
@@ -23,7 +23,7 @@ def produce_test(parent, child, direction):
                 ta.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo")))
             ta = Table(*ta)
 
-            tb = ["b", meta]
+            tb = ["b", metadata]
             tb.append(Column('id', Integer, ForeignKey("a.id"), primary_key=True, ))
 
             tb.append(Column('b_data', String(30)))
@@ -34,7 +34,7 @@ def produce_test(parent, child, direction):
                 tb.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo")))
             tb = Table(*tb)
 
-            tc = ["c", meta]
+            tc = ["c", metadata]
             tc.append(Column('id', Integer, ForeignKey("b.id"), primary_key=True, ))
 
             tc.append(Column('c_data', String(30)))
@@ -45,14 +45,14 @@ def produce_test(parent, child, direction):
                 tc.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo")))
             tc = Table(*tc)
 
-        def tearDown(self):
+        def teardown(self):
             if direction == MANYTOONE:
                 parent_table = {"a":ta, "b":tb, "c": tc}[parent]
                 parent_table.update(values={parent_table.c.child_id:None}).execute()
             elif direction == ONETOMANY:
                 child_table = {"a":ta, "b":tb, "c": tc}[child]
                 child_table.update(values={child_table.c.parent_id:None}).execute()
-            super(ABCTest, self).tearDown()
+            super(ABCTest, self).teardown()
 
         def test_roundtrip(self):
             parent_table = {"a":ta, "b":tb, "c": tc}[parent]
@@ -167,5 +167,4 @@ for parent in ["a", "b", "c"]:
             exec("%s = testclass" % testclass.__name__)
             del testclass
 
-if __name__ == "__main__":
-    testenv.main()
+del produce_test
\ No newline at end of file
similarity index 92%
rename from test/orm/inheritance/abc_polymorphic.py
rename to test/orm/inheritance/test_abc_polymorphic.py
index 6fabbb24c2bb8c6de69785197a80c4ccea6a02ec..8cad8ed78163f1328290f35c77eb9c83d0d35492 100644 (file)
@@ -1,13 +1,13 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy import util
 from sqlalchemy.orm import *
 
-from testlib import _function_named
-from orm import _base, _fixtures
+from sqlalchemy.util import function_named
+from test.orm import _base, _fixtures
 
 class ABCTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global a, b, c
         a = Table('a', metadata,
             Column('id', Integer, primary_key=True),
@@ -78,7 +78,7 @@ class ABCTest(_base.MappedTest):
                 C(cdata='c2', bdata='c2', adata='c2'),
             ] == sess.query(C).all()
 
-        test_roundtrip = _function_named(
+        test_roundtrip = function_named(
             test_roundtrip, 'test_%s' % fetchtype)
         return test_roundtrip
 
@@ -86,5 +86,3 @@ class ABCTest(_base.MappedTest):
     test_none = make_test('none')
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 95%
rename from test/orm/inheritance/basic.py
rename to test/orm/inheritance/test_basic.py
index 150874477b271b488612f4a3d040db7afcfdda24..fc4aae17d5c82ab4a49297924268e0840bc1bd8f 100644 (file)
@@ -1,17 +1,17 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy import *
 from sqlalchemy import exc as sa_exc, util
 from sqlalchemy.orm import *
 from sqlalchemy.orm import exc as orm_exc
 
-#from testlib import *
-#from testlib import fixtures
-from testlib import _function_named, testing, engines
-from orm import _base, _fixtures
+from sqlalchemy.test import testing, engines
+from sqlalchemy.util import function_named
+from test.orm import _base, _fixtures
 
 class O2MTest(_base.MappedTest):
     """deals with inheritance and one-to-many relationships"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global foo, bar, blub
         foo = Table('foo', metadata,
             Column('id', Integer, Sequence('foo_seq', optional=True),
@@ -69,7 +69,8 @@ class O2MTest(_base.MappedTest):
         self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
 
 class FalseDiscriminatorTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global t1
         t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('type', Integer, nullable=False))
         
@@ -87,7 +88,8 @@ class FalseDiscriminatorTest(_base.MappedTest):
         assert isinstance(sess.query(Foo).one(), Bar)
         
 class PolymorphicSynonymTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global t1, t2
         t1 = Table('t1', metadata,
                    Column('id', Integer, primary_key=True),
@@ -118,8 +120,8 @@ class PolymorphicSynonymTest(_base.MappedTest):
         sess.add(at2)
         sess.flush()
         sess.expunge_all()
-        self.assertEquals(sess.query(T2).filter(T2.info=='at2').one(), at2)
-        self.assertEquals(at2.info, "THE INFO IS:at2")
+        eq_(sess.query(T2).filter(T2.info=='at2').one(), at2)
+        eq_(at2.info, "THE INFO IS:at2")
         
     
 class CascadeTest(_base.MappedTest):
@@ -127,7 +129,8 @@ class CascadeTest(_base.MappedTest):
     cascading along the path of the instance's mapper, not
     the base mapper."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global t1, t2, t3, t4
         t1= Table('t1', metadata,
             Column('id', Integer, primary_key=True),
@@ -191,7 +194,8 @@ class CascadeTest(_base.MappedTest):
         sess.flush()
 
 class GetTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global foo, bar, blub
         foo = Table('foo', metadata,
             Column('id', Integer, Sequence('foo_seq', optional=True),
@@ -209,7 +213,7 @@ class GetTest(_base.MappedTest):
             Column('bar_id', Integer, ForeignKey('bar.id')),
             Column('data', String(20)))
 
-    def create_test(polymorphic, name):
+    def _create_test(polymorphic, name):
         def test_get(self):
             class Foo(object):
                 pass
@@ -271,16 +275,17 @@ class GetTest(_base.MappedTest):
 
                 self.assert_sql_count(testing.db, go, 3)
 
-        test_get = _function_named(test_get, name)
+        test_get = function_named(test_get, name)
         return test_get
 
-    test_get_polymorphic = create_test(True, 'test_get_polymorphic')
-    test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic')
+    test_get_polymorphic = _create_test(True, 'test_get_polymorphic')
+    test_get_nonpolymorphic = _create_test(False, 'test_get_nonpolymorphic')
 
 class EagerLazyTest(_base.MappedTest):
     """tests eager load/lazy load of child items off inheritance mappers, tests that
     LazyLoader constructs the right query condition."""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global foo, bar, bar_foo
         foo = Table('foo', metadata,
                     Column('id', Integer, Sequence('foo_seq', optional=True),
@@ -325,7 +330,8 @@ class EagerLazyTest(_base.MappedTest):
 
 class FlushTest(_base.MappedTest):
     """test dependency sorting among inheriting mappers"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global users, roles, user_roles, admins
         users = Table('users', metadata,
             Column('id', Integer, primary_key=True),
@@ -413,7 +419,8 @@ class FlushTest(_base.MappedTest):
         assert user_roles.count().scalar() == 1
 
 class VersioningTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global base, subtable, stuff
         base = Table('base', metadata,
             Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ),
@@ -522,7 +529,8 @@ class DistinctPKTest(_base.MappedTest):
     run_inserts = 'once'
     run_deletes = None
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global person_table, employee_table, Person, Employee
 
         person_table = Table("persons", metadata,
@@ -542,7 +550,8 @@ class DistinctPKTest(_base.MappedTest):
 
         class Employee(Person): pass
 
-    def insert_data(self):
+    @classmethod
+    def insert_data(cls):
         person_insert = person_table.insert()
         person_insert.execute(id=1, name='alice')
         person_insert.execute(id=2, name='bob')
@@ -593,7 +602,8 @@ class DistinctPKTest(_base.MappedTest):
 
 class SyncCompileTest(_base.MappedTest):
     """test that syncrules compile properly on custom inherit conds"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global _a_table, _b_table, _c_table
 
         _a_table = Table('a', metadata,
@@ -660,7 +670,8 @@ class SyncCompileTest(_base.MappedTest):
 class OverrideColKeyTest(_base.MappedTest):
     """test overriding of column attributes."""
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global base, subtable
         
         base = Table('base', metadata, 
@@ -686,7 +697,7 @@ class OverrideColKeyTest(_base.MappedTest):
         
         # Sub gets a "base_id" property using the "base_id"
         # column of both tables.
-        self.assertEquals(
+        eq_(
             class_mapper(Sub).get_property('base_id').columns,
             [base.c.base_id, subtable.c.base_id]
         )
@@ -710,7 +721,7 @@ class OverrideColKeyTest(_base.MappedTest):
             'id':[base.c.base_id, subtable.c.base_id]
         })
 
-        self.assertEquals(
+        eq_(
             class_mapper(Sub).get_property('id').columns,
             [base.c.base_id, subtable.c.base_id]
         )
@@ -733,12 +744,12 @@ class OverrideColKeyTest(_base.MappedTest):
         })
         mapper(Sub, subtable, inherits=Base)
         
-        self.assertEquals(
+        eq_(
             class_mapper(Sub).get_property('id').columns,
             [base.c.base_id]
         )
 
-        self.assertEquals(
+        eq_(
             class_mapper(Sub).get_property('base_id').columns,
             [subtable.c.base_id]
         )
@@ -782,7 +793,7 @@ class OverrideColKeyTest(_base.MappedTest):
         # it has its own "id" property.  Sub's "id" property 
         # gets joined normally with the extra column.
         
-        self.assertEquals(
+        eq_(
             class_mapper(Sub).get_property('id').columns,
             [base.c.base_id, subtable.c.base_id]
         )
@@ -892,7 +903,8 @@ class OptimizedLoadTest(_base.MappedTest):
     a column in the join condition is not available.
     
     """
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global base, sub
         base = Table('base', metadata,
             Column('id', Integer, primary_key=True),
@@ -931,7 +943,8 @@ class OptimizedLoadTest(_base.MappedTest):
         assert s1.sub == 's1sub'
 
 class PKDiscriminatorTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         parents = Table('parents', metadata,
                            Column('id', Integer, primary_key=True),
                            Column('name', String(60)))
@@ -974,7 +987,8 @@ class PKDiscriminatorTest(_base.MappedTest):
         
         
 class DeleteOrphanTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global single, parent
         single = Table('single', metadata,
             Column('id', Integer, primary_key=True),
@@ -1007,9 +1021,7 @@ class DeleteOrphanTest(_base.MappedTest):
         sess = create_session()
         s1 = SubClass(data='s1')
         sess.add(s1)
-        self.assertRaisesMessage(orm_exc.FlushError, 
+        assert_raises_message(orm_exc.FlushError, 
             "is not attached to any parent 'Parent' instance via that classes' 'related' attribute", sess.flush)
         
     
-if __name__ == "__main__":
-    testenv.main()
similarity index 95%
rename from test/orm/inheritance/concrete.py
rename to test/orm/inheritance/test_concrete.py
index 6cdaed7e6934bcf9200c41217cdb2e6cdf2e0194..4a884cb86c71a72d8588ce406c70103e048c8e76 100644 (file)
@@ -1,12 +1,13 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.orm import exc as orm_exc
-from testlib import *
-from testlib import sa, testing
-from orm import _base
+from sqlalchemy.test import *
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from test.orm import _base
 from sqlalchemy.orm import attributes
-from testlib.testing import eq_
+from sqlalchemy.test.testing import eq_
 
 class Employee(object):
     def __init__(self, name):
@@ -42,7 +43,8 @@ class Company(object):
 
 
 class ConcreteTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global managers_table, engineers_table, hackers_table, companies, employees_table
 
         companies = Table('companies', metadata,
@@ -103,7 +105,7 @@ class ConcreteTest(_base.MappedTest):
 
         manager = session.query(Manager).one()
         session.expire(manager, ['manager_data'])
-        self.assertEquals(manager.manager_data, "knows how to manage things")
+        eq_(manager.manager_data, "knows how to manage things")
 
     def test_multi_level_no_base(self):
         pjoin = polymorphic_union({
@@ -144,8 +146,8 @@ class ConcreteTest(_base.MappedTest):
         assert 'name' not in attributes.instance_state(hacker).expired_attributes
         assert 'nickname' not in attributes.instance_state(hacker).expired_attributes
         def go():
-            self.assertEquals(jerry.name, "Jerry")
-            self.assertEquals(hacker.nickname, "Badass")
+            eq_(jerry.name, "Jerry")
+            eq_(hacker.nickname, "Badass")
         self.assert_sql_count(testing.db, go, 0)
         
         session.expunge_all()
@@ -194,8 +196,8 @@ class ConcreteTest(_base.MappedTest):
         session.flush()
 
         def go():
-            self.assertEquals(jerry.name, "Jerry")
-            self.assertEquals(hacker.nickname, "Badass")
+            eq_(jerry.name, "Jerry")
+            eq_(hacker.nickname, "Badass")
         self.assert_sql_count(testing.db, go, 0)
 
         session.expunge_all()
@@ -315,7 +317,8 @@ class ConcreteTest(_base.MappedTest):
         self.assert_sql_count(testing.db, go, 1)
 
 class PropertyInheritanceTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('a_table', metadata,
             Column('id', Integer, primary_key=True),
             Column('some_c_id', Integer, ForeignKey('c_table.id')),
@@ -332,7 +335,8 @@ class PropertyInheritanceTest(_base.MappedTest):
             
         )
         
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class A(_base.ComparableEntity):
             pass
 
@@ -352,7 +356,7 @@ class PropertyInheritanceTest(_base.MappedTest):
 
         b = B()
         c = C()
-        self.assertRaises(AttributeError, setattr, b, 'some_c', c)
+        assert_raises(AttributeError, setattr, b, 'some_c', c)
 
         clear_mappers()
         mapper(A, a_table, properties={
@@ -361,7 +365,7 @@ class PropertyInheritanceTest(_base.MappedTest):
         mapper(B, b_table,inherits=A, concrete=True)
         mapper(C, c_table)
         b = B()
-        self.assertRaises(AttributeError, setattr, b, 'a_id', 3)
+        assert_raises(AttributeError, setattr, b, 'a_id', 3)
 
         clear_mappers()
         mapper(A, a_table, properties={
@@ -392,8 +396,8 @@ class PropertyInheritanceTest(_base.MappedTest):
         b1 = B(some_c=c1, bname='b1')
         b2 = B(some_c=c1, bname='b2')
         
-        self.assertRaises(AttributeError, setattr, b1, 'aname', 'foo')
-        self.assertRaises(AttributeError, getattr, A, 'bname')
+        assert_raises(AttributeError, setattr, b1, 'aname', 'foo')
+        assert_raises(AttributeError, getattr, A, 'bname')
         
         assert c2.many_a == [a2]
         assert c1.many_a == [a1]
@@ -463,7 +467,8 @@ class PropertyInheritanceTest(_base.MappedTest):
         
     
 class ColKeysTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global offices_table, refugees_table
         refugees_table = Table('refugee', metadata,
            Column('refugee_fid', Integer, primary_key=True),
@@ -473,7 +478,8 @@ class ColKeysTest(_base.MappedTest):
            Column('office_fid', Integer, primary_key=True),
            Column('office_name', Unicode(30), key='name'))
     
-    def insert_data(self):
+    @classmethod
+    def insert_data(cls):
         refugees_table.insert().execute(
             dict(refugee_fid=1, name=u"refugee1"),
             dict(refugee_fid=2, name=u"refugee2")
@@ -511,5 +517,3 @@ class ColKeysTest(_base.MappedTest):
         eq_(sess.query(Office).get(1).name, "office1")
         eq_(sess.query(Office).get(2).name, "office2")
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 97%
rename from test/orm/inheritance/magazine.py
rename to test/orm/inheritance/test_magazine.py
index 34374c887e3c4019c98a36dec25638f74924617c..06730125113c505901fa640ae580de81ede529cb 100644 (file)
@@ -1,9 +1,9 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
 
-from testlib import testing, _function_named
-from orm import _base
+from sqlalchemy.test import testing
+from sqlalchemy.util import function_named
+from test.orm import _base
 
 class BaseObject(object):
     def __init__(self, *args, **kwargs):
@@ -70,7 +70,8 @@ class ClassifiedPage(MagazinePage):
 
 
 class MagazineTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global publication_table, issue_table, location_table, location_name_table, magazine_table, \
         page_table, magazine_page_table, classified_page_table, page_size_table
 
@@ -208,7 +209,7 @@ def generate_round_trip_test(use_unions=False, use_joins=False):
         print [page, page2, page3]
         assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]), repr(p.issues[0].locations[0].magazine.pages)
 
-    test_roundtrip = _function_named(
+    test_roundtrip = function_named(
         test_roundtrip, "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions"))
     setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip)
 
@@ -216,5 +217,3 @@ for (use_union, use_join) in [(True, False), (False, True), (False, False)]:
     generate_round_trip_test(use_union, use_join)
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 96%
rename from test/orm/inheritance/manytomany.py
rename to test/orm/inheritance/test_manytomany.py
index 5dbf69ba565f25259b6f65abf46dd6256d09d190..f7e676bbbcf76d1b3b3a9692d33222269f499cbc 100644 (file)
@@ -1,14 +1,15 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 from sqlalchemy import *
 from sqlalchemy.orm import *
 
-from testlib import testing
-from orm import _base
+from sqlalchemy.test import testing
+from test.orm import _base
 
 
 class InheritTest(_base.MappedTest):
     """deals with inheritance and many-to-many relationships"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global principals
         global users
         global groups
@@ -67,7 +68,8 @@ class InheritTest(_base.MappedTest):
 
 class InheritTest2(_base.MappedTest):
     """deals with inheritance and many-to-many relationships"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global foo, bar, foo_bar
         foo = Table('foo', metadata,
             Column('id', Integer, Sequence('foo_id_seq', optional=True),
@@ -140,7 +142,8 @@ class InheritTest2(_base.MappedTest):
 
 class InheritTest3(_base.MappedTest):
     """deals with inheritance and many-to-many relationships"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global foo, bar, blub, bar_foo, blub_bar, blub_foo
 
         # the 'data' columns are to appease SQLite which cant handle a blank INSERT
@@ -196,7 +199,7 @@ class InheritTest3(_base.MappedTest):
         l = sess.query(Bar).all()
         print repr(l[0]) + repr(l[0].foos)
         found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos]))
-        self.assertEqual(found, compare)
+        eq_(found, compare)
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
     def testadvanced(self):
@@ -244,5 +247,3 @@ class InheritTest3(_base.MappedTest):
         self.assert_(repr(x) == compare)
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 93%
rename from test/orm/inheritance/poly_linked_list.py
rename to test/orm/inheritance/test_poly_linked_list.py
index 2cf0519494b51c07f1ca624b7450596b604995ee..67b543f31c6ab6255fe467a5390a3e8b04a0c39d 100644 (file)
@@ -1,15 +1,15 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
 
-from orm import _base
-from testlib import testing
+from test.orm import _base
+from sqlalchemy.test import testing
 
 
 class PolymorphicCircularTest(_base.MappedTest):
     run_setup_mappers = 'once'
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global Table1, Table1B, Table2, Table3,  Data
         table1 = Table('table1', metadata,
                        Column('id', Integer, primary_key=True),
@@ -115,26 +115,26 @@ class PolymorphicCircularTest(_base.MappedTest):
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
     def testone(self):
-        self.do_testlist([Table1, Table2, Table1, Table2])
+        self._testlist([Table1, Table2, Table1, Table2])
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
     def testtwo(self):
-        self.do_testlist([Table3])
+        self._testlist([Table3])
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
     def testthree(self):
-        self.do_testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1])
+        self._testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1])
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
     def testfour(self):
-        self.do_testlist([
+        self._testlist([
                 Table2('t2', [Data('data1'), Data('data2')]),
                 Table1('t1', []),
                 Table3('t3', [Data('data3')]),
                 Table1B('t1b', [Data('data4'), Data('data5')])
                 ])
 
-    def do_testlist(self, classes):
+    def _testlist(self, classes):
         sess = create_session( )
 
         # create objects in a linked list
@@ -195,5 +195,3 @@ class PolymorphicCircularTest(_base.MappedTest):
         # everything should match !
         assert original == forwards == backwards
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 90%
rename from test/orm/inheritance/polymorph.py
rename to test/orm/inheritance/test_polymorph.py
index 81f6c82a1ee211e431cdfd58f93b9821a1a5297f..cd3b2d89e31909fa5d6d1466b4819a6d4a952719 100644 (file)
@@ -1,11 +1,12 @@
 """tests basic polymorphic mapper loading/saving, minimal relations"""
 
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy.orm import exc as orm_exc
-from testlib import _function_named, Column, testing
-from orm import _fixtures, _base
+from sqlalchemy.test import Column, testing
+from sqlalchemy.util import function_named
+from test.orm import _fixtures, _base
 
 class Person(_fixtures.Base):
     pass
@@ -19,7 +20,8 @@ class Company(_fixtures.Base):
     pass
 
 class PolymorphTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global companies, people, engineers, managers, boss
 
         companies = Table('companies', metadata,
@@ -83,7 +85,7 @@ class InsertOrderTest(PolymorphTest):
         session.add(c)
         session.flush()
         session.expunge_all()
-        self.assertEquals(session.query(Company).get(c.company_id), c)
+        eq_(session.query(Company).get(c.company_id), c)
 
 class RelationToSubclassTest(PolymorphTest):
     def test_basic(self):
@@ -115,13 +117,13 @@ class RelationToSubclassTest(PolymorphTest):
         sess.flush()
         sess.expunge_all()
 
-        self.assertEquals(sess.query(Company).filter_by(company_id=c.company_id).one(), c)
+        eq_(sess.query(Company).filter_by(company_id=c.company_id).one(), c)
         assert c.managers[0].company is c
 
 class RoundTripTest(PolymorphTest):
     pass
 
-def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic):
+def _generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic):
     """generates a round trip test.
 
     include_base - whether or not to include the base 'person' type in the union.
@@ -205,15 +207,15 @@ def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with
         session.flush()
         session.expunge_all()
         
-        self.assertEquals(session.query(Person).get(dilbert.person_id), dilbert)
+        eq_(session.query(Person).get(dilbert.person_id), dilbert)
         session.expunge_all()
 
-        self.assertEquals(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert)
+        eq_(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert)
         session.expunge_all()
 
         def go():
             cc = session.query(Company).get(c.company_id)
-            self.assertEquals(cc.employees, employees)
+            eq_(cc.employees, employees)
             
         if not lazy_relation:
             if with_polymorphic != 'none':
@@ -229,14 +231,14 @@ def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with
         
         # test selecting from the query, using the base mapped table (people) as the selection criterion.
         # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join"
-        self.assertEquals(
+        eq_(
             session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first(),
             dilbert
         )
 
         assert session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first().person_id
 
-        self.assertEquals(
+        eq_(
             session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first(),
             dilbert
         )
@@ -268,22 +270,22 @@ def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with
         # test standalone orphans
         daboss = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'})
         session.add(daboss)
-        self.assertRaises(orm_exc.FlushError, session.flush)
+        assert_raises(orm_exc.FlushError, session.flush)
         c = session.query(Company).first()
         daboss.company = c
         manager_list = [e for e in c.employees if isinstance(e, Manager)]
         session.flush()
         session.expunge_all()
 
-        self.assertEquals(session.query(Manager).order_by(Manager.person_id).all(), manager_list)
+        eq_(session.query(Manager).order_by(Manager.person_id).all(), manager_list)
         c = session.query(Company).first()
         
         session.delete(c)
         session.flush()
         
-        self.assertEquals(people.count().scalar(), 0)
+        eq_(people.count().scalar(), 0)
         
-    test_roundtrip = _function_named(
+    test_roundtrip = function_named(
         test_roundtrip, "test_%s%s%s_%s" % (
           (lazy_relation and "lazy" or "eager"),
           (include_base and "_inclbase" or ""),
@@ -296,9 +298,7 @@ for lazy_relation in [True, False]:
         for with_polymorphic in ['unions', 'joins', 'auto', 'none']:
             if with_polymorphic == 'unions':
                 for include_base in [True, False]:
-                    generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic)
+                    _generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic)
             else:
-                generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic)
+                _generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic)
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 96%
rename from test/orm/inheritance/polymorph2.py
rename to test/orm/inheritance/test_polymorph2.py
index aec162b75ca46df1995f9da5205b3019704ff919..51b6d4970a5a92c6c55e623b69e5bb3e39967b50 100644 (file)
@@ -2,14 +2,15 @@
 inheritance setups for which we maintain compatibility.
 """
 
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 from sqlalchemy import *
 from sqlalchemy import util
 from sqlalchemy.orm import *
 
-from testlib import _function_named, TestBase, AssertsExecutionResults, testing
-from orm import _base, _fixtures
-from testlib.testing import eq_
+from sqlalchemy.test import TestBase, AssertsExecutionResults, testing
+from sqlalchemy.util import function_named
+from test.orm import _base, _fixtures
+from sqlalchemy.test.testing import eq_
 
 class AttrSettable(object):
     def __init__(self, **kwargs):
@@ -20,7 +21,8 @@ class AttrSettable(object):
 
 class RelationTest1(_base.MappedTest):
     """test self-referential relationships on polymorphic mappers"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, managers
 
         people = Table('people', metadata,
@@ -35,11 +37,11 @@ class RelationTest1(_base.MappedTest):
            Column('manager_name', String(50))
            )
 
-    def tearDown(self):
+    def teardown(self):
         people.update(values={people.c.manager_id:None}).execute()
-        super(RelationTest1, self).tearDown()
+        super(RelationTest1, self).teardown()
 
-    def testparentrefsdescendant(self):
+    def test_parent_refs_descendant(self):
         class Person(AttrSettable):
             pass
         class Manager(Person):
@@ -55,7 +57,7 @@ class RelationTest1(_base.MappedTest):
         mapper(Manager, managers, inherits=Person,
                inherit_condition=people.c.person_id==managers.c.person_id)
         
-        self.assertEquals(class_mapper(Person).get_property('manager').synchronize_pairs, [(managers.c.person_id,people.c.manager_id)])
+        eq_(class_mapper(Person).get_property('manager').synchronize_pairs, [(managers.c.person_id,people.c.manager_id)])
         
         session = create_session()
         p = Person(name='some person')
@@ -70,7 +72,7 @@ class RelationTest1(_base.MappedTest):
         print p, m, p.manager
         assert p.manager is m
 
-    def testdescendantrefsparent(self):
+    def test_descendant_refs_parent(self):
         class Person(AttrSettable):
             pass
         class Manager(Person):
@@ -99,7 +101,8 @@ class RelationTest1(_base.MappedTest):
 
 class RelationTest2(_base.MappedTest):
     """test self-referential relationships on polymorphic mappers"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, managers, data
         people = Table('people', metadata,
            Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
@@ -194,7 +197,8 @@ class RelationTest2(_base.MappedTest):
 
 class RelationTest3(_base.MappedTest):
     """test self-referential relationships on polymorphic mappers"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, managers, data
         people = Table('people', metadata,
            Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
@@ -212,7 +216,7 @@ class RelationTest3(_base.MappedTest):
            Column('data', String(30))
            )
 
-def generate_test(jointype="join1", usedata=False):
+def _generate_test(jointype="join1", usedata=False):
     def do_test(self):
         class Person(AttrSettable):
             pass
@@ -287,19 +291,20 @@ def generate_test(jointype="join1", usedata=False):
             assert p.data.data == 'ps data'
             assert m.data.data == 'ms data'
 
-    do_test = _function_named(
-        do_test, 'test_relationonbaseclass_%s_%s' % (
+    do_test = function_named(
+        do_test, 'test_relation_on_base_class_%s_%s' % (
         jointype, data and "nodata" or "data"))
     return do_test
 
 for jointype in ["join1", "join2", "join3", "join4"]:
     for data in (True, False):
-        func = generate_test(jointype, data)
+        func = _generate_test(jointype, data)
         setattr(RelationTest3, func.__name__, func)
-
+del func
 
 class RelationTest4(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, engineers, managers, cars
         people = Table('people', metadata,
            Column('person_id', Integer, primary_key=True),
@@ -411,7 +416,8 @@ class RelationTest4(_base.MappedTest):
         assert c.car_id==car1.car_id
 
 class RelationTest5(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, engineers, managers, cars
         people = Table('people', metadata,
            Column('person_id', Integer, primary_key=True),
@@ -472,7 +478,8 @@ class RelationTest5(_base.MappedTest):
 
 class RelationTest6(_base.MappedTest):
     """test self-referential relationships on a single joined-table inheritance mapper"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, managers, data
         people = Table('people', metadata,
            Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
@@ -514,7 +521,8 @@ class RelationTest6(_base.MappedTest):
         assert m.colleague is m2
 
 class RelationTest7(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, engineers, managers, cars, offroad_cars
         cars = Table('cars', metadata,
                 Column('car_id', Integer, primary_key=True),
@@ -613,7 +621,8 @@ class RelationTest7(_base.MappedTest):
             assert p.car_id == p.car.car_id
 
 class RelationTest8(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global taggable, users
         taggable = Table('taggable', metadata,
                          Column('id', Integer, primary_key=True),
@@ -658,7 +667,8 @@ class RelationTest8(_base.MappedTest):
         )
         
 class GenerativeTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         #  cars---owned by---  people (abstract) --- has a --- status
         #   |                  ^    ^                            |
         #   |                  |    |                            |
@@ -693,9 +703,10 @@ class GenerativeTest(TestBase, AssertsExecutionResults):
 
         metadata.create_all()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
-    def tearDown(self):
+    def teardown(self):
         clear_mappers()
         for t in reversed(metadata.sorted_tables):
             t.delete().execute()
@@ -784,7 +795,8 @@ class GenerativeTest(TestBase, AssertsExecutionResults):
         assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
 
 class MultiLevelTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global table_Employee, table_Engineer, table_Manager
         table_Employee = Table( 'Employee', metadata,
             Column( 'name', type_= String(100), ),
@@ -861,7 +873,8 @@ class MultiLevelTest(_base.MappedTest):
         assert session.query( Manager).all() == [c]
 
 class ManyToManyPolyTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global base_item_table, item_table, base_item_collection_table, collection_table
         base_item_table = Table(
             'base_item', metadata,
@@ -911,7 +924,8 @@ class ManyToManyPolyTest(_base.MappedTest):
         class_mapper(BaseItem)
 
 class CustomPKTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global t1, t2
         t1 = Table('t1', metadata,
             Column('id', Integer, primary_key=True),
@@ -994,7 +1008,8 @@ class CustomPKTest(_base.MappedTest):
         sess.flush()
 
 class InheritingEagerTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, employees, tags, peopleTags
 
         people = Table('people', metadata,
@@ -1055,7 +1070,8 @@ class InheritingEagerTest(_base.MappedTest):
         assert len(instance.tags) == 2
 
 class MissingPolymorphicOnTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global tablea, tableb, tablec, tabled
         tablea = Table('tablea', metadata, 
             Column('id', Integer, primary_key=True),
@@ -1101,7 +1117,5 @@ class MissingPolymorphicOnTest(_base.MappedTest):
         sess.add(d)
         sess.flush()
         sess.expunge_all()
-        self.assertEquals(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')])
+        eq_(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')])
         
-if __name__ == "__main__":
-    testenv.main()
similarity index 98%
rename from test/orm/inheritance/productspec.py
rename to test/orm/inheritance/test_productspec.py
index b6a8c514685cfd1912b0f9fb86cbcb25da6b1a44..b2bcb85d54b570b465729508dc492a635fb29b09 100644 (file)
@@ -1,16 +1,16 @@
-import testenv; testenv.configure_for_tests()
 from datetime import datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
 
 
-from testlib import testing
-from orm import _base
+from sqlalchemy.test import testing
+from test.orm import _base
 
 
 class InheritTest(_base.MappedTest):
     """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships"""
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global products_table, specification_table, documents_table
         global Product, Detail, Assembly, SpecLine, Document, RasterDocument
 
@@ -316,5 +316,3 @@ class InheritTest(_base.MappedTest):
         print new
         assert orig == new  == '<Assembly a1> specification=[<SpecLine 1.0 <Detail d1>>] documents=[<Document doc1>, <RasterDocument doc2>]'
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 79%
rename from test/orm/inheritance/query.py
rename to test/orm/inheritance/test_query.py
index 58d2054558619e13d4e433ae2e0141209c79310a..5b57e8f4575e2fe2a6ef62b70012b64904fdda31 100644 (file)
@@ -1,13 +1,13 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy import exc as sa_exc
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.engine import default
 
-from testlib import AssertsCompiledSQL, testing
-from orm import _base, _fixtures
-from testlib.testing import eq_
+from sqlalchemy.test import AssertsCompiledSQL, testing
+from test.orm import _base, _fixtures
+from sqlalchemy.test.testing import eq_
 
 class Company(_fixtures.Base):
     pass
@@ -27,13 +27,14 @@ class Machine(_fixtures.Base):
 class Paperwork(_fixtures.Base):
     pass
 
-def make_test(select_type):
+def _produce_test(select_type):
     class PolymorphicQueryTest(_base.MappedTest, AssertsCompiledSQL):
         run_inserts = 'once'
         run_setup_mappers = 'once'
         run_deletes = None
         
-        def define_tables(self, metadata):
+        @classmethod
+        def define_tables(cls, metadata):
             global companies, people, engineers, managers, boss, paperwork, machines
 
             companies = Table('companies', metadata,
@@ -128,7 +129,8 @@ def make_test(select_type):
             mapper(Paperwork, paperwork)
         
 
-        def insert_data(self):
+        @classmethod
+        def insert_data(cls):
             global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3, c1, c2
 
             c1 = Company(name="MegaCorp, Inc.")
@@ -178,14 +180,14 @@ def make_test(select_type):
             
             sess = create_session()
             def go():
-                self.assertEquals(sess.query(Person).all(), all_employees)
+                eq_(sess.query(Person).all(), all_employees)
             self.assert_sql_count(testing.db, go, {'':14, 'Polymorphic':9}.get(select_type, 10))
 
         def test_primary_eager_aliasing(self):
             sess = create_session()
             
             def go():
-                self.assertEquals(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3])
+                eq_(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3])
             self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4))
 
             sess = create_session()
@@ -194,7 +196,7 @@ def make_test(select_type):
             assert sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().subquery().count().scalar() == 2
 
             def go():
-                self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3])
+                eq_(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3])
             self.assert_sql_count(testing.db, go, 3)
             
             
@@ -203,9 +205,9 @@ def make_test(select_type):
             
             # for all mappers, ensure the primary key has been calculated as just the "person_id"
             # column
-            self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
-            self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
-            self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore"))
+            eq_(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
+            eq_(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
+            eq_(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore"))
         
         def test_multi_join(self):
             sess = create_session()
@@ -216,8 +218,8 @@ def make_test(select_type):
             q = sess.query(Company, Person, c, e).join((Person, Company.employees)).join((e, c.employees)).\
                     filter(Person.name=='dilbert').filter(e.name=='wally')
             
-            self.assertEquals(q.count(), 1)
-            self.assertEquals(q.all(), [
+            eq_(q.count(), 1)
+            eq_(q.all(), [
                 (
                     Company(company_id=1,name=u'MegaCorp, Inc.'), 
                     Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'),
@@ -228,99 +230,99 @@ def make_test(select_type):
             
         def test_filter_on_subclass(self):
             sess = create_session()
-            self.assertEquals(sess.query(Engineer).all()[0], Engineer(name="dilbert"))
+            eq_(sess.query(Engineer).all()[0], Engineer(name="dilbert"))
 
-            self.assertEquals(sess.query(Engineer).first(), Engineer(name="dilbert"))
+            eq_(sess.query(Engineer).first(), Engineer(name="dilbert"))
 
-            self.assertEquals(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert"))
+            eq_(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert"))
 
-            self.assertEquals(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert"))
+            eq_(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert"))
 
-            self.assertEquals(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss"))
+            eq_(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss"))
         
-            self.assertEquals(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss"))
+            eq_(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss"))
 
         def test_join_from_polymorphic(self):
             sess = create_session()
 
             for aliased in (True, False):
-                self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
+                eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
 
-                self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
+                eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
 
-                self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1])
+                eq_(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1])
 
-                self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+                eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
 
         def test_join_from_with_polymorphic(self):
             sess = create_session()
 
             for aliased in (True, False):
                 sess.expunge_all()
-                self.assertEquals(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
+                eq_(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
 
                 sess.expunge_all()
-                self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
+                eq_(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
 
                 sess.expunge_all()
-                self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+                eq_(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
     
         def test_join_to_polymorphic(self):
             sess = create_session()
-            self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2)
+            eq_(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2)
 
-            self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
+            eq_(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
 
         def test_polymorphic_any(self):
             sess = create_session()
 
-            self.assertEquals(
+            eq_(
                 sess.query(Company).\
                     filter(Company.employees.any(Person.name=='vlad')).all(), [c2]
             )
             
             # test that the aliasing on "Person" does not bleed into the
             # EXISTS clause generated by any()
-            self.assertEquals(
+            eq_(
                 sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\
                     filter(Company.employees.any(Person.name=='wally')).all(), [c1]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\
                     filter(Company.employees.any(Person.name=='vlad')).all(), []
             )
             
-            self.assertEquals(
+            eq_(
                 sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(),
                 c2
                 )
             
             calias = aliased(Company)
-            self.assertEquals(
+            eq_(
                 sess.query(calias).filter(calias.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(),
                 c2
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Company).filter(Company.employees.of_type(Boss).any(Boss.golf_swing=='fore')).one(),
                 c1
                 )
-            self.assertEquals(
+            eq_(
                 sess.query(Company).filter(Company.employees.of_type(Boss).any(Manager.manager_name=='pointy')).one(),
                 c1
                 )
 
             if select_type != '':
-                self.assertEquals(
+                eq_(
                     sess.query(Person).filter(Engineer.machines.any(Machine.name=="Commodore 64")).all(), [e2, e3]
                 )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Person).filter(Person.paperwork.any(Paperwork.description=="review #2")).all(), [m1]
             )
             
-            self.assertEquals(
+            eq_(
                 sess.query(Company).filter(Company.employees.of_type(Engineer).any(and_(Engineer.primary_language=='cobol'))).one(),
                 c2
                 )
@@ -328,48 +330,48 @@ def make_test(select_type):
         def test_join_from_columns_or_subclass(self):
             sess = create_session()
 
-            self.assertEquals(
+            eq_(
                 sess.query(Manager.name).order_by(Manager.name).all(),
                 [(u'dogbert',), (u'pointy haired boss',)]
             )
             
-            self.assertEquals(
+            eq_(
                 sess.query(Manager.name).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(),
                 [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Person.name).join((Paperwork, Person.paperwork)).order_by(Person.name).all(),
                 [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)]
             )
             
-            self.assertEquals(
+            eq_(
                 sess.query(Person.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Person.name).all(),
                 [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)]
             )
             
-            self.assertEquals(
+            eq_(
                 sess.query(Manager).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(),
                 [m1, b1]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Manager.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(),
                 [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Manager.person_id).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(),
                 [(4,), (4,), (3,)]
             )
             
-            self.assertEquals(
+            eq_(
                 sess.query(Manager.name, Paperwork.description).join((Paperwork, Manager.person_id==Paperwork.person_id)).all(),
                 [(u'pointy haired boss', u'review #1'), (u'dogbert', u'review #2'), (u'dogbert', u'review #3')]
             )
             
             malias = aliased(Manager)
-            self.assertEquals(
+            eq_(
                 sess.query(malias.name).join((paperwork, malias.person_id==paperwork.c.person_id)).all(),
                 [(u'pointy haired boss',), (u'dogbert',), (u'dogbert',)]
             )
@@ -391,9 +393,9 @@ def make_test(select_type):
             sess = create_session()
             
             
-            self.assertRaises(sa_exc.InvalidRequestError, sess.query(Person).with_polymorphic, Paperwork)
-            self.assertRaises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Boss)
-            self.assertRaises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Person)
+            assert_raises(sa_exc.InvalidRequestError, sess.query(Person).with_polymorphic, Paperwork)
+            assert_raises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Boss)
+            assert_raises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Person)
             
             # compare to entities without related collections to prevent additional lazy SQL from firing on 
             # loaded entities
@@ -404,32 +406,32 @@ def make_test(select_type):
                 Manager(name="dogbert", manager_name="dogbert", status="regular manager"),
                 Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer")
             ]
-            self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
+            eq_(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
             
             
             def go():
-                self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1])
+                eq_(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1])
             self.assert_sql_count(testing.db, go, 1)
             
             sess.expunge_all()
             def go():
-                self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
+                eq_(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
             self.assert_sql_count(testing.db, go, 1)
 
             sess.expunge_all()
             def go():
-                self.assertEquals(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations)
+                eq_(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations)
             self.assert_sql_count(testing.db, go, 3)
 
             sess.expunge_all()
             def go():
-                self.assertEquals(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations)
+                eq_(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations)
             self.assert_sql_count(testing.db, go, 3)
             
             sess.expunge_all()
             def go():
                 # limit the polymorphic join down to just "Person", overriding select_table
-                self.assertEquals(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations)
+                eq_(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations)
             self.assert_sql_count(testing.db, go, 6)
         
         def test_relation_to_polymorphic(self):
@@ -449,14 +451,14 @@ def make_test(select_type):
             
             def go():
                 # test load Companies with lazy load to 'employees'
-                self.assertEquals(sess.query(Company).all(), assert_result)
+                eq_(sess.query(Company).all(), assert_result)
             self.assert_sql_count(testing.db, go, {'':9, 'Polymorphic':4}.get(select_type, 5))
         
             sess = create_session()
             def go():
                 # currently, it doesn't matter if we say Company.employees, or Company.employees.of_type(Engineer).  eagerloader doesn't
                 # pick up on the "of_type()" as of yet.
-                self.assertEquals(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result)
+                eq_(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result)
             
             # in the case of select_type='', the eagerload doesn't take in this case; 
             # it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines"            
@@ -466,78 +468,78 @@ def make_test(select_type):
             sess = create_session()
             def go():
                 # test load People with eagerload to engineers + machines
-                self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(), 
+                eq_(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(), 
                 [Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])]
                 )
             self.assert_sql_count(testing.db, go, 1)
             
         def test_join_to_subclass(self):
             sess = create_session()
-            self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
+            eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
 
             if select_type == '':
-                self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
-                self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
+                eq_(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
+                eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
                 
                 ealias = aliased(Engineer)
-                self.assertEquals(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1])
+                eq_(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1])
 
-                self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3])
-                self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
-                self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2])
-                self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
+                eq_(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3])
+                eq_(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
+                eq_(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2])
+                eq_(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
             else:
-                self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
-                self.assertEquals(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1])
-                self.assertEquals(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3])
-                self.assertEquals(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
-                self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2])
-                self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
+                eq_(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
+                eq_(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1])
+                eq_(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3])
+                eq_(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
+                eq_(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2])
+                eq_(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
             
             # non-polymorphic
-            self.assertEquals(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3])
-            self.assertEquals(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
+            eq_(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3])
+            eq_(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
 
             # here's the new way
-            self.assertEquals(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1])
-            self.assertEquals(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
+            eq_(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1])
+            eq_(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
 
         def test_join_through_polymorphic(self):
 
             sess = create_session()
 
             for aliased in (True, False):
-                self.assertEquals(
+                eq_(
                     sess.query(Company).\
                         join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#2%')).all(),
                     [c1]
                 )
 
-                self.assertEquals(
+                eq_(
                     sess.query(Company).\
                         join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#%')).all(),
                     [c1, c2]
                 )
 
-                self.assertEquals(
+                eq_(
                     sess.query(Company).\
                         join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#2%')).all(),
                     [c1]
                 )
         
-                self.assertEquals(
+                eq_(
                     sess.query(Company).\
                         join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#%')).all(),
                     [c1, c2]
                 )
 
-                self.assertEquals(
+                eq_(
                     sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\
                         join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#2%')).all(),
                     [c1]
                 )
 
-                self.assertEquals(
+                eq_(
                     sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\
                         join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#%')).all(),
                     [c1, c2]
@@ -549,14 +551,14 @@ def make_test(select_type):
             # ORMJoin using regular table foreign key connections.  Engineer
             # is expressed as "(select * people join engineers) as anon_1"
             # so the join is contained.
-            self.assertEquals(
+            eq_(
                 sess.query(Company).join(Engineer).filter(Engineer.engineer_name=='vlad').one(),
                 c2
             )
 
             # same, using explicit join condition.  Query.join() must adapt the on clause
             # here to match the subquery wrapped around "people join engineers".
-            self.assertEquals(
+            eq_(
                 sess.query(Company).join((Engineer, Company.company_id==Engineer.company_id)).filter(Engineer.engineer_name=='vlad').one(),
                 c2
             )
@@ -565,17 +567,17 @@ def make_test(select_type):
         def test_filter_on_baseclass(self):
             sess = create_session()
 
-            self.assertEquals(sess.query(Person).all(), all_employees)
+            eq_(sess.query(Person).all(), all_employees)
 
-            self.assertEquals(sess.query(Person).first(), all_employees[0])
+            eq_(sess.query(Person).first(), all_employees[0])
         
-            self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2)
+            eq_(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2)
     
         def test_from_alias(self):
             sess = create_session()
             
             palias = aliased(Person)
-            self.assertEquals(
+            eq_(
                 sess.query(palias).filter(palias.name.in_(['dilbert', 'wally'])).all(),
                 [e1, e2]
             )
@@ -586,7 +588,7 @@ def make_test(select_type):
             c1_employees = [e1, e2, b1, m1]
             
             palias = aliased(Person)
-            self.assertEquals(
+            eq_(
                 sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
                     filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), 
                 [
@@ -596,7 +598,7 @@ def make_test(select_type):
                 ]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
                     filter(Person.person_id>palias.person_id).from_self().order_by(Person.person_id, palias.person_id).all(), 
                 [
@@ -614,30 +616,30 @@ def make_test(select_type):
             # the subquery and usually results in recursion overflow errors within the adaption.
             subq = sess.query(engineers.c.person_id).filter(Engineer.primary_language=='java').statement.as_scalar()
             
-            self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1)
+            eq_(sess.query(Person).filter(Person.person_id==subq).one(), e1)
             
         def test_mixed_entities(self):
             sess = create_session()
 
-            self.assertEquals(
+            eq_(
                 sess.query(Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
                 [(u'Elbonia, Inc.', 
                     Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'))]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Person, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
                 [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
                     u'Elbonia, Inc.')]
             )
             
             
-            self.assertEquals(
+            eq_(
                 sess.query(Manager.name).all(), 
                 [('pointy haired boss', ), ('dogbert',)]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Manager.name + " foo").all(), 
                 [('pointy haired boss foo', ), ('dogbert foo',)]
             )
@@ -647,12 +649,12 @@ def make_test(select_type):
             assert row.primary_language == 'java'
             
 
-            self.assertEquals(
+            eq_(
                 sess.query(Engineer.name, Engineer.primary_language).all(),
                 [(u'dilbert', u'java'), (u'wally', u'c++'), (u'vlad', u'cobol')]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Boss.name, Boss.golf_swing).all(),
                 [(u'pointy haired boss', u'fore')]
             )
@@ -670,18 +672,18 @@ def make_test(select_type):
             #    []
             # )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
                 [(u'vlad',u'Elbonia, Inc.')]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Engineer.primary_language).filter(Person.type=='engineer').all(),
                 [(u'java',), (u'c++',), (u'cobol',)]
             )
 
             if select_type != '':
-                self.assertEquals(
+                eq_(
                     sess.query(Engineer, Company.name).join(Company.employees).filter(Person.type=='engineer').all(),
                     [
                     (Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'MegaCorp, Inc.'), 
@@ -690,20 +692,20 @@ def make_test(select_type):
                     ]
                 )
             
-                self.assertEquals(
+                eq_(
                     sess.query(Engineer.primary_language, Company.name).join(Company.employees).filter(Person.type=='engineer').order_by(desc(Engineer.primary_language)).all(),
                     [(u'java', u'MegaCorp, Inc.'), (u'cobol', u'Elbonia, Inc.'), (u'c++', u'MegaCorp, Inc.')]
                 )
 
             palias = aliased(Person)
-            self.assertEquals(
+            eq_(
                 sess.query(Person, Company.name, palias).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
                 [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
                     u'Elbonia, Inc.', 
                     Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'))]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(palias, Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
                 [(Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'),
                     u'Elbonia, Inc.', 
@@ -711,13 +713,13 @@ def make_test(select_type):
                 ]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Person.name, Company.name, palias.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
                 [(u'vlad', u'Elbonia, Inc.', u'dilbert')]
             )
             
             palias = aliased(Person)
-            self.assertEquals(
+            eq_(
                 sess.query(Person.type, Person.name, palias.type, palias.name).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
                     filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), 
                 [(u'manager', u'dogbert', u'engineer', u'dilbert'), 
@@ -725,7 +727,7 @@ def make_test(select_type):
                 (u'manager', u'dogbert', u'boss', u'pointy haired boss')]
             )
         
-            self.assertEquals(
+            eq_(
                 sess.query(Person.name, Paperwork.description).filter(Person.person_id==Paperwork.person_id).order_by(Person.name, Paperwork.description).all(), 
                 [(u'dilbert', u'tps report #1'), (u'dilbert', u'tps report #2'), (u'dogbert', u'review #2'), 
                 (u'dogbert', u'review #3'), 
@@ -737,17 +739,17 @@ def make_test(select_type):
             )
 
             if select_type != '':
-                self.assertEquals(
+                eq_(
                     sess.query(func.count(Person.person_id)).filter(Engineer.primary_language=='java').all(), 
                     [(1, )]
                 )
             
-            self.assertEquals(
+            eq_(
                 sess.query(Company.name, func.count(Person.person_id)).filter(Company.company_id==Person.company_id).group_by(Company.name).order_by(Company.name).all(),
                 [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)]
             )
 
-            self.assertEquals(
+            eq_(
                 sess.query(Company.name, func.count(Person.person_id)).join(Company.employees).group_by(Company.name).order_by(Company.name).all(),
                 [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)]
             )
@@ -757,7 +759,7 @@ def make_test(select_type):
     return PolymorphicQueryTest
 
 for select_type in ('', 'Polymorphic', 'Unions', 'AliasedJoins', 'Joins'):
-    testclass = make_test(select_type)
+    testclass = _produce_test(select_type)
     exec("%s = testclass" % testclass.__name__)
     
 del testclass
@@ -765,7 +767,8 @@ del testclass
 class SelfReferentialTestJoinedToBase(_base.MappedTest):
     run_setup_mappers = 'once'
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, engineers
         people = Table('people', metadata,
            Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
@@ -778,7 +781,8 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest):
            Column('reports_to_id', Integer, ForeignKey('people.person_id'))
           )
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
         mapper(Engineer, engineers, inherits=Person, 
           inherit_condition=engineers.c.person_id==people.c.person_id,
@@ -796,7 +800,7 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest):
         sess.flush()
         sess.expunge_all()
         
-        self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert'))
+        eq_(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert'))
 
     def test_oftype_aliases_in_exists(self):
         e1 = Engineer(name='dilbert', primary_language='java')
@@ -805,7 +809,7 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest):
         sess.add_all([e1, e2])
         sess.flush()
         
-        self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.of_type(Engineer).has(Engineer.name=='dilbert')).first(), e2)
+        eq_(sess.query(Engineer).filter(Engineer.reports_to.of_type(Engineer).has(Engineer.name=='dilbert')).first(), e2)
         
     def test_join(self):
         p1 = Person(name='dogbert')
@@ -816,14 +820,15 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest):
         sess.flush()
         sess.expunge_all()
         
-        self.assertEquals(
+        eq_(
             sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), 
             Engineer(name='dilbert'))
 
 class SelfReferentialJ2JTest(_base.MappedTest):
     run_setup_mappers = 'once'
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, engineers, managers
         people = Table('people', metadata,
            Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
@@ -840,7 +845,8 @@ class SelfReferentialJ2JTest(_base.MappedTest):
             Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
         )
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
         mapper(Manager, managers, inherits=Person, polymorphic_identity='manager')
         
@@ -859,7 +865,7 @@ class SelfReferentialJ2JTest(_base.MappedTest):
         sess.flush()
         sess.expunge_all()
 
-        self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Manager.name=='dogbert')).first(), Engineer(name='dilbert'))
+        eq_(sess.query(Engineer).filter(Engineer.reports_to.has(Manager.name=='dogbert')).first(), Engineer(name='dilbert'))
 
     def test_join(self):
         m1 = Manager(name='dogbert')
@@ -870,7 +876,7 @@ class SelfReferentialJ2JTest(_base.MappedTest):
         sess.flush()
         sess.expunge_all()
 
-        self.assertEquals(
+        eq_(
             sess.query(Engineer).join('reports_to', aliased=True).filter(Manager.name=='dogbert').first(), 
             Engineer(name='dilbert'))
     
@@ -886,17 +892,17 @@ class SelfReferentialJ2JTest(_base.MappedTest):
         sess.expunge_all()
 
         # filter aliasing applied to Engineer doesn't whack Manager
-        self.assertEquals(
+        eq_(
             sess.query(Manager).join(Manager.engineers).filter(Manager.name=='dogbert').all(),
             [m1]
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(Manager).join(Manager.engineers).filter(Engineer.name=='dilbert').all(),
             [m2]
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(Manager, Engineer).join(Manager.engineers).order_by(Manager.name.desc()).all(),
             [
                 (m2, e2),
@@ -919,12 +925,12 @@ class SelfReferentialJ2JTest(_base.MappedTest):
         sess.flush()
         sess.expunge_all()
 
-        self.assertEquals(
+        eq_(
             sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==None).all(), 
             []
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==m1).all(), 
             [m1]
         )
@@ -936,7 +942,8 @@ class M2MFilterTest(_base.MappedTest):
     run_inserts = 'once'
     run_deletes = None
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, engineers, organizations, engineers_to_org
         
         organizations = Table('organizations', metadata,
@@ -958,7 +965,8 @@ class M2MFilterTest(_base.MappedTest):
            Column('primary_language', String(50)),
           )
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         global Organization
         class Organization(_fixtures.Base):
             pass
@@ -970,7 +978,8 @@ class M2MFilterTest(_base.MappedTest):
         mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
         mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
     
-    def insert_data(self):
+    @classmethod
+    def insert_data(cls):
         e1 = Engineer(name='e1')
         e2 = Engineer(name='e2')
         e3 = Engineer(name='e3')
@@ -989,20 +998,21 @@ class M2MFilterTest(_base.MappedTest):
         e1 = sess.query(Person).filter(Engineer.name=='e1').one()
         
         # this works
-        self.assertEquals(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')])
+        eq_(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')])
 
         # this had a bug
-        self.assertEquals(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')])
+        eq_(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')])
     
     def test_any(self):
         sess = create_session()
-        self.assertEquals(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')])
-        self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')])
+        eq_(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')])
+        eq_(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')])
 
 class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL):
     run_setup_mappers = 'once'
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global Parent, Child1, Child2
 
         Base = declarative_base(metadata=metadata)
@@ -1101,5 +1111,3 @@ class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL):
         
         assert q.first() is c1
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 88%
rename from test/orm/inheritance/selects.py
rename to test/orm/inheritance/test_selects.py
index e54a0ad13f92328b7cbfe92646ec8be0c68be38b..a151af4fa29f129b655935a00f257cc673bdc9f7 100644 (file)
@@ -1,14 +1,14 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
 
-from testlib import testing
-from orm._fixtures import Base
-from orm._base import MappedTest
+from sqlalchemy.test import testing
+from test.orm._fixtures import Base
+from test.orm._base import MappedTest
 
 
 class InheritingSelectablesTest(MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global foo, bar, baz
         foo = Table('foo', metadata,
                     Column('a', String(30), primary_key=1),
@@ -49,5 +49,3 @@ class InheritingSelectablesTest(MappedTest):
         assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all()
         assert [Bar(), Bar()] == s.query(Bar).all()
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 85%
rename from test/orm/inheritance/single.py
rename to test/orm/inheritance/test_single.py
index 7aee250318f7c89231728d744dcb9126a5778ae9..70582688576f7cfaea463d0d3e96cff18de92d50 100644 (file)
@@ -1,14 +1,15 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 from sqlalchemy import *
 from sqlalchemy.orm import *
 
-from testlib import testing
-from orm import _fixtures
-from orm._base import MappedTest, ComparableEntity
+from sqlalchemy.test import testing
+from test.orm import _fixtures
+from test.orm._base import MappedTest, ComparableEntity
 
 
 class SingleInheritanceTest(MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('employees', metadata,
             Column('employee_id', Integer, primary_key=True),
             Column('name', String(50)),
@@ -22,7 +23,8 @@ class SingleInheritanceTest(MappedTest):
               Column('name', String(50)),
         )
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Employee(ComparableEntity):
             pass
         class Manager(Employee):
@@ -32,8 +34,9 @@ class SingleInheritanceTest(MappedTest):
         class JuniorEngineer(Engineer):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Employee, employees, polymorphic_on=employees.c.type)
         mapper(Manager, inherits=Employee, polymorphic_identity='manager')
         mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
@@ -57,7 +60,7 @@ class SingleInheritanceTest(MappedTest):
         
         m1 = session.query(Manager).one()
         session.expire(m1, ['manager_data'])
-        self.assertEquals(m1.manager_data, "knows how to manage things")
+        eq_(m1.manager_data, "knows how to manage things")
 
         row = session.query(Engineer.name, Engineer.employee_id).filter(Engineer.name=='Kurt').first()
         assert row.name == 'Kurt'
@@ -75,32 +78,32 @@ class SingleInheritanceTest(MappedTest):
         session.flush()
 
         ealias = aliased(Engineer)
-        self.assertEquals(
+        eq_(
             session.query(Manager, ealias).all(), 
             [(m1, e1), (m1, e2)]
         )
     
-        self.assertEquals(
+        eq_(
             session.query(Manager.name).all(),
             [("Tom",)]
         )
 
-        self.assertEquals(
+        eq_(
             session.query(Manager.name, ealias.name).all(),
             [("Tom", "Kurt"), ("Tom", "Ed")]
         )
 
-        self.assertEquals(
+        eq_(
             session.query(func.upper(Manager.name), func.upper(ealias.name)).all(),
             [("TOM", "KURT"), ("TOM", "ED")]
         )
 
-        self.assertEquals(
+        eq_(
             session.query(Manager).add_entity(ealias).all(),
             [(m1, e1), (m1, e2)]
         )
         
-        self.assertEquals(
+        eq_(
             session.query(Manager.name).add_column(ealias.name).all(),
             [("Tom", "Kurt"), ("Tom", "Ed")]
         )
@@ -121,7 +124,7 @@ class SingleInheritanceTest(MappedTest):
         sess.add_all([m1, m2, e1, e2])
         sess.flush()
         
-        self.assertEquals(
+        eq_(
             sess.query(Manager).select_from(employees.select().limit(10)).all(), 
             [m1, m2]
         )
@@ -136,12 +139,12 @@ class SingleInheritanceTest(MappedTest):
         sess.add_all([m1, m2, e1, e2])
         sess.flush()
 
-        self.assertEquals(sess.query(Manager).count(), 2)
-        self.assertEquals(sess.query(Engineer).count(), 2)
-        self.assertEquals(sess.query(Employee).count(), 4)
+        eq_(sess.query(Manager).count(), 2)
+        eq_(sess.query(Engineer).count(), 2)
+        eq_(sess.query(Employee).count(), 4)
         
-        self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2)
-        self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3)
+        eq_(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2)
+        eq_(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3)
 
     @testing.resolve_artifact_names
     def test_type_filtering(self):
@@ -180,7 +183,8 @@ class SingleInheritanceTest(MappedTest):
 
 
 class RelationToSingleTest(MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('employees', metadata,
             Column('employee_id', Integer, primary_key=True),
             Column('name', String(50)),
@@ -195,7 +199,8 @@ class RelationToSingleTest(MappedTest):
             Column('name', String(50)),
         )
     
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Company(ComparableEntity):
             pass
             
@@ -229,14 +234,14 @@ class RelationToSingleTest(MappedTest):
         sess.add_all([c1, c2, m1, m2, e1, e2])
         sess.commit()
         sess.expunge_all()
-        self.assertEquals(
+        eq_(
             sess.query(Company).filter(Company.employees.of_type(JuniorEngineer).any()).all(),
             [
                 Company(name='c1'),
             ]
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(Company).join(Company.employees.of_type(JuniorEngineer)).all(),
             [
                 Company(name='c1'),
@@ -267,11 +272,11 @@ class RelationToSingleTest(MappedTest):
         sess.add_all([c1, c2, m1, m2, e1, e2])
         sess.commit()
 
-        self.assertEquals(c1.engineers, [e2])
-        self.assertEquals(c2.engineers, [e1])
+        eq_(c1.engineers, [e2])
+        eq_(c2.engineers, [e1])
         
         sess.expunge_all()
-        self.assertEquals(sess.query(Company).order_by(Company.name).all(), 
+        eq_(sess.query(Company).order_by(Company.name).all(), 
             [
                 Company(name='c1', engineers=[JuniorEngineer(name='Ed')]),
                 Company(name='c2', engineers=[Engineer(name='Kurt')])
@@ -280,7 +285,7 @@ class RelationToSingleTest(MappedTest):
 
         # eager load join should limit to only "Engineer"
         sess.expunge_all()
-        self.assertEquals(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(), 
+        eq_(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(), 
             [
                 Company(name='c1', engineers=[JuniorEngineer(name='Ed')]),
                 Company(name='c2', engineers=[Engineer(name='Kurt')])
@@ -289,7 +294,7 @@ class RelationToSingleTest(MappedTest):
 
         # join() to Company.engineers, Employee as the requested entity
         sess.expunge_all()
-        self.assertEquals(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(),
+        eq_(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(),
             [
                 (Company(name='c1'), JuniorEngineer(name='Ed')),
                 (Company(name='c2'), Engineer(name='Kurt'))
@@ -299,7 +304,7 @@ class RelationToSingleTest(MappedTest):
         # join() to Company.engineers, Engineer as the requested entity.
         # this actually applies the IN criterion twice which is less than ideal.
         sess.expunge_all()
-        self.assertEquals(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(),
+        eq_(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(),
             [
                 (Company(name='c1'), JuniorEngineer(name='Ed')),
                 (Company(name='c2'), Engineer(name='Kurt'))
@@ -308,7 +313,7 @@ class RelationToSingleTest(MappedTest):
 
         # join() to Company.engineers without any Employee/Engineer entity
         sess.expunge_all()
-        self.assertEquals(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(),
+        eq_(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(),
             [
                 Company(name='c2')
             ]
@@ -323,7 +328,7 @@ class RelationToSingleTest(MappedTest):
         @testing.fails_on_everything_except()
         def go():
             sess.expunge_all()
-            self.assertEquals(sess.query(Company).\
+            eq_(sess.query(Company).\
                 filter(Company.company_id==Engineer.company_id).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(),
                 [
                     Company(name='c2')
@@ -332,7 +337,8 @@ class RelationToSingleTest(MappedTest):
         go()
         
 class SingleOnJoinedTest(MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global persons_table, employees_table
         
         persons_table = Table('persons', metadata,
@@ -366,31 +372,29 @@ class SingleOnJoinedTest(MappedTest):
         sess.flush()
         sess.expunge_all()
         
-        self.assertEquals(sess.query(Person).order_by(Person.person_id).all(), [
+        eq_(sess.query(Person).order_by(Person.person_id).all(), [
             Person(name='p1'),
             Employee(name='e1', employee_data='ed1'),
             Manager(name='m1', employee_data='ed2', manager_data='md1')
         ])
         sess.expunge_all()
 
-        self.assertEquals(sess.query(Employee).order_by(Person.person_id).all(), [
+        eq_(sess.query(Employee).order_by(Person.person_id).all(), [
             Employee(name='e1', employee_data='ed1'),
             Manager(name='m1', employee_data='ed2', manager_data='md1')
         ])
         sess.expunge_all()
 
-        self.assertEquals(sess.query(Manager).order_by(Person.person_id).all(), [
+        eq_(sess.query(Manager).order_by(Person.person_id).all(), [
             Manager(name='m1', employee_data='ed2', manager_data='md1')
         ])
         sess.expunge_all()
         
         def go():
-            self.assertEquals(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [
+            eq_(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [
                 Person(name='p1'),
                 Employee(name='e1', employee_data='ed1'),
                 Manager(name='m1', employee_data='ed2', manager_data='md1')
             ])
         self.assert_sql_count(testing.db, go, 1)
     
-if __name__ == '__main__':
-    testenv.main()
diff --git a/test/orm/sharding/alltests.py b/test/orm/sharding/alltests.py
deleted file mode 100644 (file)
index 09fa862..0000000
+++ /dev/null
@@ -1,18 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-def suite():
-    modules_to_test = (
-        'orm.sharding.shard',
-        )
-    alltests = unittest.TestSuite()
-    for name in modules_to_test:
-        mod = __import__(name)
-        for token in name.split('.')[1:]:
-            mod = getattr(mod, token)
-        alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    return alltests
-
-
-if __name__ == '__main__':
-    testenv.main(suite())
similarity index 94%
rename from test/orm/sharding/shard.py
rename to test/orm/sharding/test_shard.py
index 10aaee131b57656dcca47f16acaa8ca04c72b375..89e23fb759139bff3a55c8aa9aa6b665eac5608c 100644 (file)
@@ -1,17 +1,17 @@
-import testenv; testenv.configure_for_tests()
 import datetime, os
 from sqlalchemy import *
 from sqlalchemy import sql
 from sqlalchemy.orm import *
 from sqlalchemy.orm.shard import ShardedSession
 from sqlalchemy.sql import operators
-from testlib import *
-from testlib.testing import eq_
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
 
 # TODO: ShardTest can be turned into a base for further subclasses
 
 class ShardTest(TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global db1, db2, db3, db4, weather_locations, weather_reports
 
         db1 = create_engine('sqlite:///shard1.db')
@@ -48,16 +48,18 @@ class ShardTest(TestBase):
 
         db1.execute(ids.insert(), nextid=1)
 
-        self.setup_session()
-        self.setup_mappers()
+        cls.setup_session()
+        cls.setup_mappers()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         for db in (db1, db2, db3, db4):
             db.connect().invalidate()
         for i in range(1,5):
             os.remove("shard%d.db" % i)
 
-    def setup_session(self):
+    @classmethod
+    def setup_session(cls):
         global create_session
 
         shard_lookup = {
@@ -104,7 +106,8 @@ class ShardTest(TestBase):
         }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser)
 
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         global WeatherLocation, Report
 
         class WeatherLocation(object):
@@ -159,5 +162,3 @@ class ShardTest(TestBase):
 
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 89%
rename from test/orm/association.py
rename to test/orm/test_association.py
index d9265ffb104b3b09ccc9e0b67ffb43fd9e808b49..ee7fb7af94ac29744abadc92298fd24693b40ecb 100644 (file)
@@ -1,17 +1,19 @@
-import testenv; testenv.configure_for_tests()
 
-from testlib import testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from orm import _base
-from testlib.testing import eq_
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from test.orm import _base
+from sqlalchemy.test.testing import eq_
 
 
 class AssociationTest(_base.MappedTest):
     run_setup_classes = 'once'
     run_setup_mappers = 'once'
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('items', metadata,
             Column('item_id', Integer, primary_key=True),
             Column('name', String(40)))
@@ -23,7 +25,8 @@ class AssociationTest(_base.MappedTest):
             Column('keyword_id', Integer, primary_key=True),
             Column('name', String(40)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Item(_base.BasicEntity):
             def __init__(self, name):
                 self.name = name
@@ -45,9 +48,10 @@ class AssociationTest(_base.MappedTest):
                 return "KeywordAssociation itemid=%d keyword=%r data=%s" % (
                     self.item_id, self.keyword, self.data)
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
-        items, item_keywords, keywords = self.tables.get_all(
+    def setup_mappers(cls):
+        items, item_keywords, keywords = cls.tables.get_all(
             'items', 'item_keywords', 'keywords')
 
         mapper(Keyword, keywords)
@@ -133,13 +137,11 @@ class AssociationTest(_base.MappedTest):
         item2.keywords.append(KeywordAssociation(Keyword('green'), 'green_assoc'))
         sess.add_all((item1, item2))
         sess.flush()
-        eq_(self.tables.item_keywords.count().scalar(), 3)
+        eq_(item_keywords.count().scalar(), 3)
 
         sess.delete(item1)
         sess.delete(item2)
         sess.flush()
-        eq_(self.tables.item_keywords.count().scalar(), 0)
+        eq_(item_keywords.count().scalar(), 0)
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 93%
rename from test/orm/assorted_eager.py
rename to test/orm/test_assorted_eager.py
index 8dc95fa5b23e9895640f0ad40d7d05eaf2f759ff..09f0075479aaca48c28207714ecf4bce408bbeab 100644 (file)
@@ -3,21 +3,25 @@
 Derived from mailing list-reported problems and trac tickets.
 
 """
-import testenv; testenv.configure_for_tests()
 import datetime
 
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, backref, create_session
-from testlib.testing import eq_
-from orm import _base
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, backref, create_session
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
 
 
 class EagerTest(_base.MappedTest):
     run_deletes = None
     run_inserts = "once"
+    run_setup_mappers = "once"
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         # determine a literal value for "false" based on the dialect
         # FIXME: this DefaultClause setup is bogus.
 
@@ -30,7 +34,7 @@ class EagerTest(_base.MappedTest):
             false = text('FALSE')
         else:
             false = str(False)
-        self.other_artifacts['false'] = false
+        cls.other_artifacts['false'] = false
 
         Table('owners', metadata ,
               Column('id', Integer, primary_key=True, nullable=False),
@@ -55,30 +59,32 @@ class EagerTest(_base.MappedTest):
               Column('someoption', sa.Boolean, server_default=false,
                      nullable=False))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Owner(_base.BasicEntity):
             pass
 
         class Category(_base.BasicEntity):
             pass
 
-        class Test(_base.BasicEntity):
+        class Thing(_base.BasicEntity):
             pass
 
         class Option(_base.BasicEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Owner, owners)
 
         mapper(Category, categories)
 
         mapper(Option, options, properties=dict(
             owner=relation(Owner),
-            test=relation(Test)))
+            test=relation(Thing)))
 
-        mapper(Test, tests, properties=dict(
+        mapper(Thing, tests, properties=dict(
             owner=relation(Owner, backref='tests'),
             category=relation(Category),
             owner_option=relation(Option,
@@ -87,16 +93,17 @@ class EagerTest(_base.MappedTest):
                 foreign_keys=[options.c.test_id, options.c.owner_id],
                 uselist=False)))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         session = create_session()
 
         o = Owner()
         c = Category(name='Some Category')
         session.add_all((
-            Test(owner=o, category=c),
-            Test(owner=o, category=c, owner_option=Option(someoption=True)),
-            Test(owner=o, category=c, owner_option=Option())))
+            Thing(owner=o, category=c),
+            Thing(owner=o, category=c, owner_option=Option(someoption=True)),
+            Thing(owner=o, category=c, owner_option=Option())))
 
         session.flush()
 
@@ -129,7 +136,7 @@ class EagerTest(_base.MappedTest):
     @testing.resolve_artifact_names
     def test_withouteagerload(self):
         s = create_session()
-        l = (s.query(Test).
+        l = (s.query(Thing).
              select_from(tests.outerjoin(options,
                                          sa.and_(tests.c.id == options.c.test_id,
                                                  tests.c.owner_id ==
@@ -150,7 +157,7 @@ class EagerTest(_base.MappedTest):
 
         """
         s = create_session()
-        q=s.query(Test).options(sa.orm.eagerload('category'))
+        q=s.query(Thing).options(sa.orm.eagerload('category'))
 
         l=(q.select_from(tests.outerjoin(options,
                                          sa.and_(tests.c.id ==
@@ -168,7 +175,7 @@ class EagerTest(_base.MappedTest):
     def test_dslish(self):
         """test the same as witheagerload except using generative"""
         s = create_session()
-        q = s.query(Test).options(sa.orm.eagerload('category'))
+        q = s.query(Thing).options(sa.orm.eagerload('category'))
         l = q.filter (
             sa.and_(tests.c.owner_id == 1,
                     sa.or_(options.c.someoption == None,
@@ -182,7 +189,7 @@ class EagerTest(_base.MappedTest):
     @testing.resolve_artifact_names
     def test_without_outerjoin_literal(self):
         s = create_session()
-        q = s.query(Test).options(sa.orm.eagerload('category'))
+        q = s.query(Thing).options(sa.orm.eagerload('category'))
         l = (q.filter(
             (tests.c.owner_id==1) &
             ('options.someoption is null or options.someoption=%s' % false)).
@@ -194,7 +201,7 @@ class EagerTest(_base.MappedTest):
     @testing.resolve_artifact_names
     def test_withoutouterjoin(self):
         s = create_session()
-        q = s.query(Test).options(sa.orm.eagerload('category'))
+        q = s.query(Thing).options(sa.orm.eagerload('category'))
         l = q.filter(
             (tests.c.owner_id==1) &
             ((options.c.someoption==None) | (options.c.someoption==False))
@@ -205,7 +212,8 @@ class EagerTest(_base.MappedTest):
 
 
 class EagerTest2(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('left', metadata,
             Column('id', Integer, ForeignKey('middle.id'), primary_key=True),
             Column('data', String(50), primary_key=True))
@@ -218,7 +226,8 @@ class EagerTest2(_base.MappedTest):
             Column('id', Integer, ForeignKey('middle.id'), primary_key=True),
             Column('data', String(50), primary_key=True))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Left(_base.BasicEntity):
             def __init__(self, data):
                 self.data = data
@@ -231,8 +240,9 @@ class EagerTest2(_base.MappedTest):
             def __init__(self, data):
                 self.data = data
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         # set up bi-directional eager loads
         mapper(Left, left)
         mapper(Right, right)
@@ -267,7 +277,8 @@ class EagerTest2(_base.MappedTest):
 class EagerTest3(_base.MappedTest):
     """Eager loading combined with nested SELECT statements, functions, and aggregates."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('datas', metadata,
               Column('id', Integer, primary_key=True, nullable=False),
               Column('a', Integer, nullable=False))
@@ -283,7 +294,8 @@ class EagerTest3(_base.MappedTest):
               Column('data_id', Integer, ForeignKey('datas.id')),
               Column('somedata', Integer, nullable=False ))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Data(_base.BasicEntity):
             pass
 
@@ -349,7 +361,8 @@ class EagerTest3(_base.MappedTest):
 
 class EagerTest4(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('departments', metadata,
               Column('department_id', Integer, primary_key=True),
               Column('name', String(50)))
@@ -360,7 +373,8 @@ class EagerTest4(_base.MappedTest):
               Column('department_id', Integer,
                      ForeignKey('departments.department_id')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Department(_base.BasicEntity):
             pass
 
@@ -401,7 +415,8 @@ class EagerTest4(_base.MappedTest):
 class EagerTest5(_base.MappedTest):
     """Construction of AliasedClauses for the same eager load property but different parent mappers, due to inheritance."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('base', metadata,
               Column('uid', String(30), primary_key=True),
               Column('x', String(30)))
@@ -421,7 +436,8 @@ class EagerTest5(_base.MappedTest):
               Column('uid', String(30), ForeignKey('base.uid')),
               Column('comment', String(30)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Base(_base.BasicEntity):
             def __init__(self, uid, x):
                 self.uid = uid
@@ -486,7 +502,8 @@ class EagerTest5(_base.MappedTest):
 
 class EagerTest6(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('design_types', metadata,
             Column('design_type_id', Integer, primary_key=True))
 
@@ -506,7 +523,8 @@ class EagerTest6(_base.MappedTest):
               Column('part_id', Integer, ForeignKey('parts.part_id')),
               Column('design_id', Integer, ForeignKey('design.design_id')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Part(_base.BasicEntity):
             pass
 
@@ -552,7 +570,8 @@ class EagerTest6(_base.MappedTest):
 
 
 class EagerTest7(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('companies', metadata,
               Column('company_id', Integer, primary_key=True,
                      test_needs_autoincrement=True),
@@ -584,7 +603,8 @@ class EagerTest7(_base.MappedTest):
               Column('code', String(20)),
               Column('qty', Integer))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Company(_base.ComparableEntity):
             pass
 
@@ -699,7 +719,8 @@ class EagerTest7(_base.MappedTest):
 
 class EagerTest8(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('prj', metadata,
               Column('id', Integer, primary_key=True),
               Column('created', sa.DateTime ),
@@ -731,8 +752,9 @@ class EagerTest8(_base.MappedTest):
               Column('name', sa.Unicode(20)),
               Column('display_name', sa.Unicode(20)))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def fixtures(self):
+    def fixtures(cls):
         return dict(
             prj=(('id',),
                  (1,)),
@@ -746,7 +768,8 @@ class EagerTest8(_base.MappedTest):
             task=(('title', 'task_type_id', 'status_id', 'prj_id'),
                   (u'task 1', 1, 1, 1)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Task_Type(_base.BasicEntity):
             pass
 
@@ -788,7 +811,8 @@ class EagerTest9(_base.MappedTest):
     throughout the query setup/mapper instances process.
 
     """
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('accounts', metadata,
             Column('account_id', Integer, primary_key=True),
             Column('name', String(40)))
@@ -805,7 +829,8 @@ class EagerTest9(_base.MappedTest):
             Column('transaction_id', Integer,
                    ForeignKey('transactions.transaction_id')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Account(_base.BasicEntity):
             pass
 
@@ -815,8 +840,9 @@ class EagerTest9(_base.MappedTest):
         class Entry(_base.BasicEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Account, accounts)
 
         mapper(Transaction, transactions)
@@ -874,5 +900,3 @@ class EagerTest9(_base.MappedTest):
 
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 99%
rename from test/orm/attributes.py
rename to test/orm/test_attributes.py
index 7c116fcf78034e54ff931dabb6c014c40115071a..3b1b42dadcdf6334c27d2b0d328f79e2ef3b560f 100644 (file)
@@ -1,12 +1,11 @@
-import testenv; testenv.configure_for_tests()
 import pickle
 import sqlalchemy.orm.attributes as attributes
 from sqlalchemy.orm.collections import collection
 from sqlalchemy.orm.interfaces import AttributeExtension
 from sqlalchemy import exc as sa_exc
-from testlib import *
-from testlib.testing import eq_
-from orm import _base
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
 import gc
 
 # global for pickling tests
@@ -15,12 +14,12 @@ MyTest2 = None
 
 
 class AttributesTest(_base.ORMTest):
-    def setUp(self):
+    def setup(self):
         global MyTest, MyTest2
         class MyTest(object): pass
         class MyTest2(object): pass
 
-    def tearDown(self):
+    def teardown(self):
         global MyTest, MyTest2
         MyTest, MyTest2 = None, None
 
@@ -588,7 +587,7 @@ class BackrefTest(_base.ORMTest):
         self.assert_(p.jack is None)
 
 class PendingBackrefTest(_base.ORMTest):
-    def setUp(self):
+    def setup(self):
         global Post, Blog, called, lazy_load
 
         class Post(object):
@@ -1327,5 +1326,3 @@ class ListenerTest(_base.ORMTest):
         assert f1.barset.pop().data == "some bar appended"
     
     
-if __name__ == "__main__":
-    testenv.main()
similarity index 71%
rename from test/orm/bind.py
rename to test/orm/test_bind.py
index 33d028d22ec74777fc09ee7d2f2acc747720c8e4..9b1c20b6056395f07046b2fb8132dfe7ff823554 100644 (file)
@@ -1,23 +1,29 @@
-import testenv; testenv.configure_for_tests()
-from testlib.sa import MetaData, Table, Column, Integer
-from testlib.sa.orm import mapper, create_session
-from testlib import sa, testing
-from orm import _base
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+from sqlalchemy import MetaData, Integer
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, create_session
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from test.orm import _base
 
 
 class BindTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('test_table', metadata,
               Column('id', Integer, primary_key=True,
                      test_needs_autoincrement=True),
               Column('data', Integer))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Foo(_base.BasicEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         meta = MetaData()
         test_table.tometadata(meta)
 
@@ -44,12 +50,10 @@ class BindTest(_base.MappedTest):
     def test_session_unbound(self):
         sess = create_session()
         sess.add(Foo())
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.UnboundExecutionError,
             ('Could not locate a bind configured on Mapper|Foo|test_table '
              'or this Session'),
             sess.flush)
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 95%
rename from test/orm/cascade.py
rename to test/orm/test_cascade.py
index c827a85cedec8eaf6d84f9c83fe36f77d9cd1622..d0a7b9ded6956fd4b10a319f4beb7b7fb030ded9 100644 (file)
@@ -1,18 +1,21 @@
-import testenv; testenv.configure_for_tests()
 
-from testlib.sa import Table, Column, Integer, String, ForeignKey, Sequence, exc as sa_exc
-from testlib.sa.orm import mapper, relation, create_session, class_mapper, backref
-from testlib.sa.orm import attributes, exc as orm_exc
-from testlib import testing
-from testlib.testing import eq_
-from orm import _base, _fixtures
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref
+from sqlalchemy.orm import attributes, exc as orm_exc
+from sqlalchemy.test import testing
+from sqlalchemy.test.testing import eq_
+from test.orm import _base, _fixtures
 
 
 class O2MCascadeTest(_fixtures.FixtureTest):
     run_inserts = None
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Address, addresses)
         mapper(User, users, properties = dict(
             addresses = relation(Address, cascade="all, delete-orphan", backref="user"),
@@ -188,8 +191,9 @@ class O2MCascadeTest(_fixtures.FixtureTest):
 class O2OCascadeTest(_fixtures.FixtureTest):
     run_inserts = None
     
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Address, addresses)
         mapper(User, users, properties = {
             'address':relation(Address, backref=backref("user", single_parent=True), uselist=False)
@@ -200,7 +204,7 @@ class O2OCascadeTest(_fixtures.FixtureTest):
         a1 = Address(email_address='some address')
         u1 = User(name='u1', address=a1)
         
-        self.assertRaises(sa_exc.InvalidRequestError, Address, email_address='asd', user=u1)
+        assert_raises(sa_exc.InvalidRequestError, Address, email_address='asd', user=u1)
         
         a2 = Address(email_address='asd')
         u1.address = a2
@@ -212,8 +216,9 @@ class O2OCascadeTest(_fixtures.FixtureTest):
 class O2MBackrefTest(_fixtures.FixtureTest):
     run_inserts = None
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(User, users, properties = dict(
             orders = relation(
                 mapper(Order, orders), cascade="all, delete-orphan", backref="user")
@@ -316,8 +321,9 @@ class NoSaveCascadeTest(_fixtures.FixtureTest):
 class O2MCascadeNoOrphanTest(_fixtures.FixtureTest):
     run_inserts = None
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(User, users, properties = dict(
             orders = relation(
                 mapper(Order, orders), cascade="all")
@@ -342,7 +348,8 @@ class O2MCascadeNoOrphanTest(_fixtures.FixtureTest):
 
 
 class M2OCascadeTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("extra", metadata,
             Column("id", Integer, Sequence("extra_id_seq", optional=True),
                    primary_key=True),
@@ -359,7 +366,8 @@ class M2OCascadeTest(_base.MappedTest):
             Column('name', String(40)),
             Column('pref_id', Integer, ForeignKey('prefs.id')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class User(_fixtures.Base):
             pass
         class Pref(_fixtures.Base):
@@ -367,8 +375,9 @@ class M2OCascadeTest(_base.MappedTest):
         class Extra(_fixtures.Base):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Extra, extra)
         mapper(Pref, prefs, properties=dict(
             extra = relation(Extra, cascade="all, delete")
@@ -377,8 +386,9 @@ class M2OCascadeTest(_base.MappedTest):
             pref = relation(Pref, lazy=False, cascade="all, delete-orphan", single_parent=True  )
         ))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         u1 = User(name='ed', pref=Pref(data="pref 1", extra=[Extra()]))
         u2 = User(name='jack', pref=Pref(data="pref 2", extra=[Extra()]))
         u3 = User(name="foo", pref=Pref(data="pref 3", extra=[Extra()]))
@@ -447,7 +457,8 @@ class M2OCascadeTest(_base.MappedTest):
             [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")])
 
 class M2OCascadeDeleteTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1', metadata,
               Column('id', Integer, primary_key=True),
               Column('data', String(50)),
@@ -460,7 +471,8 @@ class M2OCascadeDeleteTest(_base.MappedTest):
               Column('id', Integer, primary_key=True),
               Column('data', String(50)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class T1(_fixtures.Base):
             pass
         class T2(_fixtures.Base):
@@ -468,8 +480,9 @@ class M2OCascadeDeleteTest(_base.MappedTest):
         class T3(_fixtures.Base):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(T1, t1, properties={'t2': relation(T2, cascade="all")})
         mapper(T2, t2, properties={'t3': relation(T3, cascade="all")})
         mapper(T3, t3)
@@ -565,7 +578,8 @@ class M2OCascadeDeleteTest(_base.MappedTest):
 
 class M2OCascadeDeleteOrphanTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1', metadata,
               Column('id', Integer, primary_key=True),
               Column('data', String(50)),
@@ -578,7 +592,8 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest):
               Column('id', Integer, primary_key=True),
               Column('data', String(50)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class T1(_fixtures.Base):
             pass
         class T2(_fixtures.Base):
@@ -586,8 +601,9 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest):
         class T3(_fixtures.Base):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(T1, t1, properties=dict(
             t2=relation(T2, cascade="all, delete-orphan", single_parent=True)))
         mapper(T2, t2, properties=dict(
@@ -655,7 +671,7 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest):
         
         y = T2(data='T2a')
         x = T1(data='T1a', t2=y)
-        self.assertRaises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y)
+        assert_raises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y)
 
     @testing.resolve_artifact_names
     def test_single_parent_backref(self):
@@ -666,7 +682,7 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest):
         x = T2(data='T2a', t3=y)
 
         # cant attach the T3 to another T2
-        self.assertRaises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y)
+        assert_raises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y)
         
         # set via backref tho is OK, unsets from previous parent
         # first
@@ -677,7 +693,8 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest):
         assert x.t3 is None
 
 class M2MCascadeTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('a', metadata,
             Column('id', Integer, primary_key=True),
             Column('data', String(30)),
@@ -703,7 +720,8 @@ class M2MCascadeTest(_base.MappedTest):
               
               )
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class A(_fixtures.Base):
             pass
         class B(_fixtures.Base):
@@ -784,7 +802,7 @@ class M2MCascadeTest(_base.MappedTest):
         b1 =B(data='b1')
         a1 = A(data='a1', bs=[b1])
         
-        self.assertRaises(sa_exc.InvalidRequestError,
+        assert_raises(sa_exc.InvalidRequestError,
                 A, data='a2', bs=[b1]
             )
 
@@ -804,7 +822,7 @@ class M2MCascadeTest(_base.MappedTest):
         b1 =B(data='b1')
         a1 = A(data='a1', bs=[b1])
         
-        self.assertRaises(
+        assert_raises(
             sa_exc.InvalidRequestError,
             A, data='a2', bs=[b1]
         )
@@ -817,7 +835,8 @@ class M2MCascadeTest(_base.MappedTest):
 class UnsavedOrphansTest(_base.MappedTest):
     """Pending entities that are orphans"""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('users', metadata,
             Column('user_id', Integer,
                    Sequence('user_id_seq', optional=True),
@@ -831,7 +850,8 @@ class UnsavedOrphansTest(_base.MappedTest):
             Column('user_id', Integer, ForeignKey('users.user_id')),
             Column('email_address', String(40)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class User(_fixtures.Base):
             pass
         class Address(_fixtures.Base):
@@ -900,7 +920,8 @@ class UnsavedOrphansTest(_base.MappedTest):
 class UnsavedOrphansTest2(_base.MappedTest):
     """same test as UnsavedOrphans only three levels deep"""
 
-    def define_tables(self, meta):
+    @classmethod
+    def define_tables(cls, meta):
         Table('orders', meta,
             Column('id', Integer, Sequence('order_id_seq'),
                    primary_key=True),
@@ -958,7 +979,8 @@ class UnsavedOrphansTest2(_base.MappedTest):
 class UnsavedOrphansTest3(_base.MappedTest):
     """test not expunging double parents"""
 
-    def define_tables(self, meta):
+    @classmethod
+    def define_tables(cls, meta):
         Table('sales_reps', meta,
             Column('sales_rep_id', Integer,
                    Sequence('sales_rep_id_seq'),
@@ -1062,7 +1084,8 @@ class UnsavedOrphansTest3(_base.MappedTest):
 class DoubleParentOrphanTest(_base.MappedTest):
     """test orphan detection for an entity with two parent relations"""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('addresses', metadata,
             Column('address_id', Integer, primary_key=True),
             Column('street', String(30)),
@@ -1133,7 +1156,8 @@ class DoubleParentOrphanTest(_base.MappedTest):
             assert True
 
 class CollectionAssignmentOrphanTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('table_a', metadata,
               Column('id', Integer, primary_key=True),
               Column('name', String(30)))
@@ -1181,7 +1205,8 @@ class PartialFlushTest(_base.MappedTest):
     """test cascade behavior as it relates to object lists passed to flush().
     
     """
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("base", metadata,
             Column("id", Integer, primary_key=True),
             Column("descr", String(50))
@@ -1288,5 +1313,3 @@ class PartialFlushTest(_base.MappedTest):
         assert c1 not in sess.new
         assert c2 in sess.new
         
-if __name__ == "__main__":
-    testenv.main()
similarity index 98%
rename from test/orm/collection.py
rename to test/orm/test_collection.py
index 23f643597ac2529dc4a1ba44ea7a01f78a6b546c..12ff25c460ca8efeb4d2a68ba7e3cc97ed3f9181 100644 (file)
@@ -1,17 +1,19 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 import sys
 from operator import and_
 
 import sqlalchemy.orm.collections as collections
 from sqlalchemy.orm.collections import collection
 
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa import util, exc as sa_exc
-from testlib.sa.orm import create_session, mapper, relation, \
-    attributes
-from orm import _base
-from testlib.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy import util, exc as sa_exc
+from sqlalchemy.orm import create_session, mapper, relation,     attributes
+from test.orm import _base
+from sqlalchemy.test.testing import eq_
 
 class Canary(sa.orm.interfaces.AttributeExtension):
     def __init__(self):
@@ -45,12 +47,14 @@ class CollectionsTest(_base.ORMTest):
         def __repr__(self):
             return str((id(self), self.a, self.b, self.c))
 
-    def setUpAll(self):
-        attributes.register_class(self.Entity)
+    @classmethod
+    def setup_class(cls):
+        attributes.register_class(cls.Entity)
 
-    def tearDownAll(self):
-        attributes.unregister_class(self.Entity)
-        _base.ORMTest.tearDownAll(self)
+    @classmethod
+    def teardown_class(cls):
+        attributes.unregister_class(cls.Entity)
+        super(CollectionsTest, cls).teardown_class()
 
     _entity_id = 1
 
@@ -937,7 +941,7 @@ class CollectionsTest(_base.ORMTest):
                 pass
             self.assert_(obj.attr is not real_dict)
             self.assert_('badkey' not in obj.attr)
-            self.assertEquals(set(collections.collection_adapter(obj.attr)),
+            eq_(set(collections.collection_adapter(obj.attr)),
                               set([e2]))
             self.assert_(e3 not in canary.added)
         else:
@@ -945,13 +949,13 @@ class CollectionsTest(_base.ORMTest):
             obj.attr = real_dict
             self.assert_(obj.attr is not real_dict)
             self.assert_('keyignored1' not in obj.attr)
-            self.assertEquals(set(collections.collection_adapter(obj.attr)),
+            eq_(set(collections.collection_adapter(obj.attr)),
                               set([e3]))
             self.assert_(e2 in canary.removed)
             self.assert_(e3 in canary.added)
 
         obj.attr = typecallable()
-        self.assertEquals(list(collections.collection_adapter(obj.attr)), [])
+        eq_(list(collections.collection_adapter(obj.attr)), [])
 
         e4 = creator()
         try:
@@ -1336,7 +1340,8 @@ class CollectionsTest(_base.ORMTest):
 
 class DictHelpersTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('parents', metadata,
               Column('id', Integer, primary_key=True),
               Column('label', String(128)))
@@ -1348,7 +1353,8 @@ class DictHelpersTest(_base.MappedTest):
               Column('b', String(128)),
               Column('c', String(128)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Parent(_base.BasicEntity):
             def __init__(self, label=None):
                 self.label = label
@@ -1378,7 +1384,7 @@ class DictHelpersTest(_base.MappedTest):
         p = session.query(Parent).get(pid)
 
         
-        self.assertEquals(set(p.children.keys()), set(['foo', 'bar']))
+        eq_(set(p.children.keys()), set(['foo', 'bar']))
         cid = p.children['foo'].id
 
         collections.collection_adapter(p.children).append_with_event(
@@ -1519,7 +1525,8 @@ class DictHelpersTest(_base.MappedTest):
 # remove if so
 class CustomCollectionsTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('sometable', metadata,
               Column('col1',Integer, primary_key=True),
               Column('data', String(30)))
@@ -1830,5 +1837,3 @@ class InstrumentationTest(_base.ORMTest):
         instrumented = collections._instrument_class(Touchy)
         assert True
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 97%
rename from test/orm/compile.py
rename to test/orm/test_compile.py
index 7c9bed4ecc7039e213b36da292c32ad196dcf7ef..7a5b636157238007b8da0db560c145c79ade1acf 100644 (file)
@@ -1,15 +1,14 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import *
-from testlib import *
-from orm import _base
+from sqlalchemy.test import *
+from test.orm import _base
 
 
 class CompileTest(_base.ORMTest):
     """test various mapper compilation scenarios"""
 
-    def tearDown(self):
+    def teardown(self):
         clear_mappers()
 
     def testone(self):
@@ -182,5 +181,3 @@ class CompileTest(_base.ORMTest):
             assert str(e).index("Error creating backref") > -1
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 94%
rename from test/orm/cycles.py
rename to test/orm/test_cycles.py
index 3e3636085247fcfabf22297d954f66581908414d..fe77b360187e894d7a692ec8706dd84a203f9009 100644 (file)
@@ -5,19 +5,21 @@ T1<->T2, with o2m or m2o between them, and a third T3 with o2m/m2o to one/both
 T1/T2.
 
 """
-import testenv; testenv.configure_for_tests()
-from testlib import testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, backref, create_session
-from testlib.testing import eq_
-from testlib.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf
-from orm import _base
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, backref, create_session
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf
+from test.orm import _base
 
 
 class SelfReferentialTest(_base.MappedTest):
     """A self-referential mapper with an additional list of child objects."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1', metadata,
               Column('c1', Integer, primary_key=True,
                      test_needs_autoincrement=True),
@@ -29,7 +31,8 @@ class SelfReferentialTest(_base.MappedTest):
               Column('c1id', Integer, ForeignKey('t1.c1')),
               Column('data', String(20)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class C1(_base.BasicEntity):
             def __init__(self, data=None):
                 self.data = data
@@ -132,20 +135,23 @@ class SelfReferentialTest(_base.MappedTest):
 class SelfReferentialNoPKTest(_base.MappedTest):
     """A self-referential relationship that joins on a column other than the primary key column"""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('item', metadata,
            Column('id', Integer, primary_key=True),
            Column('uuid', String(32), unique=True, nullable=False),
            Column('parent_uuid', String(32), ForeignKey('item.uuid'),
                   nullable=True))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class TT(_base.BasicEntity):
             def __init__(self):
                 self.uuid = hex(id(self))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(TT, item, properties={
             'children': relation(
                 TT,
@@ -181,7 +187,8 @@ class SelfReferentialNoPKTest(_base.MappedTest):
 
 
 class InheritTestOne(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("parent", metadata,
             Column("id", Integer, primary_key=True),
             Column("parent_data", String(50)),
@@ -199,7 +206,8 @@ class InheritTestOne(_base.MappedTest):
                    nullable=False),
             Column("child2_data", String(50)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Parent(_base.BasicEntity):
             pass
 
@@ -209,8 +217,9 @@ class InheritTestOne(_base.MappedTest):
         class Child2(Parent):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Parent, parent)
         mapper(Child1, child1, inherits=Parent)
         mapper(Child2, child2, inherits=Parent, properties=dict(
@@ -250,7 +259,8 @@ class InheritTestTwo(_base.MappedTest):
 
     """
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('a', metadata,
             Column('id', Integer, primary_key=True),
             Column('data', String(30)),
@@ -266,7 +276,8 @@ class InheritTestTwo(_base.MappedTest):
             Column('aid', Integer,
                    ForeignKey('a.id', use_alter=True, name="foo")))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class A(_base.BasicEntity):
             pass
 
@@ -297,7 +308,8 @@ class InheritTestTwo(_base.MappedTest):
 class BiDirectionalManyToOneTest(_base.MappedTest):
     run_define_tables = 'each'
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1', metadata,
             Column('id', Integer, primary_key=True),
             Column('data', String(30)),
@@ -313,7 +325,8 @@ class BiDirectionalManyToOneTest(_base.MappedTest):
             Column('t1id', Integer, ForeignKey('t1.id'), nullable=False),
             Column('t2id', Integer, ForeignKey('t2.id'), nullable=False))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class T1(_base.BasicEntity):
             pass
         class T2(_base.BasicEntity):
@@ -321,8 +334,9 @@ class BiDirectionalManyToOneTest(_base.MappedTest):
         class T3(_base.BasicEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(T1, t1, properties={
             't2':relation(T2, primaryjoin=t1.c.t2id == t2.c.id)})
         mapper(T2, t2, properties={
@@ -385,7 +399,8 @@ class BiDirectionalOneToManyTest(_base.MappedTest):
 
     run_define_tables = 'each'
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1', metadata,
               Column('c1', Integer, primary_key=True,
                      test_needs_autoincrement=True),
@@ -397,7 +412,8 @@ class BiDirectionalOneToManyTest(_base.MappedTest):
               Column('c2', Integer,
                      ForeignKey('t1.c1', use_alter=True, name='t1c1_fk')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class C1(_base.BasicEntity):
             pass
 
@@ -434,7 +450,8 @@ class BiDirectionalOneToManyTest2(_base.MappedTest):
 
     run_define_tables = 'each'
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1', metadata,
               Column('c1', Integer, primary_key=True),
               Column('c2', Integer, ForeignKey('t2.c1')),
@@ -452,7 +469,8 @@ class BiDirectionalOneToManyTest2(_base.MappedTest):
               Column('data', String(20)),
               test_needs_autoincrement=True)
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class C1(_base.BasicEntity):
             pass
 
@@ -462,8 +480,9 @@ class BiDirectionalOneToManyTest2(_base.MappedTest):
         class C1Data(_base.BasicEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(C2, t2, properties={
             'c1s': relation(C1,
                             primaryjoin=t2.c.c1 == t1.c.c2,
@@ -508,7 +527,8 @@ class OneToManyManyToOneTest(_base.MappedTest):
     """
     run_define_tables = 'each'
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('ball', metadata,
               Column('id', Integer, primary_key=True,
                      test_needs_autoincrement=True),
@@ -522,7 +542,8 @@ class OneToManyManyToOneTest(_base.MappedTest):
               Column('favorite_ball_id', Integer, ForeignKey('ball.id')),
               Column('data', String(30)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Person(_base.BasicEntity):
             pass
 
@@ -709,7 +730,8 @@ class OneToManyManyToOneTest(_base.MappedTest):
 class SelfReferentialPostUpdateTest(_base.MappedTest):
     """Post_update on a single self-referential mapper"""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('node', metadata,
               Column('id', Integer, primary_key=True,
                      test_needs_autoincrement=True),
@@ -721,7 +743,8 @@ class SelfReferentialPostUpdateTest(_base.MappedTest):
               Column('next_sibling_id', Integer,
                      ForeignKey('node.id'), nullable=True))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Node(_base.BasicEntity):
             def __init__(self, path=''):
                 self.path = path
@@ -815,13 +838,15 @@ class SelfReferentialPostUpdateTest(_base.MappedTest):
 
 class SelfReferentialPostUpdateTest2(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("a_table", metadata,
               Column("id", Integer(), primary_key=True),
               Column("fui", String(128)),
               Column("b", Integer(), ForeignKey("a_table.id")))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class A(_base.BasicEntity):
             pass
 
@@ -858,5 +883,3 @@ class SelfReferentialPostUpdateTest2(_base.MappedTest):
         assert f2.foo is f1
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 88%
rename from test/orm/defaults.py
rename to test/orm/test_defaults.py
index 8dc1925195d40c66eb098e07c6ad6add645133e3..b063780ac72b0059e38dec2fb22712446dcf8b3c 100644 (file)
@@ -1,16 +1,19 @@
-import testenv; testenv.configure_for_tests()
 
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from orm import _base
-from testlib.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from test.orm import _base
+from sqlalchemy.test.testing import eq_
 
 
 class TriggerDefaultsTest(_base.MappedTest):
     __requires__ = ('row_triggers',)
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         dt = Table('dt', metadata,
                    Column('id', Integer, primary_key=True),
                    Column('col1', String(20)),
@@ -63,12 +66,14 @@ class TriggerDefaultsTest(_base.MappedTest):
         sa.DDL("DROP TRIGGER dt_up").execute_at('before-drop', dt)
 
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Default(_base.BasicEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Default, dt)
 
     @testing.resolve_artifact_names
@@ -107,7 +112,8 @@ class TriggerDefaultsTest(_base.MappedTest):
         eq_(d1.col4, 'up')
 
 class ExcludedDefaultsTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         dt = Table('dt', metadata,
                    Column('id', Integer, primary_key=True),
                    Column('col1', String(20), default="hello"),
@@ -125,5 +131,3 @@ class ExcludedDefaultsTest(_base.MappedTest):
         sess.flush()
         eq_(dt.select().execute().fetchall(), [(1, "hello")])
     
-if __name__ == "__main__":
-    testenv.main()
similarity index 96%
rename from test/orm/deprecations.py
rename to test/orm/test_deprecations.py
index 483e8f556befd2b440cb58e20194ad6ab1ff9c8e..00d64119eac172fbffc962a6994ca658a308fe97 100644 (file)
@@ -5,11 +5,12 @@ modern (i.e. not deprecated) alternative to them.  The tests snippets here can
 be migrated directly to the wiki, docs, etc.
 
 """
-import testenv; testenv.configure_for_tests()
-from testlib import testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, func
-from testlib.sa.orm import mapper, relation, create_session, sessionmaker
-from orm import _base
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey, func
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, sessionmaker
+from test.orm import _base
 
 
 class QueryAlternativesTest(_base.MappedTest):
@@ -44,7 +45,8 @@ class QueryAlternativesTest(_base.MappedTest):
     run_inserts = 'once'
     run_deletes = None
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('users_table', metadata,
               Column('id', Integer, primary_key=True),
               Column('name', String(64)))
@@ -56,21 +58,24 @@ class QueryAlternativesTest(_base.MappedTest):
               Column('purpose', String(16)),
               Column('bounces', Integer, default=0))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class User(_base.BasicEntity):
             pass
 
         class Address(_base.BasicEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(User, users_table, properties=dict(
             addresses=relation(Address, backref='user'),
             ))
         mapper(Address, addresses_table)
 
-    def fixtures(self):
+    @classmethod
+    def fixtures(cls):
         return dict(
             users_table=(
             ('id', 'name'),
@@ -479,5 +484,3 @@ class QueryAlternativesTest(_base.MappedTest):
         assert len(users) == 1 and users[0].name == 'ed'
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 96%
rename from test/orm/dynamic.py
rename to test/orm/test_dynamic.py
index 3bd94b7c0ebc2dd7d54ca50b90065ab5a79b0a57..f2089a4351b2e9f6ddc4693ef3d1ba74409fd15d 100644 (file)
@@ -1,13 +1,15 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 import operator
 from sqlalchemy.orm import dynamic_loader, backref
-from testlib import testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, desc, select, func
-from testlib.sa.orm import mapper, relation, create_session, Query, attributes
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey, desc, select, func
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, Query, attributes
 from sqlalchemy.orm.dynamic import AppenderMixin
-from testlib.testing import eq_
-from testlib.compat import _function_named
-from orm import _base, _fixtures
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.util import function_named
+from test.orm import _base, _fixtures
 
 
 class DynamicTest(_fixtures.FixtureTest):
@@ -281,7 +283,7 @@ class SessionTest(_fixtures.FixtureTest):
         sess.flush()
 
         from sqlalchemy.orm import attributes
-        self.assertEquals(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], []))
+        eq_(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], []))
 
         sess.expunge_all()
 
@@ -452,7 +454,7 @@ class SessionTest(_fixtures.FixtureTest):
         sess.close()
 
 
-def create_backref_test(autoflush, saveuser):
+def _create_backref_test(autoflush, saveuser):
 
     @testing.resolve_artifact_names
     def test_backref(self):
@@ -487,17 +489,18 @@ def create_backref_test(autoflush, saveuser):
             sess.flush()
         self.assert_(list(u.addresses) == [])
 
-    test_backref = _function_named(
+    test_backref = function_named(
         test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""),
                                     (saveuser and "_saveuser" or "_savead")))
     setattr(SessionTest, test_backref.__name__, test_backref)
 
 for autoflush in (False, True):
     for saveuser in (False, True):
-        create_backref_test(autoflush, saveuser)
+        _create_backref_test(autoflush, saveuser)
 
 class DontDereferenceTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('users', metadata,
               Column('id', Integer, primary_key=True),
               Column('name', String(40)),
@@ -509,8 +512,9 @@ class DontDereferenceTest(_base.MappedTest):
               Column('email_address', String(100), nullable=False),
               Column('user_id', Integer, ForeignKey('users.id')))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         class User(_base.ComparableEntity):
             pass
 
@@ -555,5 +559,3 @@ class DontDereferenceTest(_base.MappedTest):
         eq_(query3(), [Address(email_address='joe@joesdomain.example')])
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 97%
rename from test/orm/eager_relations.py
rename to test/orm/test_eager_relations.py
index 87c2442cc42bbb778a9418a37f94e2c6fac7f959..384e0472f6c1d5862caa0dacaa371687304113d3 100644 (file)
@@ -1,13 +1,16 @@
 """basic tests of eager loaded attributes"""
 
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
+from sqlalchemy.test.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
 from sqlalchemy.orm import eagerload, deferred, undefer
-from testlib.sa import Table, Column, Integer, String, Date, ForeignKey, and_, select, func
-from testlib.sa.orm import mapper, relation, create_session, lazyload, aliased
-from testlib.testing import eq_
-from testlib.assertsql import CompiledSQL
-from orm import _base, _fixtures
+from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, func
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, lazyload, aliased
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.assertsql import CompiledSQL
+from test.orm import _base, _fixtures
 import datetime
 
 class EagerTest(_fixtures.FixtureTest):
@@ -23,7 +26,7 @@ class EagerTest(_fixtures.FixtureTest):
         q = sess.query(User)
 
         assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all()
-        self.assertEquals(self.static.user_address_result, q.order_by(User.id).all())
+        eq_(self.static.user_address_result, q.order_by(User.id).all())
 
     @testing.resolve_artifact_names
     def test_late_compile(self):
@@ -287,7 +290,7 @@ class EagerTest(_fixtures.FixtureTest):
         assert sa.orm.class_mapper(Address).get_property('user').lazy is False
 
         sess = create_session()
-        self.assertEquals(self.static.user_address_result, sess.query(User).order_by(User.id).all())
+        eq_(self.static.user_address_result, sess.query(User).order_by(User.id).all())
 
     @testing.resolve_artifact_names
     def test_double(self):
@@ -615,13 +618,13 @@ class EagerTest(_fixtures.FixtureTest):
 
         def go():
             o1 = sess.query(Order).options(lazyload('address')).filter(Order.id==5).one()
-            self.assertEquals(o1.address, None)
+            eq_(o1.address, None)
         self.assert_sql_count(testing.db, go, 2)
         
         sess.expunge_all()
         def go():
             o1 = sess.query(Order).filter(Order.id==5).one()
-            self.assertEquals(o1.address, None)
+            eq_(o1.address, None)
         self.assert_sql_count(testing.db, go, 1)
         
     @testing.resolve_artifact_names
@@ -817,7 +820,8 @@ class AddEntityTest(_fixtures.FixtureTest):
         self.assert_sql_count(testing.db, go, 1)
 
 class OrderBySecondaryTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('m2m', metadata,
               Column('id', Integer, primary_key=True),
               Column('aid', Integer, ForeignKey('a.id')),
@@ -830,7 +834,8 @@ class OrderBySecondaryTest(_base.MappedTest):
               Column('id', Integer, primary_key=True),
               Column('data', String(50)))
 
-    def fixtures(self):
+    @classmethod
+    def fixtures(cls):
         return dict(
             a=(('id', 'data'),
                (1, 'a1'),
@@ -865,7 +870,8 @@ class OrderBySecondaryTest(_base.MappedTest):
 
 
 class SelfReferentialEagerTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('nodes', metadata,
               Column('id', Integer, sa.Sequence('node_id_seq', optional=True),
                      primary_key=True),
@@ -980,7 +986,7 @@ class SelfReferentialEagerTest(_base.MappedTest):
         sess.expunge_all()
 
         def go():
-            self.assertEquals
+            eq_
                 Node(data='n1', children=[Node(data='n11'), Node(data='n12')]),
                 sess.query(Node).order_by(Node.id).first(),
                 )
@@ -1079,7 +1085,8 @@ class SelfReferentialEagerTest(_base.MappedTest):
         self.assert_sql_count(testing.db, go, 3)
 
 class MixedSelfReferentialEagerTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('a_table', metadata,
                        Column('id', Integer, primary_key=True)
                        )
@@ -1091,8 +1098,9 @@ class MixedSelfReferentialEagerTest(_base.MappedTest):
                        Column('parent_b2_id', Integer, ForeignKey('b_table.id')))
 
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         class A(_base.ComparableEntity):
             pass
         class B(_base.ComparableEntity):
@@ -1113,8 +1121,9 @@ class MixedSelfReferentialEagerTest(_base.MappedTest):
                             )
         });
     
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         a_table.insert().execute(dict(id=1), dict(id=2), dict(id=3))
         b_table.insert().execute(
             dict(id=1, parent_a_id=2, parent_b1_id=None, parent_b2_id=None),
@@ -1149,7 +1158,8 @@ class MixedSelfReferentialEagerTest(_base.MappedTest):
         self.assert_sql_count(testing.db, go, 1)
         
 class SelfReferentialM2MEagerTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('widget', metadata,
             Column('id', Integer, primary_key=True),
             Column('name', sa.Unicode(40), nullable=False, unique=True),
@@ -1189,8 +1199,9 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
     run_inserts = 'once'
     run_deletes = None
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(User, users, properties={
             'addresses':relation(Address, backref='user'),
             'orders':relation(Order, backref='user'), # o2m, m2o
@@ -1323,7 +1334,8 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
 
 class CyclicalInheritingEagerTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1', metadata,
             Column('c1', Integer, primary_key=True),
             Column('c2', String(30)),
@@ -1361,7 +1373,8 @@ class CyclicalInheritingEagerTest(_base.MappedTest):
         create_session().query(SubT).all()
 
 class SubqueryTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('users_table', metadata,
             Column('id', Integer, primary_key=True),
             Column('name', String(16))
@@ -1449,7 +1462,8 @@ class CorrelatedSubqueryTest(_base.MappedTest):
     
     """
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         users = Table('users', metadata,
             Column('id', Integer, primary_key=True),
             Column('name', String(50))
@@ -1460,8 +1474,9 @@ class CorrelatedSubqueryTest(_base.MappedTest):
             Column('date', Date),
             Column('user_id', Integer, ForeignKey('users.id')))
     
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         users.insert().execute(
             {'id':1, 'name':'user1'},
             {'id':2, 'name':'user2'},
@@ -1591,5 +1606,3 @@ class CorrelatedSubqueryTest(_base.MappedTest):
         self.assert_sql_count(testing.db, go, 1)
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 86%
rename from test/orm/evaluator.py
rename to test/orm/test_evaluator.py
index 3527c93d77011a15f3247e1ce206ab8727ddbd41..af6a3f89e3130e2d18d56131b0318274dc93fe92 100644 (file)
@@ -1,10 +1,12 @@
 """Evluating SQL expressions on ORM objects"""
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, String, Integer, select
-from testlib.sa.orm import mapper, create_session
-from testlib.testing import eq_
-from orm import _base
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import String, Integer, select
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, create_session
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
 
 from sqlalchemy import and_, or_, not_
 from sqlalchemy.orm import evaluator
@@ -20,17 +22,20 @@ def eval_eq(clause, testcases=None):
     return testeval
 
 class EvaluateTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('users', metadata,
               Column('id', Integer, primary_key=True),
               Column('name', String(64)))
     
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class User(_base.ComparableEntity):
             pass
     
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(User, users)
     
     @testing.resolve_artifact_names
@@ -90,5 +95,3 @@ class EvaluateTest(_base.MappedTest):
             (User(id=None, name=None), None),
         ])
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 95%
rename from test/orm/expire.py
rename to test/orm/test_expire.py
index c11fb69dfec3da1d5e595513b88859e956415bcb..65934989788a013879909c3fd18fe397c5d95501 100644 (file)
@@ -1,11 +1,14 @@
 """Attribute/instance expiration, deferral of attributes, etc."""
 
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import gc
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, exc as sa_exc
-from testlib.sa.orm import mapper, relation, create_session, attributes, deferred
-from orm import _base, _fixtures
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey, exc as sa_exc
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, attributes, deferred
+from test.orm import _base, _fixtures
 
 
 class ExpireTest(_fixtures.FixtureTest):
@@ -56,7 +59,7 @@ class ExpireTest(_fixtures.FixtureTest):
         u = s.query(User).get(7)
         s.expunge_all()
 
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u)
+        assert_raises_message(sa.exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u)
 
     @testing.resolve_artifact_names
     def test_get_refreshes(self):
@@ -69,7 +72,7 @@ class ExpireTest(_fixtures.FixtureTest):
             u = s.query(User).get(10)  # get() refreshes
         self.assert_sql_count(testing.db, go, 1)
         def go():
-            self.assertEquals(u.name, 'chuck')  # attributes unexpired
+            eq_(u.name, 'chuck')  # attributes unexpired
         self.assert_sql_count(testing.db, go, 0)
         def go():
             u = s.query(User).get(10)  # expire flag reset, so not expired
@@ -86,7 +89,7 @@ class ExpireTest(_fixtures.FixtureTest):
         # add it back
         s.add(u)
         # nope, raises ObjectDeletedError
-        self.assertRaises(sa.orm.exc.ObjectDeletedError, getattr, u, 'name')
+        assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u, 'name')
 
         # do a get()/remove u from session again
         assert s.query(User).get(10) is None
@@ -97,7 +100,7 @@ class ExpireTest(_fixtures.FixtureTest):
         assert u in s
         # but now its back, rollback has occured, the _remove_newly_deleted
         # is reverted
-        self.assertEquals(u.name, 'chuck')
+        eq_(u.name, 'chuck')
 
     @testing.resolve_artifact_names
     def test_deferred(self):
@@ -122,7 +125,7 @@ class ExpireTest(_fixtures.FixtureTest):
         s = create_session(autoflush=True, autocommit=False)
         u = s.query(User).get(8)
         adlist = u.addresses
-        self.assertEquals(adlist, [
+        eq_(adlist, [
             Address(email_address='ed@bettyboop.com'), 
             Address(email_address='ed@lala.com'),
             Address(email_address='ed@wood.com'), 
@@ -130,7 +133,7 @@ class ExpireTest(_fixtures.FixtureTest):
         a1 = u.addresses[2]
         a1.email_address = 'aaaaa'
         s.expire(u, ['addresses'])
-        self.assertEquals(u.addresses, [
+        eq_(u.addresses, [
             Address(email_address='aaaaa'), 
             Address(email_address='ed@bettyboop.com'), 
             Address(email_address='ed@lala.com'),
@@ -146,10 +149,10 @@ class ExpireTest(_fixtures.FixtureTest):
         mapper(Address, addresses)
         s = create_session(autoflush=True, autocommit=False)
         u = s.query(User).get(8)
-        self.assertRaisesMessage(sa_exc.InvalidRequestError, "properties specified for refresh", s.refresh, u, ['addresses'])
+        assert_raises_message(sa_exc.InvalidRequestError, "properties specified for refresh", s.refresh, u, ['addresses'])
         
         # in contrast to a regular query with no columns
-        self.assertRaisesMessage(sa_exc.InvalidRequestError, "no columns with which to SELECT", s.query().all)
+        assert_raises_message(sa_exc.InvalidRequestError, "no columns with which to SELECT", s.query().all)
         
     @testing.resolve_artifact_names
     def test_refresh_cancels_expire(self):
@@ -161,7 +164,7 @@ class ExpireTest(_fixtures.FixtureTest):
 
         def go():
             u = s.query(User).get(7)
-            self.assertEquals(u.name, 'jack')
+            eq_(u.name, 'jack')
         self.assert_sql_count(testing.db, go, 0)
 
     @testing.resolve_artifact_names
@@ -187,7 +190,7 @@ class ExpireTest(_fixtures.FixtureTest):
 
         sess.expire(u, attribute_names=['name'])
         sess.expunge(u)
-        self.assertRaises(sa.exc.UnboundExecutionError, getattr, u, 'name')
+        assert_raises(sa.exc.UnboundExecutionError, getattr, u, 'name')
 
     @testing.resolve_artifact_names
     def test_pending_raises(self):
@@ -197,7 +200,7 @@ class ExpireTest(_fixtures.FixtureTest):
         sess = create_session()
         u = User(id=15)
         sess.add(u)
-        self.assertRaises(sa.exc.InvalidRequestError, sess.expire, u, ['name'])
+        assert_raises(sa.exc.InvalidRequestError, sess.expire, u, ['name'])
 
     @testing.resolve_artifact_names
     def test_no_instance_key(self):
@@ -668,14 +671,15 @@ class ExpireTest(_fixtures.FixtureTest):
 
         userlist = sess.query(User).order_by(User.id).all()
         u = userlist[1]
-        self.assertEquals(self.static.user_address_result, userlist)
+        eq_(self.static.user_address_result, userlist)
         assert len(list(sess)) == 9
 
 class PolymorphicExpireTest(_base.MappedTest):
     run_inserts = 'once'
     run_deletes = None
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global people, engineers, Person, Engineer
 
         people = Table('people', metadata,
@@ -690,14 +694,16 @@ class PolymorphicExpireTest(_base.MappedTest):
            Column('status', String(30)),
           )
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Person(_base.ComparableEntity):
             pass
         class Engineer(Person):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         people.insert().execute(
             {'person_id':1, 'name':'person1', 'type':'person'},
             {'person_id':2, 'name':'engineer1', 'type':'engineer'},
@@ -745,7 +751,7 @@ class PolymorphicExpireTest(_base.MappedTest):
             assert e1.status == 'new engineer'
             assert e2.status == 'old engineer'
         self.assert_sql_count(testing.db, go, 2)
-        self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1']))
+        eq_(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1']))
 
 class ExpiredPendingTest(_fixtures.FixtureTest):
     run_define_tables = 'once'
@@ -837,7 +843,7 @@ class RefreshTest(_fixtures.FixtureTest):
         s = create_session()
         u = s.query(User).get(7)
         s.expunge_all()
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u))
+        assert_raises_message(sa.exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u))
 
     @testing.resolve_artifact_names
     def test_refresh_expired(self):
@@ -908,5 +914,3 @@ class RefreshTest(_fixtures.FixtureTest):
 
         s.refresh(u)
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 87%
rename from test/orm/extendedattr.py
rename to test/orm/test_extendedattr.py
index aec6c181f26c6af1c72828b566f6268b33478892..e0c64bf64a89ee1905adf68ce8cafa085a01ab49 100644 (file)
@@ -1,4 +1,4 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import pickle
 from sqlalchemy import util
 import sqlalchemy.orm.attributes as attributes
@@ -6,8 +6,8 @@ from sqlalchemy.orm.collections import collection
 from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented
 from sqlalchemy.orm import clear_mappers
 from sqlalchemy.orm import InstrumentationManager
-from testlib import *
-from orm import _base
+from sqlalchemy.test import *
+from test.orm import _base
 
 class MyTypesManager(InstrumentationManager):
 
@@ -100,7 +100,8 @@ class MyClass(object):
             del self._goofy_dict[key]
 
 class UserDefinedExtensionTest(_base.ORMTest):
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         clear_mappers()
         attributes._install_lookup_strategy(util.symbol('native'))
 
@@ -161,30 +162,30 @@ class UserDefinedExtensionTest(_base.ORMTest):
             assert Foo in attributes.instrumentation_registry._state_finders
             f = Foo()
             attributes.instance_state(f).expire_attributes(None)
-            self.assertEquals(f.a, "this is a")
-            self.assertEquals(f.b, 12)
+            eq_(f.a, "this is a")
+            eq_(f.b, 12)
 
             f.a = "this is some new a"
             attributes.instance_state(f).expire_attributes(None)
-            self.assertEquals(f.a, "this is a")
-            self.assertEquals(f.b, 12)
+            eq_(f.a, "this is a")
+            eq_(f.b, 12)
 
             attributes.instance_state(f).expire_attributes(None)
             f.a = "this is another new a"
-            self.assertEquals(f.a, "this is another new a")
-            self.assertEquals(f.b, 12)
+            eq_(f.a, "this is another new a")
+            eq_(f.b, 12)
 
             attributes.instance_state(f).expire_attributes(None)
-            self.assertEquals(f.a, "this is a")
-            self.assertEquals(f.b, 12)
+            eq_(f.a, "this is a")
+            eq_(f.b, 12)
 
             del f.a
-            self.assertEquals(f.a, None)
-            self.assertEquals(f.b, 12)
+            eq_(f.a, None)
+            eq_(f.b, 12)
 
             attributes.instance_state(f).commit_all(attributes.instance_dict(f))
-            self.assertEquals(f.a, None)
-            self.assertEquals(f.b, 12)
+            eq_(f.a, None)
+            eq_(f.b, 12)
 
     def test_inheritance(self):
         """tests that attributes are polymorphic"""
@@ -265,27 +266,27 @@ class UserDefinedExtensionTest(_base.ORMTest):
             f1 = Foo()
             f1.name = 'f1'
 
-            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], (), ()))
+            eq_(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], (), ()))
 
             b1 = Bar()
             b1.name = 'b1'
             f1.bars.append(b1)
-            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
+            eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
 
             attributes.instance_state(f1).commit_all(attributes.instance_dict(f1))
             attributes.instance_state(b1).commit_all(attributes.instance_dict(b1))
 
-            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ()))
-            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ()))
+            eq_(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ()))
+            eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ()))
 
             f1.name = 'f1mod'
             b2 = Bar()
             b2.name = 'b2'
             f1.bars.append(b2)
-            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1']))
-            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], []))
+            eq_(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1']))
+            eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], []))
             f1.bars.remove(b1)
-            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1]))
+            eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1]))
 
     def test_null_instrumentation(self):
         class Foo(MyBaseClass):
@@ -311,9 +312,9 @@ class UserDefinedExtensionTest(_base.ORMTest):
         assert attributes.manager_of_class(None) is None
 
         assert attributes.instance_state(k) is not None
-        self.assertRaises((AttributeError, KeyError),
+        assert_raises((AttributeError, KeyError),
                           attributes.instance_state, u)
-        self.assertRaises((AttributeError, KeyError),
+        assert_raises((AttributeError, KeyError),
                           attributes.instance_state, None)
 
 
similarity index 91%
rename from test/orm/generative.py
rename to test/orm/test_generative.py
index 99523674141c5e061308b01bb9d996f4aea1f82c..0efc1814ed6e0adb632f205358d1be31387db2ae 100644 (file)
@@ -1,28 +1,34 @@
-import testenv; testenv.configure_for_tests()
-from testlib import testing, sa
-from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, func
+from sqlalchemy.test.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey, MetaData, func
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
 from sqlalchemy.orm import mapper, relation, create_session
-from testlib.testing import eq_
-from orm import _base, _fixtures
+from sqlalchemy.test.testing import eq_
+from test.orm import _base, _fixtures
 
 
 class GenerativeQueryTest(_base.MappedTest):
     run_inserts = 'once'
     run_deletes = None
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('foo', metadata,
               Column('id', Integer, sa.Sequence('foo_id_seq'), primary_key=True),
               Column('bar', Integer),
               Column('range', Integer))
 
-    def fixtures(self):
+    @classmethod
+    def fixtures(cls):
         rows = tuple([(i, i % 10) for i in range(100)])
         foo_data = (('bar', 'range'),) + rows
         return dict(foo=foo_data)
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         class Foo(_base.BasicEntity):
             pass
 
@@ -131,7 +137,8 @@ class GenerativeQueryTest(_base.MappedTest):
 
 class GenerativeTest2(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('Table1', metadata,
               Column('id', Integer, primary_key=True))
         Table('Table2', metadata,
@@ -139,8 +146,9 @@ class GenerativeTest2(_base.MappedTest):
                      primary_key=True),
               Column('num', Integer, primary_key=True))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         class Obj1(_base.BasicEntity):
             pass
         class Obj2(_base.BasicEntity):
@@ -149,7 +157,8 @@ class GenerativeTest2(_base.MappedTest):
         mapper(Obj1, Table1)
         mapper(Obj2, Table2)
 
-    def fixtures(self):
+    @classmethod
+    def fixtures(cls):
         return dict(
             Table1=(('id',),
                     (1,),
@@ -182,8 +191,9 @@ class RelationsTest(_fixtures.FixtureTest):
     run_inserts = 'once'
     run_deletes = None
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(User, users, properties={
             'orders':relation(mapper(Order, orders, properties={
                 'addresses':relation(mapper(Address, addresses))}))})
@@ -232,7 +242,8 @@ class RelationsTest(_fixtures.FixtureTest):
 
 class CaseSensitiveTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('Table1', metadata,
             Column('ID', Integer, primary_key=True))
         Table('Table2', metadata,
@@ -240,8 +251,9 @@ class CaseSensitiveTest(_base.MappedTest):
                      primary_key=True),
               Column('NUM', Integer, primary_key=True))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         class Obj1(_base.BasicEntity):
             pass
         class Obj2(_base.BasicEntity):
@@ -250,7 +262,8 @@ class CaseSensitiveTest(_base.MappedTest):
         mapper(Obj1, Table1)
         mapper(Obj2, Table2)
 
-    def fixtures(self):
+    @classmethod
+    def fixtures(cls):
         return dict(
             Table1=(('ID',),
                     (1,),
@@ -272,8 +285,6 @@ class CaseSensitiveTest(_base.MappedTest):
         res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1))
         assert res.count() == 3
         res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)).distinct()
-        self.assertEqual(res.count(), 1)
+        eq_(res.count(), 1)
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 94%
rename from test/orm/instrumentation.py
rename to test/orm/test_instrumentation.py
index fd15420d0ad8b20dd929dde1072bcba7aba9de1b..b4c8f8601c0d1475ecb2c157a1c8d4b8e718487d 100644 (file)
@@ -1,11 +1,13 @@
-import testenv; testenv.configure_for_tests()
 
-from testlib import sa
-from testlib.sa import MetaData, Table, Column, Integer, ForeignKey, util
-from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers
-from testlib.testing import eq_, ne_
-from testlib.compat import _function_named
-from orm import _base
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy import MetaData, Integer, ForeignKey, util
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers
+from sqlalchemy.test.testing import eq_, ne_
+from sqlalchemy.util import function_named
+from test.orm import _base
 
 
 def modifies_instrumentation_finders(fn):
@@ -16,7 +18,7 @@ def modifies_instrumentation_finders(fn):
         finally:
             del attributes.instrumentation_finders[:]
             attributes.instrumentation_finders.extend(pristine)
-    return _function_named(decorated, fn.func_name)
+    return function_named(decorated, fn.func_name)
 
 def with_lookup_strategy(strategy):
     def decorate(fn):
@@ -26,7 +28,7 @@ def with_lookup_strategy(strategy):
                 return fn(*args, **kw)
             finally:
                 attributes._install_lookup_strategy(sa.util.symbol('native'))
-        return _function_named(wrapped, fn.func_name)
+        return function_named(wrapped, fn.func_name)
     return decorate
 
 
@@ -459,10 +461,10 @@ class MapperInitTest(_base.ORMTest):
         m = mapper(A, self.fixture())
 
         # B is not mapped in the current implementation
-        self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, B)
+        assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, B)
 
         # C is not mapped in the current implementation
-        self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, C)
+        assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, C)
 
 class InstrumentationCollisionTest(_base.ORMTest):
     def test_none(self):
@@ -486,7 +488,7 @@ class InstrumentationCollisionTest(_base.ORMTest):
         class B(A):
             __sa_instrumentation_manager__ = staticmethod(mgr_factory)
 
-        self.assertRaises(TypeError, attributes.register_class, B)
+        assert_raises(TypeError, attributes.register_class, B)
 
     def test_single_up(self):
 
@@ -497,7 +499,7 @@ class InstrumentationCollisionTest(_base.ORMTest):
         class B(A):
             __sa_instrumentation_manager__ = staticmethod(mgr_factory)
         attributes.register_class(B)
-        self.assertRaises(TypeError, attributes.register_class, A)
+        assert_raises(TypeError, attributes.register_class, A)
 
     def test_diamond_b1(self):
         mgr_factory = lambda cls: attributes.ClassManager(cls)
@@ -508,7 +510,7 @@ class InstrumentationCollisionTest(_base.ORMTest):
             __sa_instrumentation_manager__ = mgr_factory
         class C(object): pass
 
-        self.assertRaises(TypeError, attributes.register_class, B1)
+        assert_raises(TypeError, attributes.register_class, B1)
 
     def test_diamond_b2(self):
         mgr_factory = lambda cls: attributes.ClassManager(cls)
@@ -519,7 +521,7 @@ class InstrumentationCollisionTest(_base.ORMTest):
             __sa_instrumentation_manager__ = mgr_factory
         class C(object): pass
 
-        self.assertRaises(TypeError, attributes.register_class, B2)
+        assert_raises(TypeError, attributes.register_class, B2)
 
     def test_diamond_c_b(self):
         mgr_factory = lambda cls: attributes.ClassManager(cls)
@@ -531,7 +533,7 @@ class InstrumentationCollisionTest(_base.ORMTest):
         class C(object): pass
 
         attributes.register_class(C)
-        self.assertRaises(TypeError, attributes.register_class, B1)
+        assert_raises(TypeError, attributes.register_class, B1)
 
 
 class OnLoadTest(_base.ORMTest):
@@ -557,7 +559,8 @@ class OnLoadTest(_base.ORMTest):
         finally:
             del A
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         clear_mappers()
         attributes._install_lookup_strategy(util.symbol('native'))
 
@@ -593,7 +596,7 @@ class NativeInstrumentationTest(_base.ORMTest):
         sa = attributes.ClassManager.STATE_ATTR
         ma = attributes.ClassManager.MANAGER_ATTR
 
-        fails = lambda method, attr: self.assertRaises(
+        fails = lambda method, attr: assert_raises(
             KeyError, getattr(manager, method), attr, property())
 
         fails('install_member', sa)
@@ -609,7 +612,7 @@ class NativeInstrumentationTest(_base.ORMTest):
 
         class T(object): pass
 
-        self.assertRaises(KeyError, mapper, T, t)
+        assert_raises(KeyError, mapper, T, t)
 
     @with_lookup_strategy(sa.util.symbol('native'))
     def test_mapped_managerattr(self):
@@ -618,7 +621,7 @@ class NativeInstrumentationTest(_base.ORMTest):
                   Column(attributes.ClassManager.MANAGER_ATTR, Integer))
 
         class T(object): pass
-        self.assertRaises(KeyError, mapper, T, t)
+        assert_raises(KeyError, mapper, T, t)
 
 
 class MiscTest(_base.ORMTest):
@@ -761,5 +764,3 @@ class FinderTest(_base.ORMTest):
         eq_(type(attributes.manager_of_class(A)), attributes.ClassManager)
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 96%
rename from test/orm/lazy_relations.py
rename to test/orm/test_lazy_relations.py
index b5c3b3669eac7ac0060bc5e94fa25aa6616ed213..819f29911ebf154dc09b983a6fdae7d0807686cb 100644 (file)
@@ -1,14 +1,17 @@
 """basic tests of lazy loaded attributes"""
 
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
 import datetime
 from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import attributes
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from testlib.testing import eq_
-from orm import _base, _fixtures
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from sqlalchemy.test.testing import eq_
+from test.orm import _base, _fixtures
 
 
 class LazyTest(_fixtures.FixtureTest):
@@ -35,7 +38,7 @@ class LazyTest(_fixtures.FixtureTest):
         q = sess.query(User)
         u = q.filter(users.c.id == 7).first()
         sess.expunge(u)
-        self.assertRaises(sa_exc.InvalidRequestError, getattr, u, 'addresses')
+        assert_raises(sa_exc.InvalidRequestError, getattr, u, 'addresses')
 
     @testing.resolve_artifact_names
     def test_orderby(self):
@@ -363,6 +366,7 @@ class M2OGetTest(_fixtures.FixtureTest):
 
 class CorrelatedTest(_base.MappedTest):
 
+    @classmethod
     def define_tables(self, meta):
         Table('user_t', meta,
               Column('id', Integer, primary_key=True),
@@ -373,8 +377,9 @@ class CorrelatedTest(_base.MappedTest):
               Column('date', sa.Date),
               Column('user_id', Integer, ForeignKey('user_t.id')))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         user_t.insert().execute(
             {'id':1, 'name':'user1'},
             {'id':2, 'name':'user2'},
@@ -412,5 +417,3 @@ class CorrelatedTest(_base.MappedTest):
         ])
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 88%
rename from test/orm/lazytest1.py
rename to test/orm/test_lazytest1.py
index 5ebb8feebaa35023c7660fcd9591bd5b41376de6..f76cb32035e5ab9ca9661f5ce2d33621734e3eaf 100644 (file)
@@ -1,12 +1,15 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from orm import _base
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from test.orm import _base
 
 
 class LazyTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('infos', metadata,
               Column('pk', Integer, primary_key=True),
               Column('info', String(128)))
@@ -25,8 +28,9 @@ class LazyTest(_base.MappedTest):
               Column('start', Integer),
               Column('finish', Integer))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         infos.insert().execute(
             {'pk':1, 'info':'pk_1_info'},
             {'pk':2, 'info':'pk_2_info'},
@@ -86,5 +90,3 @@ class LazyTest(_base.MappedTest):
         assert len(info.rels[0].datas) == 3
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 93%
rename from test/orm/manytomany.py
rename to test/orm/test_manytomany.py
index 23af3bd1f8bab9f0bdf2827f5491ada81696137a..dcd547f80ccdbd498c8cc4d2b3de8a068ba7f535 100644 (file)
@@ -1,12 +1,16 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from orm import _base
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from test.orm import _base
 
 
 class M2MTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('place', metadata,
             Column('place_id', Integer, sa.Sequence('pid_seq', optional=True),
                    primary_key=True),
@@ -40,7 +44,8 @@ class M2MTest(_base.MappedTest):
               Column('pl1_id', Integer, ForeignKey('place.place_id')),
               Column('pl2_id', Integer, ForeignKey('place.place_id')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Place(_base.BasicEntity):
             def __init__(self, name=None):
                 self.name = name
@@ -70,7 +75,7 @@ class M2MTest(_base.MappedTest):
         mapper(Transition, transition, properties={
             'places':relation(Place, secondary=place_input, backref='transitions')
         })
-        self.assertRaisesMessage(sa.exc.ArgumentError, "Error creating backref",
+        assert_raises_message(sa.exc.ArgumentError, "Error creating backref",
                                  sa.orm.compile_mappers)
 
     @testing.resolve_artifact_names
@@ -187,7 +192,8 @@ class M2MTest(_base.MappedTest):
 
 
 class M2MTest2(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('student', metadata,
               Column('name', String(20), primary_key=True))
 
@@ -200,7 +206,8 @@ class M2MTest2(_base.MappedTest):
             Column('course_id', String(20), ForeignKey('course.name'),
                    primary_key=True))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Student(_base.BasicEntity):
             def __init__(self, name=''):
                 self.name = name
@@ -248,7 +255,7 @@ class M2MTest2(_base.MappedTest):
         s1.courses.append(c1)
         s1.courses.append(c1)
         sess.add(s1)
-        self.assertRaises(sa.exc.DBAPIError, sess.flush)
+        assert_raises(sa.exc.DBAPIError, sess.flush)
         
     @testing.resolve_artifact_names
     def test_delete(self):
@@ -274,7 +281,8 @@ class M2MTest2(_base.MappedTest):
         assert enroll.count().scalar() == 0
 
 class M2MTest3(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('c', metadata,
             Column('c1', Integer, primary_key = True),
             Column('c2', String(20)))
@@ -320,5 +328,3 @@ class M2MTest3(_base.MappedTest):
         # how about some data/inserts/queries/assertions for this one
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 97%
rename from test/orm/mapper.py
rename to test/orm/test_mapper.py
index 13e02a38a033eb82742e5751a4c04693a044e7e9..025b96424df8758b82ac110dd6077b86c95ee19b 100644 (file)
@@ -1,14 +1,16 @@
 """General mapper operations with an emphasis on selecting/loading."""
 
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, func
-from testlib.sa.engine import default
-from testlib.sa.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased
-from testlib.sa.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property
-from testlib.testing import eq_, AssertsCompiledSQL
-import pickleable
-from orm import _base, _fixtures
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing, pickleable
+from sqlalchemy import MetaData, Integer, String, ForeignKey, func
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.engine import default
+from sqlalchemy.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased
+from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property
+from sqlalchemy.test.testing import eq_, AssertsCompiledSQL
+from test.orm import _base, _fixtures
 
 
 class MapperTest(_fixtures.FixtureTest):
@@ -22,7 +24,7 @@ class MapperTest(_fixtures.FixtureTest):
             properties={
             'addresses':relation(Address, backref='email_address')
         })
-        self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers)
+        assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers)
 
     @testing.resolve_artifact_names
     def test_update_attr_keys(self):
@@ -74,14 +76,14 @@ class MapperTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_prop_accessor(self):
         mapper(User, users)
-        self.assertRaises(NotImplementedError,
+        assert_raises(NotImplementedError,
                           getattr, sa.orm.class_mapper(User), 'properties')
 
 
     @testing.resolve_artifact_names
     def test_bad_cascade(self):
         mapper(Address, addresses)
-        self.assertRaises(sa.exc.ArgumentError,
+        assert_raises(sa.exc.ArgumentError,
                           relation, Address, cascade="fake, all, delete-orphan")
 
     @testing.resolve_artifact_names
@@ -93,7 +95,7 @@ class MapperTest(_fixtures.FixtureTest):
         })
         
         hasattr(Address.user, 'property')
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers)
+        assert_raises_message(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers)
     
     @testing.resolve_artifact_names
     def test_column_prefix(self):
@@ -111,7 +113,7 @@ class MapperTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_no_pks_1(self):
         s = sa.select([users.c.name]).alias('foo')
-        self.assertRaises(sa.exc.ArgumentError, mapper, User, s)
+        assert_raises(sa.exc.ArgumentError, mapper, User, s)
 
     @testing.emits_warning(
         'mapper Mapper|User|Select object creating an alias for '
@@ -119,7 +121,7 @@ class MapperTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_no_pks_2(self):
         s = sa.select([users.c.name])
-        self.assertRaises(sa.exc.ArgumentError, mapper, User, s)
+        assert_raises(sa.exc.ArgumentError, mapper, User, s)
 
     @testing.resolve_artifact_names
     def test_recompile_on_other_mapper(self):
@@ -167,9 +169,9 @@ class MapperTest(_fixtures.FixtureTest):
             create_session).extension)
 
         sess = create_session()
-        self.assertRaises(TypeError, Foo, 'one', _sa_session=sess)
+        assert_raises(TypeError, Foo, 'one', _sa_session=sess)
         eq_(len(list(sess)), 0)
-        self.assertRaises(TypeError, Foo, 'one')
+        assert_raises(TypeError, Foo, 'one')
         Foo('one', 'two', _sa_session=sess)
         eq_(len(list(sess)), 1)
 
@@ -197,7 +199,7 @@ class MapperTest(_fixtures.FixtureTest):
             raise Exception("this exception should be stated as a warning")
 
         sess.expunge = bad_expunge
-        self.assertRaises(sa.exc.SAWarning, Foo, _sa_session=sess)
+        assert_raises(sa.exc.SAWarning, Foo, _sa_session=sess)
 
     @testing.resolve_artifact_names
     def test_constructor_exc_2(self):
@@ -211,8 +213,8 @@ class MapperTest(_fixtures.FixtureTest):
 
         mapper(Foo, users)
         mapper(Bar, addresses)
-        self.assertRaises(TypeError, Foo, x=5)
-        self.assertRaises(TypeError, Bar, x=5)
+        assert_raises(TypeError, Foo, x=5)
+        assert_raises(TypeError, Bar, x=5)
 
     @testing.resolve_artifact_names
     def test_props(self):
@@ -499,7 +501,7 @@ class MapperTest(_fixtures.FixtureTest):
         # excluding the discriminator column is currently not allowed
         class Foo(Person):
             pass
-        self.assertRaises(sa.exc.InvalidRequestError, mapper, Foo, inherits=Person, polymorphic_identity='foo', exclude_properties=('type',) )
+        assert_raises(sa.exc.InvalidRequestError, mapper, Foo, inherits=Person, polymorphic_identity='foo', exclude_properties=('type',) )
     
     @testing.resolve_artifact_names
     def test_mapping_to_join(self):
@@ -643,7 +645,7 @@ class MapperTest(_fixtures.FixtureTest):
                    properties=dict(
                        name=relation(mapper(Address, addresses))))
 
-        self.assertRaises(sa.exc.ArgumentError, go)
+        assert_raises(sa.exc.ArgumentError, go)
 
     @testing.resolve_artifact_names
     def test_override_2(self):
@@ -739,7 +741,7 @@ class MapperTest(_fixtures.FixtureTest):
             mapper(User, users, properties={
                 'not_name':synonym('_name', map_column=True)})
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             ("Can't compile synonym '_name': no column on table "
              "'users' named 'not_name'"),
@@ -834,7 +836,7 @@ class MapperTest(_fixtures.FixtureTest):
             eq_(User.uc_name.method1(), "method1")
             eq_(User.uc_name.method2('x'), "method2")
 
-            self.assertRaisesMessage(
+            assert_raises_message(
                 AttributeError, 
                 "Neither 'extendedproperty' object nor 'UCComparator' object has an attribute 'nonexistent'", 
                 getattr, User.uc_name, 'nonexistent')
@@ -879,7 +881,7 @@ class MapperTest(_fixtures.FixtureTest):
             'name':sa.orm.column_property(users.c.name, comparator_factory=MyComparator)
         })
         
-        self.assertRaisesMessage(
+        assert_raises_message(
             AttributeError, 
             "Neither 'InstrumentedAttribute' object nor 'MyComparator' object has an attribute 'nonexistent'", 
             getattr, User.name, "nonexistent")
@@ -966,7 +968,7 @@ class MapperTest(_fixtures.FixtureTest):
             'addresses':relation(Address)
         })
 
-        self.assertRaises(sa.orm.exc.UnmappedClassError, sa.orm.compile_mappers)
+        assert_raises(sa.orm.exc.UnmappedClassError, sa.orm.compile_mappers)
 
     @testing.resolve_artifact_names
     def test_oldstyle_mixin(self):
@@ -1148,8 +1150,9 @@ class OptionsTest(_fixtures.FixtureTest):
 
 
 class DeepOptionsTest(_fixtures.FixtureTest):
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Keyword, keywords)
 
         mapper(Item, items, properties=dict(
@@ -1204,7 +1207,7 @@ class DeepOptionsTest(_fixtures.FixtureTest):
     def test_deep_options_4(self):
         sess = create_session()
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             r"Can't find entity Mapper\|Order\|orders in Query.  "
             r"Current list: \['Mapper\|User\|users'\]",
@@ -1232,7 +1235,7 @@ class ValidatorTest(_fixtures.FixtureTest):
         sess = create_session()
         u1 = User(name='ed')
         eq_(u1.name, 'ed modified')
-        self.assertRaises(AssertionError, setattr, u1, "name", "fred")
+        assert_raises(AssertionError, setattr, u1, "name", "fred")
         eq_(u1.name, 'ed modified')
         sess.add(u1)
         sess.flush()
@@ -1252,7 +1255,7 @@ class ValidatorTest(_fixtures.FixtureTest):
         mapper(Address, addresses)
         sess = create_session()
         u1 = User(name='edward')
-        self.assertRaises(AssertionError, u1.addresses.append, Address(email_address='noemail'))
+        assert_raises(AssertionError, u1.addresses.append, Address(email_address='noemail'))
         u1.addresses.append(Address(id=15, email_address='foo@bar.com'))
         sess.add(u1)
         sess.flush()
@@ -1629,7 +1632,8 @@ class DeferredTest(_fixtures.FixtureTest):
         eq_(item.description, 'item 4')
 
 class DeferredPopulationTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("thing", metadata,
             Column("id", Integer, primary_key=True),
             Column("name", String(20)))
@@ -1639,16 +1643,18 @@ class DeferredPopulationTest(_base.MappedTest):
             Column("thing_id", Integer, ForeignKey("thing.id")),
             Column("name", String(20)))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         class Human(_base.BasicEntity): pass
         class Thing(_base.BasicEntity): pass
 
         mapper(Human, human, properties={"thing": relation(Thing)})
         mapper(Thing, thing, properties={"name": deferred(thing.c.name)})
     
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         thing.insert().execute([
             {"id": 1, "name": "Chair"},
         ])
@@ -1714,7 +1720,8 @@ class DeferredPopulationTest(_base.MappedTest):
         
 class CompositeTypesTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('graphs', metadata,
             Column('id', Integer, primary_key=True),
             Column('version_id', Integer, primary_key=True, nullable=True),
@@ -2246,7 +2253,8 @@ class MapperExtensionTest(_fixtures.FixtureTest):
 class RequirementsTest(_base.MappedTest):
     """Tests the contract for user classes."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('ht1', metadata,
               Column('id', Integer, primary_key=True,
                      test_needs_autoincrement=True),
@@ -2280,9 +2288,9 @@ class RequirementsTest(_base.MappedTest):
         class OldStyle:
             pass
 
-        self.assertRaises(sa.exc.ArgumentError, mapper, OldStyle, ht1)
+        assert_raises(sa.exc.ArgumentError, mapper, OldStyle, ht1)
 
-        self.assertRaises(sa.exc.ArgumentError, mapper, 123)
+        assert_raises(sa.exc.ArgumentError, mapper, 123)
         
         class NoWeakrefSupport(str):
             pass
@@ -2388,7 +2396,8 @@ class RequirementsTest(_base.MappedTest):
 
 class MagicNamesTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('cartographers', metadata,
               Column('id', Integer, primary_key=True),
               Column('name', String(50)),
@@ -2401,7 +2410,8 @@ class MagicNamesTest(_base.MappedTest):
               Column('state', String(2)),
               Column('data', sa.Text))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Cartographer(_base.BasicEntity):
             pass
 
@@ -2441,7 +2451,7 @@ class MagicNamesTest(_base.MappedTest):
             class T(object):
                 pass
 
-            self.assertRaisesMessage(
+            assert_raises_message(
                 KeyError,
                 ('%r: requested attribute name conflicts with '
                  'instrumentation attribute of the same name.' % reserved),
@@ -2454,7 +2464,7 @@ class MagicNamesTest(_base.MappedTest):
             class M(object):
                 pass
 
-            self.assertRaisesMessage(
+            assert_raises_message(
                 KeyError,
                 ('requested attribute name conflicts with '
                  'instrumentation attribute of the same name'),
@@ -2463,5 +2473,3 @@ class MagicNamesTest(_base.MappedTest):
 
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 98%
rename from test/orm/merge.py
rename to test/orm/test_merge.py
index fd553f2bf79455a12948b4af5e74eebf279afadb..70097cbee26a0aec2bff841f06a4b8b24c2e010c 100644 (file)
@@ -1,9 +1,10 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa.util import OrderedSet
-from testlib.sa.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property
-from testlib.testing import eq_, ne_
-from orm import _base, _fixtures
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy.util import OrderedSet
+from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property
+from sqlalchemy.test.testing import eq_, ne_
+from test.orm import _base, _fixtures
 
 
 class MergeTest(_fixtures.FixtureTest):
@@ -447,7 +448,7 @@ class MergeTest(_fixtures.FixtureTest):
 
         sess = create_session()
         u = User()
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
+        assert_raises_message(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
 
 
     @testing.resolve_artifact_names
@@ -732,5 +733,3 @@ class MergeTest(_fixtures.FixtureTest):
 
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 84%
rename from test/orm/naturalpks.py
rename to test/orm/test_naturalpks.py
index 8efce660c37a777f8e4828409bd850fb856526ec..1376c402e755f1c993ae8981fe6120b125609bf7 100644 (file)
@@ -2,16 +2,20 @@
 Primary key changing capabilities and passive/non-passive cascading updates.
 
 """
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from testlib.testing import eq_
-from orm import _base
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
 
 class NaturalPKTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         users = Table('users', metadata,
             Column('username', String(50), primary_key=True),
             Column('fullname', String(100)),
@@ -32,7 +36,8 @@ class NaturalPKTest(_base.MappedTest):
             Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True),
             test_needs_fk=True)
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class User(_base.ComparableEntity):
             pass
         class Address(_base.ComparableEntity):
@@ -62,7 +67,7 @@ class NaturalPKTest(_base.MappedTest):
 
         sess.expunge_all()
         u1 = sess.query(User).get('ed')
-        self.assertEquals(User(username='ed', fullname='jack'), u1)
+        eq_(User(username='ed', fullname='jack'), u1)
 
     @testing.resolve_artifact_names
     def test_load_after_expire(self):
@@ -81,7 +86,7 @@ class NaturalPKTest(_base.MappedTest):
         # in this case so theres no way to look it up.  criterion-
         # based session invalidation could solve this [ticket:911]
         sess.expire(u1)
-        self.assertRaises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username')
+        assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username')
 
         sess.expunge_all()
         assert sess.query(User).get('jack') is None
@@ -132,7 +137,7 @@ class NaturalPKTest(_base.MappedTest):
         assert u1.addresses[0].username == 'ed'
 
         sess.expunge_all()
-        self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
+        eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
 
         u1 = sess.query(User).get('ed')
         u1.username = 'jack'
@@ -152,7 +157,7 @@ class NaturalPKTest(_base.MappedTest):
         sess.expunge_all()
         assert sess.query(Address).get('jack1').username is None
         u1 = sess.query(User).get('fred')
-        self.assertEquals(User(username='fred', fullname='jack'), u1)
+        eq_(User(username='fred', fullname='jack'), u1)
         
 
     @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
@@ -195,7 +200,7 @@ class NaturalPKTest(_base.MappedTest):
 
         assert a1.username == a2.username == 'ed'
         sess.expunge_all()
-        self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
+        eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
 
     @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
     def test_onetoone_passive(self):
@@ -236,7 +241,7 @@ class NaturalPKTest(_base.MappedTest):
         self.assert_sql_count(testing.db, go, 0)
 
         sess.expunge_all()
-        self.assertEquals([Address(username='ed')], sess.query(Address).all())
+        eq_([Address(username='ed')], sess.query(Address).all())
         
     @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
     def test_bidirectional_passive(self):
@@ -265,7 +270,7 @@ class NaturalPKTest(_base.MappedTest):
 
         u1.username = 'ed'
         (ad1, ad2) = sess.query(Address).all()
-        self.assertEquals([Address(username='jack'), Address(username='jack')], [ad1, ad2])
+        eq_([Address(username='jack'), Address(username='jack')], [ad1, ad2])
         def go():
             sess.flush()
         if passive_updates:
@@ -273,9 +278,9 @@ class NaturalPKTest(_base.MappedTest):
             self.assert_sql_count(testing.db, go, 1)
         else:
             self.assert_sql_count(testing.db, go, 3)
-        self.assertEquals([Address(username='ed'), Address(username='ed')], [ad1, ad2])
+        eq_([Address(username='ed'), Address(username='ed')], [ad1, ad2])
         sess.expunge_all()
-        self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
+        eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
 
         u1 = sess.query(User).get('ed')
         assert len(u1.addresses) == 2    # load addresses
@@ -289,7 +294,7 @@ class NaturalPKTest(_base.MappedTest):
         else:
             self.assert_sql_count(testing.db, go, 3)
         sess.expunge_all()
-        self.assertEquals([Address(username='fred'), Address(username='fred')], sess.query(Address).all())
+        eq_([Address(username='fred'), Address(username='fred')], sess.query(Address).all())
 
 
     @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
@@ -323,10 +328,10 @@ class NaturalPKTest(_base.MappedTest):
         r = sess.query(Item).all()
         # ComparableEntity can't handle a comparison with the backrefs
         # involved....
-        self.assertEquals(Item(itemname='item1'), r[0])
-        self.assertEquals(['jack'], [u.username for u in r[0].users])
-        self.assertEquals(Item(itemname='item2'), r[1])
-        self.assertEquals(['jack', 'fred'], [u.username for u in r[1].users])
+        eq_(Item(itemname='item1'), r[0])
+        eq_(['jack'], [u.username for u in r[0].users])
+        eq_(Item(itemname='item2'), r[1])
+        eq_(['jack', 'fred'], [u.username for u in r[1].users])
 
         u2.username='ed'
         def go():
@@ -338,29 +343,31 @@ class NaturalPKTest(_base.MappedTest):
 
         sess.expunge_all()
         r = sess.query(Item).all()
-        self.assertEquals(Item(itemname='item1'), r[0])
-        self.assertEquals(['jack'], [u.username for u in r[0].users])
-        self.assertEquals(Item(itemname='item2'), r[1])
-        self.assertEquals(['ed', 'jack'], sorted([u.username for u in r[1].users]))
+        eq_(Item(itemname='item1'), r[0])
+        eq_(['jack'], [u.username for u in r[0].users])
+        eq_(Item(itemname='item2'), r[1])
+        eq_(['ed', 'jack'], sorted([u.username for u in r[1].users]))
         
         sess.expunge_all()
         u2 = sess.query(User).get(u2.username)
         u2.username='wendy'
         sess.flush()
         r = sess.query(Item).with_parent(u2).all()
-        self.assertEquals(Item(itemname='item2'), r[0])
+        eq_(Item(itemname='item2'), r[0])
 
 
 class SelfRefTest(_base.MappedTest):
     __unsupported_on__ = 'mssql' # mssql doesn't allow ON UPDATE on self-referential keys
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('nodes', metadata,
               Column('name', String(50), primary_key=True),
               Column('parent', String(50),
                      ForeignKey('nodes.name', onupdate='cascade')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Node(_base.ComparableEntity):
             pass
 
@@ -391,7 +398,8 @@ class SelfRefTest(_base.MappedTest):
 
 
 class NonPKCascadeTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('users', metadata,
             Column('id', Integer, primary_key=True),
             Column('username', String(50), unique=True),
@@ -406,7 +414,8 @@ class NonPKCascadeTest(_base.MappedTest):
                      test_needs_fk=True
                      )
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class User(_base.ComparableEntity):
             pass
         class Address(_base.ComparableEntity):
@@ -433,17 +442,17 @@ class NonPKCascadeTest(_base.MappedTest):
         sess.flush()
         a1 = u1.addresses[0]
 
-        self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)])
+        eq_(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)])
 
         assert sess.query(Address).get(a1.id) is u1.addresses[0]
 
         u1.username = 'ed'
         sess.flush()
         assert u1.addresses[0].username == 'ed'
-        self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)])
+        eq_(sa.select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)])
 
         sess.expunge_all()
-        self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
+        eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
 
         u1 = sess.query(User).get(u1.id)
         u1.username = 'jack'
@@ -463,13 +472,11 @@ class NonPKCascadeTest(_base.MappedTest):
         sess.flush()
         sess.expunge_all()
         a1 = sess.query(Address).get(a1.id)
-        self.assertEquals(a1.username, None)
+        eq_(a1.username, None)
 
-        self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [(None,), (None,)])
+        eq_(sa.select([addresses.c.username]).execute().fetchall(), [(None,), (None,)])
 
         u1 = sess.query(User).get(u1.id)
-        self.assertEquals(User(username='fred', fullname='jack'), u1)
+        eq_(User(username='fred', fullname='jack'), u1)
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 81%
rename from test/orm/onetoone.py
rename to test/orm/test_onetoone.py
index be0375e48b7628a5b25df76d8ecd216ae1ff9bb1..0d66915ea5d79230bf5cd45a61ec8f0c9dfe0c87 100644 (file)
@@ -1,12 +1,15 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from orm import _base
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from test.orm import _base
 
 
 class O2OTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('jack', metadata,
               Column('id', Integer, primary_key=True),
               Column('number', String(50)),
@@ -19,8 +22,9 @@ class O2OTest(_base.MappedTest):
               Column('description', String(100)),
               Column('jack_id', Integer, ForeignKey("jack.id")))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         class Jack(_base.BasicEntity):
             pass
         class Port(_base.BasicEntity):
@@ -70,5 +74,3 @@ class O2OTest(_base.MappedTest):
         session.delete(j)
         session.flush()
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 80%
rename from test/orm/pickled.py
rename to test/orm/test_pickled.py
index 878fe931e36bfbfc3eb74cbf7eb508ee688a806c..5343cc15b940a201a1677c5b2624a1a83d301e67 100644 (file)
@@ -1,9 +1,12 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 import pickle
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session, attributes
-from orm import _base, _fixtures
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, attributes
+from test.orm import _base, _fixtures
 
 
 User, EmailUser = None, None
@@ -28,7 +31,7 @@ class PickleTest(_fixtures.FixtureTest):
 
         sess.expunge_all()
 
-        self.assertEquals(u1, sess.query(User).get(u2.id))
+        eq_(u1, sess.query(User).get(u2.id))
 
     @testing.resolve_artifact_names
     def test_class_deferred_cols(self):
@@ -52,14 +55,14 @@ class PickleTest(_fixtures.FixtureTest):
         u2 = pickle.loads(pickle.dumps(u1))
         sess2 = create_session()
         sess2.add(u2)
-        self.assertEquals(u2.name, 'ed')
-        self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+        eq_(u2.name, 'ed')
+        eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
 
         u2 = pickle.loads(pickle.dumps(u1))
         sess2 = create_session()
         u2 = sess2.merge(u2, dont_load=True)
-        self.assertEquals(u2.name, 'ed')
-        self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+        eq_(u2.name, 'ed')
+        eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
 
     @testing.resolve_artifact_names
     def test_instance_deferred_cols(self):
@@ -82,22 +85,22 @@ class PickleTest(_fixtures.FixtureTest):
         u2 = pickle.loads(pickle.dumps(u1))
         sess2 = create_session()
         sess2.add(u2)
-        self.assertEquals(u2.name, 'ed')
+        eq_(u2.name, 'ed')
         assert 'addresses' not in u2.__dict__
         ad = u2.addresses[0]
         assert 'email_address' not in ad.__dict__
-        self.assertEquals(ad.email_address, 'ed@bar.com')
-        self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+        eq_(ad.email_address, 'ed@bar.com')
+        eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
 
         u2 = pickle.loads(pickle.dumps(u1))
         sess2 = create_session()
         u2 = sess2.merge(u2, dont_load=True)
-        self.assertEquals(u2.name, 'ed')
+        eq_(u2.name, 'ed')
         assert 'addresses' not in u2.__dict__
         ad = u2.addresses[0]
         assert 'email_address' in ad.__dict__  # mapper options dont transmit over merge() right now
-        self.assertEquals(ad.email_address, 'ed@bar.com')
-        self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+        eq_(ad.email_address, 'ed@bar.com')
+        eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
 
     @testing.resolve_artifact_names
     def test_options_with_descriptors(self):
@@ -122,7 +125,7 @@ class PickleTest(_fixtures.FixtureTest):
             sa.orm.eagerload(["addresses", User.addresses]),
         ]:
             opt2 = pickle.loads(pickle.dumps(opt))
-            self.assertEquals(opt.key, opt2.key)
+            eq_(opt.key, opt2.key)
         
         u1 = sess.query(User).options(opt).first()
         
@@ -130,7 +133,8 @@ class PickleTest(_fixtures.FixtureTest):
         
         
 class PolymorphicDeferredTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('users', metadata,
             Column('id', Integer, primary_key=True),
             Column('name', String(30)),
@@ -139,7 +143,8 @@ class PolymorphicDeferredTest(_base.MappedTest):
             Column('id', Integer, ForeignKey('users.id'), primary_key=True),
             Column('email_address', String(30)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         global User, EmailUser
         class User(_base.BasicEntity):
             pass
@@ -147,10 +152,11 @@ class PolymorphicDeferredTest(_base.MappedTest):
         class EmailUser(User):
             pass
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         global User, EmailUser
         User, EmailUser = None, None
-        _base.MappedTest.tearDownAll(self)
+        super(PolymorphicDeferredTest, cls).teardown_class()
 
     @testing.resolve_artifact_names
     def test_polymorphic_deferred(self):
@@ -168,9 +174,9 @@ class PolymorphicDeferredTest(_base.MappedTest):
         sess2 = create_session()
         sess2.add(eu2)
         assert 'email_address' not in eu2.__dict__
-        self.assertEquals(eu2.email_address, 'foo@bar.com')
+        eq_(eu2.email_address, 'foo@bar.com')
 
-class CustomSetupTeardowntest(_fixtures.FixtureTest):
+class CustomSetupTeardownTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_rebuild_state(self):
         """not much of a 'test', but illustrate how to 
@@ -186,5 +192,3 @@ class CustomSetupTeardowntest(_fixtures.FixtureTest):
         attributes.manager_of_class(User).setup_instance(u2)
         assert attributes.instance_state(u2)
     
-if __name__ == '__main__':
-    testenv.main()
similarity index 88%
rename from test/orm/query.py
rename to test/orm/test_query.py
index 33c3e39d7128f4ecc2a345e23d22ba6b257bccca..66c219b10cb6443583aad7e083507ab2d5f65cfe 100644 (file)
@@ -1,4 +1,4 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import operator
 from sqlalchemy import *
 from sqlalchemy import exc as sa_exc, util
@@ -7,17 +7,18 @@ from sqlalchemy.engine import default
 from sqlalchemy.orm import *
 from sqlalchemy.orm import attributes
 
-from testlib.testing import eq_
+from sqlalchemy.test.testing import eq_
 
-from testlib import sa, testing, AssertsCompiledSQL, Column, engines
+import sqlalchemy as sa
+from sqlalchemy.test import testing, AssertsCompiledSQL, Column, engines
 
-from orm import _fixtures
-from orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \
+from test.orm import _fixtures
+from test.orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \
            Dingaling, item_keywords, dingalings, User, items,\
            orders, Address, users, nodes, \
             order_items, Item, Order, Node
 
-from orm import _base
+from test.orm import _base
 
 from sqlalchemy.orm.util import join, outerjoin, with_parent
 
@@ -27,7 +28,8 @@ class QueryTest(_fixtures.FixtureTest):
     run_deletes = None
 
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         mapper(User, users, properties={
             'addresses':relation(Address, backref='user', order_by=addresses.c.id),
             'orders':relation(Order, backref='user', order_by=orders.c.id), # o2m, m2o
@@ -82,8 +84,8 @@ class GetTest(QueryTest):
         s = create_session()
         
         q = s.query(User).join('addresses').filter(Address.user_id==8)
-        self.assertRaises(sa_exc.InvalidRequestError, q.get, 7)
-        self.assertRaises(sa_exc.InvalidRequestError, s.query(User).filter(User.id==7).get, 19)
+        assert_raises(sa_exc.InvalidRequestError, q.get, 7)
+        assert_raises(sa_exc.InvalidRequestError, s.query(User).filter(User.id==7).get, 19)
         
         # order_by()/get() doesn't raise
         s.query(User).order_by(User.id).get(8)
@@ -142,7 +144,7 @@ class GetTest(QueryTest):
             class LocalFoo(Base):
                 pass
             mapper(LocalFoo, table)
-            self.assertEquals(create_session().query(LocalFoo).get(ustring),
+            eq_(create_session().query(LocalFoo).get(ustring),
                               LocalFoo(id=ustring, data=ustring))
         finally:
             metadata.drop_all()
@@ -183,7 +185,7 @@ class GetTest(QueryTest):
     def test_query_str(self):
         s = create_session()
         q = s.query(User).filter(User.id==1)
-        self.assertEquals(
+        eq_(
             str(q).replace('\n',''), 
             'SELECT users.id AS users_id, users.name AS users_name FROM users WHERE users.id = ?'
             )
@@ -197,29 +199,29 @@ class InvalidGenerationsTest(QueryTest):
             s.query(User).offset(2),
             s.query(User).limit(2).offset(2)
         ):
-            self.assertRaises(sa_exc.InvalidRequestError, q.join, "addresses")
+            assert_raises(sa_exc.InvalidRequestError, q.join, "addresses")
 
-            self.assertRaises(sa_exc.InvalidRequestError, q.filter, User.name=='ed')
+            assert_raises(sa_exc.InvalidRequestError, q.filter, User.name=='ed')
 
-            self.assertRaises(sa_exc.InvalidRequestError, q.filter_by, name='ed')
+            assert_raises(sa_exc.InvalidRequestError, q.filter_by, name='ed')
 
-            self.assertRaises(sa_exc.InvalidRequestError, q.order_by, 'foo')
+            assert_raises(sa_exc.InvalidRequestError, q.order_by, 'foo')
 
-            self.assertRaises(sa_exc.InvalidRequestError, q.group_by, 'foo')
+            assert_raises(sa_exc.InvalidRequestError, q.group_by, 'foo')
 
-            self.assertRaises(sa_exc.InvalidRequestError, q.having, 'foo')
+            assert_raises(sa_exc.InvalidRequestError, q.having, 'foo')
     
     def test_no_from(self):
         s = create_session()
     
         q = s.query(User).select_from(users)
-        self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+        assert_raises(sa_exc.InvalidRequestError, q.select_from, users)
 
         q = s.query(User).join('addresses')
-        self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+        assert_raises(sa_exc.InvalidRequestError, q.select_from, users)
         
         q = s.query(User).order_by(User.id)
-        self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+        assert_raises(sa_exc.InvalidRequestError, q.select_from, users)
         
         # this is fine, however
         q.from_self()
@@ -227,43 +229,43 @@ class InvalidGenerationsTest(QueryTest):
     def test_invalid_select_from(self):
         s = create_session()
         q = s.query(User)
-        self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id==5)
-        self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id)
+        assert_raises(sa_exc.ArgumentError, q.select_from, User.id==5)
+        assert_raises(sa_exc.ArgumentError, q.select_from, User.id)
 
     def test_invalid_from_statement(self):
         s = create_session()
         q = s.query(User)
-        self.assertRaises(sa_exc.ArgumentError, q.from_statement, User.id==5)
-        self.assertRaises(sa_exc.ArgumentError, q.from_statement, users.join(addresses))
+        assert_raises(sa_exc.ArgumentError, q.from_statement, User.id==5)
+        assert_raises(sa_exc.ArgumentError, q.from_statement, users.join(addresses))
     
     def test_invalid_column(self):
         s = create_session()
         q = s.query(User)
-        self.assertRaises(sa_exc.InvalidRequestError, q.add_column, object())
+        assert_raises(sa_exc.InvalidRequestError, q.add_column, object())
         
     def test_mapper_zero(self):
         s = create_session()
         
         q = s.query(User, Address)
-        self.assertRaises(sa_exc.InvalidRequestError, q.get, 5)
+        assert_raises(sa_exc.InvalidRequestError, q.get, 5)
         
     def test_from_statement(self):
         s = create_session()
         
         q = s.query(User).filter(User.id==5)
-        self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
+        assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x")
 
         q = s.query(User).filter_by(id=5)
-        self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
+        assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x")
 
         q = s.query(User).limit(5)
-        self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
+        assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x")
 
         q = s.query(User).group_by(User.name)
-        self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
+        assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x")
 
         q = s.query(User).order_by(User.name)
-        self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
+        assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x")
         
 class OperatorTest(QueryTest, AssertsCompiledSQL):
     """test sql.Comparator implementation for MapperProperties"""
@@ -431,7 +433,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL):
                     "users.id IN (:id_1, :id_2)")
 
     def test_in_on_relation_not_supported(self):
-        self.assertRaises(NotImplementedError, Address.user.in_, [User(id=5)])
+        assert_raises(NotImplementedError, Address.user.in_, [User(id=5)])
         
     def test_between(self):
         self._test(User.id.between('a', 'b'),
@@ -705,8 +707,8 @@ class FilterTest(QueryTest):
         assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all()
 
         # m2m
-        self.assertEquals(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)])
-        self.assertEquals(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)])
+        eq_(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)])
+        eq_(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)])
     
     def test_filter_by(self):
         sess = create_session()
@@ -723,16 +725,16 @@ class FilterTest(QueryTest):
         sess = create_session()
         
         # o2o
-        self.assertEquals([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all())
-        self.assertEquals([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all())
+        eq_([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all())
+        eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all())
         
         # m2o
-        self.assertEquals([Order(id=5)], sess.query(Order).filter(Order.address==None).all())
-        self.assertEquals([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).order_by(Order.id).filter(Order.address!=None).all())
+        eq_([Order(id=5)], sess.query(Order).filter(Order.address==None).all())
+        eq_([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).order_by(Order.id).filter(Order.address!=None).all())
         
         # o2m
-        self.assertEquals([User(id=10)], sess.query(User).filter(User.addresses==None).all())
-        self.assertEquals([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).order_by(User.id).all())
+        eq_([User(id=10)], sess.query(User).filter(User.addresses==None).all())
+        eq_([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).order_by(User.id).all())
 
 
 class FromSelfTest(QueryTest, AssertsCompiledSQL):
@@ -818,7 +820,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL):
     def test_multiple_entities(self):
         sess = create_session()
 
-        self.assertEquals(
+        eq_(
             sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().all(),
             [
                 (User(id=8), Address(id=2)),
@@ -826,7 +828,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL):
             ]
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().options(eagerload('addresses')).first(),
             
             #    order_by(User.id, Address.id).first(),
@@ -842,11 +844,11 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL):
         ed = s.query(User).filter(User.name=='ed')
         jack = s.query(User).filter(User.name=='jack')
         
-        self.assertEquals(fred.union(ed).order_by(User.name).all(), 
+        eq_(fred.union(ed).order_by(User.name).all(), 
             [User(name='ed'), User(name='fred')]
         )
 
-        self.assertEquals(fred.union(ed, jack).order_by(User.name).all(), 
+        eq_(fred.union(ed, jack).order_by(User.name).all(), 
             [User(name='ed'), User(name='fred'), User(name='jack')]
         )
         
@@ -857,11 +859,11 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL):
         fred = s.query(User).filter(User.name=='fred')
         ed = s.query(User).filter(User.name=='ed')
         jack = s.query(User).filter(User.name=='jack')
-        self.assertEquals(fred.intersect(ed, jack).all(), 
+        eq_(fred.intersect(ed, jack).all(), 
             []
         )
 
-        self.assertEquals(fred.union(ed).intersect(ed.union(jack)).all(), 
+        eq_(fred.union(ed).intersect(ed.union(jack)).all(), 
             [User(name='ed')]
         )
     
@@ -873,7 +875,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL):
         jack = s.query(User).filter(User.name=='jack')
 
         def go():
-            self.assertEquals(
+            eq_(
                 fred.union(ed).order_by(User.name).options(eagerload(User.addresses)).all(), 
                 [
                     User(name='ed', addresses=[Address(), Address(), Address()]), 
@@ -888,8 +890,8 @@ class AggregateTest(QueryTest):
     def test_sum(self):
         sess = create_session()
         orders = sess.query(Order).filter(Order.id.in_([2, 3, 4]))
-        self.assertEquals(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,))
-        self.assertEquals(orders.value(func.sum(Order.user_id * Order.address_id)), 79)
+        eq_(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,))
+        eq_(orders.value(func.sum(Order.user_id * Order.address_id)), 79)
 
     def test_apply(self):
         sess = create_session()
@@ -987,13 +989,13 @@ class YieldTest(QueryTest):
         q = iter(sess.query(User).yield_per(1).from_statement("select * from users"))
 
         ret = []
-        self.assertEquals(len(sess.identity_map), 0)
+        eq_(len(sess.identity_map), 0)
         ret.append(q.next())
         ret.append(q.next())
-        self.assertEquals(len(sess.identity_map), 2)
+        eq_(len(sess.identity_map), 2)
         ret.append(q.next())
         ret.append(q.next())
-        self.assertEquals(len(sess.identity_map), 4)
+        eq_(len(sess.identity_map), 4)
         try:
             q.next()
             assert False
@@ -1019,7 +1021,7 @@ class TextTest(QueryTest):
 
     def test_as_column(self):
         s = create_session()
-        self.assertRaises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name"))
+        assert_raises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name"))
 
         eq_(s.query(User.id, "name").order_by(User.id).all(), [(7, u'jack'), (8, u'ed'), (9, u'fred'), (10, u'chuck')])
 
@@ -1091,24 +1093,24 @@ class JoinTest(QueryTest):
         sess = create_session()
         
         for oalias,ialias in [(True, True), (False, False), (True, False), (False, True)]:
-            self.assertEquals(
+            eq_(
                 sess.query(User).join('orders', aliased=oalias).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description == 'item 4').all(),
                 [User(name='jack')]
             )
 
             # use middle criterion
-            self.assertEquals(
+            eq_(
                 sess.query(User).join('orders', aliased=oalias).filter(Order.user_id==9).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description=='item 4').all(),
                 []
             )
         
         orderalias = aliased(Order)
         itemalias = aliased(Item)
-        self.assertEquals(
+        eq_(
             sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(itemalias.description == 'item 4').all(),
             [User(name='jack')]
         )
-        self.assertEquals(
+        eq_(
             sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(orderalias.user_id==9).filter(itemalias.description=='item 4').all(),
             []
         )
@@ -1119,28 +1121,28 @@ class JoinTest(QueryTest):
         
         sess = create_session()
 
-        self.assertEquals(
+        eq_(
             sess.query(User).join(Address.user).filter(Address.email_address=='ed@wood.com').all(),
             [User(id=8,name=u'ed')]
         )
 
         # its actually not so controversial if you view it in terms
         # of multiple entities.
-        self.assertEquals(
+        eq_(
             sess.query(User, Address).join(Address.user).filter(Address.email_address=='ed@wood.com').all(),
             [(User(id=8,name=u'ed'), Address(email_address='ed@wood.com'))]
         )
         
         # this was the controversial part.  now, raise an error if the feature is abused.
         # before the error raise was added, this would silently work.....
-        self.assertRaises(
+        assert_raises(
             sa_exc.InvalidRequestError,
             sess.query(User).join, (Address, Address.user),
         )
 
         # but this one would silently fail 
         adalias = aliased(Address)
-        self.assertRaises(
+        assert_raises(
             sa_exc.InvalidRequestError,
             sess.query(User).join, (adalias, Address.user),
         )
@@ -1153,7 +1155,7 @@ class JoinTest(QueryTest):
         oalias2 = aliased(Order)
         result = sess.query(ualias).join((oalias1, ualias.orders), (oalias2, ualias.orders)).\
                 filter(or_(oalias1.user_id==9, oalias2.user_id==7)).all()
-        self.assertEquals(result, [User(id=7,name=u'jack'), User(id=9,name=u'fred')])
+        eq_(result, [User(id=7,name=u'jack'), User(id=9,name=u'fred')])
         
     def test_orderby_arg_bug(self):
         sess = create_session()
@@ -1163,17 +1165,17 @@ class JoinTest(QueryTest):
     def test_no_onclause(self):
         sess = create_session()
 
-        self.assertEquals(
+        eq_(
             sess.query(User).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(),
             [User(name='jack')]
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(User.name).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(),
             [('jack',)]
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(User).join(Order, (Item, Order.items)).filter(Item.description == 'item 4').all(),
             [User(name='jack')]
         )
@@ -1181,7 +1183,7 @@ class JoinTest(QueryTest):
     def test_clause_onclause(self):
         sess = create_session()
 
-        self.assertEquals(
+        eq_(
             sess.query(User).join(
                 (Order, User.id==Order.user_id), 
                 (order_items, Order.id==order_items.c.order_id), 
@@ -1190,7 +1192,7 @@ class JoinTest(QueryTest):
             [User(name='jack')]
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(User.name).join(
                 (Order, User.id==Order.user_id), 
                 (order_items, Order.id==order_items.c.order_id), 
@@ -1200,7 +1202,7 @@ class JoinTest(QueryTest):
         )
 
         ualias = aliased(User)
-        self.assertEquals(
+        eq_(
             sess.query(ualias.name).join(
                 (Order, ualias.id==Order.user_id), 
                 (order_items, Order.id==order_items.c.order_id), 
@@ -1212,7 +1214,7 @@ class JoinTest(QueryTest):
         # explicit onclause with from_self(), means
         # the onclause must be aliased against the query's custom
         # FROM object
-        self.assertEquals(
+        eq_(
             sess.query(User).order_by(User.id).offset(2).from_self().join(
                 (Order, User.id==Order.user_id)
             ).all(),
@@ -1220,7 +1222,7 @@ class JoinTest(QueryTest):
         )
 
         # same with an explicit select_from()
-        self.assertEquals(
+        eq_(
             sess.query(User).select_from(select([users]).order_by(User.id).offset(2).alias()).join(
                 (Order, User.id==Order.user_id)
             ).all(),
@@ -1244,34 +1246,34 @@ class JoinTest(QueryTest):
         AdAlias = aliased(Address)
         q = q.add_entity(AdAlias).select_from(outerjoin(User, AdAlias))
         l = q.order_by(User.id, AdAlias.id).all()
-        self.assertEquals(l, expected)
+        eq_(l, expected)
 
         sess.expunge_all()
 
         q = sess.query(User).add_entity(AdAlias)
         l = q.select_from(outerjoin(User, AdAlias)).filter(AdAlias.email_address=='ed@bettyboop.com').all()
-        self.assertEquals(l, [(user8, address3)])
+        eq_(l, [(user8, address3)])
 
         l = q.select_from(outerjoin(User, AdAlias, 'addresses')).filter(AdAlias.email_address=='ed@bettyboop.com').all()
-        self.assertEquals(l, [(user8, address3)])
+        eq_(l, [(user8, address3)])
 
         l = q.select_from(outerjoin(User, AdAlias, User.id==AdAlias.user_id)).filter(AdAlias.email_address=='ed@bettyboop.com').all()
-        self.assertEquals(l, [(user8, address3)])
+        eq_(l, [(user8, address3)])
 
         # this is the first test where we are joining "backwards" - from AdAlias to User even though
         # the query is against User
         q = sess.query(User, AdAlias)
         l = q.join(AdAlias.user).filter(User.name=='ed')
-        self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
+        eq_(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
 
         q = sess.query(User, AdAlias).select_from(join(AdAlias, User, AdAlias.user)).filter(User.name=='ed')
-        self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
+        eq_(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
         
     def test_implicit_joins_from_aliases(self):
         sess = create_session()
         OrderAlias = aliased(Order)
 
-        self.assertEquals(
+        eq_(
             sess.query(OrderAlias).join('items').filter_by(description='item 3').\
                 order_by(OrderAlias.id).all(),
             [
@@ -1281,7 +1283,7 @@ class JoinTest(QueryTest):
             ]
         )
          
-        self.assertEquals(
+        eq_(
             sess.query(User, OrderAlias, Item.description).join(('orders', OrderAlias), 'items').filter_by(description='item 3').\
                 order_by(User.id, OrderAlias.id).all(),
             [
@@ -1314,12 +1316,12 @@ class JoinTest(QueryTest):
         q = sess.query(Order)
         q = q.add_entity(Item).select_from(join(Order, Item, 'items')).order_by(Order.id, Item.id)
         l = q.all()
-        self.assertEquals(l, expected)
+        eq_(l, expected)
 
         IAlias = aliased(Item)
         q = sess.query(Order, IAlias).select_from(join(Order, IAlias, 'items')).filter(IAlias.description=='item 3')
         l = q.all()
-        self.assertEquals(l, 
+        eq_(l, 
             [
                 (order1, item3),
                 (order2, item3),
@@ -1385,7 +1387,7 @@ class JoinTest(QueryTest):
         sess = create_session()
 
         ualias = aliased(User)
-        self.assertEquals(
+        eq_(
             sess.query(User, ualias).filter(User.id > ualias.id).order_by(desc(ualias.id), User.name).all(),
             [
                 (User(id=10,name=u'chuck'), User(id=9,name=u'fred')), 
@@ -1401,14 +1403,15 @@ class JoinTest(QueryTest):
         
         sess = create_session()
         
-        self.assertEquals(
+        eq_(
             sess.query(User.name).join((addresses, User.id==addresses.c.user_id)).order_by(User.id).all(),
             [(u'jack',), (u'ed',), (u'ed',), (u'ed',), (u'fred',)]
         )
         
         
 class MultiplePathTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global t1, t2, t1t2_1, t1t2_2
         t1 = Table('t1', metadata,
             Column('id', Integer, primary_key=True),
@@ -1440,7 +1443,7 @@ class MultiplePathTest(_base.MappedTest):
         mapper(T2, t2)
 
         q = create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint()
-        self.assertRaisesMessage(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.",
+        assert_raises_message(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.",
             q.join, 't2s_2'
         )
 
@@ -1449,7 +1452,8 @@ class MultiplePathTest(_base.MappedTest):
 
 class SynonymTest(QueryTest):
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         mapper(User, users, properties={
             'name_syn':synonym('name'),
             'addresses':relation(Address),
@@ -1547,7 +1551,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
         adalias = addresses.alias()
         q = sess.query(User).select_from(users.outerjoin(adalias)).options(contains_eager(User.addresses, alias=adalias))
         def go():
-            self.assertEquals(self.static.user_address_result, q.order_by(User.id).all())
+            eq_(self.static.user_address_result, q.order_by(User.id).all())
         self.assert_sql_count(testing.db, go, 1)
         sess.expunge_all()
 
@@ -1675,35 +1679,35 @@ class MixedEntitiesTest(QueryTest):
         sel = users.select(User.id.in_([7, 8])).alias()
         q = sess.query(User)
         q2 = q.select_from(sel).values(User.name)
-        self.assertEquals(list(q2), [(u'jack',), (u'ed',)])
+        eq_(list(q2), [(u'jack',), (u'ed',)])
     
         q = sess.query(User)
         q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String))
-        self.assertEquals(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')])
+        eq_(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')])
     
         q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address)
-        self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
+        eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
     
         q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice(1, 3).values(User.name, Address.email_address)
-        self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')])
+        eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')])
     
         adalias = aliased(Address)
         q2 = q.join(('addresses', adalias)).filter(User.name.like('%e%')).values(User.name, adalias.email_address)
-        self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
+        eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
     
         q2 = q.values(func.count(User.name))
         assert q2.next() == (4,)
 
         q2 = q.select_from(sel).filter(User.id==8).values(User.name, sel.c.name, User.name)
-        self.assertEquals(list(q2), [(u'ed', u'ed', u'ed')])
+        eq_(list(q2), [(u'ed', u'ed', u'ed')])
 
         # using User.xxx is alised against "sel", so this query returns nothing
         q2 = q.select_from(sel).filter(User.id==8).filter(User.id>sel.c.id).values(User.name, sel.c.name, User.name)
-        self.assertEquals(list(q2), [])
+        eq_(list(q2), [])
 
         # whereas this uses users.c.xxx, is not aliased and creates a new join
         q2 = q.select_from(sel).filter(users.c.id==8).filter(users.c.id>sel.c.id).values(users.c.name, sel.c.name, User.name)
-        self.assertEquals(list(q2), [(u'ed', u'jack', u'jack')])
+        eq_(list(q2), [(u'ed', u'jack', u'jack')])
 
     @testing.fails_on('mssql', 'FIXME: unknown')
     def test_values_specific_order_by(self):
@@ -1715,7 +1719,7 @@ class MixedEntitiesTest(QueryTest):
         q = sess.query(User)
         u2 = aliased(User)
         q2 = q.select_from(sel).filter(u2.id>1).order_by([User.id, sel.c.id, u2.id]).values(User.name, sel.c.name, u2.name)
-        self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')])
+        eq_(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')])
 
     @testing.fails_on('mssql', 'FIXME: unknown')
     def test_values_with_boolean_selects(self):
@@ -1724,7 +1728,7 @@ class MixedEntitiesTest(QueryTest):
 
         q = sess.query(User)
         q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%')))
-        self.assertEquals(list(q2), [(True, 1), (False, 3)])
+        eq_(list(q2), [(True, 1), (False, 3)])
 
     def test_correlated_subquery(self):
         """test that a subquery constructed from ORM attributes doesn't leak out 
@@ -1739,7 +1743,7 @@ class MixedEntitiesTest(QueryTest):
             label('count')
 
         # we don't want Address to be outside of the subquery here
-        self.assertEquals(
+        eq_(
             list(sess.query(User, subq)[0:3]),
             [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)]
             )
@@ -1751,7 +1755,7 @@ class MixedEntitiesTest(QueryTest):
             label('count')
 
         # we don't want Address to be outside of the subquery here
-        self.assertEquals(
+        eq_(
             list(sess.query(User, subq)[0:3]),
             [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)]
             )
@@ -1759,71 +1763,71 @@ class MixedEntitiesTest(QueryTest):
     def test_tuple_labeling(self):
         sess = create_session()
         for row in sess.query(User, Address).join(User.addresses).all():
-            self.assertEquals(set(row.keys()), set(['User', 'Address']))
-            self.assertEquals(row.User, row[0])
-            self.assertEquals(row.Address, row[1])
+            eq_(set(row.keys()), set(['User', 'Address']))
+            eq_(row.User, row[0])
+            eq_(row.Address, row[1])
         
         for row in sess.query(User.name, User.id.label('foobar')):
-            self.assertEquals(set(row.keys()), set(['name', 'foobar']))
-            self.assertEquals(row.name, row[0])
-            self.assertEquals(row.foobar, row[1])
+            eq_(set(row.keys()), set(['name', 'foobar']))
+            eq_(row.name, row[0])
+            eq_(row.foobar, row[1])
 
         for row in sess.query(User).values(User.name, User.id.label('foobar')):
-            self.assertEquals(set(row.keys()), set(['name', 'foobar']))
-            self.assertEquals(row.name, row[0])
-            self.assertEquals(row.foobar, row[1])
+            eq_(set(row.keys()), set(['name', 'foobar']))
+            eq_(row.name, row[0])
+            eq_(row.foobar, row[1])
 
         oalias = aliased(Order)
         for row in sess.query(User, oalias).join(User.orders).all():
-            self.assertEquals(set(row.keys()), set(['User']))
-            self.assertEquals(row.User, row[0])
+            eq_(set(row.keys()), set(['User']))
+            eq_(row.User, row[0])
 
         oalias = aliased(Order, name='orders')
         for row in sess.query(User, oalias).join(User.orders).all():
-            self.assertEquals(set(row.keys()), set(['User', 'orders']))
-            self.assertEquals(row.User, row[0])
-            self.assertEquals(row.orders, row[1])
+            eq_(set(row.keys()), set(['User', 'orders']))
+            eq_(row.User, row[0])
+            eq_(row.orders, row[1])
 
 
     def test_column_queries(self):
         sess = create_session()
 
-        self.assertEquals(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)])
+        eq_(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)])
     
         sel = users.select(User.id.in_([7, 8])).alias()
         q = sess.query(User.name)
         q2 = q.select_from(sel).all()
-        self.assertEquals(list(q2), [(u'jack',), (u'ed',)])
+        eq_(list(q2), [(u'jack',), (u'ed',)])
 
-        self.assertEquals(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [
+        eq_(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [
             (u'jack', u'jack@bean.com'), (u'ed', u'ed@wood.com'), 
             (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), 
             (u'fred', u'fred@fred.com')
         ])
     
-        self.assertEquals(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(), 
+        eq_(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(), 
             [(u'jack', 1), (u'ed', 3), (u'fred', 1), (u'chuck', 0)]
         )
 
-        self.assertEquals(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), 
+        eq_(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), 
             [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)]
         )
 
-        self.assertEquals(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), 
+        eq_(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), 
             [(1, User(name='jack',id=7)), (3, User(name='ed',id=8)), (1, User(name='fred',id=9)), (0, User(name='chuck',id=10))]
         )
     
         adalias = aliased(Address)
-        self.assertEquals(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(), 
+        eq_(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(), 
             [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)]
         )
 
-        self.assertEquals(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(),
+        eq_(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(),
             [(1, User(name=u'jack',id=7)), (3, User(name=u'ed',id=8)), (1, User(name=u'fred',id=9)), (0, User(name=u'chuck',id=10))]
         )
 
         # select from aliasing + explicit aliasing
-        self.assertEquals(
+        eq_(
             sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).order_by(User.id, adalias.id).all(),
             [
                 (User(name=u'jack',id=7), u'jack@bean.com'), 
@@ -1836,7 +1840,7 @@ class MixedEntitiesTest(QueryTest):
         )
     
         # anon + select from aliasing
-        self.assertEquals(
+        eq_(
             sess.query(User).join(User.addresses, aliased=True).filter(Address.email_address.like('%ed%')).from_self().all(),
             [
                 User(name=u'ed',id=8), 
@@ -1849,7 +1853,7 @@ class MixedEntitiesTest(QueryTest):
             sess.query(User, adalias.email_address).outerjoin((User.addresses, adalias)).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10),
             sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10),
         ]:
-            self.assertEquals(
+            eq_(
 
                 q.all(),
                 [(User(addresses=[Address(user_id=7,email_address=u'jack@bean.com',id=1)],name=u'jack',id=7), u'jack@bean.com'), 
@@ -1875,7 +1879,7 @@ class MixedEntitiesTest(QueryTest):
     
         def go():
             results = sess.query(User).limit(1).options(eagerload('addresses')).add_column(User.name).all()
-            self.assertEquals(results, [(User(name='jack'), 'jack')])
+            eq_(results, [(User(name='jack'), 'jack')])
         self.assert_sql_count(testing.db, go, 1)
     
     def test_self_referential(self):
@@ -1898,7 +1902,7 @@ class MixedEntitiesTest(QueryTest):
 
         ]:
     
-            self.assertEquals(
+            eq_(
             q.all(),
             [
                 (Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)), 
@@ -1924,25 +1928,25 @@ class MixedEntitiesTest(QueryTest):
         sess = create_session()
 
         selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id])
-        self.assertEquals(list(sess.query(User, Address).instances(selectquery.execute())), expected)
+        eq_(list(sess.query(User, Address).instances(selectquery.execute())), expected)
         sess.expunge_all()
 
         for address_entity in (Address, aliased(Address)):
             q = sess.query(User).add_entity(address_entity).outerjoin(('addresses', address_entity)).order_by(User.id, address_entity.id)
-            self.assertEquals(q.all(), expected)
+            eq_(q.all(), expected)
             sess.expunge_all()
 
             q = sess.query(User).add_entity(address_entity)
             q = q.join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com')
-            self.assertEquals(q.all(), [(user8, address3)])
+            eq_(q.all(), [(user8, address3)])
             sess.expunge_all()
 
             q = sess.query(User, address_entity).join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com')
-            self.assertEquals(q.all(), [(user8, address3)])
+            eq_(q.all(), [(user8, address3)])
             sess.expunge_all()
 
             q = sess.query(User, address_entity).join(('addresses', address_entity)).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
-            self.assertEquals(list(util.OrderedSet(q.all())), [(user8, address3)])
+            eq_(list(util.OrderedSet(q.all())), [(user8, address3)])
             sess.expunge_all()
 
     def test_aliased_multi_mappers(self):
@@ -1979,7 +1983,7 @@ class MixedEntitiesTest(QueryTest):
             assert sess.query(User).add_column(add_col).all() == expected
             sess.expunge_all()
 
-        self.assertRaises(sa_exc.InvalidRequestError, sess.query(User).add_column, object())
+        assert_raises(sa_exc.InvalidRequestError, sess.query(User).add_column, object())
 
     def test_add_multi_columns(self):
         """test that add_column accepts a FROM clause."""
@@ -2004,13 +2008,13 @@ class MixedEntitiesTest(QueryTest):
 
         q = sess.query(User)
         q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(Address.id).label('count'))
-        self.assertEquals(q.all(), expected)
+        eq_(q.all(), expected)
         sess.expunge_all()
     
         adalias = aliased(Address)
         q = sess.query(User)
         q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin(('addresses', adalias)).add_column(func.count(adalias.id).label('count'))
-        self.assertEquals(q.all(), expected)
+        eq_(q.all(), expected)
         sess.expunge_all()
 
         s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id)
@@ -2069,8 +2073,9 @@ class ImmediateTest(_fixtures.FixtureTest):
     run_inserts = 'once'
     run_deletes = None
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Address, addresses)
 
         mapper(User, users, properties=dict(
@@ -2080,25 +2085,25 @@ class ImmediateTest(_fixtures.FixtureTest):
     def test_one(self):
         sess = create_session()
 
-        self.assertRaises(sa.orm.exc.NoResultFound,
+        assert_raises(sa.orm.exc.NoResultFound,
                           sess.query(User).filter(User.id == 99).one)
 
         eq_(sess.query(User).filter(User.id == 7).one().id, 7)
 
-        self.assertRaises(sa.orm.exc.MultipleResultsFound,
+        assert_raises(sa.orm.exc.MultipleResultsFound,
                           sess.query(User).one)
 
-        self.assertRaises(
+        assert_raises(
             sa.orm.exc.NoResultFound,
             sess.query(User.id, User.name).filter(User.id == 99).one)
 
         eq_(sess.query(User.id, User.name).filter(User.id == 7).one(),
             (7, 'jack'))
 
-        self.assertRaises(sa.orm.exc.MultipleResultsFound,
+        assert_raises(sa.orm.exc.MultipleResultsFound,
                           sess.query(User.id, User.name).one)
 
-        self.assertRaises(sa.orm.exc.NoResultFound,
+        assert_raises(sa.orm.exc.NoResultFound,
                           (sess.query(User, Address).
                            join(User.addresses).
                            filter(Address.id == 99)).one)
@@ -2108,7 +2113,7 @@ class ImmediateTest(_fixtures.FixtureTest):
              filter(Address.id == 4)).one(),
             (User(id=8), Address(id=4)))
 
-        self.assertRaises(sa.orm.exc.MultipleResultsFound,
+        assert_raises(sa.orm.exc.MultipleResultsFound,
                           sess.query(User, Address).join(User.addresses).one)
 
     @testing.future
@@ -2133,7 +2138,7 @@ class ImmediateTest(_fixtures.FixtureTest):
         eq_(sess.query(User.id, User.name).filter_by(id=7).value(User.id), 7)
         eq_(sess.query(User).filter_by(id=0).value(User.id), None)
 
-        sess.bind = sa.testing.db
+        sess.bind = testing.db
         eq_(sess.query().value(sa.literal_column('1').label('x')), 1)
 
 
@@ -2149,19 +2154,19 @@ class SelectFromTest(QueryTest):
         sel = users.select(users.c.id.in_([7, 8])).alias()
         sess = create_session()
 
-        self.assertEquals(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)])
+        eq_(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)])
 
-        self.assertEquals(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)])
+        eq_(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)])
 
-        self.assertEquals(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [
+        eq_(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [
             User(name='jack',id=7), User(name='ed',id=8)
         ])
 
-        self.assertEquals(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [
+        eq_(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [
             User(name='ed',id=8), User(name='jack',id=7)
         ])
 
-        self.assertEquals(sess.query(User).select_from(sel).options(eagerload('addresses')).first(),
+        eq_(sess.query(User).select_from(sel).options(eagerload('addresses')).first(),
             User(name='jack', addresses=[Address(id=1)])
         )
 
@@ -2173,7 +2178,7 @@ class SelectFromTest(QueryTest):
         sel = users.select(users.c.id.in_([7, 8]))
         sess = create_session()
 
-        self.assertEquals(sess.query(User).select_from(sel).all(),
+        eq_(sess.query(User).select_from(sel).all(),
             [
                 User(name='jack',id=7), User(name='ed',id=8)
             ]
@@ -2185,7 +2190,7 @@ class SelectFromTest(QueryTest):
         sel = users.select(users.c.id.in_([7, 8]))
         sess = create_session()
 
-        self.assertEquals(sess.query(User).select_from(sel).all(),
+        eq_(sess.query(User).select_from(sel).all(),
             [
                 User(name='jack',id=7), User(name='ed',id=8)
             ]
@@ -2200,7 +2205,7 @@ class SelectFromTest(QueryTest):
         sel = users.select(users.c.id.in_([7, 8]))
         sess = create_session()
 
-        self.assertEquals(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(),
+        eq_(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(),
             [
                 (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)),
                 (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)),
@@ -2210,7 +2215,7 @@ class SelectFromTest(QueryTest):
         )
 
         adalias = aliased(Address)
-        self.assertEquals(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(),
+        eq_(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(),
             [
                 (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)),
                 (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)),
@@ -2238,16 +2243,16 @@ class SelectFromTest(QueryTest):
         # TODO: remove
         sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all()
 
-        self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+        eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
             User(name=u'jack',id=7)
         ])
 
-        self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+        eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
             User(name=u'jack',id=7)
         ])
 
         def go():
-            self.assertEquals(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+            eq_(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
                 User(name=u'jack',orders=[
                     Order(description=u'order 1',items=[
                         Item(description=u'item 1',keywords=[Keyword(name=u'red'), Keyword(name=u'big'), Keyword(name=u'round')]),
@@ -2265,11 +2270,11 @@ class SelectFromTest(QueryTest):
 
         sess.expunge_all()
         sel2 = orders.select(orders.c.id.in_([1,2,3]))
-        self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').order_by(Order.id).all(), [
+        eq_(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').order_by(Order.id).all(), [
             Order(description=u'order 1',id=1),
             Order(description=u'order 2',id=2),
         ])
-        self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [
+        eq_(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [
             Order(description=u'order 1',id=1),
             Order(description=u'order 2',id=2),
         ])
@@ -2285,7 +2290,7 @@ class SelectFromTest(QueryTest):
         sess = create_session()
 
         def go():
-            self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id).all(),
+            eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id).all(),
                 [
                     User(id=7, addresses=[Address(id=1)]),
                     User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])
@@ -2295,14 +2300,14 @@ class SelectFromTest(QueryTest):
         sess.expunge_all()
 
         def go():
-            self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).order_by(User.id).all(),
+            eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).order_by(User.id).all(),
                 [User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])]
             )
         self.assert_sql_count(testing.db, go, 1)
         sess.expunge_all()
 
         def go():
-            self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]))
+            eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]))
         self.assert_sql_count(testing.db, go, 1)
 
 class CustomJoinTest(QueryTest):
@@ -2329,14 +2334,16 @@ class SelfReferentialTest(_base.MappedTest):
     run_inserts = 'once'
     run_deletes = None
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global nodes
         nodes = Table('nodes', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('parent_id', Integer, ForeignKey('nodes.id')),
             Column('data', String(30)))
 
-    def insert_data(self):
+    @classmethod
+    def insert_data(cls):
         global Node
     
         class Node(Base):
@@ -2399,7 +2406,7 @@ class SelfReferentialTest(_base.MappedTest):
             filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).first()
         assert node.data == 'n122'
 
-        self.assertEquals(
+        eq_(
             list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\
             filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)),
             [('n122', 'n12', 'n1')])
@@ -2410,13 +2417,13 @@ class SelfReferentialTest(_base.MappedTest):
         n1 = aliased(Node)
 
         # using 'n1.parent' implicitly joins to unaliased Node
-        self.assertEquals(
+        eq_(
             sess.query(n1).join(n1.parent).filter(Node.data=='n1').all(),
             [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)]
         )
     
         # explicit (new syntax)
-        self.assertEquals(
+        eq_(
             sess.query(n1).join((Node, n1.parent)).filter(Node.data=='n1').all(),
             [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)]
         )
@@ -2426,7 +2433,7 @@ class SelfReferentialTest(_base.MappedTest):
     
         parent = aliased(Node)
         grandparent = aliased(Node)
-        self.assertEquals(
+        eq_(
             sess.query(Node, parent, grandparent).\
                 join((Node.parent, parent), (parent.parent, grandparent)).\
                     filter(Node.data=='n122').filter(parent.data=='n12').\
@@ -2434,7 +2441,7 @@ class SelfReferentialTest(_base.MappedTest):
             (Node(data='n122'), Node(data='n12'), Node(data='n1'))
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(Node, parent, grandparent).\
                 join((Node.parent, parent), (parent.parent, grandparent)).\
                     filter(Node.data=='n122').filter(parent.data=='n12').\
@@ -2443,7 +2450,7 @@ class SelfReferentialTest(_base.MappedTest):
         )
 
         # same, change order around
-        self.assertEquals(
+        eq_(
             sess.query(parent, grandparent, Node).\
                 join((Node.parent, parent), (parent.parent, grandparent)).\
                     filter(Node.data=='n122').filter(parent.data=='n12').\
@@ -2451,7 +2458,7 @@ class SelfReferentialTest(_base.MappedTest):
             (Node(data='n12'), Node(data='n1'), Node(data='n122'))
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(Node, parent, grandparent).\
                 join((Node.parent, parent), (parent.parent, grandparent)).\
                     filter(Node.data=='n122').filter(parent.data=='n12').\
@@ -2460,7 +2467,7 @@ class SelfReferentialTest(_base.MappedTest):
             (Node(data='n122'), Node(data='n12'), Node(data='n1'))
         )
 
-        self.assertEquals(
+        eq_(
             sess.query(Node, parent, grandparent).\
                 join((Node.parent, parent), (parent.parent, grandparent)).\
                     filter(Node.data=='n122').filter(parent.data=='n12').\
@@ -2472,40 +2479,41 @@ class SelfReferentialTest(_base.MappedTest):
     
     def test_any(self):
         sess = create_session()
-        self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
-        self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
-        self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
+        eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
+        eq_(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
+        eq_(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
 
     def test_has(self):
         sess = create_session()
     
-        self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
-        self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
-        self.assertEquals(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])
+        eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+        eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
+        eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])
 
     def test_contains(self):
         sess = create_session()
     
         n122 = sess.query(Node).filter(Node.data=='n122').one()
-        self.assertEquals(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')])
+        eq_(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')])
 
         n13 = sess.query(Node).filter(Node.data=='n13').one()
-        self.assertEquals(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')])
+        eq_(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')])
 
     def test_eq_ne(self):
         sess = create_session()
     
         n12 = sess.query(Node).filter(Node.data=='n12').one()
-        self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+        eq_(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
     
-        self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')])
+        eq_(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')])
 
 class SelfReferentialM2MTest(_base.MappedTest):
     run_setup_mappers = 'once'
     run_inserts = 'once'
     run_deletes = None
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global nodes, node_to_nodes
         nodes = Table('nodes', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
@@ -2516,7 +2524,8 @@ class SelfReferentialM2MTest(_base.MappedTest):
             Column('right_node_id', Integer, ForeignKey('nodes.id'),primary_key=True),
             )
 
-    def insert_data(self):
+    @classmethod
+    def insert_data(cls):
         global Node
     
         class Node(Base):
@@ -2550,20 +2559,20 @@ class SelfReferentialM2MTest(_base.MappedTest):
 
     def test_any(self):
         sess = create_session()
-        self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')])
+        eq_(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')])
 
     def test_contains(self):
         sess = create_session()
         n4 = sess.query(Node).filter_by(data='n4').one()
 
-        self.assertEquals(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')])
-        self.assertEquals(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')])
+        eq_(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')])
+        eq_(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')])
 
     def test_explicit_join(self):
         sess = create_session()
     
         n1 = aliased(Node)
-        self.assertEquals(
+        eq_(
             sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data.in_(['n3', 'n7'])).order_by(Node.id).all(),
             [Node(data='n1'), Node(data='n2')]
         )
@@ -2575,7 +2584,7 @@ class ExternalColumnsTest(QueryTest):
 
     def test_external_columns_bad(self):
 
-        self.assertRaisesMessage(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={
+        assert_raises_message(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={
             'concat': (users.c.id * 2),
         })
         clear_mappers()
@@ -2596,7 +2605,7 @@ class ExternalColumnsTest(QueryTest):
     
         sess.query(Address).options(eagerload('user')).all()
 
-        self.assertEquals(sess.query(User).all(), 
+        eq_(sess.query(User).all(), 
             [
                 User(id=7, concat=14, count=1),
                 User(id=8, concat=16, count=3),
@@ -2612,22 +2621,22 @@ class ExternalColumnsTest(QueryTest):
             Address(id=4, user=User(id=8, concat=16, count=3)),
             Address(id=5, user=User(id=9, concat=18, count=1))
         ]
-        self.assertEquals(sess.query(Address).all(), address_result)
+        eq_(sess.query(Address).all(), address_result)
 
         # run the eager version twice to test caching of aliased clauses
         for x in range(2):
             sess.expunge_all()
             def go():
-               self.assertEquals(sess.query(Address).options(eagerload('user')).all(), address_result)
+               eq_(sess.query(Address).options(eagerload('user')).all(), address_result)
             self.assert_sql_count(testing.db, go, 1)
     
         ualias = aliased(User)
-        self.assertEquals(
+        eq_(
             sess.query(Address, ualias).join(('user', ualias)).all(), 
             [(address, address.user) for address in address_result]
         )
 
-        self.assertEquals(
+        eq_(
                 sess.query(Address, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(),
                 [
                     (Address(id=1), 1),
@@ -2638,7 +2647,7 @@ class ExternalColumnsTest(QueryTest):
                 ]
             )
 
-        self.assertEquals(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(),
+        eq_(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(),
             [
                 (Address(id=1), 14, 1),
                 (Address(id=2), 16, 3),
@@ -2649,7 +2658,7 @@ class ExternalColumnsTest(QueryTest):
         )
 
         ua = aliased(User)
-        self.assertEquals(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(),
+        eq_(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(),
             [
                 (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1),
                 (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3),
@@ -2659,11 +2668,11 @@ class ExternalColumnsTest(QueryTest):
             ]
         )
 
-        self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), 
+        eq_(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), 
             [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
         )
 
-        self.assertEquals(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)), 
+        eq_(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)), 
             [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
         )
 
@@ -2686,17 +2695,18 @@ class ExternalColumnsTest(QueryTest):
         sess = create_session()
         def go():
             o1 = sess.query(Order).options(eagerload_all('address.user')).get(1)
-            self.assertEquals(o1.address.user.count, 1)
+            eq_(o1.address.user.count, 1)
         self.assert_sql_count(testing.db, go, 1)
 
         sess = create_session()
         def go():
             o1 = sess.query(Order).options(eagerload_all('address.user')).first()
-            self.assertEquals(o1.address.user.count, 1)
+            eq_(o1.address.user.count, 1)
         self.assert_sql_count(testing.db, go, 1)
 
 class TestOverlyEagerEquivalentCols(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global base, sub1, sub2
         base = Table('base', metadata, 
             Column('id', Integer, primary_key=True),
@@ -2747,14 +2757,15 @@ class TestOverlyEagerEquivalentCols(_base.MappedTest):
         q = sess.query(Base).outerjoin('sub2', aliased=True)
         assert sub1.c.id not in q._filter_aliases.equivalents
 
-        self.assertEquals(
+        eq_(
             sess.query(Base).join('sub1').outerjoin('sub2', aliased=True).\
                 filter(Sub1.id==1).one(),
                 b1
         )
     
 class UpdateDeleteTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('users', metadata,
               Column('id', Integer, primary_key=True),
               Column('name', String(32)),
@@ -2765,15 +2776,17 @@ class UpdateDeleteTest(_base.MappedTest):
               Column('user_id', None, ForeignKey('users.id')),
               Column('title', String(32)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class User(_base.ComparableEntity):
             pass
 
         class Document(_base.ComparableEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         users.insert().execute([
             dict(id=1, name='john', age=25),
             dict(id=2, name='jack', age=47),
@@ -2789,8 +2802,9 @@ class UpdateDeleteTest(_base.MappedTest):
             dict(id=3, user_id=2, title='baz'),
         ])
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(User, users)
         mapper(Document, documents, properties={
             'user': relation(User, lazy=False, backref=backref('documents', lazy=True))
@@ -2964,17 +2978,17 @@ class UpdateDeleteTest(_base.MappedTest):
         sess = create_session(bind=testing.db, autocommit=False)
 
         rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age + 0})
-        self.assertEquals(rowcount, 2)
+        eq_(rowcount, 2)
 
         rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10})
-        self.assertEquals(rowcount, 2)
+        eq_(rowcount, 2)
 
     @testing.resolve_artifact_names
     def test_delete_returns_rowcount(self):
         sess = create_session(bind=testing.db, autocommit=False)
 
         rowcount = sess.query(User).filter(User.age > 26).delete(synchronize_session=False)
-        self.assertEquals(rowcount, 3)
+        eq_(rowcount, 3)
 
     @testing.resolve_artifact_names
     def test_update_with_eager_relations(self):
@@ -3008,5 +3022,3 @@ class UpdateDeleteTest(_base.MappedTest):
 
         eq_(sess.query(Document.title).all(), zip(['baz']))
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 94%
rename from test/orm/relationships.py
rename to test/orm/test_relationships.py
index a0a8900b2c07a2d25072bb278555eba00e191f16..1bc074c3145d217281edadd6c20bc6c0ff7ecd74 100644 (file)
@@ -1,10 +1,13 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
 import datetime
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, and_
-from testlib.sa.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers
-from testlib.testing import eq_, startswith_
-from orm import _base, _fixtures
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey, MetaData, and_
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers
+from sqlalchemy.test.testing import eq_, startswith_
+from test.orm import _base, _fixtures
 
 
 class RelationTest(_base.MappedTest):
@@ -26,7 +29,8 @@ class RelationTest(_base.MappedTest):
     run_inserts = 'once'
     run_deletes = None
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("tbl_a", metadata,
             Column("id", Integer, primary_key=True),
             Column("name", String(128)))
@@ -43,7 +47,8 @@ class RelationTest(_base.MappedTest):
             Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")),
             Column("name", String(128)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class A(_base.Entity):
             pass
         class B(_base.Entity):
@@ -53,8 +58,9 @@ class RelationTest(_base.MappedTest):
         class D(_base.Entity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(A, tbl_a, properties=dict(
             c_rows=relation(C, cascade="all, delete-orphan", backref="a_row")))
         mapper(B, tbl_b)
@@ -63,8 +69,9 @@ class RelationTest(_base.MappedTest):
         mapper(D, tbl_d, properties=dict(
             b_row=relation(B)))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         session = create_session()
         a = A(name='a1')
         b = B(name='b1')
@@ -102,7 +109,8 @@ class RelationTest2(_base.MappedTest):
     key where one column in the foreign key is 'joined to itself'.
 
     """
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('company_t', metadata,
               Column('company_id', Integer, primary_key=True),
               Column('name', sa.Unicode(30)))
@@ -218,7 +226,8 @@ class RelationTest2(_base.MappedTest):
         assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'
 
 class RelationTest3(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("jobs", metadata,
               Column("jobno", sa.Unicode(15), primary_key=True),
               Column("created", sa.DateTime, nullable=False,
@@ -257,8 +266,9 @@ class RelationTest3(_base.MappedTest):
                   ["jobno", "pagename"],
                   ["pages.jobno", "pages.pagename"]))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         class Job(_base.Entity):
             def create_page(self, pagename):
                 return Page(job=self, pagename=pagename)
@@ -360,7 +370,8 @@ class RelationTest3(_base.MappedTest):
 class RelationTest4(_base.MappedTest):
     """Syncrules on foreign keys that are also primary"""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("tableA", metadata,
               Column("id",Integer,primary_key=True),
               Column("foo",Integer,),
@@ -369,7 +380,8 @@ class RelationTest4(_base.MappedTest):
               Column("id",Integer,ForeignKey("tableA.id"),primary_key=True),
               test_needs_fk=True)
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class A(_base.Entity):
             pass
 
@@ -537,7 +549,8 @@ class RelationTest4(_base.MappedTest):
 class RelationTest5(_base.MappedTest):
     """Test a map to a select that relates to a map to the table."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('items', metadata,
               Column('item_policy_num', String(10), primary_key=True,
                      key='policyNum'),
@@ -605,7 +618,8 @@ class RelationTest6(_base.MappedTest):
     
     """
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('tags', metadata, Column("id", Integer, primary_key=True),
             Column("data", String(50)),
         )
@@ -659,7 +673,8 @@ class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest):
     
     """
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         subscriber_table = Table('subscriber', metadata,
            Column('id', Integer, primary_key=True),
            Column('dummy', String(10)) # to appease older sqlite version
@@ -671,8 +686,9 @@ class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest):
                  Column('type', String(1), primary_key=True),
                  )
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         subscriber_and_address = subscriber.join(address, 
                and_(address.c.subscriber_id==subscriber.c.id, address.c.type.in_(['A', 'B', 'C'])))
 
@@ -762,7 +778,7 @@ class ManualBackrefTest(_fixtures.FixtureTest):
             'user':relation(User, back_populates='addresses')
         })
         
-        self.assertRaises(sa.exc.InvalidRequestError, compile_mappers)
+        assert_raises(sa.exc.InvalidRequestError, compile_mappers)
         
     @testing.resolve_artifact_names
     def test_invalid_target(self):
@@ -775,7 +791,7 @@ class ManualBackrefTest(_fixtures.FixtureTest):
             'dingaling':relation(Dingaling)
         })
         
-        self.assertRaisesMessage(sa.exc.ArgumentError, 
+        assert_raises_message(sa.exc.ArgumentError, 
             r"reverse_property 'dingaling' on relation User.addresses references "
             "relation Address.dingaling, which does not reference mapper Mapper\|User\|users", 
             compile_mappers)
@@ -794,7 +810,7 @@ class JoinConditionErrorTest(testing.TestBase):
             c1id = Column('c1id', Integer, ForeignKey('c1.id'))
             c2 = relation(C1, primaryjoin=C1.id)
         
-        self.assertRaises(sa.exc.ArgumentError, compile_mappers)
+        assert_raises(sa.exc.ArgumentError, compile_mappers)
 
     def test_clauseelement_pj_false(self):
         from sqlalchemy.ext.declarative import declarative_base
@@ -808,7 +824,7 @@ class JoinConditionErrorTest(testing.TestBase):
             c1id = Column('c1id', Integer, ForeignKey('c1.id'))
             c2 = relation(C1, primaryjoin="x"=="y")
 
-        self.assertRaises(sa.exc.ArgumentError, compile_mappers)
+        assert_raises(sa.exc.ArgumentError, compile_mappers)
         
     
     def test_fk_error_raised(self):
@@ -834,7 +850,7 @@ class JoinConditionErrorTest(testing.TestBase):
         mapper(C1, t1, properties={'c2':relation(C2)})
         mapper(C2, t3)
         
-        self.assertRaises(sa.exc.NoReferencedColumnError, compile_mappers)
+        assert_raises(sa.exc.NoReferencedColumnError, compile_mappers)
     
     def test_join_error_raised(self):
         m = MetaData()
@@ -858,15 +874,16 @@ class JoinConditionErrorTest(testing.TestBase):
         mapper(C1, t1, properties={'c2':relation(C2)})
         mapper(C2, t3)
 
-        self.assertRaises(sa.exc.ArgumentError, compile_mappers)
+        assert_raises(sa.exc.ArgumentError, compile_mappers)
     
-    def tearDown(self):
+    def teardown(self):
         clear_mappers()    
         
 class TypeMatchTest(_base.MappedTest):
     """test errors raised when trying to add items whose type is not handled by a relation"""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("a", metadata,
               Column('aid', Integer, primary_key=True),
               Column('data', String(30)))
@@ -924,7 +941,7 @@ class TypeMatchTest(_base.MappedTest):
         sess.add(a1)
         sess.add(b1)
         sess.add(c1)
-        self.assertRaisesMessage(sa.orm.exc.FlushError,
+        assert_raises_message(sa.orm.exc.FlushError,
                                  "Attempting to flush an item", sess.flush)
 
     @testing.resolve_artifact_names
@@ -945,7 +962,7 @@ class TypeMatchTest(_base.MappedTest):
         sess.add(a1)
         sess.add(b1)
         sess.add(c1)
-        self.assertRaisesMessage(sa.orm.exc.FlushError,
+        assert_raises_message(sa.orm.exc.FlushError,
                                  "Attempting to flush an item", sess.flush)
 
     @testing.resolve_artifact_names
@@ -962,7 +979,7 @@ class TypeMatchTest(_base.MappedTest):
         sess = create_session()
         sess.add(b1)
         sess.add(d1)
-        self.assertRaisesMessage(sa.orm.exc.FlushError,
+        assert_raises_message(sa.orm.exc.FlushError,
                                  "Attempting to flush an item", sess.flush)
 
     @testing.resolve_artifact_names
@@ -977,12 +994,13 @@ class TypeMatchTest(_base.MappedTest):
         d1 = D()
         d1.a = b1
         sess = create_session()
-        self.assertRaisesMessage(AssertionError,
+        assert_raises_message(AssertionError,
                                  "doesn't handle objects of type", sess.add, d1)
 
 class TypedAssociationTable(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         class MySpecialType(sa.types.TypeDecorator):
             impl = String
             def process_bind_param(self, value, dialect):
@@ -1033,7 +1051,8 @@ class TypedAssociationTable(_base.MappedTest):
 class ViewOnlyOverlappingNames(_base.MappedTest):
     """'viewonly' mappings with overlapping PK column names."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("t1", metadata,
             Column('id', Integer, primary_key=True),
             Column('data', String(40)))
@@ -1092,7 +1111,8 @@ class ViewOnlyOverlappingNames(_base.MappedTest):
 class ViewOnlyUniqueNames(_base.MappedTest):
     """'viewonly' mappings with unique PK column names."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("t1", metadata,
             Column('t1id', Integer, primary_key=True),
             Column('data', String(40)))
@@ -1182,7 +1202,8 @@ class ViewOnlyLocalRemoteM2M(testing.TestBase):
 class ViewOnlyNonEquijoin(_base.MappedTest):
     """'viewonly' mappings based on non-equijoins."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('foos', metadata,
                      Column('id', Integer, primary_key=True))
         Table('bars', metadata,
@@ -1223,7 +1244,8 @@ class ViewOnlyNonEquijoin(_base.MappedTest):
 class ViewOnlyRepeatedRemoteColumn(_base.MappedTest):
     """'viewonly' mappings that contain the same 'remote' column twice"""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('foos', metadata,
               Column('id', Integer, primary_key=True),
               Column('bid1', Integer,ForeignKey('bars.id')),
@@ -1270,7 +1292,8 @@ class ViewOnlyRepeatedRemoteColumn(_base.MappedTest):
 class ViewOnlyRepeatedLocalColumn(_base.MappedTest):
     """'viewonly' mappings that contain the same 'local' column twice"""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('foos', metadata,
               Column('id', Integer, primary_key=True),
               Column('data', String(50)))
@@ -1317,7 +1340,8 @@ class ViewOnlyRepeatedLocalColumn(_base.MappedTest):
 class ViewOnlyComplexJoin(_base.MappedTest):
     """'viewonly' mappings with a complex join condition."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1', metadata,
             Column('id', Integer, primary_key=True),
             Column('data', String(50)))
@@ -1332,7 +1356,8 @@ class ViewOnlyComplexJoin(_base.MappedTest):
             Column('t2id', Integer, ForeignKey('t2.id')),
             Column('t3id', Integer, ForeignKey('t3.id')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class T1(_base.ComparableEntity):
             pass
         class T2(_base.ComparableEntity):
@@ -1379,14 +1404,15 @@ class ViewOnlyComplexJoin(_base.MappedTest):
             't1':relation(T1),
             't3s':relation(T3, secondary=t2tot3)})
         mapper(T3, t3)
-        self.assertRaisesMessage(sa.exc.ArgumentError,
+        assert_raises_message(sa.exc.ArgumentError,
                                  "Specify remote_side argument",
                                  sa.orm.compile_mappers)
 
 
 class ExplicitLocalRemoteTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1', metadata,
             Column('id', String(50), primary_key=True),
             Column('data', String(50)))
@@ -1395,8 +1421,9 @@ class ExplicitLocalRemoteTest(_base.MappedTest):
             Column('data', String(50)),
             Column('t1id', String(50)))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_classes(self):
+    def setup_classes(cls):
         class T1(_base.ComparableEntity):
             pass
         class T2(_base.ComparableEntity):
@@ -1508,7 +1535,7 @@ class ExplicitLocalRemoteTest(_base.MappedTest):
                            foreign_keys=[t2.c.t1id],
                            remote_side=[t2.c.t1id])})
         mapper(T2, t2)
-        self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers)
+        assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers)
 
     @testing.resolve_artifact_names
     def test_escalation_2(self):
@@ -1517,18 +1544,20 @@ class ExplicitLocalRemoteTest(_base.MappedTest):
                            primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id),
                            _local_remote_pairs=[(t1.c.id, t2.c.t1id)])})
         mapper(T2, t2)
-        self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers)
+        assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers)
 
 class InvalidRemoteSideTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1', metadata,
             Column('id', Integer, primary_key=True),
             Column('data', String(50)),
             Column('t_id', Integer, ForeignKey('t1.id'))
             )
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_classes(self):
+    def setup_classes(cls):
         class T1(_base.ComparableEntity):
             pass
 
@@ -1538,7 +1567,7 @@ class InvalidRemoteSideTest(_base.MappedTest):
             't1s':relation(T1, backref='parent')
         })
 
-        self.assertRaisesMessage(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are "
+        assert_raises_message(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are "
                     "both of the same direction <symbol 'ONETOMANY>.  Did you "
                     "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers)
 
@@ -1548,7 +1577,7 @@ class InvalidRemoteSideTest(_base.MappedTest):
             't1s':relation(T1, backref=backref('parent', remote_side=t1.c.id), remote_side=t1.c.id)
         })
 
-        self.assertRaisesMessage(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are "
+        assert_raises_message(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are "
                     "both of the same direction <symbol 'MANYTOONE>.  Did you "
                     "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers)
 
@@ -1560,7 +1589,7 @@ class InvalidRemoteSideTest(_base.MappedTest):
         })
 
         # can't be sure of ordering here
-        self.assertRaisesMessage(sa.exc.ArgumentError, 
+        assert_raises_message(sa.exc.ArgumentError, 
                     "both of the same direction <symbol 'ONETOMANY>.  Did you "
                     "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers)
 
@@ -1572,14 +1601,15 @@ class InvalidRemoteSideTest(_base.MappedTest):
         })
 
         # can't be sure of ordering here
-        self.assertRaisesMessage(sa.exc.ArgumentError, 
+        assert_raises_message(sa.exc.ArgumentError, 
                     "both of the same direction <symbol 'MANYTOONE>.  Did you "
                     "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers)
 
         
 class InvalidRelationEscalationTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('foos', metadata,
               Column('id', Integer, primary_key=True),
               Column('fid', Integer))
@@ -1587,7 +1617,8 @@ class InvalidRelationEscalationTest(_base.MappedTest):
               Column('id', Integer, primary_key=True),
               Column('fid', Integer))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Foo(_base.Entity):
             pass
         class Bar(_base.Entity):
@@ -1599,7 +1630,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
             'bars':relation(Bar)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine join condition between parent/child "
             "tables on relation", sa.orm.compile_mappers)
@@ -1610,7 +1641,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
             'foos':relation(Foo)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine join condition between parent/child "
             "tables on relation", sa.orm.compile_mappers)
@@ -1622,7 +1653,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             primaryjoin=foos.c.id>bars.c.fid)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine relation direction for primaryjoin condition",
             sa.orm.compile_mappers)
@@ -1635,7 +1666,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             foreign_keys=bars.c.fid)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not locate any equated, locally mapped column pairs "
             "for primaryjoin condition", sa.orm.compile_mappers)
@@ -1648,7 +1679,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             foreign_keys=[foos.c.id, bars.c.fid])})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError, 
                 "Do the columns in 'foreign_keys' represent only the "
                 "'foreign' columns in this join condition ?", 
@@ -1665,7 +1696,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             )})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError, 
                 "could not determine any local/remote column pairs",
                 sa.orm.compile_mappers)
@@ -1681,7 +1712,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             )})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError, 
                 "could not determine any local/remote column pairs",
                 sa.orm.compile_mappers)
@@ -1694,7 +1725,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             primaryjoin=foos.c.id>foos.c.fid)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine relation direction for primaryjoin condition",
             sa.orm.compile_mappers)
@@ -1707,7 +1738,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             foreign_keys=[foos.c.fid])})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not locate any equated, locally mapped column pairs "
             "for primaryjoin condition", sa.orm.compile_mappers)
@@ -1720,7 +1751,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             viewonly=True)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine relation direction for primaryjoin condition",
             sa.orm.compile_mappers)
@@ -1733,7 +1764,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             viewonly=True)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Specify the 'foreign_keys' argument to indicate which columns "
             "on the relation are foreign.", sa.orm.compile_mappers)
@@ -1756,7 +1787,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             primaryjoin=foos.c.id==bars.c.fid)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine relation direction for primaryjoin condition",
             sa.orm.compile_mappers)
@@ -1767,7 +1798,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
             'foos':relation(Foo,
                             primaryjoin=foos.c.id==foos.c.fid)})
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine relation direction for primaryjoin condition",
             sa.orm.compile_mappers)
@@ -1779,7 +1810,7 @@ class InvalidRelationEscalationTest(_base.MappedTest):
                             primaryjoin=foos.c.id==foos.c.fid,
                             foreign_keys=[bars.c.id])})
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine relation direction for primaryjoin condition",
             sa.orm.compile_mappers)
@@ -1787,7 +1818,8 @@ class InvalidRelationEscalationTest(_base.MappedTest):
 
 class InvalidRelationEscalationTestM2M(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('foos', metadata,
               Column('id', Integer, primary_key=True))
         Table('foobars', metadata,
@@ -1795,8 +1827,9 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest):
         Table('bars', metadata,
               Column('id', Integer, primary_key=True))
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_classes(self):
+    def setup_classes(cls):
         class Foo(_base.Entity):
             pass
         class Bar(_base.Entity):
@@ -1808,7 +1841,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest):
             'bars': relation(Bar, secondary=foobars)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine join condition between parent/child tables "
             "on relation", sa.orm.compile_mappers)
@@ -1821,7 +1854,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest):
                             primaryjoin=foos.c.id > foobars.c.fid)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine join condition between parent/child tables "
             "on relation",
@@ -1836,7 +1869,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest):
                              secondaryjoin=foobars.c.bid<=bars.c.id)})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine relation direction for primaryjoin condition",
             sa.orm.compile_mappers)
@@ -1851,7 +1884,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest):
                             foreign_keys=[foobars.c.fid])})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not determine relation direction for secondaryjoin "
             "condition", sa.orm.compile_mappers)
@@ -1866,11 +1899,9 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest):
                             foreign_keys=[foobars.c.fid, foobars.c.bid])})
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(
+        assert_raises_message(
             sa.exc.ArgumentError,
             "Could not locate any equated, locally mapped column pairs for "
             "secondaryjoin condition", sa.orm.compile_mappers)
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 83%
rename from test/orm/scoping.py
rename to test/orm/test_scoping.py
index bdfc5a9d58bba3c1b55824f06ec6dbc2106810e8..2117e8dccbf47cc2935d65598601b96c544c7745 100644 (file)
@@ -1,10 +1,13 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing
 from sqlalchemy.orm import scoped_session
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, query
-from testlib.testing import eq_
-from orm import _base
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, query
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
 
 
 class _ScopedTest(_base.MappedTest):
@@ -15,18 +18,21 @@ class _ScopedTest(_base.MappedTest):
     _artifact_registries = (
         _base.MappedTest._artifact_registries + ('scoping',))
 
-    def setUpAll(self):
-        type(self).scoping = _base.adict()
-        _base.MappedTest.setUpAll(self)
+    @classmethod
+    def setup_class(cls):
+        cls.scoping = _base.adict()
+        super(_ScopedTest, cls).setup_class()
 
-    def tearDownAll(self):
-        self.scoping.clear()
-        _base.MappedTest.tearDownAll(self)
+    @classmethod
+    def teardown_class(cls):
+        cls.scoping.clear()
+        super(_ScopedTest, cls).teardown_class()
 
 
 class ScopedSessionTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('table1', metadata,
               Column('id', Integer, primary_key=True),
               Column('data', String(30)))
@@ -73,7 +79,8 @@ class ScopedSessionTest(_base.MappedTest):
 
 class ScopedMapperTest(_ScopedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('table1', metadata,
             Column('id', Integer, primary_key=True),
             Column('data', String(30)))
@@ -81,24 +88,27 @@ class ScopedMapperTest(_ScopedTest):
             Column('id', Integer, primary_key=True),
             Column('someid', None, ForeignKey('table1.id')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class SomeObject(_base.ComparableEntity):
             pass
         class SomeOtherObject(_base.ComparableEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         Session = scoped_session(sa.orm.create_session)
         Session.mapper(SomeObject, table1, properties={
             'options':relation(SomeOtherObject)
         })
         Session.mapper(SomeOtherObject, table2)
 
-        self.scoping['Session'] = Session
+        cls.scoping['Session'] = Session
 
+    @classmethod
     @testing.resolve_artifact_names
-    def insert_data(self):
+    def insert_data(cls):
         s = SomeObject()
         s.id = 1
         s.data = 'hello'
@@ -145,7 +155,7 @@ class ScopedMapperTest(_ScopedTest):
         scope.mapper(B, table2)
 
         A(foo='bar')
-        self.assertRaises(TypeError, B, foo='bar')
+        assert_raises(TypeError, B, foo='bar')
 
         scope = scoped_session(sa.orm.sessionmaker())
 
@@ -158,7 +168,7 @@ class ScopedMapperTest(_ScopedTest):
         scope.mapper(C, table1)
         scope.mapper(D, table2)
 
-        self.assertRaises(TypeError, C, foo='bar')
+        assert_raises(TypeError, C, foo='bar')
         D(foo='bar')
 
     @testing.resolve_artifact_names
@@ -170,7 +180,7 @@ class ScopedMapperTest(_ScopedTest):
         Session.mapper(ValidatedOtherObject, table2, validate=True)
 
         v1 = ValidatedOtherObject(someid=12)
-        self.assertRaises(sa.exc.ArgumentError, ValidatedOtherObject,
+        assert_raises(sa.exc.ArgumentError, ValidatedOtherObject,
                           someid=12, bogus=345)
 
     @testing.resolve_artifact_names
@@ -186,7 +196,8 @@ class ScopedMapperTest(_ScopedTest):
 
 class ScopedMapperTest2(_ScopedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('table1', metadata,
             Column('id', Integer, primary_key=True),
             Column('data', String(30)),
@@ -196,14 +207,16 @@ class ScopedMapperTest2(_ScopedTest):
             Column('someid', None, ForeignKey('table1.id')),
             Column('somedata', String(30)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class BaseClass(_base.ComparableEntity):
             pass
         class SubClass(BaseClass):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         Session = scoped_session(sa.orm.sessionmaker())
 
         Session.mapper(BaseClass, table1,
@@ -213,7 +226,7 @@ class ScopedMapperTest2(_ScopedTest):
                        polymorphic_identity='sub',
                        inherits=BaseClass)
 
-        self.scoping['Session'] = Session
+        cls.scoping['Session'] = Session
 
     @testing.resolve_artifact_names
     def test_inheritance(self):
@@ -234,5 +247,3 @@ class ScopedMapperTest2(_ScopedTest):
             SubClass.query.all())
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 71%
rename from test/orm/selectable.py
rename to test/orm/test_selectable.py
index 74c41c85230d8e4da321bd818636f38dccf54a21..0a20253607772a0b9348ac693e6b0b1696e122b6 100644 (file)
@@ -1,22 +1,27 @@
 """Generic mapping to Select statements"""
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, String, Integer, select
-from testlib.sa.orm import mapper, create_session
-from testlib.testing import eq_
-from orm import _base
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import String, Integer, select
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, create_session
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
 
 
 # TODO: more tests mapping to selects
 
 class SelectableNoFromsTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('common', metadata,
               Column('id', Integer, primary_key=True),
               Column('data', Integer),
               Column('extra', String(45)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Subset(_base.ComparableEntity):
             pass
 
@@ -24,7 +29,7 @@ class SelectableNoFromsTest(_base.MappedTest):
     def test_no_tables(self):
 
         selectable = select(["x", "y", "z"])
-        self.assertRaisesMessage(sa.exc.InvalidRequestError,
+        assert_raises_message(sa.exc.InvalidRequestError,
                                  "Could not find any Table objects",
                                  mapper, Subset, selectable)
 
@@ -48,5 +53,3 @@ class SelectableNoFromsTest(_base.MappedTest):
             Subset(data=1))
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 93%
rename from test/orm/session.py
rename to test/orm/test_session.py
index 6cbd62a50e152b24120cc7a088a0b5eac26175f6..3020d66e9dcfc45caed78d60a6247c1e703c0b87 100644 (file)
@@ -1,14 +1,17 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import gc
 import inspect
 import pickle
 from sqlalchemy.orm import create_session, sessionmaker, attributes
-from testlib import engines, sa, testing, config
-from testlib.sa import Table, Column, Integer, String, Sequence
-from testlib.sa.orm import mapper, relation, backref, eagerload
-from testlib.testing import eq_
-from engine import _base as engine_base
-from orm import _base, _fixtures
+import sqlalchemy as sa
+from sqlalchemy.test import engines, testing, config
+from sqlalchemy import Integer, String, Sequence
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, backref, eagerload
+from sqlalchemy.test.testing import eq_
+from test.engine import _base as engine_base
+from test.orm import _base, _fixtures
 
 
 class SessionTest(_fixtures.FixtureTest):
@@ -508,7 +511,7 @@ class SessionTest(_fixtures.FixtureTest):
 
         sess.commit()
 
-        self.assertEquals(set(sess.query(User).all()), set([u2]))
+        eq_(set(sess.query(User).all()), set([u2]))
 
         sess.begin()
         sess.begin_nested()
@@ -518,7 +521,7 @@ class SessionTest(_fixtures.FixtureTest):
         sess.commit() # commit the nested transaction
         sess.rollback()
 
-        self.assertEquals(set(sess.query(User).all()), set([u2]))
+        eq_(set(sess.query(User).all()), set([u2]))
 
         sess.close()
 
@@ -541,7 +544,7 @@ class SessionTest(_fixtures.FixtureTest):
 
         sess.close()
 
-        self.assertEquals(len(sess.query(User).all()), 1)
+        eq_(len(sess.query(User).all()), 1)
 
         t1 = sess.begin()
         t2 = sess.begin_nested()
@@ -572,7 +575,7 @@ class SessionTest(_fixtures.FixtureTest):
 
         sess.close()
 
-        self.assertEquals(len(sess.query(User).all()), 1)
+        eq_(len(sess.query(User).all()), 1)
 
     @testing.resolve_artifact_names
     def test_error_on_using_inactive_session(self):
@@ -587,7 +590,7 @@ class SessionTest(_fixtures.FixtureTest):
         sess.flush()
 
         sess.rollback()
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True)
+        assert_raises_message(sa.exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True)
         sess.close()
 
     @testing.resolve_artifact_names
@@ -612,7 +615,7 @@ class SessionTest(_fixtures.FixtureTest):
         sess.flush()
         assert transaction._connection_for_bind(testing.db) is transaction._connection_for_bind(c) is c
 
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect())
+        assert_raises_message(sa.exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect())
 
         transaction.rollback()
         assert len(sess.query(User).all()) == 0
@@ -667,8 +670,8 @@ class SessionTest(_fixtures.FixtureTest):
 
         user = User(name='u1')
 
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "is not persisted", s.update, user)
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "is not persisted", s.delete, user)
+        assert_raises_message(sa.exc.InvalidRequestError, "is not persisted", s.update, user)
+        assert_raises_message(sa.exc.InvalidRequestError, "is not persisted", s.delete, user)
 
         s.add(user)
         s.flush()
@@ -694,13 +697,13 @@ class SessionTest(_fixtures.FixtureTest):
         assert user in s
         assert user not in s.dirty
 
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "is already persistent", s.save, user)
+        assert_raises_message(sa.exc.InvalidRequestError, "is already persistent", s.save, user)
 
         s2 = create_session()
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "is already attached to session", s2.delete, user)
+        assert_raises_message(sa.exc.InvalidRequestError, "is already attached to session", s2.delete, user)
 
         u2 = s2.query(User).get(user.id)
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "another instance with key", s.delete, u2)
+        assert_raises_message(sa.exc.InvalidRequestError, "another instance with key", s.delete, u2)
 
         s.expire(user)
         s.expunge(user)
@@ -1029,7 +1032,7 @@ class SessionTest(_fixtures.FixtureTest):
         u = User(name='u1')
         sess.add(u)
         sess.flush()
-        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+        eq_(sess.query(User).order_by(User.name).all(), 
             [
                 User(name='another u1'),
                 User(name='u1')
@@ -1037,7 +1040,7 @@ class SessionTest(_fixtures.FixtureTest):
         )
         
         sess.flush()
-        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+        eq_(sess.query(User).order_by(User.name).all(), 
             [
                 User(name='another u1'),
                 User(name='u1')
@@ -1046,7 +1049,7 @@ class SessionTest(_fixtures.FixtureTest):
 
         u.name='u2'
         sess.flush()
-        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+        eq_(sess.query(User).order_by(User.name).all(), 
             [
                 User(name='another u1'),
                 User(name='another u2'),
@@ -1056,7 +1059,7 @@ class SessionTest(_fixtures.FixtureTest):
 
         sess.delete(u)
         sess.flush()
-        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+        eq_(sess.query(User).order_by(User.name).all(), 
             [
                 User(name='another u1'),
             ]
@@ -1075,7 +1078,7 @@ class SessionTest(_fixtures.FixtureTest):
         u = User(name='u1')
         sess.add(u)
         sess.flush()
-        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+        eq_(sess.query(User).order_by(User.name).all(), 
             [
                 User(name='u1')
             ]
@@ -1084,7 +1087,7 @@ class SessionTest(_fixtures.FixtureTest):
         sess.add(User(name='u2'))
         sess.flush()
         sess.expunge_all()
-        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+        eq_(sess.query(User).order_by(User.name).all(), 
             [
                 User(name='u1 modified'),
                 User(name='u2')
@@ -1102,7 +1105,7 @@ class SessionTest(_fixtures.FixtureTest):
         
         sess = create_session(extension=MyExt())
         sess.add(User(name='foo'))
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "already flushing", sess.flush)
+        assert_raises_message(sa.exc.InvalidRequestError, "already flushing", sess.flush)
 
     @testing.resolve_artifact_names
     def test_pickled_update(self):
@@ -1113,7 +1116,7 @@ class SessionTest(_fixtures.FixtureTest):
         u1 = User(name='u1')
         sess1.add(u1)
 
-        self.assertRaisesMessage(sa.exc.InvalidRequestError, "already attached to session", sess2.add, u1)
+        assert_raises_message(sa.exc.InvalidRequestError, "already attached to session", sess2.add, u1)
 
         u2 = pickle.loads(pickle.dumps(u1))
 
@@ -1139,7 +1142,7 @@ class SessionTest(_fixtures.FixtureTest):
         assert u2 is not None and u2 is not u1
         assert u2 in sess
 
-        self.assertRaises(Exception, lambda: sess.add(u1))
+        assert_raises(Exception, lambda: sess.add(u1))
 
         sess.expunge(u2)
         assert u2 not in sess
@@ -1181,24 +1184,26 @@ class DisposedStates(_base.MappedTest):
     run_inserts = 'once'
     run_deletes = None
     
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         global t1
         t1 = Table('t1', metadata, 
             Column('id', Integer, primary_key=True),
             Column('data', String(50))
             )
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         global T
         class T(object):
             def __init__(self, data):
                 self.data = data
         mapper(T, t1)
     
-    def tearDown(self):
+    def teardown(self):
         from sqlalchemy.orm.session import _sessions
         _sessions.clear()
-        super(DisposedStates, self).tearDown()
+        super(DisposedStates, self).teardown()
         
     def _set_imap_in_disposal(self, sess, *objs):
         """remove selected objects from the given session, as though they 
@@ -1291,7 +1296,7 @@ class SessionInterface(testing.TestBase):
         def x_raises_(obj, method, *args, **kw):
             watchdog.add(method)
             callable_ = getattr(obj, method)
-            self.assertRaises(sa.orm.exc.UnmappedInstanceError,
+            assert_raises(sa.orm.exc.UnmappedInstanceError,
                               callable_, *args, **kw)
 
         def raises_(method, *args, **kw):
@@ -1343,7 +1348,7 @@ class SessionInterface(testing.TestBase):
         def raises_(method, *args, **kw):
             watchdog.add(method)
             callable_ = getattr(create_session(), method)
-            self.assertRaises(sa.orm.exc.UnmappedClassError,
+            assert_raises(sa.orm.exc.UnmappedClassError,
                               callable_, *args, **kw)
 
         raises_('connection', mapper=user_arg)
@@ -1395,32 +1400,27 @@ class SessionInterface(testing.TestBase):
 
 
 class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest):
-    def create_engine(self):
+    @classmethod
+    def create_engine(cls):
         return engines.testing_engine(options=dict(strategy='threadlocal'))
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('users', metadata,
               Column('id', Integer, primary_key=True),
               Column('name', String(20)),
               test_needs_acid=True)
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class User(_base.BasicEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(User, users)
 
-    def setUpAll(self):
-        engine_base.AltEngineTest.setUpAll(self)
-        _base.MappedTest.setUpAll(self)
-
-
-    def tearDownAll(self):
-        _base.MappedTest.tearDownAll(self)
-        engine_base.AltEngineTest.tearDownAll(self)
-
     @testing.exclude('mysql', '<', (5, 0, 3), 'FIXME: unknown')
     @testing.resolve_artifact_names
     def test_session_nesting(self):
@@ -1432,5 +1432,3 @@ class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest):
         self.engine.commit()
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 86%
rename from test/orm/transaction.py
rename to test/orm/test_transaction.py
index 0fcd55df32e672341a4f28cf1431e7e2b3446535..5aa541cdadaf7b4ec324f2f7c6d07003ae8c577a 100644 (file)
@@ -1,13 +1,13 @@
-import testenv; testenv.configure_for_tests()
 
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy import *
 from sqlalchemy.orm import attributes
 from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import *
 
-from testlib import testing
-from orm import _base
-from orm._fixtures import FixtureTest, User, Address, users, addresses
+from sqlalchemy.test import testing
+from test.orm import _base
+from test.orm._fixtures import FixtureTest, User, Address, users, addresses
 
 import gc
 
@@ -16,7 +16,8 @@ class TransactionTest(FixtureTest):
     run_inserts = None
     session = sessionmaker()
 
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         mapper(User, users, properties={
             'addresses':relation(Address, backref='user',
                                  cascade="all, delete-orphan"),
@@ -32,7 +33,7 @@ class FixtureDataTest(TransactionTest):
         u1 = sess.query(User).get(7)
         u1.name = 'ed'
         sess.rollback()
-        self.assertEquals(u1.name, 'jack')
+        eq_(u1.name, 'jack')
 
     def test_commit_persistent(self):
         sess = self.session()
@@ -40,7 +41,7 @@ class FixtureDataTest(TransactionTest):
         u1.name = 'ed'
         sess.flush()
         sess.commit()
-        self.assertEquals(u1.name, 'ed')
+        eq_(u1.name, 'ed')
 
     def test_concurrent_commit_persistent(self):
         s1 = self.session()
@@ -157,7 +158,7 @@ class AutoExpireTest(TransactionTest):
         u1.addresses.remove(a1)
 
         s.flush()
-        self.assertEquals(s.query(Address).filter(Address.email_address=='foo').all(), [])
+        eq_(s.query(Address).filter(Address.email_address=='foo').all(), [])
         s.rollback()
         assert a1 not in s.deleted
         assert u1.addresses == [a1]
@@ -168,7 +169,7 @@ class AutoExpireTest(TransactionTest):
         sess.add(u1)
         sess.flush()
         sess.commit()
-        self.assertEquals(u1.name, 'newuser')
+        eq_(u1.name, 'newuser')
 
 
     def test_concurrent_commit_pending(self):
@@ -212,8 +213,8 @@ class RollbackRecoverTest(TransactionTest):
         u1.name = 'edward'
         a1.email_address = 'foober'
         s.add(u2)
-        self.assertRaises(sa_exc.FlushError, s.commit)
-        self.assertRaises(sa_exc.InvalidRequestError, s.commit)
+        assert_raises(sa_exc.FlushError, s.commit)
+        assert_raises(sa_exc.InvalidRequestError, s.commit)
         s.rollback()
         assert u2 not in s
         assert a2 not in s
@@ -224,7 +225,7 @@ class RollbackRecoverTest(TransactionTest):
         u1.name = 'edward'
         a1.email_address = 'foober'
         s.commit()
-        self.assertEquals(
+        eq_(
             s.query(User).all(),
             [User(id=1, name='edward', addresses=[Address(email_address='foober')])]
         )
@@ -244,8 +245,8 @@ class RollbackRecoverTest(TransactionTest):
         a1.email_address = 'foober'
         s.begin_nested()
         s.add(u2)
-        self.assertRaises(sa_exc.FlushError, s.commit)
-        self.assertRaises(sa_exc.InvalidRequestError, s.commit)
+        assert_raises(sa_exc.FlushError, s.commit)
+        assert_raises(sa_exc.InvalidRequestError, s.commit)
         s.rollback()
         assert u2 not in s
         assert a2 not in s
@@ -271,15 +272,15 @@ class SavepointTest(TransactionTest):
         u1.name = 'edward'
         u2.name = 'jackward'
         s.add_all([u3, u4])
-        self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+        eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
         s.rollback()
         assert u1.name == 'ed'
         assert u2.name == 'jack'
-        self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
+        eq_(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
         s.commit()
         assert u1.name == 'ed'
         assert u2.name == 'jack'
-        self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
+        eq_(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
 
     @testing.requires.savepoints
     def test_savepoint_delete(self):
@@ -287,11 +288,11 @@ class SavepointTest(TransactionTest):
         u1 = User(name='ed')
         s.add(u1)
         s.commit()
-        self.assertEquals(s.query(User).filter_by(name='ed').count(), 1)
+        eq_(s.query(User).filter_by(name='ed').count(), 1)
         s.begin_nested()
         s.delete(u1)
         s.commit()
-        self.assertEquals(s.query(User).filter_by(name='ed').count(), 0)
+        eq_(s.query(User).filter_by(name='ed').count(), 0)
         s.commit()
 
     @testing.requires.savepoints
@@ -307,16 +308,16 @@ class SavepointTest(TransactionTest):
         u1.name = 'edward'
         u2.name = 'jackward'
         s.add_all([u3, u4])
-        self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+        eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
         s.commit()
         def go():
             assert u1.name == 'edward'
             assert u2.name == 'jackward'
-            self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+            eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
         self.assert_sql_count(testing.db, go, 1)
 
         s.commit()
-        self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+        eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
 
     @testing.requires.savepoints
     def test_savepoint_rollback_collections(self):
@@ -330,20 +331,20 @@ class SavepointTest(TransactionTest):
         s.begin_nested()
         u2 = User(name='jack', addresses=[Address(email_address='bat')])
         s.add(u2)
-        self.assertEquals(s.query(User).order_by(User.id).all(),
+        eq_(s.query(User).order_by(User.id).all(),
             [
                 User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
                 User(name='jack', addresses=[Address(email_address='bat')])
             ]
         )
         s.rollback()
-        self.assertEquals(s.query(User).order_by(User.id).all(),
+        eq_(s.query(User).order_by(User.id).all(),
             [
                 User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
             ]
         )
         s.commit()
-        self.assertEquals(s.query(User).order_by(User.id).all(),
+        eq_(s.query(User).order_by(User.id).all(),
             [
                 User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
             ]
@@ -361,21 +362,21 @@ class SavepointTest(TransactionTest):
         s.begin_nested()
         u2 = User(name='jack', addresses=[Address(email_address='bat')])
         s.add(u2)
-        self.assertEquals(s.query(User).order_by(User.id).all(),
+        eq_(s.query(User).order_by(User.id).all(),
             [
                 User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
                 User(name='jack', addresses=[Address(email_address='bat')])
             ]
         )
         s.commit()
-        self.assertEquals(s.query(User).order_by(User.id).all(),
+        eq_(s.query(User).order_by(User.id).all(),
             [
                 User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
                 User(name='jack', addresses=[Address(email_address='bat')])
             ]
         )
         s.commit()
-        self.assertEquals(s.query(User).order_by(User.id).all(),
+        eq_(s.query(User).order_by(User.id).all(),
             [
                 User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
                 User(name='jack', addresses=[Address(email_address='bat')])
@@ -476,7 +477,7 @@ class AccountingFlagsTest(TransactionTest):
 class AutoCommitTest(TransactionTest):
     def test_begin_nested_requires_trans(self):
         sess = create_session(autocommit=True)
-        self.assertRaises(sa_exc.InvalidRequestError, sess.begin_nested)
+        assert_raises(sa_exc.InvalidRequestError, sess.begin_nested)
 
     def test_begin_preflush(self):
         sess = create_session(autocommit=True)
@@ -495,5 +496,3 @@ class AutoCommitTest(TransactionTest):
         
 
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 95%
rename from test/orm/unitofwork.py
rename to test/orm/test_unitofwork.py
index c5e3afd01484f091888ff9c50387c63371992e7e..f95346902be8ca91a7c0ded35cdc1dc6202ea51a 100644 (file)
@@ -1,19 +1,21 @@
 # coding: utf-8
 """Tests unitofwork operations."""
 
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import datetime
 import operator
 from sqlalchemy.orm import mapper as orm_mapper
 
-from testlib import engines, sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, literal_column
-from testlib.sa.orm import mapper, relation, create_session, column_property
-from testlib.testing import eq_, ne_
-from orm import _base, _fixtures
-from engine import _base as engine_base
-import pickleable
-from testlib.assertsql import AllOf, CompiledSQL
+import sqlalchemy as sa
+from sqlalchemy.test import engines, testing, pickleable
+from sqlalchemy import Integer, String, ForeignKey, literal_column
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, column_property
+from sqlalchemy.test.testing import eq_, ne_
+from test.orm import _base, _fixtures
+from test.engine import _base as engine_base
+from sqlalchemy.test.assertsql import AllOf, CompiledSQL
 import gc
 
 class UnitOfWorkTest(object):
@@ -22,7 +24,8 @@ class UnitOfWorkTest(object):
 class HistoryTest(_fixtures.FixtureTest):
     run_inserts = None
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class User(_base.ComparableEntity):
             pass
         class Address(_base.ComparableEntity):
@@ -51,14 +54,16 @@ class HistoryTest(_fixtures.FixtureTest):
 
 
 class VersioningTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('version_table', metadata,
               Column('id', Integer, primary_key=True,
                      test_needs_autoincrement=True),
               Column('version_id', Integer, nullable=False),
               Column('value', String(40), nullable=False))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Foo(_base.ComparableEntity):
             pass
 
@@ -86,7 +91,7 @@ class VersioningTest(_base.MappedTest):
         # Only dialects with a sane rowcount can detect the
         # ConcurrentModificationError
         if testing.db.dialect.supports_sane_rowcount:
-            self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.commit)
+            assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit)
             s1.rollback()
         else:
             s1.commit()
@@ -102,7 +107,7 @@ class VersioningTest(_base.MappedTest):
         s1.delete(f2)
 
         if testing.db.dialect.supports_sane_multi_rowcount:
-            self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.commit)
+            assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit)
         else:
             s1.commit()
 
@@ -124,7 +129,7 @@ class VersioningTest(_base.MappedTest):
         s2.commit()
 
         # load, version is wrong
-        self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id)
+        assert_raises(sa.orm.exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id)
 
         # reload it
         s1.query(Foo).populate_existing().get(f1s1.id)
@@ -153,7 +158,8 @@ class VersioningTest(_base.MappedTest):
 class UnicodeTest(_base.MappedTest):
     __requires__ = ('unicode_connections',)
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('uni_t1', metadata,
             Column('id',  Integer, primary_key=True,
                    test_needs_autoincrement=True),
@@ -163,7 +169,8 @@ class UnicodeTest(_base.MappedTest):
                    test_needs_autoincrement=True),
             Column('txt', sa.Unicode(50), ForeignKey('uni_t1')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Test(_base.BasicEntity):
             pass
         class Test2(_base.BasicEntity):
@@ -205,10 +212,12 @@ class UnicodeTest(_base.MappedTest):
 class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest):
     __requires__ = ('unicode_connections', 'unicode_ddl',)
 
-    def create_engine(self):
+    @classmethod
+    def create_engine(cls):
         return engines.utf8_engine()
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         t1 = Table('unitable1', metadata,
               Column(u'méil', Integer, primary_key=True, key='a', test_needs_autoincrement=True),
               Column(u'\u6e2c\u8a66', Integer, key='b'),
@@ -223,16 +232,16 @@ class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest):
               test_needs_fk=True,
               test_needs_autoincrement=True)
 
-        self.tables['t1'] = t1
-        self.tables['t2'] = t2
+        cls.tables['t1'] = t1
+        cls.tables['t2'] = t2
 
-    def setUpAll(self):
-        engine_base.AltEngineTest.setUpAll(self)
-        _base.MappedTest.setUpAll(self)
+    @classmethod
+    def setup_class(cls):
+        super(UnicodeSchemaTest, cls).setup_class()
 
-    def tearDownAll(self):
-        _base.MappedTest.tearDownAll(self)
-        engine_base.AltEngineTest.tearDownAll(self)
+    @classmethod
+    def teardown_class(cls):
+        super(UnicodeSchemaTest, cls).teardown_class()
 
     @testing.fails_on('mssql', 'pyodbc returns a non unicode encoding of the results description.')
     @testing.resolve_artifact_names
@@ -298,19 +307,22 @@ class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest):
 
 class MutableTypesTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('mutable_t', metadata,
             Column('id', Integer, primary_key=True,
                    test_needs_autoincrement=True),
             Column('data', sa.PickleType),
             Column('val', sa.Unicode(30)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Foo(_base.BasicEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Foo, mutable_t)
 
     @testing.resolve_artifact_names
@@ -433,20 +445,23 @@ class MutableTypesTest(_base.MappedTest):
         self.sql_count_(0, session.commit)
 
 
-class PickledDicts(_base.MappedTest):
+class PickledDictsTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('mutable_t', metadata,
             Column('id', Integer, primary_key=True,
                    test_needs_autoincrement=True),
             Column('data', sa.PickleType(comparator=operator.eq)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Foo(_base.BasicEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(Foo, mutable_t)
 
     @testing.resolve_artifact_names
@@ -519,7 +534,8 @@ class PickledDicts(_base.MappedTest):
 
 class PKTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('multipk1', metadata,
               Column('multi_id', Integer, primary_key=True,
                      test_needs_autoincrement=True),
@@ -537,7 +553,8 @@ class PKTest(_base.MappedTest):
               Column('date_assigned', sa.Date, key='assigned', primary_key=True),
               Column('data', String(30)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Entry(_base.BasicEntity):
             pass
 
@@ -587,7 +604,8 @@ class PKTest(_base.MappedTest):
 class ForeignPKTest(_base.MappedTest):
     """Detection of the relationship direction on PK joins."""
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("people", metadata,
               Column('person', String(10), primary_key=True),
               Column('firstname', String(10)),
@@ -598,7 +616,8 @@ class ForeignPKTest(_base.MappedTest):
                      primary_key=True),
               Column('site', String(10)))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Person(_base.BasicEntity):
             pass
         class PersonSite(_base.BasicEntity):
@@ -629,19 +648,22 @@ class ForeignPKTest(_base.MappedTest):
 
 class ClauseAttributesTest(_base.MappedTest):
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('users_t', metadata,
             Column('id', Integer, primary_key=True,
                    test_needs_autoincrement=True),
             Column('name', String(30)),
             Column('counter', Integer, default=1))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class User(_base.ComparableEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         mapper(User, users_t)
 
     @testing.resolve_artifact_names
@@ -697,7 +719,8 @@ class ClauseAttributesTest(_base.MappedTest):
 class PassiveDeletesTest(_base.MappedTest):
     __requires__ = ('foreign_keys',)
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('mytable', metadata,
               Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(30)),
@@ -712,7 +735,8 @@ class PassiveDeletesTest(_base.MappedTest):
                                       ondelete="CASCADE"),
               test_needs_fk=True)
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class MyClass(_base.BasicEntity):
             pass
         class MyOtherClass(_base.BasicEntity):
@@ -773,7 +797,8 @@ class PassiveDeletesTest(_base.MappedTest):
 class ExtraPassiveDeletesTest(_base.MappedTest):
     __requires__ = ('foreign_keys',)
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('mytable', metadata,
               Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
               Column('data', String(30)),
@@ -788,7 +813,8 @@ class ExtraPassiveDeletesTest(_base.MappedTest):
                                       ['mytable.id']),
               test_needs_fk=True)
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class MyClass(_base.BasicEntity):
             pass
         class MyOtherClass(_base.BasicEntity):
@@ -829,7 +855,7 @@ class ExtraPassiveDeletesTest(_base.MappedTest):
         assert myothertable.count().scalar() == 4
         mc = session.query(MyClass).get(mc.id)
         session.delete(mc)
-        self.assertRaises(sa.exc.DBAPIError, session.flush)
+        assert_raises(sa.exc.DBAPIError, session.flush)
 
     @testing.resolve_artifact_names
     def test_extra_passive_2(self):
@@ -851,7 +877,7 @@ class ExtraPassiveDeletesTest(_base.MappedTest):
         mc = session.query(MyClass).get(mc.id)
         session.delete(mc)
         mc.children[0].data = 'some new data'
-        self.assertRaises(sa.exc.DBAPIError, session.flush)
+        assert_raises(sa.exc.DBAPIError, session.flush)
 
 
 class DefaultTest(_base.MappedTest):
@@ -864,7 +890,8 @@ class DefaultTest(_base.MappedTest):
 
     """
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         use_string_defaults = testing.against('postgres', 'oracle', 'sqlite', 'mssql')
 
         if use_string_defaults:
@@ -876,8 +903,8 @@ class DefaultTest(_base.MappedTest):
             hohoval = 9
             althohoval = 15
 
-        self.other_artifacts['hohoval'] = hohoval
-        self.other_artifacts['althohoval'] = althohoval
+        cls.other_artifacts['hohoval'] = hohoval
+        cls.other_artifacts['althohoval'] = althohoval
 
         dt = Table('default_t', metadata,
             Column('id', Integer, primary_key=True,
@@ -906,7 +933,8 @@ class DefaultTest(_base.MappedTest):
             st.append_column(
                 Column('hoho', hohotype, ForeignKey('default_t.hoho')))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Hoho(_base.ComparableEntity):
             pass
         class Secondary(_base.ComparableEntity):
@@ -1037,7 +1065,8 @@ class DefaultTest(_base.MappedTest):
                     Secondary(data='s2')]))
 
 class ColumnPropertyTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('data', metadata, 
             Column('id', Integer, primary_key=True),
             Column('a', String(50)),
@@ -1049,7 +1078,8 @@ class ColumnPropertyTest(_base.MappedTest):
             Column('c', String(50)),
             )
             
-    def setup_mappers(self):
+    @classmethod
+    def setup_mappers(cls):
         class Data(_base.BasicEntity):
             pass
         
@@ -1079,7 +1109,7 @@ class ColumnPropertyTest(_base.MappedTest):
         sd1 = SubData(a="hello", b="there", c="hi")
         sess.add(sd1)
         sess.flush()
-        self.assertEquals(sd1.aplusb, "hello there")
+        eq_(sd1.aplusb, "hello there")
         
     @testing.resolve_artifact_names
     def _test(self):
@@ -1089,16 +1119,16 @@ class ColumnPropertyTest(_base.MappedTest):
         sess.add(d1)
         sess.flush()
         
-        self.assertEquals(d1.aplusb, "hello there")
+        eq_(d1.aplusb, "hello there")
         
         d1.b = "bye"
         sess.flush()
-        self.assertEquals(d1.aplusb, "hello bye")
+        eq_(d1.aplusb, "hello bye")
         
         d1.b = 'foobar'
         d1.aplusb = 'im setting this explicitly'
         sess.flush()
-        self.assertEquals(d1.aplusb, "im setting this explicitly")
+        eq_(d1.aplusb, "im setting this explicitly")
     
 class OneToManyTest(_fixtures.FixtureTest):
     run_inserts = None
@@ -1596,7 +1626,7 @@ class SaveTest(_fixtures.FixtureTest):
         u1 = User(name='user1')
         u2 = User(name='user2')
         session.add_all((u1, u2))
-        self.assertRaises(AssertionError, session.flush)
+        assert_raises(AssertionError, session.flush)
 
 
 class ManyToOneTest(_fixtures.FixtureTest):
@@ -2029,7 +2059,8 @@ class SaveTest2(_fixtures.FixtureTest):
         )
 
 class SaveTest3(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('items', metadata,
               Column('item_id', Integer, primary_key=True,
                      test_needs_autoincrement=True),
@@ -2045,7 +2076,8 @@ class SaveTest3(_base.MappedTest):
               Column('keyword_id', Integer, ForeignKey("keywords")),
               Column('foo', sa.Boolean, default=True))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class Keyword(_base.BasicEntity):
             pass
         class Item(_base.BasicEntity):
@@ -2076,7 +2108,8 @@ class SaveTest3(_base.MappedTest):
         assert assoc.count().scalar() == 0
 
 class BooleanColTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('t1_t', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('name', String(30)),
@@ -2118,7 +2151,8 @@ class BooleanColTest(_base.MappedTest):
 
 
 class RowSwitchTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         # parent
         Table('t5', metadata,
             Column('id', Integer, primary_key=True),
@@ -2140,7 +2174,8 @@ class RowSwitchTest(_base.MappedTest):
             Column('t5id', Integer, ForeignKey('t5.id'),nullable=False),
             Column('t7id', Integer, ForeignKey('t7.id'),nullable=False))
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class T5(_base.ComparableEntity):
             pass
 
@@ -2240,7 +2275,8 @@ class RowSwitchTest(_base.MappedTest):
         assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some other t6', 2)]
 
 class InheritingRowSwitchTest(_base.MappedTest):
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table('parent', metadata,
             Column('id', Integer, primary_key=True),
             Column('pdata', String(30))
@@ -2251,7 +2287,8 @@ class InheritingRowSwitchTest(_base.MappedTest):
             Column('cdata', String(30))
         )
 
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class P(_base.ComparableEntity):
             pass
 
@@ -2290,7 +2327,8 @@ class TransactionTest(_base.MappedTest):
     # be specified.  it'll raise immediately post-INSERT, instead of at
     # COMMIT. either way, this test should pass.
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         t1 = Table('t1', metadata,
             Column('id', Integer, primary_key=True))
 
@@ -2299,15 +2337,17 @@ class TransactionTest(_base.MappedTest):
             Column('t1_id', Integer,
                    ForeignKey('t1.id', deferrable=True, initially='deferred')
                    ))
-    def setup_classes(self):
+    @classmethod
+    def setup_classes(cls):
         class T1(_base.ComparableEntity):
             pass
 
         class T2(_base.ComparableEntity):
             pass
 
+    @classmethod
     @testing.resolve_artifact_names
-    def setup_mappers(self):
+    def setup_mappers(cls):
         orm_mapper(T1, t1)
         orm_mapper(T2, t2)
 
@@ -2332,5 +2372,3 @@ class TransactionTest(_base.MappedTest):
         if testing.against('postgres'):
             t1.bind.engine.dispose()
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 95%
rename from test/orm/utils.py
rename to test/orm/test_utils.py
index 813121a446586a211e5027e09e514d194221b066..06533a243b20245ffaa7194b9f2eb4fec09935ff 100644 (file)
@@ -1,4 +1,4 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
 from sqlalchemy.orm import interfaces, util
 from sqlalchemy import Column
 from sqlalchemy import Integer
@@ -8,10 +8,10 @@ from sqlalchemy.orm import aliased
 from sqlalchemy.orm import mapper, create_session
 
 
-from testlib import TestBase, testing
+from sqlalchemy.test import TestBase, testing
 
-from orm import _fixtures
-from testlib.testing import eq_
+from test.orm import _fixtures
+from sqlalchemy.test.testing import eq_
 
 
 class ExtensionCarrierTest(TestBase):
@@ -22,7 +22,7 @@ class ExtensionCarrierTest(TestBase):
         assert carrier.translate_row() is interfaces.EXT_CONTINUE
         assert 'translate_row' not in carrier
 
-        self.assertRaises(AttributeError, lambda: carrier.snickysnack)
+        assert_raises(AttributeError, lambda: carrier.snickysnack)
 
         class Partial(object):
             def __init__(self, marker):
@@ -74,7 +74,7 @@ class AliasedClassTest(TestBase):
         table = self.point_map(Point)
         alias = aliased(Point)
 
-        self.assertRaises(TypeError, alias)
+        assert_raises(TypeError, alias)
 
     def test_instancemethods(self):
         class Point(object):
@@ -236,6 +236,4 @@ class IdentityKeyTest(_fixtures.FixtureTest):
         key = util.identity_key(User, row=row)
         eq_(key, (User, (1,)))
     
-if __name__ == '__main__':
-    testenv.main()
 
diff --git a/test/profiling/alltests.py b/test/profiling/alltests.py
deleted file mode 100644 (file)
index 1940109..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-
-def suite():
-    modules_to_test = (
-        'profiling.memusage',
-        'profiling.compiler',
-        'profiling.pool',
-        'profiling.zoomark',
-        'profiling.zoomark_orm',
-        )
-    alltests = unittest.TestSuite()
-    if testenv.testlib.config.coverage_enabled:
-        return alltests
-        
-    for name in modules_to_test:
-        mod = __import__(name)
-        for token in name.split('.')[1:]:
-            mod = getattr(mod, token)
-        alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    return alltests
-
-
-if __name__ == '__main__':
-    testenv.main(suite())
index c1a107eeb353bd4226a69e4a74553ea10a9f0414..48879ae7e3a8c6c7c96d76f77a91f2ecf861c33a 100644 (file)
@@ -1,4 +1,4 @@
-from engine import _base as engine_base
+from test.engine import _base as engine_base
 
 
 TablesTest = engine_base.TablesTest
diff --git a/test/sql/alltests.py b/test/sql/alltests.py
deleted file mode 100644 (file)
index f01b0e6..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-
-def suite():
-    modules_to_test = (
-        'sql.testtypes',
-        'sql.columns',
-        'sql.constraints',
-
-        'sql.generative',
-
-        # SQL syntax
-        'sql.select',
-        'sql.selectable',
-        'sql.case_statement',
-        'sql.labels',
-        'sql.unicode',
-
-        # assorted round-trip tests
-        'sql.functions',
-        'sql.query',
-        'sql.quote',
-        'sql.rowcount',
-
-        # defaults, sequences (postgres/oracle)
-        'sql.defaults',
-        )
-    alltests = unittest.TestSuite()
-    for name in modules_to_test:
-        mod = __import__(name)
-        for token in name.split('.')[1:]:
-            mod = getattr(mod, token)
-        alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    return alltests
-
-if __name__ == '__main__':
-    testenv.main(suite())
similarity index 94%
rename from test/sql/case_statement.py
rename to test/sql/test_case_statement.py
index 1d5383749516b657cecc2cf90c759740eccf5e1a..3f3abe7e1900b743bafe572bfa567f8ec946d9b8 100644 (file)
@@ -1,14 +1,15 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
 import sys
 from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
 from sqlalchemy import util, exc
 from sqlalchemy.sql import table, column
 
 
 class CaseTest(TestBase, AssertsCompiledSQL):
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         metadata = MetaData(testing.db)
         global info_table
         info_table = Table('infos', metadata,
@@ -24,7 +25,8 @@ class CaseTest(TestBase, AssertsCompiledSQL):
                 {'pk':4, 'info':'pk_4_data'},
                 {'pk':5, 'info':'pk_5_data'},
                 {'pk':6, 'info':'pk_6_data'})
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         info_table.drop()
 
     @testing.fails_on('firebird', 'FIXME: unknown')
@@ -93,7 +95,7 @@ class CaseTest(TestBase, AssertsCompiledSQL):
     def test_literal_interpretation(self):
         t = table('test', column('col1'))
         
-        self.assertRaises(exc.ArgumentError, case, [("x", "y")])
+        assert_raises(exc.ArgumentError, case, [("x", "y")])
         
         self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END")
         self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END")
@@ -133,5 +135,3 @@ class CaseTest(TestBase, AssertsCompiledSQL):
             ('other', 3),
         ]
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 79%
rename from test/sql/columns.py
rename to test/sql/test_columns.py
index 661be891aee6ef0b54e4c8b45176bd1348ee1685..e9dabe1421fb391516d76286f58a8fb956c3a5e7 100644 (file)
@@ -1,7 +1,7 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
 from sqlalchemy import *
 from sqlalchemy import exc, sql
-from testlib import *
+from sqlalchemy.test import *
 from sqlalchemy import Table, Column  # don't use testlib's wrappers
 
 
@@ -37,7 +37,7 @@ class ColumnDefinitionTest(TestBase):
     def test_incomplete(self):
         c = self.columns()
 
-        self.assertRaises(exc.ArgumentError, Table, 't', MetaData(), *c)
+        assert_raises(exc.ArgumentError, Table, 't', MetaData(), *c)
 
     def test_incomplete_key(self):
         c = Column(Integer)
@@ -52,9 +52,7 @@ class ColumnDefinitionTest(TestBase):
 
 
     def test_bogus(self):
-        self.assertRaises(exc.ArgumentError, Column, 'foo', name='bar')
-        self.assertRaises(exc.ArgumentError, Column, 'foo', Integer,
+        assert_raises(exc.ArgumentError, Column, 'foo', name='bar')
+        assert_raises(exc.ArgumentError, Column, 'foo', Integer,
                           type_=Integer())
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 94%
rename from test/sql/constraints.py
rename to test/sql/test_constraints.py
index d019aa0378bb30098d70efed426d50fb8377fc95..8abeb3533817b535393e524fe02f72795aac013d 100644 (file)
@@ -1,16 +1,16 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy import *
 from sqlalchemy import exc
-from testlib import *
-from testlib import config, engines
+from sqlalchemy.test import *
+from sqlalchemy.test import config, engines
 
 class ConstraintTest(TestBase, AssertsExecutionResults):
 
-    def setUp(self):
+    def setup(self):
         global metadata
         metadata = MetaData(testing.db)
 
-    def tearDown(self):
+    def teardown(self):
         metadata.drop_all()
 
     def test_constraint(self):
@@ -33,7 +33,7 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
     def test_double_fk_usage_raises(self):
         f = ForeignKey('b.id')
         
-        self.assertRaises(exc.InvalidRequestError, Table, "a", metadata,
+        assert_raises(exc.InvalidRequestError, Table, "a", metadata,
             Column('x', Integer, f),
             Column('y', Integer, f)
         )
@@ -219,14 +219,14 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
 
         t1 = Table("sometable", MetaData(), Column("foo", Integer))
         schemagen.visit_index(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
-        self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)")
+        eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)")
         schemagen.buffer.truncate(0)
         schemagen.visit_index(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo))
-        self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)")
+        eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)")
 
         schemadrop = dialect.schemadropper(dialect, None)
         schemadrop.execute = lambda: None
-        self.assertRaises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
+        assert_raises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
 
     
 class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
@@ -245,7 +245,7 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
         def clear(self):
             del self.statements[:]
 
-    def setUp(self):
+    def setup(self):
         self.sql = self.accum()
         opts = config.db_opts.copy()
         opts['strategy'] = 'mock'
@@ -333,5 +333,3 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
         assert 'INITIALLY DEFERRED' in self.sql, self.sql
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 96%
rename from test/sql/defaults.py
rename to test/sql/test_defaults.py
index bea6dc04bea01d7eb253cec4e30a8f49afd57817..96415746650108e921a3fb5ed32142734d705d53 100644 (file)
@@ -1,16 +1,19 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import datetime
 from sqlalchemy import Sequence, Column, func
 from sqlalchemy.sql import select, text
-from testlib import sa, testing
-from testlib.sa import MetaData, Table, Integer, String, ForeignKey, Boolean
-from testlib.testing import eq_
-from sql import _base
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.testing import eq_
+from test.sql import _base
 
 
 class DefaultTest(testing.TestBase):
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global t, f, f2, ts, currenttime, metadata, default_generator
 
         db = testing.db
@@ -117,10 +120,11 @@ class DefaultTest(testing.TestBase):
                    server_default='ddl'))
         t.create()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         t.drop()
 
-    def tearDown(self):
+    def teardown(self):
         default_generator['x'] = 50
         t.delete().execute()
 
@@ -139,7 +143,7 @@ class DefaultTest(testing.TestBase):
         fn4 = FN4()
 
         for fn in fn1, fn2, fn3, fn4:
-            self.assertRaisesMessage(sa.exc.ArgumentError,
+            assert_raises_message(sa.exc.ArgumentError,
                                      ex_msg,
                                      sa.ColumnDefault, fn)
 
@@ -387,7 +391,8 @@ class DefaultTest(testing.TestBase):
 class PKDefaultTest(_base.TablesTest):
     __requires__ = ('subqueries',)
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         t2 = Table('t2', metadata,
             Column('nextid', Integer))
 
@@ -411,7 +416,8 @@ class PKDefaultTest(_base.TablesTest):
 class PKIncrementTest(_base.TablesTest):
     run_define_tables = 'each'
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         Table("aitable", metadata,
               Column('id', Integer, Sequence('ai_id_seq', optional=True),
                      primary_key=True),
@@ -484,8 +490,8 @@ class EmptyInsertTest(testing.TestBase):
         
         try:
             result = t1.insert().execute()
-            self.assertEquals(1, select([func.count(text('*'))], from_obj=t1).scalar())
-            self.assertEquals(True, t1.select().scalar())
+            eq_(1, select([func.count(text('*'))], from_obj=t1).scalar())
+            eq_(True, t1.select().scalar())
         finally:
             metadata.drop_all()
 
@@ -493,7 +499,8 @@ class AutoIncrementTest(_base.TablesTest):
     __requires__ = ('identity',)
     run_define_tables = 'each'
 
-    def define_tables(self, metadata):
+    @classmethod
+    def define_tables(cls, metadata):
         """Each test manipulates self.metadata individually."""
 
     @testing.exclude('sqlite', '<', (3, 4), 'no database support')
@@ -542,7 +549,8 @@ class AutoIncrementTest(_base.TablesTest):
 class SequenceTest(testing.TestBase):
     __requires__ = ('sequences',)
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global cartitems, sometable, metadata
         metadata = MetaData(testing.db)
         cartitems = Table("cartitems", metadata,
@@ -626,9 +634,8 @@ class SequenceTest(testing.TestBase):
         x = cartitems.c.cart_id.sequence.execute()
         self.assert_(1 <= x <= 4)
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 97%
rename from test/sql/functions.py
rename to test/sql/test_functions.py
index 17d8a35e97464ec893679aa6cd80a95fb167e72c..e9bf49ce30af1016896f633652f5a65f8a2f1eef 100644 (file)
@@ -1,15 +1,15 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
 import datetime
 from sqlalchemy import *
 from sqlalchemy.sql import table, column
 from sqlalchemy import databases, sql, util
 from sqlalchemy.sql.compiler import BIND_TEMPLATES
 from sqlalchemy.engine import default
-from testlib.engines import all_dialects
+from sqlalchemy.test.engines import all_dialects
 from sqlalchemy import types as sqltypes
-from testlib import *
+from sqlalchemy.test import *
 from sqlalchemy.sql.functions import GenericFunction
-from testlib.testing import eq_
+from sqlalchemy.test.testing import eq_
 from decimal import Decimal as _python_Decimal
 
 from sqlalchemy.databases import *
@@ -237,7 +237,7 @@ class ExecuteTest(TestBase):
             t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi")
 
             res = exec_sorted(select([t2.c.value, t2.c.stuff]))
-            self.assertEquals(res, [(-14, 'hi'), (3, None), (7, None)])
+            eq_(res, [(-14, 'hi'), (3, None), (7, None)])
 
             t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff")
             assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")]
@@ -315,5 +315,3 @@ def exec_sorted(statement, *args, **kw):
     return sorted([tuple(row)
                    for row in statement.execute(*args, **kw).fetchall()])
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 99%
rename from test/sql/generative.py
rename to test/sql/test_generative.py
index 3947a450fe7dc04c5941a64f46772eaf354ce8cc..ca427ca5f57d5043c0c6ce127f1965d1239ba77b 100644 (file)
@@ -1,8 +1,7 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.sql import table, column, ClauseElement
 from sqlalchemy.sql.expression import  _clone, _from_objects
-from testlib import *
+from sqlalchemy.test import *
 from sqlalchemy.sql.visitors import *
 from sqlalchemy import util
 from sqlalchemy.sql import util as sql_util
@@ -12,7 +11,8 @@ class TraversalTest(TestBase, AssertsExecutionResults):
     """test ClauseVisitor's traversal, particularly its ability to copy and modify
     a ClauseElement in place."""
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global A, B
 
         # establish two ficticious ClauseElements.
@@ -162,7 +162,8 @@ class TraversalTest(TestBase, AssertsExecutionResults):
 class ClauseTest(TestBase, AssertsCompiledSQL):
     """test copy-in-place behavior of various ClauseElements."""
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global t1, t2
         t1 = table("table1",
             column("col1"),
@@ -361,7 +362,8 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(s2, "SELECT 1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1 WHERE table1.col1 = anon_1.col1")
         
 class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global t1, t2
         t1 = table("table1",
             column("col1"),
@@ -630,7 +632,8 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
         )
 
 class SpliceJoinsTest(TestBase, AssertsCompiledSQL):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global table1, table2, table3, table4
         def _table(name):
             return table(name, column("col1"), column("col2"),column("col3"))
@@ -691,7 +694,8 @@ class SpliceJoinsTest(TestBase, AssertsCompiledSQL):
 class SelectTest(TestBase, AssertsCompiledSQL):
     """tests the generative capability of Select"""
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global t1, t2
         t1 = table("table1",
             column("col1"),
@@ -772,7 +776,8 @@ class InsertTest(TestBase, AssertsCompiledSQL):
 
     # fixme: consolidate converage from elsewhere here and expand
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global t1, t2
         t1 = table("table1",
             column("col1"),
@@ -811,5 +816,3 @@ class InsertTest(TestBase, AssertsCompiledSQL):
                             "table1 (col1, col2, col3) "
                             "VALUES (:col1, :col2, :col3)")
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 95%
rename from test/sql/labels.py
rename to test/sql/test_labels.py
index 94ee20342e6ac7cb5d39c98c06a3337920c773ba..b946b0ae9885078a769ce3a17587487ddae291ea 100644 (file)
@@ -1,7 +1,7 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
 from sqlalchemy import *
 from sqlalchemy import exc as exceptions
-from testlib import *
+from sqlalchemy.test import *
 from sqlalchemy.engine import default
 
 IDENT_LENGTH = 29
@@ -16,7 +16,8 @@ class LabelTypeTest(TestBase):
         assert isinstance(select([t.c.col2]).as_scalar().label('lala').type, Float)
 
 class LongLabelsTest(TestBase, AssertsCompiledSQL):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, table1, table2, maxlen
         metadata = MetaData(testing.db)
         table1 = Table("some_large_named_table", metadata,
@@ -34,20 +35,21 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL):
         maxlen = testing.db.dialect.max_identifier_length
         testing.db.dialect.max_identifier_length = IDENT_LENGTH
 
-    def tearDown(self):
+    def teardown(self):
         table1.delete().execute()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
         testing.db.dialect.max_identifier_length = maxlen
 
     def test_too_long_name_disallowed(self):
         m = MetaData(testing.db)
         t1 = Table("this_name_is_too_long_for_what_were_doing_in_this_test", m, Column('foo', Integer))
-        self.assertRaises(exceptions.IdentifierError, m.create_all)
-        self.assertRaises(exceptions.IdentifierError, m.drop_all)
-        self.assertRaises(exceptions.IdentifierError, t1.create)
-        self.assertRaises(exceptions.IdentifierError, t1.drop)
+        assert_raises(exceptions.IdentifierError, m.create_all)
+        assert_raises(exceptions.IdentifierError, m.drop_all)
+        assert_raises(exceptions.IdentifierError, t1.create)
+        assert_raises(exceptions.IdentifierError, t1.drop)
         
     def test_result(self):
         table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"})
@@ -191,5 +193,3 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL):
             "FROM some_large_named_table WHERE some_large_named_table.this_is_the_primarykey_column = :_1) AS _1", dialect=compile_dialect)
         
         
-if __name__ == '__main__':
-    testenv.main()
similarity index 94%
rename from test/sql/query.py
rename to test/sql/test_query.py
index b428d8991c9f5064461a69b8b50bf01e6ae0175a..c9305b615f058a07e6f4cc2d7cf21034ea81458c 100644 (file)
@@ -1,14 +1,14 @@
-import testenv; testenv.configure_for_tests()
 import datetime
 from sqlalchemy import *
 from sqlalchemy import exc, sql
 from sqlalchemy.engine import default
-from testlib import *
-from testlib.testing import eq_
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
 
 class QueryTest(TestBase):
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global users, users2, addresses, metadata
         metadata = MetaData(testing.db)
         users = Table('query_users', metadata,
@@ -31,7 +31,8 @@ class QueryTest(TestBase):
         users.delete().execute()
         users2.delete().execute()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     def test_insert(self):
@@ -174,19 +175,19 @@ class QueryTest(TestBase):
         )
         
         concat = ("test: " + users.c.user_name).label('thedata')
-        self.assertEquals(
+        eq_(
             select([concat]).order_by(concat).execute().fetchall(),
             [("test: ed",), ("test: fred",), ("test: jack",)]
         )
 
         concat = ("test: " + users.c.user_name).label('thedata')
-        self.assertEquals(
+        eq_(
             select([concat]).order_by(desc(concat)).execute().fetchall(),
             [("test: jack",), ("test: fred",), ("test: ed",)]
         )
 
         concat = ("test: " + users.c.user_name).label('thedata')
-        self.assertEquals(
+        eq_(
             select([concat]).order_by(concat + "x").execute().fetchall(),
             [("test: ed",), ("test: fred",), ("test: jack",)]
         )
@@ -211,11 +212,11 @@ class QueryTest(TestBase):
     def test_or_and_as_columns(self):
         true, false = literal(True), literal(False)
         
-        self.assertEquals(testing.db.execute(select([and_(true, false)])).scalar(), False)
-        self.assertEquals(testing.db.execute(select([and_(true, true)])).scalar(), True)
-        self.assertEquals(testing.db.execute(select([or_(true, false)])).scalar(), True)
-        self.assertEquals(testing.db.execute(select([or_(false, false)])).scalar(), False)
-        self.assertEquals(testing.db.execute(select([not_(or_(false, false))])).scalar(), True)
+        eq_(testing.db.execute(select([and_(true, false)])).scalar(), False)
+        eq_(testing.db.execute(select([and_(true, true)])).scalar(), True)
+        eq_(testing.db.execute(select([or_(true, false)])).scalar(), True)
+        eq_(testing.db.execute(select([or_(false, false)])).scalar(), False)
+        eq_(testing.db.execute(select([not_(or_(false, false))])).scalar(), True)
 
         row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).fetchone()
         assert row.x == False
@@ -272,13 +273,13 @@ class QueryTest(TestBase):
             {'user_id':4, 'user_name':'OnE'},
         )
 
-        self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('one')).execute().fetchall(), [(1, ), (3, ), (4, )])
+        eq_(select([users.c.user_id]).where(users.c.user_name.ilike('one')).execute().fetchall(), [(1, ), (3, ), (4, )])
 
-        self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )])
+        eq_(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )])
 
         if testing.against('postgres'):
-            self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )])
-            self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), [])
+            eq_(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )])
+            eq_(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), [])
 
 
     def test_compiled_execute(self):
@@ -388,7 +389,7 @@ class QueryTest(TestBase):
 
         def a_eq(executable, wanted):
             got = list(executable.execute())
-            self.assertEquals(got, wanted)
+            eq_(got, wanted)
 
         for labels in False, True:
             a_eq(users.select(order_by=[users.c.user_id],
@@ -511,23 +512,23 @@ class QueryTest(TestBase):
     def test_keys(self):
         users.insert().execute(user_id=1, user_name='foo')
         r = users.select().execute().fetchone()
-        self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
+        eq_([x.lower() for x in r.keys()], ['user_id', 'user_name'])
 
     def test_items(self):
         users.insert().execute(user_id=1, user_name='foo')
         r = users.select().execute().fetchone()
-        self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
+        eq_([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
 
     def test_len(self):
         users.insert().execute(user_id=1, user_name='foo')
         r = users.select().execute().fetchone()
-        self.assertEqual(len(r), 2)
+        eq_(len(r), 2)
         r.close()
         r = testing.db.execute('select user_name, user_id from query_users').fetchone()
-        self.assertEqual(len(r), 2)
+        eq_(len(r), 2)
         r.close()
         r = testing.db.execute('select user_name from query_users').fetchone()
-        self.assertEqual(len(r), 1)
+        eq_(len(r), 1)
         r.close()
 
     def test_cant_execute_join(self):
@@ -542,19 +543,19 @@ class QueryTest(TestBase):
         # should return values in column definition order
         users.insert().execute(user_id=1, user_name='foo')
         r = users.select(users.c.user_id==1).execute().fetchone()
-        self.assertEqual(r[0], 1)
-        self.assertEqual(r[1], 'foo')
-        self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
-        self.assertEqual(r.values(), [1, 'foo'])
+        eq_(r[0], 1)
+        eq_(r[1], 'foo')
+        eq_([x.lower() for x in r.keys()], ['user_id', 'user_name'])
+        eq_(r.values(), [1, 'foo'])
 
     def test_column_order_with_text_query(self):
         # should return values in query order
         users.insert().execute(user_id=1, user_name='foo')
         r = testing.db.execute('select user_name, user_id from query_users').fetchone()
-        self.assertEqual(r[0], 'foo')
-        self.assertEqual(r[1], 1)
-        self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id'])
-        self.assertEqual(r.values(), ['foo', 1])
+        eq_(r[0], 'foo')
+        eq_(r[1], 1)
+        eq_([x.lower() for x in r.keys()], ['user_name', 'user_id'])
+        eq_(r.values(), ['foo', 1])
 
     @testing.crashes('oracle', 'FIXME: unknown, varify not fails_on()')
     @testing.crashes('firebird', 'An identifier must begin with a letter')
@@ -657,9 +658,10 @@ class PercentSchemaNamesTest(TestBase):
     operation the same way we do for text() and column labels.
     
     """
+    @classmethod
     @testing.crashes('mysql', 'mysqldb calls name % (params)')
     @testing.crashes('postgres', 'postgres calls name % (params)')
-    def setUpAll(self):
+    def setup_class(cls):
         global percent_table, metadata
         metadata = MetaData(testing.db)
         percent_table = Table('percent%table', metadata,
@@ -669,9 +671,10 @@ class PercentSchemaNamesTest(TestBase):
         )
         metadata.create_all()
 
+    @classmethod
     @testing.crashes('mysql', 'mysqldb calls name % (params)')
     @testing.crashes('postgres', 'postgres calls name % (params)')
-    def tearDownAll(self):
+    def teardown_class(cls):
         metadata.drop_all()
     
     @testing.crashes('mysql', 'mysqldb calls name % (params)')
@@ -734,7 +737,8 @@ class PercentSchemaNamesTest(TestBase):
         
 class LimitTest(TestBase):
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global users, addresses, metadata
         metadata = MetaData(testing.db)
         users = Table('query_users', metadata,
@@ -746,9 +750,7 @@ class LimitTest(TestBase):
             Column('user_id', Integer, ForeignKey('query_users.user_id')),
             Column('address', String(30)))
         metadata.create_all()
-        self._data()
-        
-    def _data(self):
+
         users.insert().execute(user_id=1, user_name='john')
         addresses.insert().execute(address_id=1, user_id=1, address='addr1')
         users.insert().execute(user_id=2, user_name='jack')
@@ -764,7 +766,8 @@ class LimitTest(TestBase):
         users.insert().execute(user_id=7, user_name='fido')
         addresses.insert().execute(address_id=7, user_id=7, address='addr5')
         
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     def test_select_limit(self):
@@ -805,7 +808,8 @@ class LimitTest(TestBase):
 class CompoundTest(TestBase):
     """test compound statements like UNION, INTERSECT, particularly their ability to nest on
     different databases."""
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, t1, t2, t3
         metadata = MetaData(testing.db)
         t1 = Table('t1', metadata,
@@ -842,7 +846,8 @@ class CompoundTest(TestBase):
             dict(col2="t3col2r3", col3="ccc", col4="bbb"),
         ])
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     def _fetchall_sorted(self, executed):
@@ -861,10 +866,10 @@ class CompoundTest(TestBase):
         wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
                   ('ccc', 'aaa')]
         found1 = self._fetchall_sorted(u.execute())
-        self.assertEquals(found1, wanted)
+        eq_(found1, wanted)
 
         found2 = self._fetchall_sorted(u.alias('bar').select().execute())
-        self.assertEquals(found2, wanted)
+        eq_(found2, wanted)
 
     def test_union_ordered(self):
         (s1, s2) = (
@@ -877,7 +882,7 @@ class CompoundTest(TestBase):
 
         wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
                   ('ccc', 'aaa')]
-        self.assertEquals(u.execute().fetchall(), wanted)
+        eq_(u.execute().fetchall(), wanted)
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
     @testing.requires.subqueries
@@ -892,7 +897,7 @@ class CompoundTest(TestBase):
 
         wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
                   ('ccc', 'aaa')]
-        self.assertEquals(u.alias('bar').select().execute().fetchall(), wanted)
+        eq_(u.alias('bar').select().execute().fetchall(), wanted)
 
     @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
     @testing.fails_on('mysql', 'FIXME: unknown')
@@ -908,10 +913,10 @@ class CompoundTest(TestBase):
 
         wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)]
         found1 = self._fetchall_sorted(e.execute())
-        self.assertEquals(found1, wanted)
+        eq_(found1, wanted)
 
         found2 = self._fetchall_sorted(e.alias('foo').select().execute())
-        self.assertEquals(found2, wanted)
+        eq_(found2, wanted)
 
     @testing.crashes('firebird', 'Does not support intersect')
     @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
@@ -925,10 +930,10 @@ class CompoundTest(TestBase):
         wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
 
         found1 = self._fetchall_sorted(i.execute())
-        self.assertEquals(found1, wanted)
+        eq_(found1, wanted)
 
         found2 = self._fetchall_sorted(i.alias('bar').select().execute())
-        self.assertEquals(found2, wanted)
+        eq_(found2, wanted)
 
     @testing.crashes('firebird', 'Does not support except')
     @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
@@ -945,7 +950,7 @@ class CompoundTest(TestBase):
                   ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
 
         found = self._fetchall_sorted(e.alias('bar').select().execute())
-        self.assertEquals(found, wanted)
+        eq_(found, wanted)
 
     @testing.crashes('firebird', 'Does not support except')
     @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
@@ -962,10 +967,10 @@ class CompoundTest(TestBase):
                   ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
 
         found1 = self._fetchall_sorted(e.execute())
-        self.assertEquals(found1, wanted)
+        eq_(found1, wanted)
 
         found2 = self._fetchall_sorted(e.alias('bar').select().execute())
-        self.assertEquals(found2, wanted)
+        eq_(found2, wanted)
 
     @testing.crashes('firebird', 'Does not support except')
     @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
@@ -981,8 +986,8 @@ class CompoundTest(TestBase):
                 select([t3.c.col3], t3.c.col3 == 'ccc'), #ccc
             )
         )
-        self.assertEquals(e.execute().fetchall(), [('ccc',)])
-        self.assertEquals(e.alias('foo').select().execute().fetchall(),
+        eq_(e.execute().fetchall(), [('ccc',)])
+        eq_(e.alias('foo').select().execute().fetchall(),
                           [('ccc',)])
 
     @testing.crashes('firebird', 'Does not support intersect')
@@ -999,7 +1004,7 @@ class CompoundTest(TestBase):
         wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
         found = self._fetchall_sorted(u.execute())
 
-        self.assertEquals(found, wanted)
+        eq_(found, wanted)
 
     @testing.crashes('firebird', 'Does not support intersect')
     @testing.fails_on('mysql', 'FIXME: unknown')
@@ -1015,7 +1020,7 @@ class CompoundTest(TestBase):
 
         wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
         found = self._fetchall_sorted(ua.select().execute())
-        self.assertEquals(found, wanted)
+        eq_(found, wanted)
 
 
 class JoinTest(TestBase):
@@ -1028,7 +1033,8 @@ class JoinTest(TestBase):
     database seems to be sensitive to this.
     """
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata
         global t1, t2, t3
 
@@ -1057,7 +1063,8 @@ class JoinTest(TestBase):
                             {'t2_id': 21, 't1_id': 11, 'name': 't2 #21'})
         t3.insert().execute({'t3_id': 30, 't2_id': 20, 'name': 't3 #30'})
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     def assertRows(self, statement, expected):
@@ -1066,7 +1073,7 @@ class JoinTest(TestBase):
         found = sorted([tuple(row)
                        for row in statement.execute().fetchall()])
 
-        self.assertEquals(found, sorted(expected))
+        eq_(found, sorted(expected))
 
     def test_join_x1(self):
         """Joins t1->t2."""
@@ -1289,7 +1296,8 @@ class JoinTest(TestBase):
 
 
 class OperatorTest(TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global metadata, flds
         metadata = MetaData(testing.db)
         flds = Table('flds', metadata,
@@ -1304,18 +1312,14 @@ class OperatorTest(TestBase):
             dict(intcol=13, strcol='bar')
         ])
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
     def test_modulo(self):
-        self.assertEquals(
+        eq_(
             select([flds.c.intcol % 3],
                    order_by=flds.c.idcol).execute().fetchall(),
             [(2,),(1,)]
         )
-
-
-
-if __name__ == "__main__":
-    testenv.main()
similarity index 97%
rename from test/sql/quote.py
rename to test/sql/test_quote.py
index 106189afe09285f4032a400df9372195d763f7c1..64e097b85fa266cec41914ef3b05650e20dc8ff9 100644 (file)
@@ -1,12 +1,12 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy import sql
 from sqlalchemy.sql import compiler
-from testlib import *
+from sqlalchemy.test import *
 
 
 class QuoteTest(TestBase, AssertsCompiledSQL):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         # TODO: figure out which databases/which identifiers allow special
         # characters to be used, such as: spaces, quote characters,
         # punctuation characters, set up tests for those as well.
@@ -24,11 +24,12 @@ class QuoteTest(TestBase, AssertsCompiledSQL):
         table1.create()
         table2.create()
 
-    def tearDown(self):
+    def teardown(self):
         table1.delete().execute()
         table2.delete().execute()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         table1.drop()
         table2.drop()
 
@@ -207,5 +208,3 @@ class PreparerTest(TestBase):
         a_eq(unformat('`foo`.bar'), ['foo', 'bar'])
         a_eq(unformat('`foo`.`b``a``r`.`baz`'), ['foo', 'b`a`r', 'baz'])
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 91%
rename from test/sql/rowcount.py
rename to test/sql/test_rowcount.py
index 3c9caad75424c8f3d2240b4280faa73451e16373..82301a4a5c84d068b8a9a3e31648ef49fe32c0d6 100644 (file)
@@ -1,11 +1,11 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
 
 
 class FoundRowsTest(TestBase, AssertsExecutionResults):
     """tests rowcount functionality"""
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         metadata = MetaData(testing.db)
 
         global employees_table
@@ -17,7 +17,7 @@ class FoundRowsTest(TestBase, AssertsExecutionResults):
         )
         employees_table.create()
 
-    def setUp(self):
+    def setup(self):
         global data
         data = [ ('Angela', 'A'),
                  ('Andrew', 'A'),
@@ -31,10 +31,11 @@ class FoundRowsTest(TestBase, AssertsExecutionResults):
 
         i = employees_table.insert()
         i.execute(*[{'name':n, 'department':d} for n, d in data])
-    def tearDown(self):
+    def teardown(self):
         employees_table.delete().execute()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         employees_table.drop()
 
     def testbasic(self):
@@ -67,5 +68,3 @@ class FoundRowsTest(TestBase, AssertsExecutionResults):
         if testing.db.dialect.supports_sane_rowcount:
             assert r.rowcount == 3
 
-if __name__ == '__main__':
-    testenv.main()
similarity index 98%
rename from test/sql/select.py
rename to test/sql/test_select.py
index 2ec5b8da5152e2ea2cbd5b89aeae5cccb3338a28..1d9e531de3aed92961c512f72698e92632239e67 100644 (file)
@@ -1,4 +1,4 @@
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import datetime, re, operator
 from sqlalchemy import *
 from sqlalchemy import exc, sql, util
@@ -6,7 +6,7 @@ from sqlalchemy.sql import table, column, label, compiler
 from sqlalchemy.sql.expression import ClauseList
 from sqlalchemy.engine import default
 from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
-from testlib import *
+from sqlalchemy.test import *
 
 table1 = table('mytable',
     column('myid', Integer),
@@ -185,7 +185,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         t2 = table('t2', column('c'), column('d'))
         s = select([t.c.a]).where(t.c.a==t2.c.d).as_scalar()
         s2 =select([t, t2, s])
-        self.assertRaises(exc.InvalidRequestError, str, s2)
+        assert_raises(exc.InvalidRequestError, str, s2)
 
         # intentional again
         s = s.correlate(t, t2)
@@ -1149,10 +1149,10 @@ UNION SELECT mytable.myid FROM mytable"
 
         # check that conflicts with "unique" params are caught
         s = select([table1], or_(table1.c.myid==7, table1.c.myid==bindparam('myid_1')))
-        self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
+        assert_raises_message(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
 
         s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('myid_1')))
-        self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
+        assert_raises_message(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
 
     def test_binds_no_hash_collision(self):
         """test that construct_params doesn't corrupt dict due to hash collisions"""
@@ -1287,20 +1287,20 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)")
                     )
 
         def check_results(dialect, expected_results, literal):
-            self.assertEqual(len(expected_results), 5, 'Incorrect number of expected results')
-            self.assertEqual(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0])
-            self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1])
-            self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2])
-            self.assertEqual(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
-            self.assertEqual(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4]))
+            eq_(len(expected_results), 5, 'Incorrect number of expected results')
+            eq_(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0])
+            eq_(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1])
+            eq_(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2])
+            eq_(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
+            eq_(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4]))
             # fixme: shoving all of this dialect-specific stuff in one test
             # is now officialy completely ridiculous AND non-obviously omits
             # coverage on other dialects.
             sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect)
             if isinstance(dialect, type(mysql.dialect())):
-                self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest")
+                eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest")
             else:
-                self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest")
+                eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest")
 
         # first test with Postgres engine
         check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s')
@@ -1548,5 +1548,3 @@ class SchemaTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(table4.insert(values=(2, 5, 'test')), "INSERT INTO remote_owner.remotetable (rem_id, datatype_id, value) VALUES "\
             "(:rem_id, :datatype_id, :value)")
 
-if __name__ == "__main__":
-    testenv.main()
old mode 100755 (executable)
new mode 100644 (file)
similarity index 98%
rename from test/sql/selectable.py
rename to test/sql/test_selectable.py
index e9ed5f5..a172eb4
@@ -1,8 +1,8 @@
 """Test various algorithmic properties of selectables."""
 
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
 from sqlalchemy.sql import util as sql_util, visitors
 from sqlalchemy import exc
 from sqlalchemy.sql import table, column
@@ -228,7 +228,7 @@ class SelectableTest(TestBase, AssertsExecutionResults):
 
         s = select([t2, t3], use_labels=True)
 
-        self.assertRaises(exc.NoReferencedTableError, s.join, t1)
+        assert_raises(exc.NoReferencedTableError, s.join, t1)
         
 class PrimaryKeyTest(TestBase, AssertsExecutionResults):
     def test_join_pk_collapse_implicit(self):
@@ -304,12 +304,12 @@ class PrimaryKeyTest(TestBase, AssertsExecutionResults):
             Column('id', Integer, ForeignKey( 'Employee.id', ), primary_key=True),
         )
 
-        self.assertEquals(
+        eq_(
             util.column_set(employee.join(engineer, employee.c.id==engineer.c.id).primary_key),
             util.column_set([employee.c.id])
         )
 
-        self.assertEquals(
+        eq_(
             util.column_set(employee.join(engineer, engineer.c.id==employee.c.id).primary_key),
             util.column_set([employee.c.id])
         )
@@ -329,7 +329,7 @@ class ReduceTest(TestBase, AssertsExecutionResults):
             Column('t3data', String(30)))
         
         
-        self.assertEquals(
+        eq_(
             util.column_set(sql_util.reduce_columns([t1.c.t1id, t1.c.t1data, t2.c.t2id, t2.c.t2data, t3.c.t3id, t3.c.t3data])),
             util.column_set([t1.c.t1id, t1.c.t1data, t2.c.t2data, t3.c.t3data])
         )
@@ -349,7 +349,7 @@ class ReduceTest(TestBase, AssertsExecutionResults):
 
        s = select([engineers, managers]).where(engineers.c.engineer_name==managers.c.manager_name)
        
-       self.assertEquals(util.column_set(sql_util.reduce_columns(list(s.c), s)),
+       eq_(util.column_set(sql_util.reduce_columns(list(s.c), s)),
         util.column_set([s.c.engineer_id, s.c.engineer_name, s.c.manager_id])
         )
        
@@ -374,7 +374,7 @@ class ReduceTest(TestBase, AssertsExecutionResults):
            )
         
         pjoin = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin')
-        self.assertEquals(
+        eq_(
             util.column_set(sql_util.reduce_columns([pjoin.c.people_person_id, pjoin.c.engineers_person_id, pjoin.c.managers_person_id])),
             util.column_set([pjoin.c.people_person_id])
         )
@@ -398,7 +398,7 @@ class ReduceTest(TestBase, AssertsExecutionResults):
             'Item':base_item_table.join(item_table),
             }, None, 'item_join')
             
-        self.assertEquals(
+        eq_(
             util.column_set(sql_util.reduce_columns([item_join.c.id, item_join.c.dummy, item_join.c.child_name])),
             util.column_set([item_join.c.id, item_join.c.dummy, item_join.c.child_name])
         )    
@@ -423,7 +423,7 @@ class ReduceTest(TestBase, AssertsExecutionResults):
                 'c': page_table.join(magazine_page_table).join(classified_page_table),
             }, None, 'page_join')
             
-        self.assertEquals(
+        eq_(
             util.column_set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])),
             util.column_set([pjoin.c.id])
         )    
@@ -522,5 +522,3 @@ class AnnotationsTest(TestBase):
         assert b4.left is bin.left  # since column is immutable
         assert b4.right is not bin.right is not b2.right is not b3.right
         
-if __name__ == "__main__":
-    testenv.main()
similarity index 94%
rename from test/sql/testtypes.py
rename to test/sql/test_types.py
index e5cffe328207bc2192bca608b1a24409e5cc5616..13b6d0954e52c4050835fdc7358c5995324728d5 100644 (file)
@@ -1,13 +1,13 @@
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 import decimal
-import testenv; testenv.configure_for_tests()
-import datetime, os, pickleable, re
+import datetime, os, re
 from sqlalchemy import *
 from sqlalchemy import exc, types, util
 from sqlalchemy.sql import operators
-from testlib.testing import eq_
+from sqlalchemy.test.testing import eq_
 import sqlalchemy.engine.url as url
 from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
-from testlib import *
+from sqlalchemy.test import *
 
 
 class AdaptTest(TestBase):
@@ -121,13 +121,14 @@ class UserDefinedTest(TestBase):
             l
         ):
             for col in row[1:5]:
-                self.assertEquals(col, assertstr)
-            self.assertEquals(row[5], assertint)
-            self.assertEquals(row[6], assertint2)
+                eq_(col, assertstr)
+            eq_(row[5], assertint)
+            eq_(row[6], assertint2)
             for col in row[3], row[4]:
                 assert isinstance(col, unicode)
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global users, metadata
 
         class MyType(types.TypeEngine):
@@ -226,7 +227,8 @@ class UserDefinedTest(TestBase):
 
         metadata.create_all()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
 class ColumnsTest(TestBase, AssertsExecutionResults):
@@ -263,14 +265,15 @@ class ColumnsTest(TestBase, AssertsExecutionResults):
         )
 
         for aCol in testTable.c:
-            self.assertEquals(
+            eq_(
                 expectedResults[aCol.name],
                 db.dialect.schemagenerator(db.dialect, db, None, None).\
                   get_column_specification(aCol))
 
 class UnicodeTest(TestBase, AssertsExecutionResults):
     """tests the Unicode type.  also tests the TypeDecorator with instances in the types package."""
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global unicode_table
         metadata = MetaData(testing.db)
         unicode_table = Table('unicode_table', metadata,
@@ -280,10 +283,11 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
             Column('plain_varchar', String(250))
             )
         unicode_table.create()
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         unicode_table.drop()
 
-    def tearDown(self):
+    def teardown(self):
         unicode_table.delete().execute()
 
     def test_round_trip(self):
@@ -387,7 +391,8 @@ class BinaryTest(TestBase, AssertsExecutionResults):
         ('mysql', '<', (4, 1, 1)),  # screwy varbinary types
         )
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global binary_table, MyPickleType
 
         class MyPickleType(types.TypeDecorator):
@@ -416,10 +421,11 @@ class BinaryTest(TestBase, AssertsExecutionResults):
         )
         binary_table.create()
 
-    def tearDown(self):
+    def teardown(self):
         binary_table.delete().execute()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         binary_table.drop()
 
     @testing.fails_on('mssql', 'MSSQl BINARY type right pads the fixed length with \x00')
@@ -439,21 +445,22 @@ class BinaryTest(TestBase, AssertsExecutionResults):
             text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db)
         ):
             l = stmt.execute().fetchall()
-            self.assertEquals(list(stream1), list(l[0]['data']))
-            self.assertEquals(list(stream1[0:100]), list(l[0]['data_slice']))
-            self.assertEquals(list(stream2), list(l[1]['data']))
-            self.assertEquals(testobj1, l[0]['pickled'])
-            self.assertEquals(testobj2, l[1]['pickled'])
-            self.assertEquals(testobj3.moredata, l[0]['mypickle'].moredata)
-            self.assertEquals(l[0]['mypickle'].stuff, 'this is the right stuff')
+            eq_(list(stream1), list(l[0]['data']))
+            eq_(list(stream1[0:100]), list(l[0]['data_slice']))
+            eq_(list(stream2), list(l[1]['data']))
+            eq_(testobj1, l[0]['pickled'])
+            eq_(testobj2, l[1]['pickled'])
+            eq_(testobj3.moredata, l[0]['mypickle'].moredata)
+            eq_(l[0]['mypickle'].stuff, 'this is the right stuff')
 
     def load_stream(self, name, len=12579):
-        f = os.path.join(os.path.dirname(testenv.__file__), name)
+        f = os.path.join(os.path.dirname(__file__), "..", name)
         # put a number less than the typical MySQL default BLOB size
         return file(f).read(len)
 
 class ExpressionTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global test_table, meta
 
         class MyCustomType(types.TypeEngine):
@@ -481,7 +488,8 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
 
         test_table.insert().execute({'id':1, 'data':'somedata', 'atimestamp':datetime.date(2007, 10, 15), 'avalue':25})
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         meta.drop_all()
 
     def test_control(self):
@@ -523,7 +531,8 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
         assert testing.db.execute(select([expr])).scalar() == -15
 
 class DateTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global users_with_date, insert_data
 
         db = testing.db
@@ -605,7 +614,8 @@ class DateTest(TestBase, AssertsExecutionResults):
         for idict in insert_dicts:
             users_with_date.insert().execute(**idict)
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         users_with_date.drop()
 
     def testdate(self):
@@ -649,8 +659,8 @@ class DateTest(TestBase, AssertsExecutionResults):
 
             # test mismatched date/datetime
             t.insert().execute(adate=d2, adatetime=d2)
-            self.assertEquals(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)])
-            self.assertEquals(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)])
+            eq_(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)])
+            eq_(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)])
 
         finally:
             t.drop(checkfirst=True)
@@ -674,7 +684,8 @@ def _missing_decimal():
         return True
 
 class NumericTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global numeric_table, metadata
         metadata = MetaData(testing.db)
         numeric_table = Table('numeric_table', metadata,
@@ -686,10 +697,11 @@ class NumericTest(TestBase, AssertsExecutionResults):
         )
         metadata.create_all()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
-    def tearDown(self):
+    def teardown(self):
         numeric_table.delete().execute()
 
     @testing.fails_if(_missing_decimal)
@@ -723,7 +735,7 @@ class NumericTest(TestBase, AssertsExecutionResults):
             assert isinstance(row['fcasdec'], decimal.Decimal)
 
     def test_length_deprecation(self):
-        self.assertRaises(exc.SADeprecationWarning, Numeric, length=8)
+        assert_raises(exc.SADeprecationWarning, Numeric, length=8)
         
         @testing.uses_deprecated(".*is deprecated for Numeric")
         def go():
@@ -751,7 +763,8 @@ class NumericTest(TestBase, AssertsExecutionResults):
                 
             
 class IntervalTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global interval_table, metadata
         metadata = MetaData(testing.db)
         interval_table = Table("intervaltable", metadata,
@@ -760,10 +773,11 @@ class IntervalTest(TestBase, AssertsExecutionResults):
             )
         metadata.create_all()
 
-    def tearDown(self):
+    def teardown(self):
         interval_table.delete().execute()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         metadata.drop_all()
 
     def test_roundtrip(self):
@@ -776,14 +790,16 @@ class IntervalTest(TestBase, AssertsExecutionResults):
         assert interval_table.select().execute().fetchone()['interval'] is None
 
 class BooleanTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global bool_table
         metadata = MetaData(testing.db)
         bool_table = Table('booltest', metadata,
             Column('id', Integer, primary_key=True),
             Column('value', Boolean))
         bool_table.create()
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         bool_table.drop()
     def testbasic(self):
         bool_table.insert().execute(id=1, value=True)
@@ -802,11 +818,11 @@ class PickleTest(TestBase):
     def test_noeq_deprecation(self):
         p1 = PickleType()
         
-        self.assertRaises(DeprecationWarning, 
+        assert_raises(DeprecationWarning, 
             p1.compare_values, pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)
         )
 
-        self.assertRaises(DeprecationWarning, 
+        assert_raises(DeprecationWarning, 
             p1.compare_values, pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)
         )
         
@@ -833,7 +849,7 @@ class PickleTest(TestBase):
         ):
             assert p1.compare_values(p1.copy_value(obj), obj)
 
-        self.assertRaises(NotImplementedError, p1.compare_values, pickleable.BrokenComparable('foo'),pickleable.BrokenComparable('foo'))
+        assert_raises(NotImplementedError, p1.compare_values, pickleable.BrokenComparable('foo'),pickleable.BrokenComparable('foo'))
         
     def test_nonmutable_comparison(self):
         p1 = PickleType()
@@ -846,11 +862,13 @@ class PickleTest(TestBase):
             assert p1.compare_values(p1.copy_value(obj), obj)
     
 class CallableTest(TestBase):
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global meta
         meta = MetaData(testing.db)
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         meta.drop_all()
 
     def test_callable_as_arg(self):
@@ -871,5 +889,3 @@ class CallableTest(TestBase):
         assert isinstance(thang_table.c.name.type, Unicode)
         thang_table.create()
 
-if __name__ == "__main__":
-    testenv.main()
similarity index 95%
rename from test/sql/unicode.py
rename to test/sql/test_unicode.py
index c5002aaffb2fc9445661fd1e046484f0319064f0..d759132678e5fe0a61127e0aee9a2ee7393ca8ec 100644 (file)
@@ -1,16 +1,16 @@
 # coding: utf-8
 """verrrrry basic unicode column name testing"""
 
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from testlib import *
-from testlib.engines import utf8_engine
+from sqlalchemy.test import *
+from sqlalchemy.test.engines import utf8_engine
 from sqlalchemy.sql import column
 
 class UnicodeSchemaTest(TestBase):
     __requires__ = ('unicode_ddl',)
 
-    def setUpAll(self):
+    @classmethod
+    def setup_class(cls):
         global unicode_bind, metadata, t1, t2, t3
 
         unicode_bind = utf8_engine()
@@ -56,13 +56,14 @@ class UnicodeSchemaTest(TestBase):
                        )
         metadata.create_all()
 
-    def tearDown(self):
+    def teardown(self):
         if metadata.tables:
             t3.delete().execute()
             t2.delete().execute()
             t1.delete().execute()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         global unicode_bind
         metadata.drop_all()
         del unicode_bind
@@ -135,5 +136,3 @@ class EscapesDefaultsTest(testing.TestBase):
             t1.drop()
 
 
-if __name__ == '__main__':
-    testenv.main()
diff --git a/test/testenv.py b/test/testenv.py
deleted file mode 100644 (file)
index 808a3c5..0000000
+++ /dev/null
@@ -1,36 +0,0 @@
-"""First import for all test cases, sets sys.path and loads configuration."""
-
-import sys, os, logging, warnings
-
-if sys.version_info < (2, 4):
-    warnings.filterwarnings('ignore', category=FutureWarning)
-
-
-from testlib.testing import main
-import testlib.config
-
-
-_setup = False
-
-def configure_for_tests():
-    """import testenv; testenv.configure_for_tests()"""
-
-    global _setup
-    if not _setup:
-        sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
-        logging.basicConfig()
-
-        testlib.config.configure()
-        _setup = True
-
-def simple_setup():
-    """import testenv; testenv.simple_setup()"""
-
-    global _setup
-    if not _setup:
-        sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
-        logging.basicConfig()
-
-        testlib.config.configure_defaults()
-        _setup = True
-
diff --git a/test/testlib/__init__.py b/test/testlib/__init__.py
deleted file mode 100644 (file)
index 5b8075d..0000000
+++ /dev/null
@@ -1,38 +0,0 @@
-"""Enhance unittest and instrument SQLAlchemy classes for testing.
-
-Load after sqlalchemy imports to use instrumented stand-ins like Table.
-"""
-
-import sys
-import testlib.config
-from testlib.schema import Table, Column
-import testlib.testing as testing
-from testlib.testing import \
-     AssertsCompiledSQL, \
-     AssertsExecutionResults, \
-     ComparesTables, \
-     TestBase, \
-     rowset
-from testlib.orm import mapper
-import testlib.profiling as profiling
-import testlib.engines as engines
-import testlib.requires as requires
-from testlib.compat import _function_named
-
-
-__all__ = ('testing',
-           'mapper',
-           'Table', 'Column',
-           'rowset',
-           'TestBase', 'AssertsExecutionResults',
-           'AssertsCompiledSQL', 'ComparesTables',
-           'profiling', 'engines',
-           '_function_named')
-
-
-testing.requires = requires
-
-sys.modules['testlib.sa'] = sa = testing.CompositeModule(
-    'testlib.sa', 'sqlalchemy', 'testlib.schema', orm=testing.CompositeModule(
-    'testlib.sa.orm', 'sqlalchemy.orm', 'testlib.orm'))
-sys.modules['testlib.sa.orm'] = sa.orm
diff --git a/test/testlib/compat.py b/test/testlib/compat.py
deleted file mode 100644 (file)
index 73eb2d6..0000000
+++ /dev/null
@@ -1,19 +0,0 @@
-import types
-import __builtin__
-
-__all__ = '_function_named', 'callable'
-
-
-def _function_named(fn, newname):
-    try:
-        fn.__name__ = newname
-    except:
-        fn = types.FunctionType(fn.func_code, fn.func_globals, newname,
-                          fn.func_defaults, fn.func_closure)
-    return fn
-
-try:
-    callable = __builtin__.callable
-except NameError:
-    def callable(fn): return hasattr(fn, '__call__')
-
diff --git a/test/testlib/config.py b/test/testlib/config.py
deleted file mode 100644 (file)
index cef4c6e..0000000
+++ /dev/null
@@ -1,344 +0,0 @@
-import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
-logging, require = None, None
-
-
-__all__ = 'parser', 'configure', 'options',
-
-db = None
-db_label, db_url, db_opts = None, None, {}
-
-options = None
-file_config = None
-coverage_enabled = False
-
-base_config = """
-[db]
-sqlite=sqlite:///:memory:
-sqlite_file=sqlite:///querytest.db
-postgres=postgres://scott:tiger@127.0.0.1:5432/test
-mysql=mysql://scott:tiger@127.0.0.1:3306/test
-oracle=oracle://scott:tiger@127.0.0.1:1521
-oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
-mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
-firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb
-maxdb=maxdb://MONA:RED@/maxdb1
-"""
-
-parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
-
-def configure():
-    global options, config
-    global getopts_options, file_config
-
-    file_config = ConfigParser.ConfigParser()
-    file_config.readfp(StringIO.StringIO(base_config))
-    file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
-
-    # Opt parsing can fire immediate actions, like logging and coverage
-    (options, args) = parser.parse_args()
-    sys.argv[1:] = args
-
-    # Lazy setup of other options (post coverage)
-    for fn in post_configure:
-        fn(options, file_config)
-
-    return options, file_config
-
-def configure_defaults():
-    global options, config
-    global getopts_options, file_config
-    global db
-
-    file_config = ConfigParser.ConfigParser()
-    file_config.readfp(StringIO.StringIO(base_config))
-    file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
-    (options, args) = parser.parse_args([])
-
-    # make error messages raised by decorators that depend on a default
-    # database clearer.
-    class _engine_bomb(object):
-        def __getattr__(self, key):
-            raise RuntimeError('No default engine available, testlib '
-                               'was configured with defaults only.')
-
-    db = _engine_bomb()
-    import testlib.testing
-    testlib.testing.db = db
-
-    return options, file_config
-
-def _log(option, opt_str, value, parser):
-    global logging
-    if not logging:
-        import logging
-        logging.basicConfig()
-
-    if opt_str.endswith('-info'):
-        logging.getLogger(value).setLevel(logging.INFO)
-    elif opt_str.endswith('-debug'):
-        logging.getLogger(value).setLevel(logging.DEBUG)
-
-def _start_cumulative_coverage(option, opt_str, value, parser):
-    _start_coverage(option, opt_str, value, parser, erase=False)
-
-def _start_coverage(option, opt_str, value, parser, erase=True):
-    import sys, atexit, coverage
-    true_out = sys.stdout
-    
-    global coverage_enabled
-    coverage_enabled = True
-    
-    def _iter_covered_files(mod, recursive=True):
-        
-        if recursive:
-            ff = os.walk
-        else:
-            ff = os.listdir
-            
-        for rec in ff(os.path.dirname(mod.__file__)):
-            for x in rec[2]:
-                if x.endswith('.py'):
-                    yield os.path.join(rec[0], x)
-            
-    def _stop():
-        coverage.stop()
-        true_out.write("\nPreparing coverage report...\n")
-
-        from sqlalchemy import sql, orm, engine, \
-                            ext, databases, log
-                        
-        import sqlalchemy
-        
-        for modset in [
-            _iter_covered_files(sqlalchemy, recursive=False),
-            _iter_covered_files(databases),
-            _iter_covered_files(engine),
-            _iter_covered_files(ext),
-            _iter_covered_files(orm),
-        ]:
-            coverage.report(list(modset),
-                            show_missing=False, ignore_errors=False,
-                            file=true_out)
-    atexit.register(_stop)
-    if erase:
-        coverage.erase()
-    coverage.start()
-
-def _list_dbs(*args):
-    print "Available --db options (use --dburi to override)"
-    for macro in sorted(file_config.options('db')):
-        print "%20s\t%s" % (macro, file_config.get('db', macro))
-    sys.exit(0)
-
-def _server_side_cursors(options, opt_str, value, parser):
-    db_opts['server_side_cursors'] = True
-
-def _engine_strategy(options, opt_str, value, parser):
-    if value:
-        db_opts['strategy'] = value
-
-opt = parser.add_option
-opt("--verbose", action="store_true", dest="verbose",
-    help="enable stdout echoing/printing")
-opt("--quiet", action="store_true", dest="quiet", help="suppress output")
-opt("--log-info", action="callback", type="string", callback=_log,
-    help="turn on info logging for <LOG> (multiple OK)")
-opt("--log-debug", action="callback", type="string", callback=_log,
-    help="turn on debug logging for <LOG> (multiple OK)")
-opt("--require", action="append", dest="require", default=[],
-    help="require a particular driver or module version (multiple OK)")
-opt("--db", action="store", dest="db", default="sqlite",
-    help="Use prefab database uri")
-opt('--dbs', action='callback', callback=_list_dbs,
-    help="List available prefab dbs")
-opt("--dburi", action="store", dest="dburi",
-    help="Database uri (overrides --db)")
-opt("--dropfirst", action="store_true", dest="dropfirst",
-    help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)")
-opt("--mockpool", action="store_true", dest="mockpool",
-    help="Use mock pool (asserts only one connection used)")
-opt("--enginestrategy", action="callback", type="string",
-    callback=_engine_strategy,
-    help="Engine strategy (plain or threadlocal, defaults to plain)")
-opt("--reversetop", action="store_true", dest="reversetop", default=False,
-    help="Reverse the collection ordering for topological sorts (helps "
-          "reveal dependency issues)")
-opt("--unhashable", action="store_true", dest="unhashable", default=False,
-    help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
-opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
-    help="Disallow SQLAlchemy from performing == on mapped test objects.")
-opt("--truthless", action="store_true", dest="truthless", default=False,
-    help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
-opt("--serverside", action="callback", callback=_server_side_cursors,
-    help="Turn on server side cursors for PG")
-opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
-    help="Use the specified MySQL storage engine for all tables, default is "
-         "a db-default/InnoDB combo.")
-opt("--table-option", action="append", dest="tableopts", default=[],
-    help="Add a dialect-specific table option, key=value")
-opt("--coverage", action="callback", callback=_start_coverage,
-    help="Dump a full coverage report after running tests")
-opt("--cumulative-coverage", action="callback", callback=_start_cumulative_coverage,
-    help="Like --coverage, but accumlate coverage into the current DB")
-opt("--profile", action="append", dest="profile_targets", default=[],
-    help="Enable a named profile target (multiple OK.)")
-opt("--profile-sort", action="store", dest="profile_sort", default=None,
-    help="Sort profile stats with this comma-separated sort order")
-opt("--profile-limit", type="int", action="store", dest="profile_limit",
-    default=None,
-    help="Limit function count in profile stats")
-
-class _ordered_map(object):
-    def __init__(self):
-        self._keys = list()
-        self._data = dict()
-
-    def __setitem__(self, key, value):
-        if key not in self._keys:
-            self._keys.append(key)
-        self._data[key] = value
-
-    def __iter__(self):
-        for key in self._keys:
-            yield self._data[key]
-
-# at one point in refactoring, modules were injecting into the config
-# process.  this could probably just become a list now.
-post_configure = _ordered_map()
-
-def _engine_uri(options, file_config):
-    global db_label, db_url
-    db_label = 'sqlite'
-    if options.dburi:
-        db_url = options.dburi
-        db_label = db_url[:db_url.index(':')]
-    elif options.db:
-        db_label = options.db
-        db_url = None
-
-    if db_url is None:
-        if db_label not in file_config.options('db'):
-            raise RuntimeError(
-                "Unknown engine.  Specify --dbs for known engines.")
-        db_url = file_config.get('db', db_label)
-post_configure['engine_uri'] = _engine_uri
-
-def _require(options, file_config):
-    if not(options.require or
-           (file_config.has_section('require') and
-            file_config.items('require'))):
-        return
-
-    try:
-        import pkg_resources
-    except ImportError:
-        raise RuntimeError("setuptools is required for version requirements")
-
-    cmdline = []
-    for requirement in options.require:
-        pkg_resources.require(requirement)
-        cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
-
-    if file_config.has_section('require'):
-        for label, requirement in file_config.items('require'):
-            if not label == db_label or label.startswith('%s.' % db_label):
-                continue
-            seen = [c for c in cmdline if requirement.startswith(c)]
-            if seen:
-                continue
-            pkg_resources.require(requirement)
-post_configure['require'] = _require
-
-def _engine_pool(options, file_config):
-    if options.mockpool:
-        from sqlalchemy import pool
-        db_opts['poolclass'] = pool.AssertionPool
-post_configure['engine_pool'] = _engine_pool
-
-def _create_testing_engine(options, file_config):
-    from testlib import engines, testing
-    global db
-    db = engines.testing_engine(db_url, db_opts)
-    testing.db = db
-post_configure['create_engine'] = _create_testing_engine
-
-def _prep_testing_database(options, file_config):
-    from testlib import engines
-    from sqlalchemy import schema
-
-    try:
-        # also create alt schemas etc. here?
-        if options.dropfirst:
-            e = engines.utf8_engine()
-            existing = e.table_names()
-            if existing:
-                if not options.quiet:
-                    print "Dropping existing tables in database: " + db_url
-                    try:
-                        print "Tables: %s" % ', '.join(existing)
-                    except:
-                        pass
-                    print "Abort within 5 seconds..."
-                    time.sleep(5)
-                md = schema.MetaData(e, reflect=True)
-                md.drop_all()
-            e.dispose()
-    except (KeyboardInterrupt, SystemExit):
-        raise
-    except Exception, e:
-        if not options.quiet:
-            warnings.warn(RuntimeWarning(
-                "Error checking for existing tables in testing "
-                "database: %s" % e))
-post_configure['prep_db'] = _prep_testing_database
-
-def _set_table_options(options, file_config):
-    import testlib.schema
-
-    table_options = testlib.schema.table_options
-    for spec in options.tableopts:
-        key, value = spec.split('=')
-        table_options[key] = value
-
-    if options.mysql_engine:
-        table_options['mysql_engine'] = options.mysql_engine
-post_configure['table_options'] = _set_table_options
-
-def _reverse_topological(options, file_config):
-    if options.reversetop:
-        from sqlalchemy.orm import unitofwork
-        from sqlalchemy import topological
-        class RevQueueDepSort(topological.QueueDependencySorter):
-            def __init__(self, tuples, allitems):
-                self.tuples = list(tuples)
-                self.allitems = list(allitems)
-                self.tuples.reverse()
-                self.allitems.reverse()
-        topological.QueueDependencySorter = RevQueueDepSort
-        unitofwork.DependencySorter = RevQueueDepSort
-post_configure['topological'] = _reverse_topological
-
-def _set_profile_targets(options, file_config):
-    from testlib import profiling
-
-    profile_config = profiling.profile_config
-
-    for target in options.profile_targets:
-        profile_config['targets'].add(target)
-
-    if options.profile_sort:
-        profile_config['sort'] = options.profile_sort.split(',')
-
-    if options.profile_limit:
-        profile_config['limit'] = options.profile_limit
-
-    if options.quiet:
-        profile_config['report'] = False
-
-    # magic "all" target
-    if 'all' in profiling.all_targets:
-        targets = profile_config['targets']
-        if 'all' in targets and len(targets) != 1:
-            targets.clear()
-            targets.add('all')
-post_configure['profile_targets'] = _set_profile_targets
diff --git a/test/testlib/coverage.py b/test/testlib/coverage.py
deleted file mode 100644 (file)
index fc0f2c2..0000000
+++ /dev/null
@@ -1,1098 +0,0 @@
-#!/usr/bin/python
-#
-#             Perforce Defect Tracking Integration Project
-#              <http://www.ravenbrook.com/project/p4dti/>
-#
-#                   COVERAGE.PY -- COVERAGE TESTING
-#
-#             Gareth Rees, Ravenbrook Limited, 2001-12-04
-#                     Ned Batchelder, 2004-12-12
-#         http://nedbatchelder.com/code/modules/coverage.html
-#
-#
-# 1. INTRODUCTION
-#
-# This module provides coverage testing for Python code.
-#
-# The intended readership is all Python developers.
-#
-# This document is not confidential.
-#
-# See [GDR 2001-12-04a] for the command-line interface, programmatic
-# interface and limitations.  See [GDR 2001-12-04b] for requirements and
-# design.
-
-r"""\
-Usage:
-
-coverage.py -x [-p] MODULE.py [ARG1 ARG2 ...]
-    Execute module, passing the given command-line arguments, collecting
-    coverage data. With the -p option, write to a temporary file containing
-    the machine name and process ID.
-
-coverage.py -e
-    Erase collected coverage data.
-
-coverage.py -c
-    Collect data from multiple coverage files (as created by -p option above)
-    and store it into a single file representing the union of the coverage.
-
-coverage.py -r [-m] [-o dir1,dir2,...] FILE1 FILE2 ...
-    Report on the statement coverage for the given files.  With the -m
-    option, show line numbers of the statements that weren't executed.
-
-coverage.py -a [-d dir] [-o dir1,dir2,...] FILE1 FILE2 ...
-    Make annotated copies of the given files, marking statements that
-    are executed with > and statements that are missed with !.  With
-    the -d option, make the copies in that directory.  Without the -d
-    option, make each copy in the same directory as the original.
-
--o dir,dir2,...
-  Omit reporting or annotating files when their filename path starts with
-  a directory listed in the omit list.
-  e.g. python coverage.py -i -r -o c:\python23,lib\enthought\traits
-
-Coverage data is saved in the file .coverage by default.  Set the
-COVERAGE_FILE environment variable to save it somewhere else."""
-
-__version__ = "2.75.20070722"    # see detailed history at the end of this file.
-
-import compiler
-import compiler.visitor
-import glob
-import os
-import re
-import string
-import symbol
-import sys
-import threading
-import token
-import types
-from socket import gethostname
-
-
-# 2. IMPLEMENTATION
-#
-# This uses the "singleton" pattern.
-#
-# The word "morf" means a module object (from which the source file can
-# be deduced by suitable manipulation of the __file__ attribute) or a
-# filename.
-#
-# When we generate a coverage report we have to canonicalize every
-# filename in the coverage dictionary just in case it refers to the
-# module we are reporting on.  It seems a shame to throw away this
-# information so the data in the coverage dictionary is transferred to
-# the 'cexecuted' dictionary under the canonical filenames.
-#
-# The coverage dictionary is called "c" and the trace function "t".  The
-# reason for these short names is that Python looks up variables by name
-# at runtime and so execution time depends on the length of variables!
-# In the bottleneck of this application it's appropriate to abbreviate
-# names to increase speed.
-
-class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
-    """ A visitor for a parsed Abstract Syntax Tree which finds executable
-        statements.
-    """
-    def __init__(self, statements, excluded, suite_spots):
-        compiler.visitor.ASTVisitor.__init__(self)
-        self.statements = statements
-        self.excluded = excluded
-        self.suite_spots = suite_spots
-        self.excluding_suite = 0
-        
-    def doRecursive(self, node):
-        for n in node.getChildNodes():
-            self.dispatch(n)
-
-    visitStmt = visitModule = doRecursive
-    
-    def doCode(self, node):
-        if hasattr(node, 'decorators') and node.decorators:
-            self.dispatch(node.decorators)
-            self.recordAndDispatch(node.code)
-        else:
-            self.doSuite(node, node.code)
-            
-    visitFunction = visitClass = doCode
-
-    def getFirstLine(self, node):
-        # Find the first line in the tree node.
-        lineno = node.lineno
-        for n in node.getChildNodes():
-            f = self.getFirstLine(n)
-            if lineno and f:
-                lineno = min(lineno, f)
-            else:
-                lineno = lineno or f
-        return lineno
-
-    def getLastLine(self, node):
-        # Find the first line in the tree node.
-        lineno = node.lineno
-        for n in node.getChildNodes():
-            lineno = max(lineno, self.getLastLine(n))
-        return lineno
-    
-    def doStatement(self, node):
-        self.recordLine(self.getFirstLine(node))
-
-    visitAssert = visitAssign = visitAssTuple = visitPrint = \
-        visitPrintnl = visitRaise = visitSubscript = visitDecorators = \
-        doStatement
-    
-    def visitPass(self, node):
-        # Pass statements have weird interactions with docstrings.  If this
-        # pass statement is part of one of those pairs, claim that the statement
-        # is on the later of the two lines.
-        l = node.lineno
-        if l:
-            lines = self.suite_spots.get(l, [l,l])
-            self.statements[lines[1]] = 1
-        
-    def visitDiscard(self, node):
-        # Discard nodes are statements that execute an expression, but then
-        # discard the results.  This includes function calls, so we can't 
-        # ignore them all.  But if the expression is a constant, the statement
-        # won't be "executed", so don't count it now.
-        if node.expr.__class__.__name__ != 'Const':
-            self.doStatement(node)
-
-    def recordNodeLine(self, node):
-        # Stmt nodes often have None, but shouldn't claim the first line of
-        # their children (because the first child might be an ignorable line
-        # like "global a").
-        if node.__class__.__name__ != 'Stmt':
-            return self.recordLine(self.getFirstLine(node))
-        else:
-            return 0
-    
-    def recordLine(self, lineno):
-        # Returns a bool, whether the line is included or excluded.
-        if lineno:
-            # Multi-line tests introducing suites have to get charged to their
-            # keyword.
-            if lineno in self.suite_spots:
-                lineno = self.suite_spots[lineno][0]
-            # If we're inside an excluded suite, record that this line was
-            # excluded.
-            if self.excluding_suite:
-                self.excluded[lineno] = 1
-                return 0
-            # If this line is excluded, or suite_spots maps this line to
-            # another line that is exlcuded, then we're excluded.
-            elif self.excluded.has_key(lineno) or \
-                 self.suite_spots.has_key(lineno) and \
-                 self.excluded.has_key(self.suite_spots[lineno][1]):
-                return 0
-            # Otherwise, this is an executable line.
-            else:
-                self.statements[lineno] = 1
-                return 1
-        return 0
-    
-    default = recordNodeLine
-    
-    def recordAndDispatch(self, node):
-        self.recordNodeLine(node)
-        self.dispatch(node)
-
-    def doSuite(self, intro, body, exclude=0):
-        exsuite = self.excluding_suite
-        if exclude or (intro and not self.recordNodeLine(intro)):
-            self.excluding_suite = 1
-        self.recordAndDispatch(body)
-        self.excluding_suite = exsuite
-        
-    def doPlainWordSuite(self, prevsuite, suite):
-        # Finding the exclude lines for else's is tricky, because they aren't
-        # present in the compiler parse tree.  Look at the previous suite,
-        # and find its last line.  If any line between there and the else's
-        # first line are excluded, then we exclude the else.
-        lastprev = self.getLastLine(prevsuite)
-        firstelse = self.getFirstLine(suite)
-        for l in range(lastprev+1, firstelse):
-            if self.suite_spots.has_key(l):
-                self.doSuite(None, suite, exclude=self.excluded.has_key(l))
-                break
-        else:
-            self.doSuite(None, suite)
-        
-    def doElse(self, prevsuite, node):
-        if node.else_:
-            self.doPlainWordSuite(prevsuite, node.else_)
-    
-    def visitFor(self, node):
-        self.doSuite(node, node.body)
-        self.doElse(node.body, node)
-
-    visitWhile = visitFor
-
-    def visitIf(self, node):
-        # The first test has to be handled separately from the rest.
-        # The first test is credited to the line with the "if", but the others
-        # are credited to the line with the test for the elif.
-        self.doSuite(node, node.tests[0][1])
-        for t, n in node.tests[1:]:
-            self.doSuite(t, n)
-        self.doElse(node.tests[-1][1], node)
-
-    def visitTryExcept(self, node):
-        self.doSuite(node, node.body)
-        for i in range(len(node.handlers)):
-            a, b, h = node.handlers[i]
-            if not a:
-                # It's a plain "except:".  Find the previous suite.
-                if i > 0:
-                    prev = node.handlers[i-1][2]
-                else:
-                    prev = node.body
-                self.doPlainWordSuite(prev, h)
-            else:
-                self.doSuite(a, h)
-        self.doElse(node.handlers[-1][2], node)
-    
-    def visitTryFinally(self, node):
-        self.doSuite(node, node.body)
-        self.doPlainWordSuite(node.body, node.final)
-        
-    def visitGlobal(self, node):
-        # "global" statements don't execute like others (they don't call the
-        # trace function), so don't record their line numbers.
-        pass
-
-the_coverage = None
-
-class CoverageException(Exception): pass
-
-class coverage:
-    # Name of the cache file (unless environment variable is set).
-    cache_default = ".coverage"
-
-    # Environment variable naming the cache file.
-    cache_env = "COVERAGE_FILE"
-
-    # A dictionary with an entry for (Python source file name, line number
-    # in that file) if that line has been executed.
-    c = {}
-    
-    # A map from canonical Python source file name to a dictionary in
-    # which there's an entry for each line number that has been
-    # executed.
-    cexecuted = {}
-
-    # Cache of results of calling the analysis2() method, so that you can
-    # specify both -r and -a without doing double work.
-    analysis_cache = {}
-
-    # Cache of results of calling the canonical_filename() method, to
-    # avoid duplicating work.
-    canonical_filename_cache = {}
-
-    def __init__(self):
-        global the_coverage
-        if the_coverage:
-            raise CoverageException, "Only one coverage object allowed."
-        self.usecache = 1
-        self.cache = None
-        self.parallel_mode = False
-        self.exclude_re = ''
-        self.nesting = 0
-        self.cstack = []
-        self.xstack = []
-        self.relative_dir = os.path.normcase(os.path.abspath(os.curdir)+os.sep)
-        self.exclude('# *pragma[: ]*[nN][oO] *[cC][oO][vV][eE][rR]')
-
-    # t(f, x, y).  This method is passed to sys.settrace as a trace function.  
-    # See [van Rossum 2001-07-20b, 9.2] for an explanation of sys.settrace and 
-    # the arguments and return value of the trace function.
-    # See [van Rossum 2001-07-20a, 3.2] for a description of frame and code
-    # objects.
-    
-    def t(self, f, w, unused):                                   #pragma: no cover
-        if w == 'line':
-            #print "Executing %s @ %d" % (f.f_code.co_filename, f.f_lineno)
-            self.c[(f.f_code.co_filename, f.f_lineno)] = 1
-            for c in self.cstack:
-                c[(f.f_code.co_filename, f.f_lineno)] = 1
-        return self.t
-    
-    def help(self, error=None):     #pragma: no cover
-        if error:
-            print error
-            print
-        print __doc__
-        sys.exit(1)
-
-    def command_line(self, argv, help_fn=None):
-        import getopt
-        help_fn = help_fn or self.help
-        settings = {}
-        optmap = {
-            '-a': 'annotate',
-            '-c': 'collect',
-            '-d:': 'directory=',
-            '-e': 'erase',
-            '-h': 'help',
-            '-i': 'ignore-errors',
-            '-m': 'show-missing',
-            '-p': 'parallel-mode',
-            '-r': 'report',
-            '-x': 'execute',
-            '-o:': 'omit=',
-            }
-        short_opts = string.join(map(lambda o: o[1:], optmap.keys()), '')
-        long_opts = optmap.values()
-        options, args = getopt.getopt(argv, short_opts, long_opts)
-        for o, a in options:
-            if optmap.has_key(o):
-                settings[optmap[o]] = 1
-            elif optmap.has_key(o + ':'):
-                settings[optmap[o + ':']] = a
-            elif o[2:] in long_opts:
-                settings[o[2:]] = 1
-            elif o[2:] + '=' in long_opts:
-                settings[o[2:]+'='] = a
-            else:       #pragma: no cover
-                pass    # Can't get here, because getopt won't return anything unknown.
-
-        if settings.get('help'):
-            help_fn()
-
-        for i in ['erase', 'execute']:
-            for j in ['annotate', 'report', 'collect']:
-                if settings.get(i) and settings.get(j):
-                    help_fn("You can't specify the '%s' and '%s' "
-                              "options at the same time." % (i, j))
-
-        args_needed = (settings.get('execute')
-                       or settings.get('annotate')
-                       or settings.get('report'))
-        action = (settings.get('erase') 
-                  or settings.get('collect')
-                  or args_needed)
-        if not action:
-            help_fn("You must specify at least one of -e, -x, -c, -r, or -a.")
-        if not args_needed and args:
-            help_fn("Unexpected arguments: %s" % " ".join(args))
-        
-        self.parallel_mode = settings.get('parallel-mode')
-        self.get_ready()
-
-        if settings.get('erase'):
-            self.erase()
-        if settings.get('execute'):
-            if not args:
-                help_fn("Nothing to do.")
-            sys.argv = args
-            self.start()
-            import __main__
-            sys.path[0] = os.path.dirname(sys.argv[0])
-            execfile(sys.argv[0], __main__.__dict__)
-        if settings.get('collect'):
-            self.collect()
-        if not args:
-            args = self.cexecuted.keys()
-        
-        ignore_errors = settings.get('ignore-errors')
-        show_missing = settings.get('show-missing')
-        directory = settings.get('directory=')
-
-        omit = settings.get('omit=')
-        if omit is not None:
-            omit = omit.split(',')
-        else:
-            omit = []
-
-        if settings.get('report'):
-            self.report(args, show_missing, ignore_errors, omit_prefixes=omit)
-        if settings.get('annotate'):
-            self.annotate(args, directory, ignore_errors, omit_prefixes=omit)
-
-    def use_cache(self, usecache, cache_file=None):
-        self.usecache = usecache
-        if cache_file and not self.cache:
-            self.cache_default = cache_file
-        
-    def get_ready(self, parallel_mode=False):
-        if self.usecache and not self.cache:
-            self.cache = os.environ.get(self.cache_env, self.cache_default)
-            if self.parallel_mode:
-                self.cache += "." + gethostname() + "." + str(os.getpid())
-            self.restore()
-        self.analysis_cache = {}
-        
-    def start(self, parallel_mode=False):
-        self.get_ready()
-        if self.nesting == 0:                               #pragma: no cover
-            sys.settrace(self.t)
-            if hasattr(threading, 'settrace'):
-                threading.settrace(self.t)
-        self.nesting += 1
-        
-    def stop(self):
-        self.nesting -= 1
-        if self.nesting == 0:                               #pragma: no cover
-            sys.settrace(None)
-            if hasattr(threading, 'settrace'):
-                threading.settrace(None)
-
-    def erase(self):
-        self.get_ready()
-        self.c = {}
-        self.analysis_cache = {}
-        self.cexecuted = {}
-        if self.cache and os.path.exists(self.cache):
-            os.remove(self.cache)
-
-    def exclude(self, re):
-        if self.exclude_re:
-            self.exclude_re += "|"
-        self.exclude_re += "(" + re + ")"
-
-    def begin_recursive(self):
-        self.cstack.append(self.c)
-        self.xstack.append(self.exclude_re)
-        
-    def end_recursive(self):
-        self.c = self.cstack.pop()
-        self.exclude_re = self.xstack.pop()
-
-    # save().  Save coverage data to the coverage cache.
-
-    def save(self):
-        if self.usecache and self.cache:
-            self.canonicalize_filenames()
-            cache = open(self.cache, 'wb')
-            import marshal
-            marshal.dump(self.cexecuted, cache)
-            cache.close()
-
-    # restore().  Restore coverage data from the coverage cache (if it exists).
-
-    def restore(self):
-        self.c = {}
-        self.cexecuted = {}
-        assert self.usecache
-        if os.path.exists(self.cache):
-            self.cexecuted = self.restore_file(self.cache)
-
-    def restore_file(self, file_name):
-        try:
-            cache = open(file_name, 'rb')
-            import marshal
-            cexecuted = marshal.load(cache)
-            cache.close()
-            if isinstance(cexecuted, types.DictType):
-                return cexecuted
-            else:
-                return {}
-        except:
-            return {}
-
-    # collect(). Collect data in multiple files produced by parallel mode
-
-    def collect(self):
-        cache_dir, local = os.path.split(self.cache)
-        for f in os.listdir(cache_dir or '.'):
-            if not f.startswith(local):
-                continue
-
-            full_path = os.path.join(cache_dir, f)
-            cexecuted = self.restore_file(full_path)
-            self.merge_data(cexecuted)
-
-    def merge_data(self, new_data):
-        for file_name, file_data in new_data.items():
-            if self.cexecuted.has_key(file_name):
-                self.merge_file_data(self.cexecuted[file_name], file_data)
-            else:
-                self.cexecuted[file_name] = file_data
-
-    def merge_file_data(self, cache_data, new_data):
-        for line_number in new_data.keys():
-            if not cache_data.has_key(line_number):
-                cache_data[line_number] = new_data[line_number]
-
-    # canonical_filename(filename).  Return a canonical filename for the
-    # file (that is, an absolute path with no redundant components and
-    # normalized case).  See [GDR 2001-12-04b, 3.3].
-
-    def canonical_filename(self, filename):
-        if not self.canonical_filename_cache.has_key(filename):
-            f = filename
-            if os.path.isabs(f) and not os.path.exists(f):
-                f = os.path.basename(f)
-            if not os.path.isabs(f):
-                for path in [os.curdir] + sys.path:
-                    g = os.path.join(path, f)
-                    if os.path.exists(g):
-                        f = g
-                        break
-            cf = os.path.normcase(os.path.abspath(f))
-            self.canonical_filename_cache[filename] = cf
-        return self.canonical_filename_cache[filename]
-
-    # canonicalize_filenames().  Copy results from "c" to "cexecuted", 
-    # canonicalizing filenames on the way.  Clear the "c" map.
-
-    def canonicalize_filenames(self):
-        for filename, lineno in self.c.keys():
-            if filename == '<string>':
-                # Can't do anything useful with exec'd strings, so skip them.
-                continue
-            f = self.canonical_filename(filename)
-            if not self.cexecuted.has_key(f):
-                self.cexecuted[f] = {}
-            self.cexecuted[f][lineno] = 1
-        self.c = {}
-
-    # morf_filename(morf).  Return the filename for a module or file.
-
-    def morf_filename(self, morf):
-        if isinstance(morf, types.ModuleType):
-            if not hasattr(morf, '__file__'):
-                raise CoverageException, "Module has no __file__ attribute."
-            f = morf.__file__
-        else:
-            f = morf
-        return self.canonical_filename(f)
-
-    # analyze_morf(morf).  Analyze the module or filename passed as
-    # the argument.  If the source code can't be found, raise an error.
-    # Otherwise, return a tuple of (1) the canonical filename of the
-    # source code for the module, (2) a list of lines of statements
-    # in the source code, (3) a list of lines of excluded statements,
-    # and (4), a map of line numbers to multi-line line number ranges, for
-    # statements that cross lines.
-    
-    def analyze_morf(self, morf):
-        if self.analysis_cache.has_key(morf):
-            return self.analysis_cache[morf]
-        filename = self.morf_filename(morf)
-        ext = os.path.splitext(filename)[1]
-        if ext == '.pyc':
-            if not os.path.exists(filename[0:-1]):
-                raise CoverageException, ("No source for compiled code '%s'."
-                                   % filename)
-            filename = filename[0:-1]
-        elif ext != '.py':
-            raise CoverageException, "File '%s' not Python source." % filename
-        source = open(filename, 'r')
-        lines, excluded_lines, line_map = self.find_executable_statements(
-            source.read(), exclude=self.exclude_re
-            )
-        source.close()
-        result = filename, lines, excluded_lines, line_map
-        self.analysis_cache[morf] = result
-        return result
-
-    def first_line_of_tree(self, tree):
-        while True:
-            if len(tree) == 3 and type(tree[2]) == type(1):
-                return tree[2]
-            tree = tree[1]
-    
-    def last_line_of_tree(self, tree):
-        while True:
-            if len(tree) == 3 and type(tree[2]) == type(1):
-                return tree[2]
-            tree = tree[-1]
-    
-    def find_docstring_pass_pair(self, tree, spots):
-        for i in range(1, len(tree)):
-            if self.is_string_constant(tree[i]) and self.is_pass_stmt(tree[i+1]):
-                first_line = self.first_line_of_tree(tree[i])
-                last_line = self.last_line_of_tree(tree[i+1])
-                self.record_multiline(spots, first_line, last_line)
-        
-    def is_string_constant(self, tree):
-        try:
-            return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.expr_stmt
-        except:
-            return False
-        
-    def is_pass_stmt(self, tree):
-        try:
-            return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.pass_stmt
-        except:
-            return False
-
-    def record_multiline(self, spots, i, j):
-        for l in range(i, j+1):
-            spots[l] = (i, j)
-            
-    def get_suite_spots(self, tree, spots):
-        """ Analyze a parse tree to find suite introducers which span a number
-            of lines.
-        """
-        for i in range(1, len(tree)):
-            if type(tree[i]) == type(()):
-                if tree[i][0] == symbol.suite:
-                    # Found a suite, look back for the colon and keyword.
-                    lineno_colon = lineno_word = None
-                    for j in range(i-1, 0, -1):
-                        if tree[j][0] == token.COLON:
-                            # Colons are never executed themselves: we want the
-                            # line number of the last token before the colon.
-                            lineno_colon = self.last_line_of_tree(tree[j-1])
-                        elif tree[j][0] == token.NAME:
-                            if tree[j][1] == 'elif':
-                                # Find the line number of the first non-terminal
-                                # after the keyword.
-                                t = tree[j+1]
-                                while t and token.ISNONTERMINAL(t[0]):
-                                    t = t[1]
-                                if t:
-                                    lineno_word = t[2]
-                            else:
-                                lineno_word = tree[j][2]
-                            break
-                        elif tree[j][0] == symbol.except_clause:
-                            # "except" clauses look like:
-                            # ('except_clause', ('NAME', 'except', lineno), ...)
-                            if tree[j][1][0] == token.NAME:
-                                lineno_word = tree[j][1][2]
-                                break
-                    if lineno_colon and lineno_word:
-                        # Found colon and keyword, mark all the lines
-                        # between the two with the two line numbers.
-                        self.record_multiline(spots, lineno_word, lineno_colon)
-
-                    # "pass" statements are tricky: different versions of Python
-                    # treat them differently, especially in the common case of a
-                    # function with a doc string and a single pass statement.
-                    self.find_docstring_pass_pair(tree[i], spots)
-                    
-                elif tree[i][0] == symbol.simple_stmt:
-                    first_line = self.first_line_of_tree(tree[i])
-                    last_line = self.last_line_of_tree(tree[i])
-                    if first_line != last_line:
-                        self.record_multiline(spots, first_line, last_line)
-                self.get_suite_spots(tree[i], spots)
-
-    def find_executable_statements(self, text, exclude=None):
-        # Find lines which match an exclusion pattern.
-        excluded = {}
-        suite_spots = {}
-        if exclude:
-            reExclude = re.compile(exclude)
-            lines = text.split('\n')
-            for i in range(len(lines)):
-                if reExclude.search(lines[i]):
-                    excluded[i+1] = 1
-
-        # Parse the code and analyze the parse tree to find out which statements
-        # are multiline, and where suites begin and end.
-        import parser
-        tree = parser.suite(text+'\n\n').totuple(1)
-        self.get_suite_spots(tree, suite_spots)
-        #print "Suite spots:", suite_spots
-        
-        # Use the compiler module to parse the text and find the executable
-        # statements.  We add newlines to be impervious to final partial lines.
-        statements = {}
-        ast = compiler.parse(text+'\n\n')
-        visitor = StatementFindingAstVisitor(statements, excluded, suite_spots)
-        compiler.walk(ast, visitor, walker=visitor)
-
-        lines = statements.keys()
-        lines.sort()
-        excluded_lines = excluded.keys()
-        excluded_lines.sort()
-        return lines, excluded_lines, suite_spots
-
-    # format_lines(statements, lines).  Format a list of line numbers
-    # for printing by coalescing groups of lines as long as the lines
-    # represent consecutive statements.  This will coalesce even if
-    # there are gaps between statements, so if statements =
-    # [1,2,3,4,5,10,11,12,13,14] and lines = [1,2,5,10,11,13,14] then
-    # format_lines will return "1-2, 5-11, 13-14".
-
-    def format_lines(self, statements, lines):
-        pairs = []
-        i = 0
-        j = 0
-        start = None
-        pairs = []
-        while i < len(statements) and j < len(lines):
-            if statements[i] == lines[j]:
-                if start == None:
-                    start = lines[j]
-                end = lines[j]
-                j = j + 1
-            elif start:
-                pairs.append((start, end))
-                start = None
-            i = i + 1
-        if start:
-            pairs.append((start, end))
-        def stringify(pair):
-            start, end = pair
-            if start == end:
-                return "%d" % start
-            else:
-                return "%d-%d" % (start, end)
-        ret = string.join(map(stringify, pairs), ", ")
-        return ret
-
-    # Backward compatibility with version 1.
-    def analysis(self, morf):
-        f, s, _, m, mf = self.analysis2(morf)
-        return f, s, m, mf
-
-    def analysis2(self, morf):
-        filename, statements, excluded, line_map = self.analyze_morf(morf)
-        self.canonicalize_filenames()
-        if not self.cexecuted.has_key(filename):
-            self.cexecuted[filename] = {}
-        missing = []
-        for line in statements:
-            lines = line_map.get(line, [line, line])
-            for l in range(lines[0], lines[1]+1):
-                if self.cexecuted[filename].has_key(l):
-                    break
-            else:
-                missing.append(line)
-        return (filename, statements, excluded, missing,
-                self.format_lines(statements, missing))
-
-    def relative_filename(self, filename):
-        """ Convert filename to relative filename from self.relative_dir.
-        """
-        return filename.replace(self.relative_dir, "")
-
-    def morf_name(self, morf):
-        """ Return the name of morf as used in report.
-        """
-        if isinstance(morf, types.ModuleType):
-            return morf.__name__
-        else:
-            return self.relative_filename(os.path.splitext(morf)[0])
-
-    def filter_by_prefix(self, morfs, omit_prefixes):
-        """ Return list of morfs where the morf name does not begin
-            with any one of the omit_prefixes.
-        """
-        filtered_morfs = []
-        for morf in morfs:
-            for prefix in omit_prefixes:
-                if self.morf_name(morf).startswith(prefix):
-                    break
-            else:
-                filtered_morfs.append(morf)
-
-        return filtered_morfs
-
-    def morf_name_compare(self, x, y):
-        return cmp(self.morf_name(x), self.morf_name(y))
-
-    def report(self, morfs, show_missing=1, ignore_errors=0, file=None, omit_prefixes=[]):
-        if not isinstance(morfs, types.ListType):
-            morfs = [morfs]
-        # On windows, the shell doesn't expand wildcards.  Do it here.
-        globbed = []
-        for morf in morfs:
-            if isinstance(morf, basestring):
-                globbed.extend(glob.glob(morf))
-            else:
-                globbed.append(morf)
-        morfs = globbed
-        
-        morfs = self.filter_by_prefix(morfs, omit_prefixes)
-        morfs.sort(self.morf_name_compare)
-
-        max_name = max([5,] + map(len, map(self.morf_name, morfs)))
-        fmt_name = "%%- %ds  " % max_name
-        fmt_err = fmt_name + "%s: %s"
-        header = fmt_name % "Name" + " Stmts   Exec  Cover"
-        fmt_coverage = fmt_name + "% 6d % 6d % 5d%%"
-        if show_missing:
-            header = header + "   Missing"
-            fmt_coverage = fmt_coverage + "   %s"
-        if not file:
-            file = sys.stdout
-        print >>file, header
-        print >>file, "-" * len(header)
-        total_statements = 0
-        total_executed = 0
-        for morf in morfs:
-            name = self.morf_name(morf)
-            try:
-                _, statements, _, missing, readable  = self.analysis2(morf)
-                n = len(statements)
-                m = n - len(missing)
-                if n > 0:
-                    pc = 100.0 * m / n
-                else:
-                    pc = 100.0
-                args = (name, n, m, pc)
-                if show_missing:
-                    args = args + (readable,)
-                print >>file, fmt_coverage % args
-                total_statements = total_statements + n
-                total_executed = total_executed + m
-            except KeyboardInterrupt:                       #pragma: no cover
-                raise
-            except:
-                if not ignore_errors:
-                    typ, msg = sys.exc_info()[0:2]
-                    print >>file, fmt_err % (name, typ, msg)
-        if len(morfs) > 1:
-            print >>file, "-" * len(header)
-            if total_statements > 0:
-                pc = 100.0 * total_executed / total_statements
-            else:
-                pc = 100.0
-            args = ("TOTAL", total_statements, total_executed, pc)
-            if show_missing:
-                args = args + ("",)
-            print >>file, fmt_coverage % args
-
-    # annotate(morfs, ignore_errors).
-
-    blank_re = re.compile(r"\s*(#|$)")
-    else_re = re.compile(r"\s*else\s*:\s*(#|$)")
-
-    def annotate(self, morfs, directory=None, ignore_errors=0, omit_prefixes=[]):
-        morfs = self.filter_by_prefix(morfs, omit_prefixes)
-        for morf in morfs:
-            try:
-                filename, statements, excluded, missing, _ = self.analysis2(morf)
-                self.annotate_file(filename, statements, excluded, missing, directory)
-            except KeyboardInterrupt:
-                raise
-            except:
-                if not ignore_errors:
-                    raise
-                
-    def annotate_file(self, filename, statements, excluded, missing, directory=None):
-        source = open(filename, 'r')
-        if directory:
-            dest_file = os.path.join(directory,
-                                     os.path.basename(filename)
-                                     + ',cover')
-        else:
-            dest_file = filename + ',cover'
-        dest = open(dest_file, 'w')
-        lineno = 0
-        i = 0
-        j = 0
-        covered = 1
-        while 1:
-            line = source.readline()
-            if line == '':
-                break
-            lineno = lineno + 1
-            while i < len(statements) and statements[i] < lineno:
-                i = i + 1
-            while j < len(missing) and missing[j] < lineno:
-                j = j + 1
-            if i < len(statements) and statements[i] == lineno:
-                covered = j >= len(missing) or missing[j] > lineno
-            if self.blank_re.match(line):
-                dest.write('  ')
-            elif self.else_re.match(line):
-                # Special logic for lines containing only 'else:'.  
-                # See [GDR 2001-12-04b, 3.2].
-                if i >= len(statements) and j >= len(missing):
-                    dest.write('! ')
-                elif i >= len(statements) or j >= len(missing):
-                    dest.write('> ')
-                elif statements[i] == missing[j]:
-                    dest.write('! ')
-                else:
-                    dest.write('> ')
-            elif lineno in excluded:
-                dest.write('- ')
-            elif covered:
-                dest.write('> ')
-            else:
-                dest.write('! ')
-            dest.write(line)
-        source.close()
-        dest.close()
-
-# Singleton object.
-the_coverage = coverage()
-
-# Module functions call methods in the singleton object.
-def use_cache(*args, **kw): 
-    return the_coverage.use_cache(*args, **kw)
-
-def start(*args, **kw): 
-    return the_coverage.start(*args, **kw)
-
-def stop(*args, **kw): 
-    return the_coverage.stop(*args, **kw)
-
-def erase(*args, **kw): 
-    return the_coverage.erase(*args, **kw)
-
-def begin_recursive(*args, **kw): 
-    return the_coverage.begin_recursive(*args, **kw)
-
-def end_recursive(*args, **kw): 
-    return the_coverage.end_recursive(*args, **kw)
-
-def exclude(*args, **kw): 
-    return the_coverage.exclude(*args, **kw)
-
-def analysis(*args, **kw): 
-    return the_coverage.analysis(*args, **kw)
-
-def analysis2(*args, **kw): 
-    return the_coverage.analysis2(*args, **kw)
-
-def report(*args, **kw): 
-    return the_coverage.report(*args, **kw)
-
-def annotate(*args, **kw): 
-    return the_coverage.annotate(*args, **kw)
-
-def annotate_file(*args, **kw): 
-    return the_coverage.annotate_file(*args, **kw)
-
-# Save coverage data when Python exits.  (The atexit module wasn't
-# introduced until Python 2.0, so use sys.exitfunc when it's not
-# available.)
-try:
-    import atexit
-    atexit.register(the_coverage.save)
-except ImportError:
-    sys.exitfunc = the_coverage.save
-
-# Command-line interface.
-if __name__ == '__main__':
-    the_coverage.command_line(sys.argv[1:])
-
-
-# A. REFERENCES
-#
-# [GDR 2001-12-04a] "Statement coverage for Python"; Gareth Rees;
-# Ravenbrook Limited; 2001-12-04;
-# <http://www.nedbatchelder.com/code/modules/rees-coverage.html>.
-#
-# [GDR 2001-12-04b] "Statement coverage for Python: design and
-# analysis"; Gareth Rees; Ravenbrook Limited; 2001-12-04;
-# <http://www.nedbatchelder.com/code/modules/rees-design.html>.
-#
-# [van Rossum 2001-07-20a] "Python Reference Manual (releae 2.1.1)";
-# Guide van Rossum; 2001-07-20;
-# <http://www.python.org/doc/2.1.1/ref/ref.html>.
-#
-# [van Rossum 2001-07-20b] "Python Library Reference"; Guido van Rossum;
-# 2001-07-20; <http://www.python.org/doc/2.1.1/lib/lib.html>.
-#
-#
-# B. DOCUMENT HISTORY
-#
-# 2001-12-04 GDR Created.
-#
-# 2001-12-06 GDR Added command-line interface and source code
-# annotation.
-#
-# 2001-12-09 GDR Moved design and interface to separate documents.
-#
-# 2001-12-10 GDR Open cache file as binary on Windows.  Allow
-# simultaneous -e and -x, or -a and -r.
-#
-# 2001-12-12 GDR Added command-line help.  Cache analysis so that it
-# only needs to be done once when you specify -a and -r.
-#
-# 2001-12-13 GDR Improved speed while recording.  Portable between
-# Python 1.5.2 and 2.1.1.
-#
-# 2002-01-03 GDR Module-level functions work correctly.
-#
-# 2002-01-07 GDR Update sys.path when running a file with the -x option,
-# so that it matches the value the program would get if it were run on
-# its own.
-#
-# 2004-12-12 NMB Significant code changes.
-# - Finding executable statements has been rewritten so that docstrings and
-#   other quirks of Python execution aren't mistakenly identified as missing
-#   lines.
-# - Lines can be excluded from consideration, even entire suites of lines.
-# - The filesystem cache of covered lines can be disabled programmatically.
-# - Modernized the code.
-#
-# 2004-12-14 NMB Minor tweaks.  Return 'analysis' to its original behavior
-# and add 'analysis2'.  Add a global for 'annotate', and factor it, adding
-# 'annotate_file'.
-#
-# 2004-12-31 NMB Allow for keyword arguments in the module global functions.
-# Thanks, Allen.
-#
-# 2005-12-02 NMB Call threading.settrace so that all threads are measured.
-# Thanks Martin Fuzzey. Add a file argument to report so that reports can be 
-# captured to a different destination.
-#
-# 2005-12-03 NMB coverage.py can now measure itself.
-#
-# 2005-12-04 NMB Adapted Greg Rogers' patch for using relative filenames,
-# and sorting and omitting files to report on.
-#
-# 2006-07-23 NMB Applied Joseph Tate's patch for function decorators.
-#
-# 2006-08-21 NMB Applied Sigve Tjora and Mark van der Wal's fixes for argument
-# handling.
-#
-# 2006-08-22 NMB Applied Geoff Bache's parallel mode patch.
-#
-# 2006-08-23 NMB Refactorings to improve testability.  Fixes to command-line
-# logic for parallel mode and collect.
-#
-# 2006-08-25 NMB "#pragma: nocover" is excluded by default.
-#
-# 2006-09-10 NMB Properly ignore docstrings and other constant expressions that
-# appear in the middle of a function, a problem reported by Tim Leslie.
-# Minor changes to avoid lint warnings.
-#
-# 2006-09-17 NMB coverage.erase() shouldn't clobber the exclude regex.
-# Change how parallel mode is invoked, and fix erase() so that it erases the
-# cache when called programmatically.
-#
-# 2007-07-21 NMB In reports, ignore code executed from strings, since we can't
-# do anything useful with it anyway.
-# Better file handling on Linux, thanks Guillaume Chazarain.
-# Better shell support on Windows, thanks Noel O'Boyle.
-# Python 2.2 support maintained, thanks Catherine Proulx.
-#
-# 2007-07-22 NMB Python 2.5 now fully supported. The method of dealing with
-# multi-line statements is now less sensitive to the exact line that Python
-# reports during execution. Pass statements are handled specially so that their
-# disappearance during execution won't throw off the measurement.
-
-# C. COPYRIGHT AND LICENCE
-#
-# Copyright 2001 Gareth Rees.  All rights reserved.
-# Copyright 2004-2007 Ned Batchelder.  All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# 1. Redistributions of source code must retain the above copyright
-#    notice, this list of conditions and the following disclaimer.
-#
-# 2. Redistributions in binary form must reproduce the above copyright
-#    notice, this list of conditions and the following disclaimer in the
-#    documentation and/or other materials provided with the
-#    distribution.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# HOLDERS AND CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
-# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
-# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
-# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
-# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
-# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
-# DAMAGE.
-#
-# $Id: coverage.py 67 2007-07-21 19:51:07Z nedbat $
diff --git a/test/testlib/sa_unittest.py b/test/testlib/sa_unittest.py
deleted file mode 100644 (file)
index 7eb2c07..0000000
+++ /dev/null
@@ -1,787 +0,0 @@
-#!/usr/bin/env python
-'''
-unittest.py from Python 2.5.
-
-SQLAlchemy extends unittest internals to provide setUpAll()/tearDownAll()
-so we include a fixed version here to insulate from changes.  2.6 and
-3.0's unittest is incompatible with our changes.
-
-Approaches to removing this dependency are:
-
-* find a unittest-supported method of grouping UnitTest classes within 
-a setUpAll()/tearDownAll() pair, such that all tests within a single 
-UnitTest class are executed within a single execution of setUpAll()/
-tearDownAll().  It may be possible to create nested TestSuite objects
-to accomplish this but it's not clear.
-* migrate to a different system such as nose.
-
-Copyright (c) 1999-2003 Steve Purcell
-This module is free software, and you may redistribute it and/or modify
-it under the same terms as Python itself, so long as this copyright message
-and disclaimer are retained in their original form.
-
-IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
-SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF
-THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
-DAMAGE.
-
-THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
-PARTICULAR PURPOSE.  THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
-AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
-SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
-'''
-
-__author__ = "Steve Purcell"
-__email__ = "stephen_purcell at yahoo dot com"
-__version__ = "#Revision: 1.63 $"[11:-2]
-
-import time
-import sys
-import traceback
-import os
-import types
-from testlib.compat import callable
-
-##############################################################################
-# Exported classes and functions
-##############################################################################
-__all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner',
-           'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader']
-
-# Expose obsolete functions for backwards compatibility
-__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
-
-
-##############################################################################
-# Test framework core
-##############################################################################
-
-# All classes defined herein are 'new-style' classes, allowing use of 'super()'
-__metaclass__ = type
-
-def _strclass(cls):
-    return "%s.%s" % (cls.__module__, cls.__name__)
-
-__unittest = 1
-
-class TestResult:
-    """Holder for test result information.
-
-    Test results are automatically managed by the TestCase and TestSuite
-    classes, and do not need to be explicitly manipulated by writers of tests.
-
-    Each instance holds the total number of tests run, and collections of
-    failures and errors that occurred among those test runs. The collections
-    contain tuples of (testcase, exceptioninfo), where exceptioninfo is the
-    formatted traceback of the error that occurred.
-    """
-    def __init__(self):
-        self.failures = []
-        self.errors = []
-        self.testsRun = 0
-        self.shouldStop = 0
-
-    def startTest(self, test):
-        "Called when the given test is about to be run"
-        self.testsRun = self.testsRun + 1
-
-    def stopTest(self, test):
-        "Called when the given test has been run"
-        pass
-
-    def addError(self, test, err):
-        """Called when an error has occurred. 'err' is a tuple of values as
-        returned by sys.exc_info().
-        """
-        self.errors.append((test, self._exc_info_to_string(err, test)))
-
-    def addFailure(self, test, err):
-        """Called when an error has occurred. 'err' is a tuple of values as
-        returned by sys.exc_info()."""
-        self.failures.append((test, self._exc_info_to_string(err, test)))
-
-    def addSuccess(self, test):
-        "Called when a test has completed successfully"
-        pass
-
-    def wasSuccessful(self):
-        "Tells whether or not this result was a success"
-        return len(self.failures) == len(self.errors) == 0
-
-    def stop(self):
-        "Indicates that the tests should be aborted"
-        self.shouldStop = True
-
-    def _exc_info_to_string(self, err, test):
-        """Converts a sys.exc_info()-style tuple of values into a string."""
-        exctype, value, tb = err
-        # Skip test runner traceback levels
-        while tb and self._is_relevant_tb_level(tb):
-            tb = tb.tb_next
-        if exctype is test.failureException:
-            # Skip assert*() traceback levels
-            length = self._count_relevant_tb_levels(tb)
-            return ''.join(traceback.format_exception(exctype, value, tb, length))
-        return ''.join(traceback.format_exception(exctype, value, tb))
-
-    def _is_relevant_tb_level(self, tb):
-        return tb.tb_frame.f_globals.has_key('__unittest')
-
-    def _count_relevant_tb_levels(self, tb):
-        length = 0
-        while tb and not self._is_relevant_tb_level(tb):
-            length += 1
-            tb = tb.tb_next
-        return length
-
-    def __repr__(self):
-        return "<%s run=%i errors=%i failures=%i>" % \
-               (_strclass(self.__class__), self.testsRun, len(self.errors),
-                len(self.failures))
-
-class TestCase:
-    """A class whose instances are single test cases.
-
-    By default, the test code itself should be placed in a method named
-    'runTest'.
-
-    If the fixture may be used for many test cases, create as
-    many test methods as are needed. When instantiating such a TestCase
-    subclass, specify in the constructor arguments the name of the test method
-    that the instance is to execute.
-
-    Test authors should subclass TestCase for their own tests. Construction
-    and deconstruction of the test's environment ('fixture') can be
-    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
-
-    If it is necessary to override the __init__ method, the base class
-    __init__ method must always be called. It is important that subclasses
-    should not change the signature of their __init__ method, since instances
-    of the classes are instantiated automatically by parts of the framework
-    in order to be run.
-    """
-
-    # This attribute determines which exception will be raised when
-    # the instance's assertion methods fail; test methods raising this
-    # exception will be deemed to have 'failed' rather than 'errored'
-
-    failureException = AssertionError
-
-    def __init__(self, methodName='runTest'):
-        """Create an instance of the class that will use the named test
-           method when executed. Raises a ValueError if the instance does
-           not have a method with the specified name.
-        """
-        try:
-            self._testMethodName = methodName
-            testMethod = getattr(self, methodName)
-            self._testMethodDoc = testMethod.__doc__
-        except AttributeError:
-            raise ValueError, "no such test method in %s: %s" % \
-                  (self.__class__, methodName)
-
-    def setUp(self):
-        "Hook method for setting up the test fixture before exercising it."
-        pass
-
-    def tearDown(self):
-        "Hook method for deconstructing the test fixture after testing it."
-        pass
-
-    def countTestCases(self):
-        return 1
-
-    def defaultTestResult(self):
-        return TestResult()
-
-    def shortDescription(self):
-        """Returns a one-line description of the test, or None if no
-        description has been provided.
-
-        The default implementation of this method returns the first line of
-        the specified test method's docstring.
-        """
-        doc = self._testMethodDoc
-        return doc and doc.split("\n")[0].strip() or None
-
-    def id(self):
-        return "%s.%s" % (_strclass(self.__class__), self._testMethodName)
-
-    def __str__(self):
-        return "%s (%s)" % (self._testMethodName, _strclass(self.__class__))
-
-    def __repr__(self):
-        return "<%s testMethod=%s>" % \
-               (_strclass(self.__class__), self._testMethodName)
-
-    def run(self, result=None):
-        if result is None: result = self.defaultTestResult()
-        result.startTest(self)
-        testMethod = getattr(self, self._testMethodName)
-        try:
-            try:
-                self.setUp()
-            except KeyboardInterrupt:
-                raise
-            except:
-                result.addError(self, self._exc_info())
-                return
-
-            ok = False
-            try:
-                testMethod()
-                ok = True
-            except self.failureException:
-                result.addFailure(self, self._exc_info())
-            except KeyboardInterrupt:
-                raise
-            except:
-                result.addError(self, self._exc_info())
-
-            try:
-                self.tearDown()
-            except KeyboardInterrupt:
-                raise
-            except:
-                result.addError(self, self._exc_info())
-                ok = False
-            if ok: result.addSuccess(self)
-        finally:
-            result.stopTest(self)
-
-    def __call__(self, *args, **kwds):
-        return self.run(*args, **kwds)
-
-    def debug(self):
-        """Run the test without collecting errors in a TestResult"""
-        self.setUp()
-        getattr(self, self._testMethodName)()
-        self.tearDown()
-
-    def _exc_info(self):
-        """Return a version of sys.exc_info() with the traceback frame
-           minimised; usually the top level of the traceback frame is not
-           needed.
-        """
-        exctype, excvalue, tb = sys.exc_info()
-        if sys.platform[:4] == 'java': ## tracebacks look different in Jython
-            return (exctype, excvalue, tb)
-        return (exctype, excvalue, tb)
-
-    def fail(self, msg=None):
-        """Fail immediately, with the given message."""
-        raise self.failureException, msg
-
-    def failIf(self, expr, msg=None):
-        "Fail the test if the expression is true."
-        if expr: raise self.failureException, msg
-
-    def failUnless(self, expr, msg=None):
-        """Fail the test unless the expression is true."""
-        if not expr: raise self.failureException, msg
-
-    def failUnlessRaises(self, excClass, callableObj, *args, **kwargs):
-        """Fail unless an exception of class excClass is thrown
-           by callableObj when invoked with arguments args and keyword
-           arguments kwargs. If a different type of exception is
-           thrown, it will not be caught, and the test case will be
-           deemed to have suffered an error, exactly as for an
-           unexpected exception.
-        """
-        try:
-            callableObj(*args, **kwargs)
-        except excClass:
-            return
-        else:
-            if hasattr(excClass,'__name__'): excName = excClass.__name__
-            else: excName = str(excClass)
-            raise self.failureException, "%s not raised" % excName
-
-    def failUnlessEqual(self, first, second, msg=None):
-        """Fail if the two objects are unequal as determined by the '=='
-           operator.
-        """
-        if not first == second:
-            raise self.failureException, \
-                  (msg or '%r != %r' % (first, second))
-
-    def failIfEqual(self, first, second, msg=None):
-        """Fail if the two objects are equal as determined by the '=='
-           operator.
-        """
-        if first == second:
-            raise self.failureException, \
-                  (msg or '%r == %r' % (first, second))
-
-    def failUnlessAlmostEqual(self, first, second, places=7, msg=None):
-        """Fail if the two objects are unequal as determined by their
-           difference rounded to the given number of decimal places
-           (default 7) and comparing to zero.
-
-           Note that decimal places (from zero) are usually not the same
-           as significant digits (measured from the most signficant digit).
-        """
-        if round(second-first, places) != 0:
-            raise self.failureException, \
-                  (msg or '%r != %r within %r places' % (first, second, places))
-
-    def failIfAlmostEqual(self, first, second, places=7, msg=None):
-        """Fail if the two objects are equal as determined by their
-           difference rounded to the given number of decimal places
-           (default 7) and comparing to zero.
-
-           Note that decimal places (from zero) are usually not the same
-           as significant digits (measured from the most signficant digit).
-        """
-        if round(second-first, places) == 0:
-            raise self.failureException, \
-                  (msg or '%r == %r within %r places' % (first, second, places))
-
-    # Synonyms for assertion methods
-
-    assertEqual = assertEquals = failUnlessEqual
-
-    assertNotEqual = assertNotEquals = failIfEqual
-
-    assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual
-
-    assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual
-
-    assertRaises = failUnlessRaises
-
-    assert_ = assertTrue = failUnless
-
-    assertFalse = failIf
-
-
-
-class TestSuite:
-    """A test suite is a composite test consisting of a number of TestCases.
-
-    For use, create an instance of TestSuite, then add test case instances.
-    When all tests have been added, the suite can be passed to a test
-    runner, such as TextTestRunner. It will run the individual test cases
-    in the order in which they were added, aggregating the results. When
-    subclassing, do not forget to call the base class constructor.
-    """
-    def __init__(self, tests=()):
-        self._tests = []
-        self.addTests(tests)
-
-    def __repr__(self):
-        return "<%s tests=%s>" % (_strclass(self.__class__), self._tests)
-
-    __str__ = __repr__
-
-    def __iter__(self):
-        return iter(self._tests)
-
-    def countTestCases(self):
-        cases = 0
-        for test in self._tests:
-            cases += test.countTestCases()
-        return cases
-
-    def addTest(self, test):
-        # sanity checks
-        if not callable(test):
-            raise TypeError("the test to add must be callable")
-        if (isinstance(test, (type, types.ClassType)) and
-            issubclass(test, (TestCase, TestSuite))):
-            raise TypeError("TestCases and TestSuites must be instantiated "
-                            "before passing them to addTest()")
-        self._tests.append(test)
-
-    def addTests(self, tests):
-        if isinstance(tests, basestring):
-            raise TypeError("tests must be an iterable of tests, not a string")
-        for test in tests:
-            self.addTest(test)
-
-    def run(self, result):
-        for test in self._tests:
-            if result.shouldStop:
-                break
-            test(result)
-        return result
-
-    def __call__(self, *args, **kwds):
-        return self.run(*args, **kwds)
-
-    def debug(self):
-        """Run the tests without collecting errors in a TestResult"""
-        for test in self._tests: test.debug()
-
-
-class FunctionTestCase(TestCase):
-    """A test case that wraps a test function.
-
-    This is useful for slipping pre-existing test functions into the
-    PyUnit framework. Optionally, set-up and tidy-up functions can be
-    supplied. As with TestCase, the tidy-up ('tearDown') function will
-    always be called if the set-up ('setUp') function ran successfully.
-    """
-
-    def __init__(self, testFunc, setUp=None, tearDown=None,
-                 description=None):
-        TestCase.__init__(self)
-        self.__setUpFunc = setUp
-        self.__tearDownFunc = tearDown
-        self.__testFunc = testFunc
-        self.__description = description
-
-    def setUp(self):
-        if self.__setUpFunc is not None:
-            self.__setUpFunc()
-
-    def tearDown(self):
-        if self.__tearDownFunc is not None:
-            self.__tearDownFunc()
-
-    def runTest(self):
-        self.__testFunc()
-
-    def id(self):
-        return self.__testFunc.__name__
-
-    def __str__(self):
-        return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__)
-
-    def __repr__(self):
-        return "<%s testFunc=%s>" % (_strclass(self.__class__), self.__testFunc)
-
-    def shortDescription(self):
-        if self.__description is not None: return self.__description
-        doc = self.__testFunc.__doc__
-        return doc and doc.split("\n")[0].strip() or None
-
-
-
-##############################################################################
-# Locating and loading tests
-##############################################################################
-
-class TestLoader:
-    """This class is responsible for loading tests according to various
-    criteria and returning them wrapped in a Test
-    """
-    testMethodPrefix = 'test'
-    suiteClass = TestSuite
-
-    def loadTestsFromTestCase(self, testCaseClass):
-        """Return a suite of all tests cases contained in testCaseClass"""
-        if issubclass(testCaseClass, TestSuite):
-            raise TypeError("Test cases should not be derived from TestSuite. Maybe you meant to derive from TestCase?")
-        testCaseNames = self.getTestCaseNames(testCaseClass)
-        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
-            testCaseNames = ['runTest']
-        return self.suiteClass(map(testCaseClass, testCaseNames))
-
-    def loadTestsFromModule(self, module):
-        """Return a suite of all tests cases contained in the given module"""
-        tests = []
-        for name in dir(module):
-            obj = getattr(module, name)
-            if (isinstance(obj, (type, types.ClassType)) and
-                issubclass(obj, TestCase)):
-                tests.append(self.loadTestsFromTestCase(obj))
-        return self.suiteClass(tests)
-
-    def loadTestsFromName(self, name, module=None):
-        """Return a suite of all tests cases given a string specifier.
-
-        The name may resolve either to a module, a test case class, a
-        test method within a test case class, or a callable object which
-        returns a TestCase or TestSuite instance.
-
-        The method optionally resolves the names relative to a given module.
-        """
-        parts = name.split('.')
-        if module is None:
-            parts_copy = parts[:]
-            while parts_copy:
-                try:
-                    module = __import__('.'.join(parts_copy))
-                    break
-                except ImportError:
-                    del parts_copy[-1]
-                    if not parts_copy: raise
-            parts = parts[1:]
-        obj = module
-        for part in parts:
-            parent, obj = obj, getattr(obj, part)
-
-        if type(obj) == types.ModuleType:
-            return self.loadTestsFromModule(obj)
-        elif (isinstance(obj, (type, types.ClassType)) and
-              issubclass(obj, TestCase)):
-            return self.loadTestsFromTestCase(obj)
-        elif type(obj) == types.UnboundMethodType:
-            return parent(obj.__name__)
-        elif isinstance(obj, TestSuite):
-            return obj
-        elif callable(obj):
-            test = obj()
-            if not isinstance(test, (TestCase, TestSuite)):
-                raise ValueError, \
-                      "calling %s returned %s, not a test" % (obj,test)
-            return test
-        else:
-            raise ValueError, "don't know how to make test from: %s" % obj
-
-    def loadTestsFromNames(self, names, module=None):
-        """Return a suite of all tests cases found using the given sequence
-        of string specifiers. See 'loadTestsFromName()'.
-        """
-        suites = [self.loadTestsFromName(name, module) for name in names]
-        return self.suiteClass(suites)
-
-    def getTestCaseNames(self, testCaseClass):
-        """Return a sorted sequence of method names found within testCaseClass
-        """
-
-        def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix):
-            return attrname.startswith(prefix) and callable(getattr(testCaseClass, attrname))
-        testFnNames = filter(isTestMethod, dir(testCaseClass))
-        for baseclass in testCaseClass.__bases__:
-            for testFnName in self.getTestCaseNames(baseclass):
-                if testFnName not in testFnNames:  # handle overridden methods
-                    testFnNames.append(testFnName)
-        testFnNames.sort()
-        return testFnNames
-
-
-
-defaultTestLoader = TestLoader()
-
-
-##############################################################################
-# Patches for old functions: these functions should be considered obsolete
-##############################################################################
-
-def _makeLoader(prefix, sortUsing, suiteClass=None):
-    loader = TestLoader()
-    loader.testMethodPrefix = prefix
-    if suiteClass: loader.suiteClass = suiteClass
-    return loader
-
-def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
-    return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
-
-def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
-    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
-
-def findTestCases(module, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
-    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
-
-
-##############################################################################
-# Text UI
-##############################################################################
-
-class _WritelnDecorator:
-    """Used to decorate file-like objects with a handy 'writeln' method"""
-    def __init__(self,stream):
-        self.stream = stream
-
-    def __getattr__(self, attr):
-        return getattr(self.stream,attr)
-
-    def writeln(self, arg=None):
-        if arg: self.write(arg)
-        self.write('\n') # text-mode streams translate to \r\n if needed
-
-
-class _TextTestResult(TestResult):
-    """A test result class that can print formatted text results to a stream.
-
-    Used by TextTestRunner.
-    """
-    separator1 = '=' * 70
-    separator2 = '-' * 70
-
-    def __init__(self, stream, descriptions, verbosity):
-        TestResult.__init__(self)
-        self.stream = stream
-        self.showAll = verbosity > 1
-        self.dots = verbosity == 1
-        self.descriptions = descriptions
-
-    def getDescription(self, test):
-        if self.descriptions:
-            return test.shortDescription() or str(test)
-        else:
-            return str(test)
-
-    def startTest(self, test):
-        TestResult.startTest(self, test)
-        if self.showAll:
-            self.stream.write(self.getDescription(test))
-            self.stream.write(" ... ")
-
-    def addSuccess(self, test):
-        TestResult.addSuccess(self, test)
-        if self.showAll:
-            self.stream.writeln("ok")
-        elif self.dots:
-            self.stream.write('.')
-
-    def addError(self, test, err):
-        TestResult.addError(self, test, err)
-        if self.showAll:
-            self.stream.writeln("ERROR")
-        elif self.dots:
-            self.stream.write('E')
-
-    def addFailure(self, test, err):
-        TestResult.addFailure(self, test, err)
-        if self.showAll:
-            self.stream.writeln("FAIL")
-        elif self.dots:
-            self.stream.write('F')
-
-    def printErrors(self):
-        if self.dots or self.showAll:
-            self.stream.writeln()
-        self.printErrorList('ERROR', self.errors)
-        self.printErrorList('FAIL', self.failures)
-
-    def printErrorList(self, flavour, errors):
-        for test, err in errors:
-            self.stream.writeln(self.separator1)
-            self.stream.writeln("%s: %s" % (flavour,self.getDescription(test)))
-            self.stream.writeln(self.separator2)
-            self.stream.writeln("%s" % err)
-
-
-class TextTestRunner:
-    """A test runner class that displays results in textual form.
-
-    It prints out the names of tests as they are run, errors as they
-    occur, and a summary of the results at the end of the test run.
-    """
-    def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1):
-        self.stream = _WritelnDecorator(stream)
-        self.descriptions = descriptions
-        self.verbosity = verbosity
-
-    def _makeResult(self):
-        return _TextTestResult(self.stream, self.descriptions, self.verbosity)
-
-    def run(self, test):
-        "Run the given test case or test suite."
-        result = self._makeResult()
-        startTime = time.time()
-        test(result)
-        stopTime = time.time()
-        timeTaken = stopTime - startTime
-        result.printErrors()
-        self.stream.writeln(result.separator2)
-        run = result.testsRun
-        self.stream.writeln("Ran %d test%s in %.3fs" %
-                            (run, run != 1 and "s" or "", timeTaken))
-        self.stream.writeln()
-        if not result.wasSuccessful():
-            self.stream.write("FAILED (")
-            failed, errored = map(len, (result.failures, result.errors))
-            if failed:
-                self.stream.write("failures=%d" % failed)
-            if errored:
-                if failed: self.stream.write(", ")
-                self.stream.write("errors=%d" % errored)
-            self.stream.writeln(")")
-        else:
-            self.stream.writeln("OK")
-        return result
-
-
-
-##############################################################################
-# Facilities for running tests from the command line
-##############################################################################
-
-class TestProgram:
-    """A command-line program that runs a set of tests; this is primarily
-       for making test modules conveniently executable.
-    """
-    USAGE = """\
-Usage: %(progName)s [options] [test] [...]
-
-Options:
-  -h, --help       Show this message
-  -v, --verbose    Verbose output
-  -q, --quiet      Minimal output
-
-Examples:
-  %(progName)s                               - run default set of tests
-  %(progName)s MyTestSuite                   - run suite 'MyTestSuite'
-  %(progName)s MyTestCase.testSomething      - run MyTestCase.testSomething
-  %(progName)s MyTestCase                    - run all 'test*' test methods
-                                               in MyTestCase
-"""
-    def __init__(self, module='__main__', defaultTest=None,
-                 argv=None, testRunner=None, testLoader=defaultTestLoader):
-        if type(module) == type(''):
-            self.module = __import__(module)
-            for part in module.split('.')[1:]:
-                self.module = getattr(self.module, part)
-        else:
-            self.module = module
-        if argv is None:
-            argv = sys.argv
-        self.verbosity = 1
-        self.defaultTest = defaultTest
-        self.testRunner = testRunner
-        self.testLoader = testLoader
-        self.progName = os.path.basename(argv[0])
-        self.parseArgs(argv)
-        self.runTests()
-
-    def usageExit(self, msg=None):
-        if msg: print msg
-        print self.USAGE % self.__dict__
-        sys.exit(2)
-
-    def parseArgs(self, argv):
-        import getopt
-        try:
-            options, args = getopt.getopt(argv[1:], 'hHvq',
-                                          ['help','verbose','quiet'])
-            for opt, value in options:
-                if opt in ('-h','-H','--help'):
-                    self.usageExit()
-                if opt in ('-q','--quiet'):
-                    self.verbosity = 0
-                if opt in ('-v','--verbose'):
-                    self.verbosity = 2
-            if len(args) == 0 and self.defaultTest is None:
-                self.test = self.testLoader.loadTestsFromModule(self.module)
-                return
-            if len(args) > 0:
-                self.testNames = args
-            else:
-                self.testNames = (self.defaultTest,)
-            self.createTests()
-        except getopt.error, msg:
-            self.usageExit(msg)
-
-    def createTests(self):
-        self.test = self.testLoader.loadTestsFromNames(self.testNames,
-                                                       self.module)
-
-    def runTests(self):
-        if self.testRunner is None:
-            self.testRunner = TextTestRunner(verbosity=self.verbosity)
-        result = self.testRunner.run(self.test)
-        sys.exit(not result.wasSuccessful())
-
-main = TestProgram
-
-
-##############################################################################
-# Executing this module from the command line
-##############################################################################
-
-if __name__ == "__main__":
-    main(module=None)
diff --git a/test/zblog/alltests.py b/test/zblog/alltests.py
deleted file mode 100644 (file)
index 34a188e..0000000
+++ /dev/null
@@ -1,18 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-def suite():
-    modules_to_test = (
-        'zblog.tests',
-        )
-    alltests = unittest.TestSuite()
-    for name in modules_to_test:
-        mod = __import__(name)
-        for token in name.split('.')[1:]:
-            mod = getattr(mod, token)
-        alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    return alltests
-
-
-if __name__ == '__main__':
-    testenv.main(suite())
index 0d789f3d0b92624c669babf419323aacd6d176e7..5203bd866a41c3cf573351acedfc11917a42b863 100644 (file)
@@ -1,8 +1,7 @@
 """mapper.py - defines mappers for domain objects, mapping operations"""
 
-import zblog.tables as tables
-import zblog.user as user
-from zblog.blog import *
+import tables, user
+from blog import *
 from sqlalchemy import *
 from sqlalchemy.orm import *
 import sqlalchemy.util as util
index 4fce48a4c4bd4b9a443eabd84c2d77a97d4fcf52..36c7aeb8b19cc67a9b8621e5e7fab444766c976f 100644 (file)
@@ -1,7 +1,6 @@
 """application table metadata objects are described here."""
 
 from sqlalchemy import *
-from testlib import *
 
 
 metadata = MetaData()
similarity index 79%
rename from test/zblog/tests.py
rename to test/zblog/test_zblog.py
index f784c27962c74c9579484001f88ab54450575439..8170766cb253bcb5f28e5359fbfd0cf599691447 100644 (file)
@@ -1,33 +1,39 @@
-import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from testlib import *
-from zblog import mappers, tables
-from zblog.user import *
-from zblog.blog import *
+from sqlalchemy.test import *
+import mappers, tables
+from user import *
+from blog import *
 
 
 class ZBlogTest(TestBase, AssertsExecutionResults):
 
-    def create_tables(self):
+    @classmethod
+    def create_tables(cls):
         tables.metadata.drop_all(bind=testing.db)
         tables.metadata.create_all(bind=testing.db)
-    def drop_tables(self):
+    
+    @classmethod
+    def drop_tables(cls):
         tables.metadata.drop_all(bind=testing.db)
 
-    def setUpAll(self):
-        self.create_tables()
-    def tearDownAll(self):
-        self.drop_tables()
-    def tearDown(self):
+    @classmethod
+    def setup_class(cls):
+        cls.create_tables()
+    @classmethod
+    def teardown_class(cls):
+        cls.drop_tables()
+    def teardown(self):
         pass
-    def setUp(self):
+    def setup(self):
         pass
 
 
 class SavePostTest(ZBlogTest):
-    def setUpAll(self):
-        super(SavePostTest, self).setUpAll()
+    @classmethod
+    def setup_class(cls):
+        super(SavePostTest, cls).setup_class()
+        
         mappers.zblog_mappers()
         global blog_id, user_id
         s = create_session(bind=testing.db)
@@ -41,9 +47,10 @@ class SavePostTest(ZBlogTest):
         user_id = user.id
         s.close()
 
-    def tearDownAll(self):
+    @classmethod
+    def teardown_class(cls):
         clear_mappers()
-        super(SavePostTest, self).tearDownAll()
+        super(SavePostTest, cls).teardown_class()
 
     def testattach(self):
         """test that a transient/pending instance has proper bi-directional behavior.
@@ -93,5 +100,3 @@ class SavePostTest(ZBlogTest):
             s.rollback()
 
 
-if __name__ == "__main__":
-    testenv.main()