From: Mike Bayer Date: Wed, 10 Jun 2009 21:18:24 +0000 (+0000) Subject: - unit tests have been migrated from unittest to nose. X-Git-Tag: rel_0_5_5~19 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=45cec095b4904ba71425d2fe18c143982dd08f43;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - unit tests have been migrated from unittest to nose. See README.unittests for information on how to run the tests. [ticket:970] --- diff --git a/CHANGES b/CHANGES index 0653bd68da..9a4596c72f 100644 --- a/CHANGES +++ b/CHANGES @@ -6,6 +6,10 @@ CHANGES 0.5.5 ======= +- general + - unit tests have been migrated from unittest to nose. + See README.unittests for information on how to run + the tests. [ticket:970] - orm - Fixed bug introduced in 0.5.4 whereby Composite types fail when default-holding columns are flushed. diff --git a/README.unittests b/README.unittests index f70f6ab177..bfc31e28fb 100644 --- a/README.unittests +++ b/README.unittests @@ -2,65 +2,63 @@ SQLALCHEMY UNIT TESTS ===================== -SETUP ------ -SQLite support is required. These instructions assume standard Python 2.4 or -higher. - -The 'test' directory must be on the PYTHONPATH. +SQLAlchemy unit tests by default run using Python's built-in sqlite3 +module. If running on Python 2.4, pysqlite must be installed. -cd into the SQLAlchemy distribution directory +As of 0.5.5, unit tests are run using nose. Documentation and +downloads for nose are available at: -In bash: +http://somethingaboutorange.com/mrl/projects/nose/0.11.1/index.html - $ export PYTHONPATH=./test/ -On windows: +SQLAlchemy implements a nose plugin that must be present when tests are run. +This plugin is available when SQLAlchemy is installed via setuptools. - C:\sa\> set PYTHONPATH=test\ +SETUP +----- - Adjust any other use Unix-style paths in this README as needed. +All that's required is for SQLAlchemy to be installed via setuptools. +For example, to create a local install in a source distribution directory: -The unittest framework will automatically prepend the lib/ directory to -sys.path. This forces the local version of SQLAlchemy to be used, bypassing -any setuptools-installed installations (setuptools places .egg files ahead of -plain directories, even if on PYTHONPATH, unfortunately). + $ export PYTHONPATH=. + $ python setup.py develop -d . +The above will create a setuptools "development" distribution in the local +path, which allows the Nose plugin to be available when nosetests is run. +The plugin is enabled using the "with-sqlalchemy=True" configuration +in setup.cfg. RUNNING ALL TESTS ----------------- To run all tests: - $ python test/alltests.py + $ nosetests + +Assuming all tests pass, this is a very unexciting output. To make it more +intersesting: + $ nosetests -v RUNNING INDIVIDUAL TESTS ------------------------- -Any unittest module can be run directly from the module file: +Any test module can be run directly by specifying its module name: - python test/orm/mapper.py + $ nosetests test.orm.test_mapper -To run a specific test within the module, specify it as ClassName.methodname: +To run a specific test within the module, specify it as module:ClassName.methodname: - python test/orm/mapper.py MapperTest.testget + $ nosetests test.orm.test_mapper:MapperTest.test_utils COMMAND LINE OPTIONS -------------------- -Help is available via --help +Help is available via --help: - $ python test/alltests.py --help + $ nosetests --help - usage: alltests.py [options] [tests...] - - Options: - -h, --help show this help message and exit - --verbose enable stdout echoing/printing - --quiet suppress output - [...] - -Command line options can applied to alltests.py or any individual test module. -Many are available. The most commonly used are '--db' and '--dburi'. +The --help screen is a combination of common nose options and options which +the SQLAlchemy nose plugin adds. The most commonly SQLAlchemy-specific +options used are '--db' and '--dburi'. DATABASE TARGETS @@ -78,7 +76,7 @@ preexisting tables will interfere with the tests If you'll be running the tests frequently, database aliases can save a lot of typing. The --dbs option lists the built-in aliases and their matching URLs: - $ python test/alltests.py --dbs + $ nosetests --dbs Available --db options (use --dburi to override) mysql mysql://scott:tiger@127.0.0.1:3306/test oracle oracle://scott:tiger@127.0.0.1:1521 @@ -87,7 +85,7 @@ typing. The --dbs option lists the built-in aliases and their matching URLs: To run tests against an aliased database: - $ python test/alltests.py --db=postgres + $ nosetests --db=postgres To customize the URLs with your own users or hostnames, make a simple .ini file called `test.cfg` at the top level of the SQLAlchemy source distribution @@ -106,7 +104,7 @@ SQLAlchemy logs its activity and debugging through Python's logging package. Any log target can be directed to the console with command line options, such as: - $ python test/orm/unitofwork.py --log-info=sqlalchemy.orm.mapper \ + $ nosetests test.orm.unitofwork --log-info=sqlalchemy.orm.mapper \ --log-debug=sqlalchemy.pool --log-info=sqlalchemy.engine This would log mapper configuration, connection pool checkouts, and SQL @@ -115,21 +113,10 @@ statement execution. BUILT-IN COVERAGE REPORTING ------------------------------ -Coverage is tracked with coverage.py module, included in the './test/' -directory. Running the test suite with the --coverage switch will generate a -local file ".coverage" containing coverage details, and a report will be -printed to standard output with an overview of the coverage gathered from the -last unittest run (the file is deleted between runs). - -After the suite has been run with --coverage, an annotated version of any -source file can be generated, marking statements that are executed with > and -statements that are missed with !, by running the coverage.py utility with the -"-a" (annotate) option, such as: - - $ python ./test/testlib/coverage.py -a ./lib/sqlalchemy/sql.py +Coverage is tracked using Nose's coverage plugin. See the nose +documentation for details. Basic usage is: -This will create a new annotated file ./lib/sqlalchemy/sql.py,cover. Pretty -cool! + $ nosetests test.sql.test_query --with-coverage BIG COVERAGE TIP !!! There is an issue where existing .pyc files may store the incorrect filepaths, which will break the coverage system. If diff --git a/convert.py b/convert.py new file mode 100644 index 0000000000..b574c27a92 --- /dev/null +++ b/convert.py @@ -0,0 +1,228 @@ +import os +import subprocess +import re + +def walk(): + for root, dirs, files in os.walk("./test/"): + if root.endswith("/perf"): + continue + + for fname in files: + if not fname.endswith(".py"): + continue + if fname == "alltests.py": + subprocess.call(["svn", "remove", os.path.join(root, fname)]) + elif fname.startswith("_") or fname == "__init__.py" or fname == "pickleable.py": + convert(os.path.join(root, fname)) + elif not fname.startswith("test_"): + if os.path.exists(os.path.join(root, "test_" + fname)): + os.unlink(os.path.join(root, "test_" + fname)) + subprocess.call(["svn", "rename", os.path.join(root, fname), os.path.join(root, "test_" + fname)]) + convert(os.path.join(root, "test_" + fname)) + + +def convert(fname): + lines = list(file(fname)) + replaced = [] + flags = {} + + while lines: + for reg, handler in handlers: + m = reg.match(lines[0]) + if m: + handler(lines, replaced, flags) + break + + post_handler(lines, replaced, flags) + f = file(fname, 'w') + f.write("".join(replaced)) + f.close() + +handlers = [] + + +def post_handler(lines, replaced, flags): + imports = [] + if "needs_eq" in flags: + imports.append("eq_") + if "needs_assert_raises" in flags: + imports += ["assert_raises", "assert_raises_message"] + if imports: + for i, line in enumerate(replaced): + if "import" in line: + replaced.insert(i, "from sqlalchemy.test.testing import %s\n" % ", ".join(imports)) + break + +def remove_line(lines, replaced, flags): + lines.pop(0) + +handlers.append((re.compile(r"import testenv; testenv\.configure_for_tests"), remove_line)) +handlers.append((re.compile(r"(.*\s)?import sa_unittest"), remove_line)) + + +def import_testlib_sa(lines, replaced, flags): + line = lines.pop(0) + line = line.replace("import testlib.sa", "import sqlalchemy") + replaced.append(line) +handlers.append((re.compile("import testlib\.sa"), import_testlib_sa)) + +def from_testlib_sa(lines, replaced, flags): + line = lines.pop(0) + while True: + if line.endswith("\\\n"): + line = line[0:-2] + lines.pop(0) + else: + break + + components = re.compile(r'from testlib\.sa import (.*)').match(line) + if components: + components = re.split(r"\s*,\s*", components.group(1)) + line = "from sqlalchemy import %s\n" % (", ".join(c for c in components if c not in ("Table", "Column"))) + replaced.append(line) + if "Table" in components: + replaced.append("from sqlalchemy.test.schema import Table\n") + if "Column" in components: + replaced.append("from sqlalchemy.test.schema import Column\n") + return + + line = line.replace("testlib.sa", "sqlalchemy") + replaced.append(line) +handlers.append((re.compile("from testlib\.sa.*import"), from_testlib_sa)) + +def from_testlib(lines, replaced, flags): + line = lines.pop(0) + + components = re.compile(r'from testlib import (.*)').match(line) + if components: + components = re.split(r"\s*,\s*", components.group(1)) + if "sa" in components: + replaced.append("import sqlalchemy as sa\n") + replaced.append("from sqlalchemy.test import %s\n" % (", ".join(c for c in components if c != "sa" and c != "sa as tsa"))) + return + elif "sa as tsa" in components: + replaced.append("import sqlalchemy as tsa\n") + replaced.append("from sqlalchemy.test import %s\n" % (", ".join(c for c in components if c != "sa" and c != "sa as tsa"))) + return + + line = line.replace("testlib", "sqlalchemy.test") + replaced.append(line) +handlers.append((re.compile(r"from testlib"), from_testlib)) + +def from_orm(lines, replaced, flags): + line = lines.pop(0) + line = line.replace("from orm import", "from test.orm import") + line = line.replace("from orm.", "from test.orm.") + replaced.append(line) +handlers.append((re.compile(r'from orm( import|\.)'), from_orm)) + +def assert_equals(lines, replaced, flags): + line = lines.pop(0) + line = line.replace("self.assertEquals", "eq_") + line = line.replace("self.assertEqual", "eq_") + replaced.append(line) + flags["needs_eq"] = True +handlers.append((re.compile(r"\s*self\.assertEqual(s)?"), assert_equals)) + +def assert_raises(lines, replaced, flags): + line = lines.pop(0) + line = line.replace("self.assertRaisesMessage", "assert_raises_message") + line = line.replace("self.assertRaises", "assert_raises") + replaced.append(line) + flags["needs_assert_raises"] = True +handlers.append((re.compile(r"\s*self\.assertRaises(Message)?"), assert_raises)) + +def setup_all(lines, replaced, flags): + line = lines.pop(0) + whitespace = re.compile(r"(\s*)def setUpAll\(self\)\:").match(line).group(1) + replaced.append("%s@classmethod\n" % whitespace) + replaced.append("%sdef setup_class(cls):\n" % whitespace) +handlers.append((re.compile(r"\s*def setUpAll\(self\)"), setup_all)) + +def teardown_all(lines, replaced, flags): + line = lines.pop(0) + whitespace = re.compile(r"(\s*)def tearDownAll\(self\)\:").match(line).group(1) + replaced.append("%s@classmethod\n" % whitespace) + replaced.append("%sdef teardown_class(cls):\n" % whitespace) +handlers.append((re.compile(r"\s*def tearDownAll\(self\)"), teardown_all)) + +def setup(lines, replaced, flags): + line = lines.pop(0) + whitespace = re.compile(r"(\s*)def setUp\(self\)\:").match(line).group(1) + replaced.append("%sdef setup(self):\n" % whitespace) +handlers.append((re.compile(r"\s*def setUp\(self\)"), setup)) + +def teardown(lines, replaced, flags): + line = lines.pop(0) + whitespace = re.compile(r"(\s*)def tearDown\(self\)\:").match(line).group(1) + replaced.append("%sdef teardown(self):\n" % whitespace) +handlers.append((re.compile(r"\s*def tearDown\(self\)"), teardown)) + +def define_tables(lines, replaced, flags): + line = lines.pop(0) + whitespace = re.compile(r"(\s*)def define_tables").match(line).group(1) + replaced.append("%s@classmethod\n" % whitespace) + replaced.append("%sdef define_tables(cls, metadata):\n" % whitespace) +handlers.append((re.compile(r"\s*def define_tables\(self, metadata\)"), define_tables)) + +def setup_mappers(lines, replaced, flags): + line = lines.pop(0) + whitespace = re.compile(r"(\s*)def setup_mappers").match(line).group(1) + + i = -1 + while re.match("\s*@testing", replaced[i]): + i -= 1 + + replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace) + replaced.append("%sdef setup_mappers(cls):\n" % whitespace) +handlers.append((re.compile(r"\s*def setup_mappers\(self\)"), setup_mappers)) + +def setup_classes(lines, replaced, flags): + line = lines.pop(0) + whitespace = re.compile(r"(\s*)def setup_classes").match(line).group(1) + + i = -1 + while re.match("\s*@testing", replaced[i]): + i -= 1 + + replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace) + replaced.append("%sdef setup_classes(cls):\n" % whitespace) +handlers.append((re.compile(r"\s*def setup_classes\(self\)"), setup_classes)) + +def insert_data(lines, replaced, flags): + line = lines.pop(0) + whitespace = re.compile(r"(\s*)def insert_data").match(line).group(1) + + i = -1 + while re.match("\s*@testing", replaced[i]): + i -= 1 + + replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace) + replaced.append("%sdef insert_data(cls):\n" % whitespace) +handlers.append((re.compile(r"\s*def insert_data\(self\)"), insert_data)) + +def fixtures(lines, replaced, flags): + line = lines.pop(0) + whitespace = re.compile(r"(\s*)def fixtures").match(line).group(1) + + i = -1 + while re.match("\s*@testing", replaced[i]): + i -= 1 + + replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace) + replaced.append("%sdef fixtures(cls):\n" % whitespace) +handlers.append((re.compile(r"\s*def fixtures\(self\)"), fixtures)) + + +def call_main(lines, replaced, flags): + replaced.pop(-1) + lines.pop(0) +handlers.append((re.compile(r"\s+testenv\.main\(\)"), call_main)) + +def default(lines, replaced, flags): + replaced.append(lines.pop(0)) +handlers.append((re.compile(r".*"), default)) + + +if __name__ == '__main__': + convert("test/orm/inheritance/abc_inheritance.py") +# walk() diff --git a/lib/sqlalchemy/test/__init__.py b/lib/sqlalchemy/test/__init__.py new file mode 100644 index 0000000000..d69cedefdd --- /dev/null +++ b/lib/sqlalchemy/test/__init__.py @@ -0,0 +1,26 @@ +"""Testing environment and utilities. + +This package contains base classes and routines used by +the unit tests. Tests are based on Nose and bootstrapped +by noseplugin.NoseSQLAlchemy. + +""" + +from sqlalchemy.test import testing, engines, requires, profiling, pickleable, config +from sqlalchemy.test.schema import Column, Table +from sqlalchemy.test.testing import \ + AssertsCompiledSQL, \ + AssertsExecutionResults, \ + ComparesTables, \ + TestBase, \ + rowset + + +__all__ = ('testing', + 'Column', 'Table', + 'rowset', + 'TestBase', 'AssertsExecutionResults', + 'AssertsCompiledSQL', 'ComparesTables', + 'engines', 'profiling', 'pickleable') + + diff --git a/test/testlib/assertsql.py b/lib/sqlalchemy/test/assertsql.py similarity index 100% rename from test/testlib/assertsql.py rename to lib/sqlalchemy/test/assertsql.py diff --git a/lib/sqlalchemy/test/config.py b/lib/sqlalchemy/test/config.py new file mode 100644 index 0000000000..6ea5667cc3 --- /dev/null +++ b/lib/sqlalchemy/test/config.py @@ -0,0 +1,177 @@ +import optparse, os, sys, re, ConfigParser, StringIO, time, warnings +logging = None + +__all__ = 'parser', 'configure', 'options', + +db = None +db_label, db_url, db_opts = None, None, {} + +options = None +file_config = None + +base_config = """ +[db] +sqlite=sqlite:///:memory: +sqlite_file=sqlite:///querytest.db +postgres=postgres://scott:tiger@127.0.0.1:5432/test +mysql=mysql://scott:tiger@127.0.0.1:3306/test +oracle=oracle://scott:tiger@127.0.0.1:1521 +oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 +mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test +firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb +maxdb=maxdb://MONA:RED@/maxdb1 +""" + +def _log(option, opt_str, value, parser): + global logging + if not logging: + import logging + logging.basicConfig() + + if opt_str.endswith('-info'): + logging.getLogger(value).setLevel(logging.INFO) + elif opt_str.endswith('-debug'): + logging.getLogger(value).setLevel(logging.DEBUG) + + +def _list_dbs(*args): + print "Available --db options (use --dburi to override)" + for macro in sorted(file_config.options('db')): + print "%20s\t%s" % (macro, file_config.get('db', macro)) + sys.exit(0) + +def _server_side_cursors(options, opt_str, value, parser): + db_opts['server_side_cursors'] = True + +def _engine_strategy(options, opt_str, value, parser): + if value: + db_opts['strategy'] = value + +class _ordered_map(object): + def __init__(self): + self._keys = list() + self._data = dict() + + def __setitem__(self, key, value): + if key not in self._keys: + self._keys.append(key) + self._data[key] = value + + def __iter__(self): + for key in self._keys: + yield self._data[key] + +# at one point in refactoring, modules were injecting into the config +# process. this could probably just become a list now. +post_configure = _ordered_map() + +def _engine_uri(options, file_config): + global db_label, db_url + db_label = 'sqlite' + if options.dburi: + db_url = options.dburi + db_label = db_url[:db_url.index(':')] + elif options.db: + db_label = options.db + db_url = None + + if db_url is None: + if db_label not in file_config.options('db'): + raise RuntimeError( + "Unknown engine. Specify --dbs for known engines.") + db_url = file_config.get('db', db_label) +post_configure['engine_uri'] = _engine_uri + +def _require(options, file_config): + if not(options.require or + (file_config.has_section('require') and + file_config.items('require'))): + return + + try: + import pkg_resources + except ImportError: + raise RuntimeError("setuptools is required for version requirements") + + cmdline = [] + for requirement in options.require: + pkg_resources.require(requirement) + cmdline.append(re.split('\s*(=)', requirement, 1)[0]) + + if file_config.has_section('require'): + for label, requirement in file_config.items('require'): + if not label == db_label or label.startswith('%s.' % db_label): + continue + seen = [c for c in cmdline if requirement.startswith(c)] + if seen: + continue + pkg_resources.require(requirement) +post_configure['require'] = _require + +def _engine_pool(options, file_config): + if options.mockpool: + from sqlalchemy import pool + db_opts['poolclass'] = pool.AssertionPool +post_configure['engine_pool'] = _engine_pool + +def _create_testing_engine(options, file_config): + from sqlalchemy.test import engines, testing + global db + db = engines.testing_engine(db_url, db_opts) + testing.db = db +post_configure['create_engine'] = _create_testing_engine + +def _prep_testing_database(options, file_config): + from sqlalchemy.test import engines + from sqlalchemy import schema + + try: + # also create alt schemas etc. here? + if options.dropfirst: + e = engines.utf8_engine() + existing = e.table_names() + if existing: + print "Dropping existing tables in database: " + db_url + try: + print "Tables: %s" % ', '.join(existing) + except: + pass + print "Abort within 5 seconds..." + time.sleep(5) + md = schema.MetaData(e, reflect=True) + md.drop_all() + e.dispose() + except (KeyboardInterrupt, SystemExit): + raise + except Exception, e: + warnings.warn(RuntimeWarning( + "Error checking for existing tables in testing " + "database: %s" % e)) +post_configure['prep_db'] = _prep_testing_database + +def _set_table_options(options, file_config): + from sqlalchemy.test import schema + + table_options = schema.table_options + for spec in options.tableopts: + key, value = spec.split('=') + table_options[key] = value + + if options.mysql_engine: + table_options['mysql_engine'] = options.mysql_engine +post_configure['table_options'] = _set_table_options + +def _reverse_topological(options, file_config): + if options.reversetop: + from sqlalchemy.orm import unitofwork + from sqlalchemy import topological + class RevQueueDepSort(topological.QueueDependencySorter): + def __init__(self, tuples, allitems): + self.tuples = list(tuples) + self.allitems = list(allitems) + self.tuples.reverse() + self.allitems.reverse() + topological.QueueDependencySorter = RevQueueDepSort + unitofwork.DependencySorter = RevQueueDepSort +post_configure['topological'] = _reverse_topological + diff --git a/test/testlib/engines.py b/lib/sqlalchemy/test/engines.py similarity index 96% rename from test/testlib/engines.py rename to lib/sqlalchemy/test/engines.py index 4068f43d0a..f0001978bf 100644 --- a/test/testlib/engines.py +++ b/lib/sqlalchemy/test/engines.py @@ -1,7 +1,7 @@ import sys, types, weakref from collections import deque -from testlib import config -from testlib.compat import _function_named, callable +import config +from sqlalchemy.util import function_named, callable class ConnectionKiller(object): def __init__(self): @@ -44,7 +44,7 @@ def assert_conns_closed(fn): fn(*args, **kw) finally: testing_reaper.assert_all_closed() - return _function_named(decorated, fn.__name__) + return function_named(decorated, fn.__name__) def rollback_open_connections(fn): """Decorator that rolls back all open connections after fn execution.""" @@ -54,7 +54,7 @@ def rollback_open_connections(fn): fn(*args, **kw) finally: testing_reaper.rollback_all() - return _function_named(decorated, fn.__name__) + return function_named(decorated, fn.__name__) def close_open_connections(fn): """Decorator that closes all connections after fn execution.""" @@ -64,7 +64,7 @@ def close_open_connections(fn): fn(*args, **kw) finally: testing_reaper.close_all() - return _function_named(decorated, fn.__name__) + return function_named(decorated, fn.__name__) def all_dialects(): import sqlalchemy.databases as d @@ -104,7 +104,7 @@ def testing_engine(url=None, options=None): """Produce an engine configured by --options with optional overrides.""" from sqlalchemy import create_engine - from testlib.assertsql import asserter + from sqlalchemy.test.assertsql import asserter url = url or config.db_url options = options or config.db_opts diff --git a/lib/sqlalchemy/test/noseplugin.py b/lib/sqlalchemy/test/noseplugin.py new file mode 100644 index 0000000000..263d2d7831 --- /dev/null +++ b/lib/sqlalchemy/test/noseplugin.py @@ -0,0 +1,156 @@ +import logging +import os +import re +import sys +import time +import warnings +import ConfigParser +import StringIO +from config import db, db_label, db_url, file_config, base_config, \ + post_configure, \ + _list_dbs, _server_side_cursors, _engine_strategy, \ + _engine_uri, _require, _engine_pool, \ + _create_testing_engine, _prep_testing_database, \ + _set_table_options, _reverse_topological, _log +from sqlalchemy.test import testing, config, requires +from nose.plugins import Plugin +from nose.util import tolist +import nose.case + +log = logging.getLogger('nose.plugins.sqlalchemy') + +class NoseSQLAlchemy(Plugin): + """ + Handles the setup and extra properties required for testing SQLAlchemy + """ + enabled = True + name = 'sqlalchemy' + score = 100 + + def options(self, parser, env=os.environ): + Plugin.options(self, parser, env) + opt = parser.add_option + #opt("--verbose", action="store_true", dest="verbose", + #help="enable stdout echoing/printing") + #opt("--quiet", action="store_true", dest="quiet", help="suppress output") + opt("--log-info", action="callback", type="string", callback=_log, + help="turn on info logging for (multiple OK)") + opt("--log-debug", action="callback", type="string", callback=_log, + help="turn on debug logging for (multiple OK)") + opt("--require", action="append", dest="require", default=[], + help="require a particular driver or module version (multiple OK)") + opt("--db", action="store", dest="db", default="sqlite", + help="Use prefab database uri") + opt('--dbs', action='callback', callback=_list_dbs, + help="List available prefab dbs") + opt("--dburi", action="store", dest="dburi", + help="Database uri (overrides --db)") + opt("--dropfirst", action="store_true", dest="dropfirst", + help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)") + opt("--mockpool", action="store_true", dest="mockpool", + help="Use mock pool (asserts only one connection used)") + opt("--enginestrategy", action="callback", type="string", + callback=_engine_strategy, + help="Engine strategy (plain or threadlocal, defaults to plain)") + opt("--reversetop", action="store_true", dest="reversetop", default=False, + help="Reverse the collection ordering for topological sorts (helps " + "reveal dependency issues)") + opt("--unhashable", action="store_true", dest="unhashable", default=False, + help="Disallow SQLAlchemy from performing a hash() on mapped test objects.") + opt("--noncomparable", action="store_true", dest="noncomparable", default=False, + help="Disallow SQLAlchemy from performing == on mapped test objects.") + opt("--truthless", action="store_true", dest="truthless", default=False, + help="Disallow SQLAlchemy from truth-evaluating mapped test objects.") + opt("--serverside", action="callback", callback=_server_side_cursors, + help="Turn on server side cursors for PG") + opt("--mysql-engine", action="store", dest="mysql_engine", default=None, + help="Use the specified MySQL storage engine for all tables, default is " + "a db-default/InnoDB combo.") + opt("--table-option", action="append", dest="tableopts", default=[], + help="Add a dialect-specific table option, key=value") + + global file_config + file_config = ConfigParser.ConfigParser() + file_config.readfp(StringIO.StringIO(base_config)) + file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')]) + config.file_config = file_config + + def configure(self, options, conf): + Plugin.configure(self, options, conf) + + import testing, requires + testing.db = db + testing.requires = requires + + # Lazy setup of other options (post coverage) + for fn in post_configure: + fn(options, file_config) + + def describeTest(self, test): + return "" + + def wantClass(self, cls): + """Return true if you want the main test selector to collect + tests from this class, false if you don't, and None if you don't + care. + + :Parameters: + cls : class + The class being examined by the selector + + """ + + if not issubclass(cls, testing.TestBase): + return False + else: + if (hasattr(cls, '__whitelist__') and + testing.db.name in cls.__whitelist__): + return True + else: + return not self.__should_skip_for(cls) + + def __should_skip_for(self, cls): + if hasattr(cls, '__requires__'): + def test_suite(): return 'ok' + for requirement in cls.__requires__: + check = getattr(requires, requirement) + if check(test_suite)() != 'ok': + # The requirement will perform messaging. + return True + if (hasattr(cls, '__unsupported_on__') and + testing.db.name in cls.__unsupported_on__): + print "'%s' unsupported on DB implementation '%s'" % ( + cls.__class__.__name__, testing.db.name) + return True + if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)): + print "'%s' unsupported on DB implementation '%s'" % ( + cls.__class__.__name__, testing.db.name) + return True + if (getattr(cls, '__skip_if__', False)): + for c in getattr(cls, '__skip_if__'): + if c(): + print "'%s' skipped by %s" % ( + cls.__class__.__name__, c.__name__) + return True + for rule in getattr(cls, '__excluded_on__', ()): + if testing._is_excluded(*rule): + print "'%s' unsupported on DB %s version %s" % ( + cls.__class__.__name__, testing.db.name, + _server_version()) + return True + return False + + #def begin(self): + #pass + + def beforeTest(self, test): + testing.resetwarnings() + + def afterTest(self, test): + testing.resetwarnings() + + #def handleError(self, test, err): + #pass + + #def finalize(self, result=None): + #pass diff --git a/test/testlib/orm.py b/lib/sqlalchemy/test/orm.py similarity index 96% rename from test/testlib/orm.py rename to lib/sqlalchemy/test/orm.py index 22d6246011..7ec13c5559 100644 --- a/test/testlib/orm.py +++ b/lib/sqlalchemy/test/orm.py @@ -1,8 +1,6 @@ import inspect, re -from testlib import config, testing - -sa = None -orm = None +import config, testing +from sqlalchemy import orm __all__ = 'mapper', @@ -93,10 +91,6 @@ def _make_blocker(method_name, fallback): return method def mapper(type_, *args, **kw): - global orm - if orm is None: - from sqlalchemy import orm - forbidden = [ ('__hash__', 'unhashable', lambda s: id(s)), ('__eq__', 'noncomparable', lambda s, o: s is o), diff --git a/test/pickleable.py b/lib/sqlalchemy/test/pickleable.py similarity index 90% rename from test/pickleable.py rename to lib/sqlalchemy/test/pickleable.py index ffb22f3a24..9794e424db 100644 --- a/test/pickleable.py +++ b/lib/sqlalchemy/test/pickleable.py @@ -1,5 +1,9 @@ -"""since the cPickle module as of py2.4 uses erroneous relative imports, define the various -picklable classes here so we can test PickleType stuff without issue.""" +""" + +some objects used for pickle tests, declared in their own module so that they +are easily pickleable. + +""" class Foo(object): diff --git a/test/testlib/profiling.py b/lib/sqlalchemy/test/profiling.py similarity index 89% rename from test/testlib/profiling.py rename to lib/sqlalchemy/test/profiling.py index 89db330111..ca4b31cbd8 100644 --- a/test/testlib/profiling.py +++ b/lib/sqlalchemy/test/profiling.py @@ -1,8 +1,13 @@ -"""Profiling support for unit and performance tests.""" +"""Profiling support for unit and performance tests. + +These are special purpose profiling methods which operate +in a more fine-grained way than nose's profiling plugin. + +""" import os, sys -from testlib.compat import _function_named -import testlib.config +from sqlalchemy.util import function_named +import config __all__ = 'profiled', 'function_call_count', 'conditional_call_count' @@ -43,12 +48,8 @@ def profiled(target=None, **target_opts): elapsed, load_stats, result = _profile( filename, fn, *args, **kw) - if not testlib.config.options.quiet: - print "Profiled target '%s', wall time: %.2f seconds" % ( - target, elapsed) - report = target_opts.get('report', profile_config['report']) - if report and testlib.config.options.verbose: + if report: sort_ = target_opts.get('sort', profile_config['sort']) limit = target_opts.get('limit', profile_config['limit']) print "Profile report for target '%s' (%s)" % ( @@ -63,7 +64,7 @@ def profiled(target=None, **target_opts): #stats.print_callers() os.unlink(filename) return result - return _function_named(profiled, fn.__name__) + return function_named(profiled, fn.__name__) return decorator def function_call_count(count=None, versions={}, variance=0.05): @@ -113,10 +114,9 @@ def function_call_count(count=None, versions={}, variance=0.05): stats = stat_loader() calls = stats.total_calls - if testlib.config.options.verbose: - stats.sort_stats('calls', 'cumulative') - stats.print_stats() - #stats.print_callers() + stats.sort_stats('calls', 'cumulative') + stats.print_stats() + #stats.print_callers() deviance = int(count * variance) if (calls < (count - deviance) or calls > (count + deviance)): @@ -129,7 +129,7 @@ def function_call_count(count=None, versions={}, variance=0.05): finally: if os.path.exists(filename): os.unlink(filename) - return _function_named(counted, fn.__name__) + return function_named(counted, fn.__name__) return decorator def conditional_call_count(discriminator, categories): @@ -155,7 +155,7 @@ def conditional_call_count(discriminator, categories): rewrapped = function_call_count(*criteria)(fn) return rewrapped(*args, **kw) - return _function_named(at_runtime, fn.__name__) + return function_named(at_runtime, fn.__name__) return decorator diff --git a/test/testlib/requires.py b/lib/sqlalchemy/test/requires.py similarity index 99% rename from test/testlib/requires.py rename to lib/sqlalchemy/test/requires.py index b20929a83b..b23b8620da 100644 --- a/test/testlib/requires.py +++ b/lib/sqlalchemy/test/requires.py @@ -5,7 +5,7 @@ target database. """ -from testlib.testing import \ +from testing import \ _block_unconditionally as no_support, \ _chain_decorators_on, \ exclude, \ diff --git a/test/testlib/schema.py b/lib/sqlalchemy/test/schema.py similarity index 92% rename from test/testlib/schema.py rename to lib/sqlalchemy/test/schema.py index 7009fd65d8..f96805fe49 100644 --- a/test/testlib/schema.py +++ b/lib/sqlalchemy/test/schema.py @@ -1,6 +1,9 @@ -from testlib import testing +"""Enhanced versions of schema.Table and schema.Column which establish +desired state for different backends. +""" -schema = None +from sqlalchemy.test import testing +from sqlalchemy import schema __all__ = 'Table', 'Column', @@ -9,10 +12,6 @@ table_options = {} def Table(*args, **kw): """A schema.Table wrapper/hook for dialect-specific tweaks.""" - global schema - if schema is None: - from sqlalchemy import schema - test_opts = dict([(k,kw.pop(k)) for k in kw.keys() if k.startswith('test_')]) @@ -65,10 +64,6 @@ def Table(*args, **kw): def Column(*args, **kw): """A schema.Column wrapper/hook for dialect-specific tweaks.""" - global schema - if schema is None: - from sqlalchemy import schema - test_opts = dict([(k,kw.pop(k)) for k in kw.keys() if k.startswith('test_')]) diff --git a/test/testlib/testing.py b/lib/sqlalchemy/test/testing.py similarity index 71% rename from test/testlib/testing.py rename to lib/sqlalchemy/test/testing.py index 408dda79f1..36c7d340a3 100644 --- a/test/testlib/testing.py +++ b/lib/sqlalchemy/test/testing.py @@ -1,28 +1,17 @@ """TestCase and TestSuite artifacts and testing decorators.""" -# monkeypatches unittest.TestLoader.suiteClass at import time - import itertools import operator import re import sys import types -from testlib import sa_unittest as unittest import warnings from cStringIO import StringIO -import testlib.config as config -from testlib.compat import _function_named, callable - -# Delayed imports -MetaData = None -Session = None -clear_mappers = None -sa_exc = None -schema = None -sqltypes = None -util = None +from sqlalchemy.test import config, assertsql +from sqlalchemy.util import function_named +from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema _ops = { '<': operator.lt, '>': operator.gt, @@ -68,7 +57,7 @@ def fails_if(callable_): raise AssertionError( "Unexpected success for '%s' (condition: %s)" % (fn_name, description)) - return _function_named(maybe, fn_name) + return function_named(maybe, fn_name) return decorate @@ -89,7 +78,7 @@ def future(fn): else: raise AssertionError( "Unexpected success for future test '%s'" % fn_name) - return _function_named(decorated, fn_name) + return function_named(decorated, fn_name) def fails_on(dbs, reason): """Mark a test as expected to fail on the specified database @@ -118,7 +107,7 @@ def fails_on(dbs, reason): raise AssertionError( "Unexpected success for '%s' on DB implementation '%s'" % (fn_name, config.db.name)) - return _function_named(maybe, fn_name) + return function_named(maybe, fn_name) return decorate def fails_on_everything_except(*dbs): @@ -145,7 +134,7 @@ def fails_on_everything_except(*dbs): raise AssertionError( "Unexpected success for '%s' on DB implementation '%s'" % (fn_name, config.db.name)) - return _function_named(maybe, fn_name) + return function_named(maybe, fn_name) return decorate def crashes(db, reason): @@ -168,7 +157,7 @@ def crashes(db, reason): return True else: return fn(*args, **kw) - return _function_named(maybe, fn_name) + return function_named(maybe, fn_name) return decorate def _block_unconditionally(db, reason): @@ -192,7 +181,7 @@ def _block_unconditionally(db, reason): return True else: return fn(*args, **kw) - return _function_named(maybe, fn_name) + return function_named(maybe, fn_name) return decorate @@ -221,7 +210,7 @@ def exclude(db, op, spec, reason): return True else: return fn(*args, **kw) - return _function_named(maybe, fn_name) + return function_named(maybe, fn_name) return decorate def _should_carp_about_exclusion(reason): @@ -281,7 +270,7 @@ def skip_if(predicate, reason=None): return True else: return fn(*args, **kw) - return _function_named(maybe, fn_name) + return function_named(maybe, fn_name) return decorate def emits_warning(*messages): @@ -299,10 +288,6 @@ def emits_warning(*messages): # - update: jython looks ok, it uses cpython's module def decorate(fn): def safe(*args, **kw): - global sa_exc - if sa_exc is None: - import sqlalchemy.exc as sa_exc - # todo: should probably be strict about this, too filters = [dict(action='ignore', category=sa_exc.SAPendingDeprecationWarning)] @@ -320,7 +305,7 @@ def emits_warning(*messages): return fn(*args, **kw) finally: resetwarnings() - return _function_named(safe, fn.__name__) + return function_named(safe, fn.__name__) return decorate def emits_warning_on(db, *warnings): @@ -344,7 +329,7 @@ def emits_warning_on(db, *warnings): else: wrapped = emits_warning(*warnings)(fn) return wrapped(*args, **kw) - return _function_named(maybe, fn.__name__) + return function_named(maybe, fn.__name__) return decorate def uses_deprecated(*messages): @@ -361,10 +346,6 @@ def uses_deprecated(*messages): def decorate(fn): def safe(*args, **kw): - global sa_exc - if sa_exc is None: - import sqlalchemy.exc as sa_exc - # todo: should probably be strict about this, too filters = [dict(action='ignore', category=sa_exc.SAPendingDeprecationWarning)] @@ -387,16 +368,12 @@ def uses_deprecated(*messages): return fn(*args, **kw) finally: resetwarnings() - return _function_named(safe, fn.__name__) + return function_named(safe, fn.__name__) return decorate def resetwarnings(): """Reset warning behavior to testing defaults.""" - global sa_exc - if sa_exc is None: - import sqlalchemy.exc as sa_exc - warnings.filterwarnings('ignore', category=sa_exc.SAPendingDeprecationWarning) warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning) @@ -474,7 +451,23 @@ def startswith_(a, fragment, msg=None): assert a.startswith(fragment), msg or "%r does not start with %r" % ( a, fragment) +def assert_raises(except_cls, callable_, *args, **kw): + try: + callable_(*args, **kw) + assert False, "Callable did not raise an exception" + except except_cls, e: + pass +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + try: + callable_(*args, **kwargs) + assert False, "Callable did not raise an exception" + except except_cls, e: + assert re.search(msg, str(e)), "%r !~ %s" % (msg, e) + +def fail(msg): + assert False, msg + def fixture(table, columns, *rows): """Insert data into table after creation.""" def onload(event, schema_item, connection): @@ -484,43 +477,6 @@ def fixture(table, columns, *rows): for column_values in rows]) table.append_ddl_listener('after-create', onload) -def _import_by_name(name): - submodule = name.split('.')[-1] - return __import__(name, globals(), locals(), [submodule]) - -class CompositeModule(types.ModuleType): - """Merged attribute access for multiple modules.""" - - # break the habit - __all__ = () - - def __init__(self, name, *modules, **overrides): - """Construct a new lazy composite of modules. - - Modules may be string names or module-like instances. Individual - attribute overrides may be specified as keyword arguments for - convenience. - - The constructed module will resolve attribute access in reverse order: - overrides, then each member of reversed(modules). Modules specified - by name will be loaded lazily when encountered in attribute - resolution. - - """ - types.ModuleType.__init__(self, name) - self.__modules = list(reversed(modules)) - for key, value in overrides.iteritems(): - setattr(self, key, value) - - def __getattr__(self, key): - for idx, mod in enumerate(self.__modules): - if isinstance(mod, basestring): - self.__modules[idx] = mod = _import_by_name(mod) - if hasattr(mod, key): - return getattr(mod, key) - raise AttributeError(key) - - def resolve_artifact_names(fn): """Decorator, augment function globals with tables and classes. @@ -546,7 +502,7 @@ def resolve_artifact_names(fn): fn.func_code, context, fn.func_name, fn.func_defaults, fn.func_closure) return rebound(*args, **kwargs) - return _function_named(resolved, fn.func_name) + return function_named(resolved, fn.func_name) class adict(dict): """Dict keys available as attributes. Shadows.""" @@ -560,7 +516,7 @@ class adict(dict): return tuple([self[key] for key in keys]) -class TestBase(unittest.TestCase): +class TestBase(object): # A sequence of database names to always run, regardless of the # constraints below. __whitelist__ = () @@ -579,37 +535,11 @@ class TestBase(unittest.TestCase): # skipped. __skip_if__ = None - _artifact_registries = () - _sa_first_test = False - _sa_last_test = False - - def __init__(self, *args, **params): - unittest.TestCase.__init__(self, *args, **params) - - def setUpAll(self): - pass - - def tearDownAll(self): - pass - - def shortDescription(self): - """overridden to not return docstrings""" - return None - - def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs): - try: - callable_(*args, **kwargs) - assert False, "Callable did not raise an exception" - except except_cls, e: - assert re.search(msg, str(e)), "%r !~ %s" % (msg, e) - - if not hasattr(unittest.TestCase, 'assertTrue'): - assertTrue = unittest.TestCase.failUnless - if not hasattr(unittest.TestCase, 'assertFalse'): - assertFalse = unittest.TestCase.failIf - + def assert_(self, val, msg=None): + assert val, msg + class AssertsCompiledSQL(object): def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None): if dialect is None: @@ -626,25 +556,20 @@ class AssertsCompiledSQL(object): cc = re.sub(r'\n', '', str(c)) - self.assertEquals(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) if checkparams is not None: - self.assertEquals(c.construct_params(params), checkparams) + eq_(c.construct_params(params), checkparams) class ComparesTables(object): def assert_tables_equal(self, table, reflected_table): - global sqltypes, schema - if sqltypes is None: - import sqlalchemy.types as sqltypes - if schema is None: - import sqlalchemy.schema as schema base_mro = sqltypes.TypeEngine.__mro__ assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): - self.assertEquals(c.name, reflected_c.name) + eq_(c.name, reflected_c.name) assert reflected_c is reflected_table.c[c.name] - self.assertEquals(c.primary_key, reflected_c.primary_key) - self.assertEquals(c.nullable, reflected_c.nullable) + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) assert len( set(type(reflected_c.type).__mro__).difference(base_mro).intersection( set(type(c.type).__mro__).difference(base_mro) @@ -652,9 +577,9 @@ class ComparesTables(object): ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) if isinstance(c.type, sqltypes.String): - self.assertEquals(c.type.length, reflected_c.type.length) + eq_(c.type.length, reflected_c.type.length) - self.assertEquals(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) + eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) if c.default: assert isinstance(reflected_c.server_default, schema.FetchedValue) @@ -704,10 +629,6 @@ class AssertsExecutionResults(object): numbers of rows that the test suite manipulates. """ - global util - if util is None: - from sqlalchemy import util - class frozendict(dict): def __hash__(self): return id(self) @@ -716,11 +637,11 @@ class AssertsExecutionResults(object): expected = set([frozendict(e) for e in expected]) for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found): - self.fail('Unexpected type "%s", expected "%s"' % ( + fail('Unexpected type "%s", expected "%s"' % ( type(wrong).__name__, cls.__name__)) if len(found) != len(expected): - self.fail('Unexpected object count "%s", expected "%s"' % ( + fail('Unexpected object count "%s", expected "%s"' % ( len(found), len(expected))) NOVALUE = object() @@ -743,13 +664,12 @@ class AssertsExecutionResults(object): found.remove(found_item) break else: - self.fail( + fail( "Expected %s instance with attributes %s not found." % ( cls.__name__, repr(expected_item))) return True def assert_sql_execution(self, db, callable_, *rules): - from testlib import assertsql assertsql.asserter.add_rules(rules) try: callable_() @@ -758,8 +678,6 @@ class AssertsExecutionResults(object): assertsql.asserter.clear_rules() def assert_sql(self, db, callable_, list_, with_sequences=None): - from testlib import assertsql - if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): rules = with_sequences else: @@ -778,142 +696,6 @@ class AssertsExecutionResults(object): self.assert_sql_execution(db, callable_, *newrules) def assert_sql_count(self, db, callable_, count): - from testlib import assertsql self.assert_sql_execution(db, callable_, assertsql.CountStatements(count)) - -class TTestSuite(unittest.TestSuite): - """A TestSuite with once per TestCase setUpAll() and tearDownAll()""" - - def __init__(self, tests=()): - if len(tests) > 0 and isinstance(tests[0], TestBase): - self._initTest = tests[0] - else: - self._initTest = None - - for t in tests: - if isinstance(t, TestBase): - t._sa_first_test = True - break - for t in reversed(tests): - if isinstance(t, TestBase): - t._sa_last_test = True - break - unittest.TestSuite.__init__(self, tests) - - def run(self, result): - init = getattr(self, '_initTest', None) - if init is not None: - if (hasattr(init, '__whitelist__') and - config.db.name in init.__whitelist__): - pass - else: - if self.__should_skip_for(init): - return True - try: - resetwarnings() - init.setUpAll() - except: - # skip tests if global setup fails - ex = self.__exc_info() - for test in self._tests: - result.addError(test, ex) - return False - try: - resetwarnings() - for test in self._tests: - if result.shouldStop: - break - test(result) - return result - finally: - try: - resetwarnings() - if init is not None: - init.tearDownAll() - except: - result.addError(init, self.__exc_info()) - pass - - def __should_skip_for(self, cls): - if hasattr(cls, '__requires__'): - global requires - if requires is None: - from testing import requires - def test_suite(): return 'ok' - for requirement in cls.__requires__: - check = getattr(requires, requirement) - if check(test_suite)() != 'ok': - # The requirement will perform messaging. - return True - if (hasattr(cls, '__unsupported_on__') and - config.db.name in cls.__unsupported_on__): - print "'%s' unsupported on DB implementation '%s'" % ( - cls.__class__.__name__, config.db.name) - return True - if (getattr(cls, '__only_on__', None) not in (None,config.db.name)): - print "'%s' unsupported on DB implementation '%s'" % ( - cls.__class__.__name__, config.db.name) - return True - if (getattr(cls, '__skip_if__', False)): - for c in getattr(cls, '__skip_if__'): - if c(): - print "'%s' skipped by %s" % ( - cls.__class__.__name__, c.__name__) - return True - for rule in getattr(cls, '__excluded_on__', ()): - if _is_excluded(*rule): - print "'%s' unsupported on DB %s version %s" % ( - cls.__class__.__name__, config.db.name, - _server_version()) - return True - return False - - - def __exc_info(self): - """Return a version of sys.exc_info() with the traceback frame - minimised; usually the top level of the traceback frame is not - needed. - ripped off out of unittest module since its double __ - """ - exctype, excvalue, tb = sys.exc_info() - if sys.platform[:4] == 'java': ## tracebacks look different in Jython - return (exctype, excvalue, tb) - return (exctype, excvalue, tb) - -# monkeypatch -unittest.TestLoader.suiteClass = TTestSuite - - -class DevNullWriter(object): - def write(self, msg): - pass - def flush(self): - pass - -def runTests(suite): - verbose = config.options.verbose - quiet = config.options.quiet - orig_stdout = sys.stdout - - try: - if not verbose or quiet: - sys.stdout = DevNullWriter() - runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2) - return runner.run(suite) - finally: - if not verbose or quiet: - sys.stdout = orig_stdout - -def main(suite=None): - if not suite: - if sys.argv[1:]: - suite =unittest.TestLoader().loadTestsFromNames( - sys.argv[1:], __import__('__main__')) - else: - suite = unittest.TestLoader().loadTestsFromModule( - __import__('__main__')) - - result = runTests(suite) - sys.exit(not result.wasSuccessful()) diff --git a/setup.cfg b/setup.cfg index 01bb954499..25ee974dba 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,6 @@ [egg_info] tag_build = dev tag_svn_revision = true + +[nosetests] +with-sqlalchemy = true \ No newline at end of file diff --git a/setup.py b/setup.py index 6a24677652..3d65f022e0 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,13 @@ setup(name = "SQLAlchemy", packages = find_packages('lib'), package_dir = {'':'lib'}, license = "MIT License", + + entry_points = { + 'nose.plugins.0.10': [ + 'sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy', + ] + }, + long_description = """\ SQLAlchemy is: diff --git a/test/profiling/__init__.py b/test/aaa_profiling/__init__.py similarity index 100% rename from test/profiling/__init__.py rename to test/aaa_profiling/__init__.py diff --git a/test/profiling/compiler.py b/test/aaa_profiling/test_compiler.py similarity index 84% rename from test/profiling/compiler.py rename to test/aaa_profiling/test_compiler.py index 26260068a6..3e4274d47d 100644 --- a/test/profiling/compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -1,10 +1,10 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * -from testlib import * +from sqlalchemy.test import * class CompileTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global t1, t2, metadata metadata = MetaData() t1 = Table('t1', metadata, @@ -28,5 +28,3 @@ class CompileTest(TestBase, AssertsExecutionResults): s = select([t1], t1.c.c2==t2.c.c1) s.compile() -if __name__ == '__main__': - testenv.main() diff --git a/test/profiling/memusage.py b/test/aaa_profiling/test_memusage.py similarity index 96% rename from test/profiling/memusage.py rename to test/aaa_profiling/test_memusage.py index ccafc7bd7e..70a3cf8cd6 100644 --- a/test/profiling/memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1,14 +1,16 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import gc from sqlalchemy.orm import mapper, relation, create_session, clear_mappers, sessionmaker from sqlalchemy.orm.mapper import _mapper_registry from sqlalchemy.orm.session import _sessions import operator -from testlib import testing -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, PickleType +from sqlalchemy.test import testing +from sqlalchemy import MetaData, Integer, String, ForeignKey, PickleType +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column import sqlalchemy as sa from sqlalchemy.sql import column -from orm import _base +from test.orm import _base class A(_base.ComparableEntity): @@ -52,7 +54,7 @@ def assert_no_mappers(): assert len(_mapper_registry) == 0 class EnsureZeroed(_base.ORMTest): - def setUp(self): + def setup(self): _sessions.clear() _mapper_registry.clear() @@ -106,7 +108,7 @@ class MemUsageTest(EnsureZeroed): sess.expunge_all() alist = sess.query(A).all() - self.assertEquals( + eq_( [ A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]), A(col2="a2", bs=[]), @@ -157,7 +159,7 @@ class MemUsageTest(EnsureZeroed): sess.expunge_all() alist = sess.query(A).order_by(A.col1).all() - self.assertEquals( + eq_( [ A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]), A(col2="a2", bs=[]), @@ -217,7 +219,7 @@ class MemUsageTest(EnsureZeroed): sess.expunge_all() alist = sess.query(A).order_by(A.col1).all() - self.assertEquals( + eq_( [ A(), A(), B(col3='b1'), B(col3='b2') ], @@ -281,7 +283,7 @@ class MemUsageTest(EnsureZeroed): sess.expunge_all() alist = sess.query(A).order_by(A.col1).all() - self.assertEquals( + eq_( [ A(bs=[B(col2='b1')]), A(bs=[B(col2='b2')]) ], @@ -398,5 +400,3 @@ class MemUsageTest(EnsureZeroed): cast.compile(dialect=dialect) go() -if __name__ == '__main__': - testenv.main() diff --git a/test/profiling/pool.py b/test/aaa_profiling/test_pool.py similarity index 87% rename from test/profiling/pool.py rename to test/aaa_profiling/test_pool.py index f3f69222c0..7bb61deb28 100644 --- a/test/profiling/pool.py +++ b/test/aaa_profiling/test_pool.py @@ -1,6 +1,5 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * -from testlib import * +from sqlalchemy.test import * from sqlalchemy.pool import QueuePool @@ -9,7 +8,7 @@ class QueuePoolTest(TestBase, AssertsExecutionResults): def close(self): pass - def setUp(self): + def setup(self): global pool pool = QueuePool(creator=self.Connection, pool_size=3, max_overflow=-1, @@ -39,5 +38,3 @@ class QueuePoolTest(TestBase, AssertsExecutionResults): c2 = go() -if __name__ == '__main__': - testenv.main() diff --git a/test/profiling/zoomark.py b/test/aaa_profiling/test_zoomark.py similarity index 99% rename from test/profiling/zoomark.py rename to test/aaa_profiling/test_zoomark.py index c9f3d9df80..be29318964 100644 --- a/test/profiling/zoomark.py +++ b/test/aaa_profiling/test_zoomark.py @@ -6,9 +6,8 @@ An adaptation of Robert Brewers' ZooMark speed tests. import datetime import sys import time -import testenv; testenv.configure_for_tests() from sqlalchemy import * -from testlib import * +from sqlalchemy.test import * ITERATIONS = 1 @@ -356,5 +355,3 @@ class ZooMarkTest(TestBase): self.test_baseline_8_drop() -if __name__ == '__main__': - testenv.main() diff --git a/test/profiling/zoomark_orm.py b/test/aaa_profiling/test_zoomark_orm.py similarity index 99% rename from test/profiling/zoomark_orm.py rename to test/aaa_profiling/test_zoomark_orm.py index 5d7192261d..57e1e24049 100644 --- a/test/profiling/zoomark_orm.py +++ b/test/aaa_profiling/test_zoomark_orm.py @@ -6,10 +6,9 @@ An adaptation of Robert Brewers' ZooMark speed tests. import datetime import sys import time -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * +from sqlalchemy.test import * ITERATIONS = 1 @@ -318,5 +317,3 @@ class ZooMarkTest(TestBase): self.test_baseline_7_drop() -if __name__ == '__main__': - testenv.main() diff --git a/test/alltests.py b/test/alltests.py deleted file mode 100644 index b014bc9da1..0000000000 --- a/test/alltests.py +++ /dev/null @@ -1,24 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -import orm.alltests as orm -import base.alltests as base -import sql.alltests as sql -import engine.alltests as engine -import dialect.alltests as dialect -import ext.alltests as ext -import zblog.alltests as zblog -import profiling.alltests as profiling - -# The profiling tests are sensitive to foibles of CPython VM state, so -# run them first. Ideally, each should be run in a fresh interpreter. - -def suite(): - alltests = unittest.TestSuite() - for suite in (profiling, base, engine, sql, dialect, orm, ext, zblog): - alltests.addTest(suite.suite()) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/base/alltests.py b/test/base/alltests.py deleted file mode 100644 index 3fef623776..0000000000 --- a/test/base/alltests.py +++ /dev/null @@ -1,21 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -def suite(): - modules_to_test = ( - # core utilities - 'base.dependency', - 'base.utils', - 'base.except', - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/base/dependency.py b/test/base/test_dependency.py similarity index 97% rename from test/base/dependency.py rename to test/base/test_dependency.py index 8fcd093b25..0457d552a4 100644 --- a/test/base/dependency.py +++ b/test/base/test_dependency.py @@ -1,6 +1,5 @@ -import testenv; testenv.configure_for_tests() import sqlalchemy.topological as topological -from testlib import TestBase +from sqlalchemy.test import TestBase class DependencySortTest(TestBase): @@ -185,5 +184,3 @@ class DependencySortTest(TestBase): head = topological.sort_as_tree(tuples, []) -if __name__ == "__main__": - testenv.main() diff --git a/test/base/except.py b/test/base/test_except.py similarity index 96% rename from test/base/except.py rename to test/base/test_except.py index 3f4d654771..efb18a153c 100644 --- a/test/base/except.py +++ b/test/base/test_except.py @@ -1,8 +1,7 @@ """Tests exceptions and DB-API exception wrapping.""" -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest import exceptions as stdlib_exceptions from sqlalchemy import exc as sa_exceptions +from sqlalchemy.test import TestBase class Error(stdlib_exceptions.StandardError): @@ -18,7 +17,7 @@ class OutOfSpec(DatabaseError): pass -class WrapTest(unittest.TestCase): +class WrapTest(TestBase): def test_db_error_normal(self): try: raise sa_exceptions.DBAPIError.instance( @@ -118,5 +117,3 @@ class WrapTest(unittest.TestCase): self.assert_(True) -if __name__ == "__main__": - testenv.main() diff --git a/test/base/utils.py b/test/base/test_utils.py similarity index 94% rename from test/base/utils.py rename to test/base/test_utils.py index bc3fc02838..39561e9682 100644 --- a/test/base/utils.py +++ b/test/base/test_utils.py @@ -1,8 +1,8 @@ -import testenv; testenv.configure_for_tests() -import copy, threading, unittest +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import copy, threading from sqlalchemy import util, sql, exc -from testlib import TestBase -from testlib.testing import eq_, is_, ne_ +from sqlalchemy.test import TestBase +from sqlalchemy.test.testing import eq_, is_, ne_ class OrderedDictTest(TestBase): def test_odict(self): @@ -159,7 +159,7 @@ class HashEqOverride(object): return True -class IdentitySetTest(unittest.TestCase): +class IdentitySetTest(TestBase): def assert_eq(self, identityset, expected_iterable): expected = sorted([id(o) for o in expected_iterable]) found = sorted([id(o) for o in identityset]) @@ -205,7 +205,7 @@ class IdentitySetTest(unittest.TestCase): ids.discard(o1) ids.add(o1) ids.remove(o1) - self.assertRaises(KeyError, ids.remove, o1) + assert_raises(KeyError, ids.remove, o1) eq_(ids.copy(), ids) @@ -260,8 +260,8 @@ class IdentitySetTest(unittest.TestCase): except TypeError: assert True - self.assertRaises(TypeError, cmp, ids) - self.assertRaises(TypeError, hash, ids) + assert_raises(TypeError, cmp, ids) + assert_raises(TypeError, hash, ids) def test_difference(self): os1 = util.IdentitySet([1,2,3]) @@ -271,12 +271,12 @@ class IdentitySetTest(unittest.TestCase): eq_(os1 - os2, util.IdentitySet([1, 2])) eq_(os2 - os1, util.IdentitySet([4, 5])) - self.assertRaises(TypeError, lambda: os1 - s2) - self.assertRaises(TypeError, lambda: os1 - [3, 4, 5]) - self.assertRaises(TypeError, lambda: s1 - os2) - self.assertRaises(TypeError, lambda: s1 - [3, 4, 5]) + assert_raises(TypeError, lambda: os1 - s2) + assert_raises(TypeError, lambda: os1 - [3, 4, 5]) + assert_raises(TypeError, lambda: s1 - os2) + assert_raises(TypeError, lambda: s1 - [3, 4, 5]) -class OrderedIdentitySetTest(unittest.TestCase): +class OrderedIdentitySetTest(TestBase): def assert_eq(self, identityset, expected_iterable): expected = [id(o) for o in expected_iterable] @@ -303,7 +303,7 @@ class OrderedIdentitySetTest(unittest.TestCase): eq_(s1.union(s2).intersection(s3), [a, d, f]) -class DictlikeIteritemsTest(unittest.TestCase): +class DictlikeIteritemsTest(TestBase): baseline = set([('a', 1), ('b', 2), ('c', 3)]) def _ok(self, instance): @@ -311,7 +311,7 @@ class DictlikeIteritemsTest(unittest.TestCase): eq_(set(iterator), self.baseline) def _notok(self, instance): - self.assertRaises(TypeError, + assert_raises(TypeError, util.dictlike_iteritems, instance) @@ -638,7 +638,7 @@ class WeakIdentityMappingTest(TestBase): def test_update(self): data, wim = self._fixture() - self.assertRaises(NotImplementedError, wim.update) + assert_raises(NotImplementedError, wim.update) def test_weak_clear(self): data, wim = self._fixture() @@ -838,13 +838,13 @@ class AsInterfaceTest(TestBase): def test_instance(self): obj = object() - self.assertRaises(TypeError, util.as_interface, obj, + assert_raises(TypeError, util.as_interface, obj, cls=self.Something) - self.assertRaises(TypeError, util.as_interface, obj, + assert_raises(TypeError, util.as_interface, obj, methods=('foo')) - self.assertRaises(TypeError, util.as_interface, obj, + assert_raises(TypeError, util.as_interface, obj, cls=self.Something, required=('foo')) obj = self.Something() @@ -860,25 +860,25 @@ class AsInterfaceTest(TestBase): for obj in partial, slotted: eq_(obj, util.as_interface(obj, cls=self.Something)) - self.assertRaises(TypeError, util.as_interface, obj, + assert_raises(TypeError, util.as_interface, obj, methods=('foo')) eq_(obj, util.as_interface(obj, methods=('bar',))) eq_(obj, util.as_interface(obj, cls=self.Something, required=('bar',))) - self.assertRaises(TypeError, util.as_interface, obj, + assert_raises(TypeError, util.as_interface, obj, cls=self.Something, required=('foo',)) - self.assertRaises(TypeError, util.as_interface, obj, + assert_raises(TypeError, util.as_interface, obj, cls=self.Something, required=self.Something) def test_dict(self): obj = {} - self.assertRaises(TypeError, util.as_interface, obj, + assert_raises(TypeError, util.as_interface, obj, cls=self.Something) - self.assertRaises(TypeError, util.as_interface, obj, + assert_raises(TypeError, util.as_interface, obj, methods=('foo')) - self.assertRaises(TypeError, util.as_interface, obj, + assert_raises(TypeError, util.as_interface, obj, cls=self.Something, required=('foo')) def assertAdapted(obj, *methods): @@ -911,13 +911,13 @@ class AsInterfaceTest(TestBase): res = util.as_interface(obj, methods=('foo', 'bar'), required=('foo',)) assertAdapted(res, 'foo', 'bar') - self.assertRaises(TypeError, util.as_interface, obj, methods=('foo',)) + assert_raises(TypeError, util.as_interface, obj, methods=('foo',)) - self.assertRaises(TypeError, util.as_interface, obj, + assert_raises(TypeError, util.as_interface, obj, methods=('foo', 'bar', 'baz'), required=('baz',)) obj = {'foo': 123} - self.assertRaises(TypeError, util.as_interface, obj, cls=self.Something) + assert_raises(TypeError, util.as_interface, obj, cls=self.Something) class TestClassHierarchy(TestBase): @@ -955,5 +955,3 @@ class TestClassHierarchy(TestBase): eq_(set(util.class_hierarchy(A)), set((A, B, object))) -if __name__ == "__main__": - testenv.main() diff --git a/test/dialect/alltests.py b/test/dialect/alltests.py deleted file mode 100644 index 0defb6a15e..0000000000 --- a/test/dialect/alltests.py +++ /dev/null @@ -1,28 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - - -def suite(): - modules_to_test = ( - 'dialect.access', - 'dialect.firebird', - 'dialect.informix', - 'dialect.maxdb', - 'dialect.mssql', - 'dialect.mysql', - 'dialect.oracle', - 'dialect.postgres', - 'dialect.sqlite', - 'dialect.sybase', - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/dialect/access.py b/test/dialect/test_access.py similarity index 86% rename from test/dialect/access.py rename to test/dialect/test_access.py index 57af45a9d6..0ea8d9a61a 100644 --- a/test/dialect/access.py +++ b/test/dialect/test_access.py @@ -1,8 +1,7 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy import sql from sqlalchemy.databases import access -from testlib import * +from sqlalchemy.test import * class CompileTest(TestBase, AssertsCompiledSQL): @@ -30,5 +29,3 @@ class CompileTest(TestBase, AssertsCompiledSQL): 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % subst) -if __name__ == "__main__": - testenv.main() diff --git a/test/dialect/firebird.py b/test/dialect/test_firebird.py similarity index 84% rename from test/dialect/firebird.py rename to test/dialect/test_firebird.py index 5a0109dcc4..fa608c9a18 100644 --- a/test/dialect/firebird.py +++ b/test/dialect/test_firebird.py @@ -1,9 +1,9 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ from sqlalchemy import * from sqlalchemy.databases import firebird from sqlalchemy.exc import ProgrammingError from sqlalchemy.sql import table, column -from testlib import * +from sqlalchemy.test import * class DomainReflectionTest(TestBase, AssertsExecutionResults): @@ -11,7 +11,8 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): __only_on__ = 'firebird' - def setUpAll(self): + @classmethod + def setup_class(cls): con = testing.db.connect() try: con.execute('CREATE DOMAIN int_domain AS INTEGER DEFAULT 42 NOT NULL') @@ -38,7 +39,8 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): NEW.question = gen_id(gen_testtable_id, 1); END''') - def tearDownAll(self): + @classmethod + def teardown_class(cls): con = testing.db.connect() con.execute('DROP TABLE testtable') con.execute('DROP DOMAIN int_domain') @@ -50,22 +52,22 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): def test_table_is_reflected(self): metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True) - self.assertEquals(set(table.columns.keys()), + eq_(set(table.columns.keys()), set(['question', 'answer', 'remark', 'photo', 'd', 't', 'dt']), "Columns of reflected table didn't equal expected columns") - self.assertEquals(table.c.question.primary_key, True) - self.assertEquals(table.c.question.sequence.name, 'gen_testtable_id') - self.assertEquals(table.c.question.type.__class__, firebird.FBInteger) - self.assertEquals(table.c.question.server_default.arg.text, "42") - self.assertEquals(table.c.answer.type.__class__, firebird.FBString) - self.assertEquals(table.c.answer.server_default.arg.text, "'no answer'") - self.assertEquals(table.c.remark.type.__class__, firebird.FBText) - self.assertEquals(table.c.remark.server_default.arg.text, "''") - self.assertEquals(table.c.photo.type.__class__, firebird.FBBinary) + eq_(table.c.question.primary_key, True) + eq_(table.c.question.sequence.name, 'gen_testtable_id') + eq_(table.c.question.type.__class__, firebird.FBInteger) + eq_(table.c.question.server_default.arg.text, "42") + eq_(table.c.answer.type.__class__, firebird.FBString) + eq_(table.c.answer.server_default.arg.text, "'no answer'") + eq_(table.c.remark.type.__class__, firebird.FBText) + eq_(table.c.remark.server_default.arg.text, "''") + eq_(table.c.photo.type.__class__, firebird.FBBinary) # The following assume a Dialect 3 database - self.assertEquals(table.c.d.type.__class__, firebird.FBDate) - self.assertEquals(table.c.t.type.__class__, firebird.FBTime) - self.assertEquals(table.c.dt.type.__class__, firebird.FBDateTime) + eq_(table.c.d.type.__class__, firebird.FBDate) + eq_(table.c.t.type.__class__, firebird.FBTime) + eq_(table.c.dt.type.__class__, firebird.FBDateTime) class CompileTest(TestBase, AssertsCompiledSQL): @@ -140,10 +142,10 @@ class ReturningTest(TestBase, AssertsExecutionResults): table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute() - self.assertEqual(result.fetchall(), [(1,)]) + eq_(result.fetchall(), [(1,)]) result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - self.assertEqual(result2.fetchall(), [(1,True),(2,False)]) + eq_(result2.fetchall(), [(1,True),(2,False)]) finally: table.drop() @@ -159,19 +161,19 @@ class ReturningTest(TestBase, AssertsExecutionResults): try: result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False}) - self.assertEqual(result.fetchall(), [(1,)]) + eq_(result.fetchall(), [(1,)]) # Multiple inserts only return the last row result2 = table.insert(firebird_returning=[table]).execute( [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) - self.assertEqual(result2.fetchall(), [(3,3,True)]) + eq_(result2.fetchall(), [(3,3,True)]) result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False}) - self.assertEqual([dict(row) for row in result3], [{'ID':4}]) + eq_([dict(row) for row in result3], [{'ID':4}]) result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons') - self.assertEqual([dict(row) for row in result4], [{'PERSONS': 10}]) + eq_([dict(row) for row in result4], [{'PERSONS': 10}]) finally: table.drop() @@ -188,10 +190,10 @@ class ReturningTest(TestBase, AssertsExecutionResults): table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) result = table.delete(table.c.persons > 4, firebird_returning=[table.c.id]).execute() - self.assertEqual(result.fetchall(), [(1,)]) + eq_(result.fetchall(), [(1,)]) result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - self.assertEqual(result2.fetchall(), [(2,False),]) + eq_(result2.fetchall(), [(2,False),]) finally: table.drop() @@ -224,5 +226,3 @@ class MiscFBTests(TestBase): assert len(version) == 3, "Got strange version info: %s" % repr(version) -if __name__ == '__main__': - testenv.main() diff --git a/test/dialect/informix.py b/test/dialect/test_informix.py similarity index 88% rename from test/dialect/informix.py rename to test/dialect/test_informix.py index 1fbbaa0cb4..86a4e751d4 100644 --- a/test/dialect/informix.py +++ b/test/dialect/test_informix.py @@ -1,7 +1,6 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.databases import informix -from testlib import * +from sqlalchemy.test import * class CompileTest(TestBase, AssertsCompiledSQL): @@ -19,5 +18,3 @@ class CompileTest(TestBase, AssertsCompiledSQL): self.assert_compile(t1.update().values({t1.c.col1 : t1.c.col1 + 1}), 'UPDATE t1 SET col1=(t1.col1 + ?)') -if __name__ == "__main__": - testenv.main() diff --git a/test/dialect/maxdb.py b/test/dialect/test_maxdb.py similarity index 96% rename from test/dialect/maxdb.py rename to test/dialect/test_maxdb.py index c2daf8959a..033a05533f 100644 --- a/test/dialect/maxdb.py +++ b/test/dialect/test_maxdb.py @@ -1,12 +1,12 @@ """MaxDB-specific tests.""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import StringIO, sys from sqlalchemy import * from sqlalchemy import exc, sql from decimal import Decimal from sqlalchemy.databases import maxdb -from testlib import * +from sqlalchemy.test import * # TODO @@ -43,12 +43,12 @@ class ReflectionTest(TestBase, AssertsExecutionResults): cols = ['d1','d2','n1','i1'] t.insert().execute(dict(zip(cols,vals))) roundtrip = list(t.select().execute()) - self.assertEquals(roundtrip, [tuple([1] + vals)]) + eq_(roundtrip, [tuple([1] + vals)]) t.insert().execute(dict(zip(['id'] + cols, [2] + list(roundtrip[0][1:])))) roundtrip2 = list(t.select(order_by=t.c.id).execute()) - self.assertEquals(roundtrip2, [tuple([1] + vals), + eq_(roundtrip2, [tuple([1] + vals), tuple([2] + vals)]) finally: try: @@ -233,8 +233,6 @@ class DBAPITest(TestBase, AssertsExecutionResults): def test_modulo_operator(self): st = str(select([sql.column('col') % 5]).compile(testing.db)) - self.assertEquals(st, 'SELECT mod(col, ?) FROM DUAL') + eq_(st, 'SELECT mod(col, ?) FROM DUAL') -if __name__ == "__main__": - testenv.main() diff --git a/test/dialect/mssql.py b/test/dialect/test_mssql.py old mode 100755 new mode 100644 similarity index 92% rename from test/dialect/mssql.py rename to test/dialect/test_mssql.py index 50f9594ef3..5e2c9a672d --- a/test/dialect/mssql.py +++ b/test/dialect/test_mssql.py @@ -1,14 +1,14 @@ # -*- encoding: utf-8 -import testenv; testenv.configure_for_tests() -import datetime, os, pickleable, re +from sqlalchemy.test.testing import eq_ +import datetime, os, re from sqlalchemy import * from sqlalchemy import types, exc from sqlalchemy.orm import * from sqlalchemy.sql import table, column from sqlalchemy.databases import mssql import sqlalchemy.engine.url as url -from testlib import * -from testlib.testing import eq_ +from sqlalchemy.test import * +from sqlalchemy.test.testing import eq_ class CompileTest(TestBase, AssertsCompiledSQL): @@ -162,7 +162,8 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL): __only_on__ = 'mssql' __dialect__ = mssql.MSSQLDialect() - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, cattable metadata = MetaData(testing.db) @@ -172,10 +173,10 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL): PrimaryKeyConstraint('id', name='PK_cattable'), ) - def setUp(self): + def setup(self): metadata.create_all() - def tearDown(self): + def teardown(self): metadata.drop_all() def test_compiled(self): @@ -185,12 +186,12 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL): cattable.insert().values(id=9, description='Python').execute() cats = cattable.select().order_by(cattable.c.id).execute() - self.assertEqual([(9, 'Python')], list(cats)) + eq_([(9, 'Python')], list(cats)) result = cattable.insert().values(description='PHP').execute() - self.assertEqual([10], result.last_inserted_ids()) + eq_([10], result.last_inserted_ids()) lastcat = cattable.select().order_by(desc(cattable.c.id)).execute() - self.assertEqual((10, 'PHP'), lastcat.fetchone()) + eq_((10, 'PHP'), lastcat.fetchone()) def test_executemany(self): cattable.insert().execute([ @@ -201,7 +202,7 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL): ]) cats = cattable.select().order_by(cattable.c.id).execute() - self.assertEqual([(1, 'Java'), (3, 'Perl'), (8, 'Ruby'), (89, 'Python')], list(cats)) + eq_([(1, 'Java'), (3, 'Perl'), (8, 'Ruby'), (89, 'Python')], list(cats)) cattable.insert().execute([ {'description': 'PHP'}, @@ -209,7 +210,7 @@ class IdentityInsertTest(TestBase, AssertsCompiledSQL): ]) lastcats = cattable.select().order_by(desc(cattable.c.id)).limit(2).execute() - self.assertEqual([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats)) + eq_([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats)) class ReflectionTest(TestBase): @@ -330,7 +331,8 @@ class Foo(object): class GenerativeQueryTest(TestBase): __only_on__ = 'mssql' - def setUpAll(self): + @classmethod + def setup_class(cls): global foo, metadata metadata = MetaData(testing.db) foo = Table('foo', metadata, @@ -347,7 +349,8 @@ class GenerativeQueryTest(TestBase): sess.add(Foo(bar=i, range=i%10)) sess.flush() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() clear_mappers() @@ -361,7 +364,7 @@ class GenerativeQueryTest(TestBase): class SchemaTest(TestBase): - def setUp(self): + def setup(self): t = Table('sometable', MetaData(), Column('pk_column', Integer), Column('test_column', String) @@ -418,7 +421,8 @@ class MatchTest(TestBase, AssertsCompiledSQL): __only_on__ = 'mssql' __skip_if__ = (full_text_search_missing, ) - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, cattable, matchtable metadata = MetaData(testing.db) @@ -456,7 +460,8 @@ class MatchTest(TestBase, AssertsCompiledSQL): ]) DDL("WAITFOR DELAY '00:00:05'").execute(bind=engines.testing_engine()) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() connection = testing.db.connect() connection.execute("DROP FULLTEXT CATALOG Catalog") @@ -467,44 +472,44 @@ class MatchTest(TestBase, AssertsCompiledSQL): def test_simple_match(self): results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([2, 5], [r.id for r in results]) + eq_([2, 5], [r.id for r in results]) def test_simple_match_with_apostrophe(self): results = matchtable.select().where(matchtable.c.title.match('"Matz''s"')).execute().fetchall() - self.assertEquals([3], [r.id for r in results]) + eq_([3], [r.id for r in results]) def test_simple_prefix_match(self): results = matchtable.select().where(matchtable.c.title.match('"nut*"')).execute().fetchall() - self.assertEquals([5], [r.id for r in results]) + eq_([5], [r.id for r in results]) def test_simple_inflectional_match(self): results = matchtable.select().where(matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")')).execute().fetchall() - self.assertEquals([2], [r.id for r in results]) + eq_([2], [r.id for r in results]) def test_or_match(self): results1 = matchtable.select().where(or_(matchtable.c.title.match('nutshell'), matchtable.c.title.match('ruby')) ).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([3, 5], [r.id for r in results1]) + eq_([3, 5], [r.id for r in results1]) results2 = matchtable.select().where(matchtable.c.title.match('nutshell OR ruby'), ).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([3, 5], [r.id for r in results2]) + eq_([3, 5], [r.id for r in results2]) def test_and_match(self): results1 = matchtable.select().where(and_(matchtable.c.title.match('python'), matchtable.c.title.match('nutshell')) ).execute().fetchall() - self.assertEquals([5], [r.id for r in results1]) + eq_([5], [r.id for r in results1]) results2 = matchtable.select().where(matchtable.c.title.match('python AND nutshell'), ).execute().fetchall() - self.assertEquals([5], [r.id for r in results2]) + eq_([5], [r.id for r in results2]) def test_match_across_joins(self): results = matchtable.select().where(and_(cattable.c.id==matchtable.c.category_id, or_(cattable.c.description.match('Ruby'), matchtable.c.title.match('nutshell'))) ).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([1, 3, 5], [r.id for r in results]) + eq_([1, 3, 5], [r.id for r in results]) class ParseConnectTest(TestBase, AssertsCompiledSQL): @@ -514,77 +519,78 @@ class ParseConnectTest(TestBase, AssertsCompiledSQL): u = url.make_url('mssql://mydsn') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['dsn=mydsn;TrustedConnection=Yes'], {}], connection) + eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection) def test_pyodbc_connect_old_style_dsn_trusted(self): u = url.make_url('mssql:///?dsn=mydsn') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['dsn=mydsn;TrustedConnection=Yes'], {}], connection) + eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection) def test_pyodbc_connect_dsn_non_trusted(self): u = url.make_url('mssql://username:password@mydsn') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['dsn=mydsn;UID=username;PWD=password'], {}], connection) + eq_([['dsn=mydsn;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_dsn_extra(self): u = url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_english&foo=bar') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection) + eq_([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection) def test_pyodbc_connect(self): u = url.make_url('mssql://username:password@hostspec/database') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) + eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_comma_port(self): u = url.make_url('mssql://username:password@hostspec:12345/database') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection) + eq_([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_connect_config_port(self): u = url.make_url('mssql://username:password@hostspec/database?port=12345') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection) + eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection) def test_pyodbc_extra_connect(self): u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection) + eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection) def test_pyodbc_odbc_connect(self): u = url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) + eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_odbc_connect_with_dsn(self): u = url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection) + eq_([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection) def test_pyodbc_odbc_connect_ignores_other_values(self): u = url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword') dialect = mssql.MSSQLDialect_pyodbc() connection = dialect.create_connect_args(u) - self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) + eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) class TypesTest(TestBase): __only_on__ = 'mssql' - def setUpAll(self): + @classmethod + def setup_class(cls): global numeric_table, metadata metadata = MetaData(testing.db) - def tearDown(self): + def teardown(self): metadata.drop_all() def test_decimal_notation(self): @@ -611,7 +617,7 @@ class TypesTest(TestBase): numeric_table.insert().execute(numericcol=value) for value in select([numeric_table.c.numericcol]).execute(): - self.assertTrue(value[0] in test_items, "%s not in test_items" % value[0]) + assert value[0] in test_items, "%s not in test_items" % value[0] except Exception, e: raise e @@ -763,7 +769,7 @@ class TypesTest2(TestBase, AssertsExecutionResults): t.insert().execute(adate=d1, adatetime=d2, atime=t1) - self.assertEquals(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)]) + eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)]) finally: t.drop(checkfirst=True) @@ -1072,7 +1078,8 @@ def colspec(c): class BinaryTest(TestBase, AssertsExecutionResults): """Test the Binary and VarBinary types""" - def setUpAll(self): + @classmethod + def setup_class(cls): global binary_table, MyPickleType class MyPickleType(types.TypeDecorator): @@ -1102,10 +1109,11 @@ class BinaryTest(TestBase, AssertsExecutionResults): ) binary_table.create() - def tearDown(self): + def teardown(self): binary_table.delete().execute() - def tearDownAll(self): + @classmethod + def teardown_class(cls): binary_table.drop() def test_binary(self): @@ -1124,23 +1132,21 @@ class BinaryTest(TestBase, AssertsExecutionResults): text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db) ): l = stmt.execute().fetchall() - self.assertEquals(list(stream1), list(l[0]['data'])) + eq_(list(stream1), list(l[0]['data'])) paddedstream = list(stream1[0:100]) paddedstream.extend(['\x00'] * (100 - len(paddedstream))) - self.assertEquals(paddedstream, list(l[0]['data_slice'])) + eq_(paddedstream, list(l[0]['data_slice'])) - self.assertEquals(list(stream2), list(l[1]['data'])) - self.assertEquals(list(stream2), list(l[1]['data_image'])) - self.assertEquals(testobj1, l[0]['pickled']) - self.assertEquals(testobj2, l[1]['pickled']) - self.assertEquals(testobj3.moredata, l[0]['mypickle'].moredata) - self.assertEquals(l[0]['mypickle'].stuff, 'this is the right stuff') + eq_(list(stream2), list(l[1]['data'])) + eq_(list(stream2), list(l[1]['data_image'])) + eq_(testobj1, l[0]['pickled']) + eq_(testobj2, l[1]['pickled']) + eq_(testobj3.moredata, l[0]['mypickle'].moredata) + eq_(l[0]['mypickle'].stuff, 'this is the right stuff') def load_stream(self, name, len=3000): - f = os.path.join(os.path.dirname(testenv.__file__), name) + f = os.path.join(os.path.dirname(__file__), "..", name) return file(f).read(len) -if __name__ == "__main__": - testenv.main() diff --git a/test/dialect/mysql.py b/test/dialect/test_mysql.py similarity index 97% rename from test/dialect/mysql.py rename to test/dialect/test_mysql.py index fa8a85ec45..8adb2d71c5 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/test_mysql.py @@ -1,10 +1,10 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import sets from sqlalchemy import * from sqlalchemy import sql, exc from sqlalchemy.databases import mysql -from testlib.testing import eq_ -from testlib import * +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test import * class TypesTest(TestBase, AssertsExecutionResults): @@ -522,7 +522,7 @@ class TypesTest(TestBase, AssertsExecutionResults): [set_table.c.s3], set_table.c.s3.in_([set(['5']), set(['5', '7'])])).execute()) found = set([frozenset(row[0]) for row in rows]) - self.assertEquals(found, + eq_(found, set([frozenset(['5']), frozenset(['5', '7'])])) finally: meta.drop_all() @@ -812,7 +812,7 @@ class TypesTest(TestBase, AssertsExecutionResults): if got != wanted: print "Expected %s" % wanted print "Found %s" % got - self.assertEqual(got, wanted) + eq_(got, wanted) class SQLTest(TestBase, AssertsCompiledSQL): @@ -831,24 +831,24 @@ class SQLTest(TestBase, AssertsCompiledSQL): kw['prefixes'] = prefixes return str(select(['q'], **kw).compile(dialect=dialect)) - self.assertEqual(gen(None), 'SELECT q') - self.assertEqual(gen(True), 'SELECT DISTINCT q') - self.assertEqual(gen(1), 'SELECT DISTINCT q') - self.assertEqual(gen('diSTInct'), 'SELECT DISTINCT q') - self.assertEqual(gen('DISTINCT'), 'SELECT DISTINCT q') + eq_(gen(None), 'SELECT q') + eq_(gen(True), 'SELECT DISTINCT q') + eq_(gen(1), 'SELECT DISTINCT q') + eq_(gen('diSTInct'), 'SELECT DISTINCT q') + eq_(gen('DISTINCT'), 'SELECT DISTINCT q') # Standard SQL - self.assertEqual(gen('all'), 'SELECT ALL q') - self.assertEqual(gen('distinctrow'), 'SELECT DISTINCTROW q') + eq_(gen('all'), 'SELECT ALL q') + eq_(gen('distinctrow'), 'SELECT DISTINCTROW q') # Interaction with MySQL prefix extensions - self.assertEqual( + eq_( gen(None, ['straight_join']), 'SELECT straight_join q') - self.assertEqual( + eq_( gen('all', ['HIGH_PRIORITY SQL_SMALL_RESULT']), 'SELECT HIGH_PRIORITY SQL_SMALL_RESULT ALL q') - self.assertEqual( + eq_( gen(True, ['high_priority', sql.text('sql_cache')]), 'SELECT high_priority sql_cache DISTINCT q') @@ -997,7 +997,7 @@ class SQLTest(TestBase, AssertsCompiledSQL): class RawReflectionTest(TestBase): - def setUp(self): + def setup(self): self.dialect = mysql.dialect() self.reflector = mysql.MySQLSchemaReflector( self.dialect.identifier_preparer) @@ -1059,7 +1059,8 @@ class ExecutionTest(TestBase): class MatchTest(TestBase, AssertsCompiledSQL): __only_on__ = 'mysql' - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, cattable, matchtable metadata = MetaData(testing.db) @@ -1096,7 +1097,8 @@ class MatchTest(TestBase, AssertsCompiledSQL): 'category_id': 1} ]) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def test_expression(self): @@ -1110,14 +1112,14 @@ class MatchTest(TestBase, AssertsCompiledSQL): order_by(matchtable.c.id). execute(). fetchall()) - self.assertEquals([2, 5], [r.id for r in results]) + eq_([2, 5], [r.id for r in results]) def test_simple_match_with_apostrophe(self): results = (matchtable.select(). where(matchtable.c.title.match('"Matz''s"')). execute(). fetchall()) - self.assertEquals([3], [r.id for r in results]) + eq_([3], [r.id for r in results]) def test_or_match(self): results1 = (matchtable.select(). @@ -1126,13 +1128,13 @@ class MatchTest(TestBase, AssertsCompiledSQL): order_by(matchtable.c.id). execute(). fetchall()) - self.assertEquals([3, 5], [r.id for r in results1]) + eq_([3, 5], [r.id for r in results1]) results2 = (matchtable.select(). where(matchtable.c.title.match('nutshell ruby')). order_by(matchtable.c.id). execute(). fetchall()) - self.assertEquals([3, 5], [r.id for r in results2]) + eq_([3, 5], [r.id for r in results2]) def test_and_match(self): @@ -1141,12 +1143,12 @@ class MatchTest(TestBase, AssertsCompiledSQL): matchtable.c.title.match('nutshell'))). execute(). fetchall()) - self.assertEquals([5], [r.id for r in results1]) + eq_([5], [r.id for r in results1]) results2 = (matchtable.select(). where(matchtable.c.title.match('+python +nutshell')). execute(). fetchall()) - self.assertEquals([5], [r.id for r in results2]) + eq_([5], [r.id for r in results2]) def test_match_across_joins(self): results = (matchtable.select(). @@ -1156,12 +1158,10 @@ class MatchTest(TestBase, AssertsCompiledSQL): order_by(matchtable.c.id). execute(). fetchall()) - self.assertEquals([1, 3, 5], [r.id for r in results]) + eq_([1, 3, 5], [r.id for r in results]) def colspec(c): return testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None).get_column_specification(c) -if __name__ == "__main__": - testenv.main() diff --git a/test/dialect/oracle.py b/test/dialect/test_oracle.py similarity index 97% rename from test/dialect/oracle.py rename to test/dialect/test_oracle.py index 2186f22595..16175c8512 100644 --- a/test/dialect/oracle.py +++ b/test/dialect/test_oracle.py @@ -1,19 +1,20 @@ # coding: utf-8 -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ from sqlalchemy import * from sqlalchemy.sql import table, column from sqlalchemy.databases import oracle -from testlib import * -from testlib.testing import eq_ -from testlib.engines import testing_engine +from sqlalchemy.test import * +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.engines import testing_engine import os class OutParamTest(TestBase, AssertsExecutionResults): __only_on__ = 'oracle' - def setUpAll(self): + @classmethod + def setup_class(cls): testing.db.execute(""" create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT number, z_out OUT varchar) IS retval number; @@ -29,7 +30,8 @@ create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT numb result = testing.db.execute(text("begin foo(:x_in, :x_out, :y_out, :z_out); end;", bindparams=[bindparam('x_in', Numeric), outparam('x_out', Numeric), outparam('y_out', Numeric), outparam('z_out', String)]), x_in=5) assert result.out_parameters == {'x_out':10, 'y_out':75, 'z_out':None}, result.out_parameters - def tearDownAll(self): + @classmethod + def teardown_class(cls): testing.db.execute("DROP PROCEDURE foo") @@ -200,7 +202,7 @@ class MultiSchemaTest(TestBase, AssertsCompiledSQL): try: parent.insert().execute({'pid':1}) child.insert().execute({'cid':1, 'pid':1}) - self.assertEquals(child.select().execute().fetchall(), [(1, 1)]) + eq_(child.select().execute().fetchall(), [(1, 1)]) finally: meta.drop_all() @@ -217,7 +219,7 @@ class MultiSchemaTest(TestBase, AssertsCompiledSQL): try: parent.insert().execute({'pid':1}) child.insert().execute({'cid':1, 'pid':1}) - self.assertEquals(child.select().execute().fetchall(), [(1, 1)]) + eq_(child.select().execute().fetchall(), [(1, 1)]) finally: meta.drop_all() @@ -346,7 +348,8 @@ class TypesTest(TestBase, AssertsCompiledSQL): class BufferedColumnTest(TestBase, AssertsCompiledSQL): __only_on__ = 'oracle' - def setUpAll(self): + @classmethod + def setup_class(cls): global binary_table, stream, meta meta = MetaData(testing.db) binary_table = Table('binary_table', meta, @@ -360,18 +363,19 @@ class BufferedColumnTest(TestBase, AssertsCompiledSQL): for i in range(1, 11): binary_table.insert().execute(id=i, data=stream) - def tearDownAll(self): + @classmethod + def teardown_class(cls): meta.drop_all() def test_fetch(self): - self.assertEquals( + eq_( binary_table.select().execute().fetchall() , [(i, stream) for i in range(1, 11)], ) def test_fetch_single_arraysize(self): eng = testing_engine(options={'arraysize':1}) - self.assertEquals( + eq_( eng.execute(binary_table.select()).fetchall(), [(i, stream) for i in range(1, 11)], ) @@ -393,5 +397,3 @@ class ExecuteTest(TestBase): def test_basic(self): assert testing.db.execute("/*+ this is a comment */ SELECT 1 FROM DUAL").fetchall() == [(1,)] -if __name__ == '__main__': - testenv.main() diff --git a/test/dialect/postgres.py b/test/dialect/test_postgres.py similarity index 88% rename from test/dialect/postgres.py rename to test/dialect/test_postgres.py index 2dfbe018cc..8ca714badc 100644 --- a/test/dialect/postgres.py +++ b/test/dialect/test_postgres.py @@ -1,11 +1,11 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import datetime from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy import exc from sqlalchemy.databases import postgres from sqlalchemy.engine.strategies import MockEngineStrategy -from testlib import * +from sqlalchemy.test import * from sqlalchemy.sql import table, column @@ -86,10 +86,10 @@ class ReturningTest(TestBase, AssertsExecutionResults): table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) result = table.update(table.c.persons > 4, dict(full=True), postgres_returning=[table.c.id]).execute() - self.assertEqual(result.fetchall(), [(1,)]) + eq_(result.fetchall(), [(1,)]) result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() - self.assertEqual(result2.fetchall(), [(1,True),(2,False)]) + eq_(result2.fetchall(), [(1,True),(2,False)]) finally: table.drop() @@ -105,22 +105,22 @@ class ReturningTest(TestBase, AssertsExecutionResults): try: result = table.insert(postgres_returning=[table.c.id]).execute({'persons': 1, 'full': False}) - self.assertEqual(result.fetchall(), [(1,)]) + eq_(result.fetchall(), [(1,)]) @testing.fails_on('postgres', 'Known limitation of psycopg2') def test_executemany(): # return value is documented as failing with psycopg2/executemany result2 = table.insert(postgres_returning=[table]).execute( [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) - self.assertEqual(result2.fetchall(), [(2, 2, False), (3,3,True)]) + eq_(result2.fetchall(), [(2, 2, False), (3,3,True)]) test_executemany() result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False}) - self.assertEqual([dict(row) for row in result3], [{'double_id':8}]) + eq_([dict(row) for row in result3], [{'double_id':8}]) result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons') - self.assertEqual([dict(row) for row in result4], [{'persons': 10}]) + eq_([dict(row) for row in result4], [{'persons': 10}]) finally: table.drop() @@ -128,11 +128,12 @@ class ReturningTest(TestBase, AssertsExecutionResults): class InsertTest(TestBase, AssertsExecutionResults): __only_on__ = 'postgres' - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata metadata = MetaData(testing.db) - def tearDown(self): + def teardown(self): metadata.drop_all() metadata.tables.clear() @@ -397,7 +398,8 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): __only_on__ = 'postgres' - def setUpAll(self): + @classmethod + def setup_class(cls): con = testing.db.connect() for ddl in ('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42', 'CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0'): @@ -410,7 +412,8 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)') con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)') - def tearDownAll(self): + @classmethod + def teardown_class(cls): con = testing.db.connect() con.execute('DROP TABLE testtable') con.execute('DROP TABLE alt_schema.testtable') @@ -421,32 +424,32 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): def test_table_is_reflected(self): metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True) - self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns") - self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger) + eq_(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns") + eq_(table.c.answer.type.__class__, postgres.PGInteger) def test_domain_is_reflected(self): metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True) - self.assertEquals(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value") - self.assertFalse(table.columns.answer.nullable, "Expected reflected column to not be nullable.") + eq_(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value") + assert not table.columns.answer.nullable, "Expected reflected column to not be nullable." def test_table_is_reflected_alt_schema(self): metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True, schema='alt_schema') - self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns") - self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger) + eq_(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns") + eq_(table.c.anything.type.__class__, postgres.PGInteger) def test_schema_domain_is_reflected(self): metadata = MetaData(testing.db) table = Table('testtable', metadata, autoload=True, schema='alt_schema') - self.assertEquals(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value") - self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.") + eq_(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value") + assert table.columns.answer.nullable, "Expected reflected column to be nullable." def test_crosschema_domain_is_reflected(self): metadata = MetaData(testing.db) table = Table('crosschema', metadata, autoload=True) - self.assertEquals(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value") - self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.") + eq_(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value") + assert table.columns.answer.nullable, "Expected reflected column to be nullable." def test_unknown_types(self): from sqlalchemy.databases import postgres @@ -455,7 +458,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults): postgres.ischema_names = {} try: m2 = MetaData(testing.db) - self.assertRaises(exc.SAWarning, Table, "testtable", m2, autoload=True) + assert_raises(exc.SAWarning, Table, "testtable", m2, autoload=True) @testing.emits_warning('Did not recognize type') def warns(): @@ -519,15 +522,15 @@ class MiscTest(TestBase, AssertsExecutionResults): t = Table('mytable', MetaData(testing.db), Column('id', Integer, primary_key=True), Column('a', String(8))) - self.assertEquals( + eq_( str(t.select(distinct=t.c.a)), 'SELECT DISTINCT ON (mytable.a) mytable.id, mytable.a \n' 'FROM mytable') - self.assertEquals( + eq_( str(t.select(distinct=['id','a'])), 'SELECT DISTINCT ON (id, a) mytable.id, mytable.a \n' 'FROM mytable') - self.assertEquals( + eq_( str(t.select(distinct=[t.c.id, t.c.a])), 'SELECT DISTINCT ON (mytable.id, mytable.a) mytable.id, mytable.a \n' 'FROM mytable') @@ -616,14 +619,14 @@ class MiscTest(TestBase, AssertsExecutionResults): users.insert().execute(id=3, name='name3') users.insert().execute(id=4, name='name4') - self.assertEquals(users.select().where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')]) - self.assertEquals(users.select(use_labels=True).where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')]) + eq_(users.select().where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')]) + eq_(users.select(use_labels=True).where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')]) users.delete().where(users.c.id==3).execute() - self.assertEquals(users.select().where(users.c.name=='name3').execute().fetchall(), []) + eq_(users.select().where(users.c.name=='name3').execute().fetchall(), []) users.update().where(users.c.name=='name4').execute(name='newname') - self.assertEquals(users.select(use_labels=True).where(users.c.id==4).execute().fetchall(), [(4, 'newname')]) + eq_(users.select(use_labels=True).where(users.c.id==4).execute().fetchall(), [(4, 'newname')]) finally: users.drop() @@ -733,7 +736,8 @@ class TimezoneTest(TestBase, AssertsExecutionResults): __only_on__ = 'postgres' - def setUpAll(self): + @classmethod + def setup_class(cls): global tztable, notztable, metadata metadata = MetaData(testing.db) @@ -749,7 +753,8 @@ class TimezoneTest(TestBase, AssertsExecutionResults): Column("name", String(20)), ) metadata.create_all() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def test_with_timezone(self): @@ -769,7 +774,8 @@ class TimezoneTest(TestBase, AssertsExecutionResults): class ArrayTest(TestBase, AssertsExecutionResults): __only_on__ = 'postgres' - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, arrtable metadata = MetaData(testing.db) @@ -779,47 +785,48 @@ class ArrayTest(TestBase, AssertsExecutionResults): Column('strarr', postgres.PGArray(String(convert_unicode=True)), nullable=False) ) metadata.create_all() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def test_reflect_array_column(self): metadata2 = MetaData(testing.db) tbl = Table('arrtable', metadata2, autoload=True) - self.assertTrue(isinstance(tbl.c.intarr.type, postgres.PGArray)) - self.assertTrue(isinstance(tbl.c.strarr.type, postgres.PGArray)) - self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer)) - self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String)) + assert isinstance(tbl.c.intarr.type, postgres.PGArray) + assert isinstance(tbl.c.strarr.type, postgres.PGArray) + assert isinstance(tbl.c.intarr.type.item_type, Integer) + assert isinstance(tbl.c.strarr.type.item_type, String) def test_insert_array(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) results = arrtable.select().execute().fetchall() - self.assertEquals(len(results), 1) - self.assertEquals(results[0]['intarr'], [1,2,3]) - self.assertEquals(results[0]['strarr'], ['abc','def']) + eq_(len(results), 1) + eq_(results[0]['intarr'], [1,2,3]) + eq_(results[0]['strarr'], ['abc','def']) arrtable.delete().execute() def test_array_where(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) arrtable.insert().execute(intarr=[4,5,6], strarr='ABC') results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall() - self.assertEquals(len(results), 1) - self.assertEquals(results[0]['intarr'], [1,2,3]) + eq_(len(results), 1) + eq_(results[0]['intarr'], [1,2,3]) arrtable.delete().execute() def test_array_concat(self): arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def']) results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall() - self.assertEquals(len(results), 1) - self.assertEquals(results[0][0], [1,2,3,4,5,6]) + eq_(len(results), 1) + eq_(results[0][0], [1,2,3,4,5,6]) arrtable.delete().execute() def test_array_subtype_resultprocessor(self): arrtable.insert().execute(intarr=[4,5,6], strarr=[[u'm\xe4\xe4'], [u'm\xf6\xf6']]) arrtable.insert().execute(intarr=[1,2,3], strarr=[u'm\xe4\xe4', u'm\xf6\xf6']) results = arrtable.select(order_by=[arrtable.c.intarr]).execute().fetchall() - self.assertEquals(len(results), 2) - self.assertEquals(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6']) - self.assertEquals(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']]) + eq_(len(results), 2) + eq_(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6']) + eq_(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']]) arrtable.delete().execute() def test_array_mutability(self): @@ -839,23 +846,23 @@ class ArrayTest(TestBase, AssertsExecutionResults): sess.flush() sess.expunge_all() foo = sess.query(Foo).get(1) - self.assertEquals(foo.intarr, [1,2,3]) + eq_(foo.intarr, [1,2,3]) foo.intarr.append(4) sess.flush() sess.expunge_all() foo = sess.query(Foo).get(1) - self.assertEquals(foo.intarr, [1,2,3,4]) + eq_(foo.intarr, [1,2,3,4]) foo.intarr = [] sess.flush() sess.expunge_all() - self.assertEquals(foo.intarr, []) + eq_(foo.intarr, []) foo.intarr = None sess.flush() sess.expunge_all() - self.assertEquals(foo.intarr, None) + eq_(foo.intarr, None) # Errors in r4217: foo = Foo() @@ -872,16 +879,18 @@ class TimeStampTest(TestBase, AssertsExecutionResults): connection = engine.connect() s = select([func.TIMESTAMP("12/25/07").label("ts")]) result = connection.execute(s).fetchone() - self.assertEqual(result[0], datetime.datetime(2007, 12, 25, 0, 0)) + eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0)) class ServerSideCursorsTest(TestBase, AssertsExecutionResults): __only_on__ = 'postgres' - def setUpAll(self): + @classmethod + def setup_class(cls): global ss_engine ss_engine = engines.testing_engine(options={'server_side_cursors':True}) - def tearDownAll(self): + @classmethod + def teardown_class(cls): ss_engine.dispose() def test_uses_ss(self): @@ -906,12 +915,12 @@ class ServerSideCursorsTest(TestBase, AssertsExecutionResults): nextid = ss_engine.execute(Sequence('test_table_id_seq')) test_table.insert().execute(id=nextid, data='data2') - self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2')]) + eq_(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2')]) test_table.update().where(test_table.c.id==2).values(data=test_table.c.data + ' updated').execute() - self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2 updated')]) + eq_(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2 updated')]) test_table.delete().execute() - self.assertEquals(test_table.count().scalar(), 0) + eq_(test_table.count().scalar(), 0) finally: test_table.drop(checkfirst=True) @@ -921,7 +930,8 @@ class SpecialTypesTest(TestBase, ComparesTables): __only_on__ = 'postgres' __excluded_on__ = (('postgres', '<', (8, 3, 0)),) - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, table metadata = MetaData(testing.db) @@ -935,7 +945,8 @@ class SpecialTypesTest(TestBase, ComparesTables): metadata.create_all() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def test_reflection(self): @@ -949,7 +960,8 @@ class MatchTest(TestBase, AssertsCompiledSQL): __only_on__ = 'postgres' __excluded_on__ = (('postgres', '<', (8, 3, 0)),) - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, cattable, matchtable metadata = MetaData(testing.db) @@ -976,7 +988,8 @@ class MatchTest(TestBase, AssertsCompiledSQL): {'id': 5, 'title': 'Python in a Nutshell', 'category_id': 1} ]) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def test_expression(self): @@ -984,42 +997,40 @@ class MatchTest(TestBase, AssertsCompiledSQL): def test_simple_match(self): results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([2, 5], [r.id for r in results]) + eq_([2, 5], [r.id for r in results]) def test_simple_match_with_apostrophe(self): results = matchtable.select().where(matchtable.c.title.match("Matz''s")).execute().fetchall() - self.assertEquals([3], [r.id for r in results]) + eq_([3], [r.id for r in results]) def test_simple_derivative_match(self): results = matchtable.select().where(matchtable.c.title.match('nutshells')).execute().fetchall() - self.assertEquals([5], [r.id for r in results]) + eq_([5], [r.id for r in results]) def test_or_match(self): results1 = matchtable.select().where(or_(matchtable.c.title.match('nutshells'), matchtable.c.title.match('rubies')) ).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([3, 5], [r.id for r in results1]) + eq_([3, 5], [r.id for r in results1]) results2 = matchtable.select().where(matchtable.c.title.match('nutshells | rubies'), ).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([3, 5], [r.id for r in results2]) + eq_([3, 5], [r.id for r in results2]) def test_and_match(self): results1 = matchtable.select().where(and_(matchtable.c.title.match('python'), matchtable.c.title.match('nutshells')) ).execute().fetchall() - self.assertEquals([5], [r.id for r in results1]) + eq_([5], [r.id for r in results1]) results2 = matchtable.select().where(matchtable.c.title.match('python & nutshells'), ).execute().fetchall() - self.assertEquals([5], [r.id for r in results2]) + eq_([5], [r.id for r in results2]) def test_match_across_joins(self): results = matchtable.select().where(and_(cattable.c.id==matchtable.c.category_id, or_(cattable.c.description.match('Ruby'), matchtable.c.title.match('nutshells'))) ).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([1, 3, 5], [r.id for r in results]) + eq_([1, 3, 5], [r.id for r in results]) -if __name__ == "__main__": - testenv.main() diff --git a/test/dialect/sqlite.py b/test/dialect/test_sqlite.py similarity index 93% rename from test/dialect/sqlite.py rename to test/dialect/test_sqlite.py index d01be3521d..eb4581e20f 100644 --- a/test/dialect/sqlite.py +++ b/test/dialect/test_sqlite.py @@ -1,11 +1,11 @@ """SQLite-specific tests.""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import datetime from sqlalchemy import * from sqlalchemy import exc, sql from sqlalchemy.databases import sqlite -from testlib import * +from sqlalchemy.test import * class TestTypes(TestBase, AssertsExecutionResults): @@ -34,22 +34,22 @@ class TestTypes(TestBase, AssertsExecutionResults): meta.drop_all() def test_string_dates_raise(self): - self.assertRaises(TypeError, testing.db.execute, select([1]).where(bindparam("date", type_=Date)), date=str(datetime.date(2007, 10, 30))) + assert_raises(TypeError, testing.db.execute, select([1]).where(bindparam("date", type_=Date)), date=str(datetime.date(2007, 10, 30))) def test_time_microseconds(self): dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125) # 125 usec - self.assertEquals(str(dt), '2008-06-27 12:00:00.000125') + eq_(str(dt), '2008-06-27 12:00:00.000125') sldt = sqlite.SLDateTime() bp = sldt.bind_processor(None) - self.assertEquals(bp(dt), '2008-06-27 12:00:00.000125') + eq_(bp(dt), '2008-06-27 12:00:00.000125') rp = sldt.result_processor(None) - self.assertEquals(rp(bp(dt)), dt) + eq_(rp(bp(dt)), dt) sldt.__legacy_microseconds__ = True bp = sldt.bind_processor(None) - self.assertEquals(bp(dt), '2008-06-27 12:00:00.125') - self.assertEquals(rp(bp(dt)), dt) + eq_(bp(dt), '2008-06-27 12:00:00.125') + eq_(rp(bp(dt)), dt) def test_no_convert_unicode(self): """test no utf-8 encoding occurs""" @@ -163,7 +163,7 @@ class TestDefaults(TestBase, AssertsExecutionResults): rt = Table('t_defaults', m2, autoload=True) expected = [c[1] for c in specs] for i, reflected in enumerate(rt.c): - self.assertEquals(reflected.server_default.arg.text, expected[i]) + eq_(reflected.server_default.arg.text, expected[i]) finally: m.drop_all() @@ -184,7 +184,7 @@ class TestDefaults(TestBase, AssertsExecutionResults): rt = Table('r_defaults', m, autoload=True) for i, reflected in enumerate(rt.c): - self.assertEquals(reflected.server_default.arg.text, expected[i]) + eq_(reflected.server_default.arg.text, expected[i]) finally: db.execute("DROP TABLE r_defaults") @@ -258,7 +258,7 @@ class DialectTest(TestBase, AssertsExecutionResults): schema='alt_schema') meta.create_all(cx) - self.assertEquals(dialect.table_names(cx, 'alt_schema'), + eq_(dialect.table_names(cx, 'alt_schema'), ['created']) assert len(alt_master.c) > 0 @@ -350,7 +350,7 @@ class InsertTest(TestBase, AssertsExecutionResults): table.insert().execute() rows = table.select().execute().fetchall() - self.assertEquals(len(rows), wanted) + eq_(len(rows), wanted) finally: table.drop() @@ -362,7 +362,7 @@ class InsertTest(TestBase, AssertsExecutionResults): @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support') def test_empty_insert_pk2(self): - self.assertRaises( + assert_raises( exc.DBAPIError, self._test_empty_insert, Table('b', MetaData(testing.db), @@ -371,7 +371,7 @@ class InsertTest(TestBase, AssertsExecutionResults): @testing.exclude('sqlite', '<', (3, 3, 8), 'no database support') def test_empty_insert_pk3(self): - self.assertRaises( + assert_raises( exc.DBAPIError, self._test_empty_insert, Table('c', MetaData(testing.db), @@ -429,7 +429,8 @@ class MatchTest(TestBase, AssertsCompiledSQL): __only_on__ = 'sqlite' __skip_if__ = (full_text_search_missing, ) - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, cattable, matchtable metadata = MetaData(testing.db) @@ -465,7 +466,8 @@ class MatchTest(TestBase, AssertsCompiledSQL): {'id': 5, 'title': 'Python in a Nutshell', 'category_id': 1} ]) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def test_expression(self): @@ -473,29 +475,27 @@ class MatchTest(TestBase, AssertsCompiledSQL): def test_simple_match(self): results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([2, 5], [r.id for r in results]) + eq_([2, 5], [r.id for r in results]) def test_simple_prefix_match(self): results = matchtable.select().where(matchtable.c.title.match('nut*')).execute().fetchall() - self.assertEquals([5], [r.id for r in results]) + eq_([5], [r.id for r in results]) def test_or_match(self): results2 = matchtable.select().where(matchtable.c.title.match('nutshell OR ruby'), ).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([3, 5], [r.id for r in results2]) + eq_([3, 5], [r.id for r in results2]) def test_and_match(self): results2 = matchtable.select().where(matchtable.c.title.match('python nutshell'), ).execute().fetchall() - self.assertEquals([5], [r.id for r in results2]) + eq_([5], [r.id for r in results2]) def test_match_across_joins(self): results = matchtable.select().where(and_(cattable.c.id==matchtable.c.category_id, cattable.c.description.match('Ruby')) ).order_by(matchtable.c.id).execute().fetchall() - self.assertEquals([1, 3], [r.id for r in results]) + eq_([1, 3], [r.id for r in results]) -if __name__ == "__main__": - testenv.main() diff --git a/test/dialect/sybase.py b/test/dialect/test_sybase.py similarity index 85% rename from test/dialect/sybase.py rename to test/dialect/test_sybase.py index 32b9904d8a..37de91d1c4 100644 --- a/test/dialect/sybase.py +++ b/test/dialect/test_sybase.py @@ -1,8 +1,7 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy import sql from sqlalchemy.databases import sybase -from testlib import * +from sqlalchemy.test import * class CompileTest(TestBase, AssertsCompiledSQL): @@ -27,5 +26,3 @@ class CompileTest(TestBase, AssertsCompiledSQL): -if __name__ == "__main__": - testenv.main() diff --git a/test/engine/_base.py b/test/engine/_base.py index 3c31d378ad..ec91243d24 100644 --- a/test/engine/_base.py +++ b/test/engine/_base.py @@ -1,5 +1,6 @@ -from testlib import sa, testing -from testlib.testing import adict +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy.test.testing import adict class TablesTest(testing.TestBase): @@ -27,41 +28,38 @@ class TablesTest(testing.TestBase): tables = None other_artifacts = None - def setUpAll(self): - if self.run_setup_bind is None: - assert self.bind is not None - assert self.run_deletes in (None, 'each') - if self.run_inserts == 'once': - assert self.run_deletes is None + @classmethod + def setup_class(cls): + if cls.run_setup_bind is None: + assert cls.bind is not None + assert cls.run_deletes in (None, 'each') + if cls.run_inserts == 'once': + assert cls.run_deletes is None - cls = self.__class__ if cls.tables is None: cls.tables = adict() if cls.other_artifacts is None: cls.other_artifacts = adict() - if self.bind is None: - setattr(type(self), 'bind', self.setup_bind()) - - if self.metadata is None: - setattr(type(self), 'metadata', sa.MetaData()) + if cls.bind is None: + setattr(cls, 'bind', cls.setup_bind()) - if self.metadata.bind is None: - self.metadata.bind = self.bind + if cls.metadata is None: + setattr(cls, 'metadata', sa.MetaData()) - if self.run_define_tables: - self.define_tables(self.metadata) - self.metadata.create_all() - self.tables.update(self.metadata.tables) + if cls.metadata.bind is None: + cls.metadata.bind = cls.bind - if self.run_inserts: - self._load_fixtures() - self.insert_data() + if cls.run_define_tables == 'once': + cls.define_tables(cls.metadata) + cls.metadata.create_all() + cls.tables.update(cls.metadata.tables) - def setUp(self): - if self._sa_first_test: - return + if cls.run_inserts == 'once': + cls._load_fixtures() + cls.insert_data() + def setup(self): cls = self.__class__ if self.setup_bind == 'each': @@ -79,7 +77,7 @@ class TablesTest(testing.TestBase): self._load_fixtures() self.insert_data() - def tearDown(self): + def teardown(self): # no need to run deletes if tables are recreated on setup if self.run_define_tables != 'each' and self.run_deletes: for table in reversed(self.metadata.sorted_tables): @@ -92,33 +90,39 @@ class TablesTest(testing.TestBase): if self.run_dispose_bind == 'each': self.dispose_bind(self.bind) - def tearDownAll(self): - self.metadata.drop_all() + @classmethod + def teardown_class(cls): + cls.metadata.drop_all() - if self.dispose_bind: - self.dispose_bind(self.bind) + if cls.dispose_bind: + cls.dispose_bind(cls.bind) - self.metadata.bind = None + cls.metadata.bind = None - if self.run_setup_bind is not None: - self.bind = None + if cls.run_setup_bind is not None: + cls.bind = None - def setup_bind(self): + @classmethod + def setup_bind(cls): return testing.db - def dispose_bind(self, bind): + @classmethod + def dispose_bind(cls, bind): if hasattr(bind, 'dispose'): bind.dispose() elif hasattr(bind, 'close'): bind.close() - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): raise NotImplementedError() - def fixtures(self): + @classmethod + def fixtures(cls): return {} - def insert_data(self): + @classmethod + def insert_data(cls): pass def sql_count_(self, count, fn): @@ -147,14 +151,17 @@ class TablesTest(testing.TestBase): class AltEngineTest(testing.TestBase): engine = None - def setUpAll(self): - type(self).engine = self.create_engine() - testing.TestBase.setUpAll(self) - - def tearDownAll(self): - testing.TestBase.tearDownAll(self) - self.engine.dispose() - type(self).engine = None - - def create_engine(self): + @classmethod + def setup_class(cls): + cls.engine = cls.create_engine() + super(AltEngineTest, cls).setup_class() + + @classmethod + def teardown_class(cls): + cls.engine.dispose() + cls.engine = None + super(AltEngineTest, cls).teardown_class() + + @classmethod + def create_engine(cls): raise NotImplementedError diff --git a/test/engine/alltests.py b/test/engine/alltests.py deleted file mode 100644 index ed722aa3b5..0000000000 --- a/test/engine/alltests.py +++ /dev/null @@ -1,31 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - - -def suite(): - modules_to_test = ( - # connectivity, execution - 'engine.parseconnect', - 'engine.pool', - 'engine.bind', - 'engine.reconnect', - 'engine.execute', - 'engine.metadata', - 'engine.transaction', - - # schema/tables - 'engine.reflection', - 'engine.ddlevents', - - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/engine/bind.py b/test/engine/test_bind.py similarity index 96% rename from test/engine/bind.py rename to test/engine/test_bind.py index 5b8605aada..7fd3009bca 100644 --- a/test/engine/bind.py +++ b/test/engine/test_bind.py @@ -1,11 +1,14 @@ """tests the "bind" attribute/argument across schema and SQL, including the deprecated versions of these arguments""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ from sqlalchemy import engine, exc from sqlalchemy import MetaData, ThreadLocalMetaData -from testlib.sa import Table, Column, Integer, text -from testlib import sa, testing +from sqlalchemy import Integer, text +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +import sqlalchemy as sa +from sqlalchemy.test import testing class BindTest(testing.TestBase): @@ -43,7 +46,7 @@ class BindTest(testing.TestBase): meth() assert False except exc.UnboundExecutionError, e: - self.assertEquals( + eq_( str(e), "The MetaData " "is not bound to an Engine or Connection. " @@ -61,7 +64,7 @@ class BindTest(testing.TestBase): meth() assert False except exc.UnboundExecutionError, e: - self.assertEquals( + eq_( str(e), "The Table 'test_table' " "is not bound to an Engine or Connection. " @@ -85,7 +88,7 @@ class BindTest(testing.TestBase): meth() assert False except exc.UnboundExecutionError, e: - self.assertEquals( + eq_( str(e), "The Table 'test_table' " "is not bound to an Engine or Connection. " @@ -219,5 +222,3 @@ class BindTest(testing.TestBase): metadata.drop_all(bind=testing.db) -if __name__ == '__main__': - testenv.main() diff --git a/test/engine/ddlevents.py b/test/engine/test_ddlevents.py similarity index 92% rename from test/engine/ddlevents.py rename to test/engine/test_ddlevents.py index 8274c63476..5716006d93 100644 --- a/test/engine/ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -1,9 +1,11 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy.schema import DDL from sqlalchemy import create_engine -from testlib.sa import MetaData, Table, Column, Integer, String -import testlib.sa as tsa -from testlib import TestBase, testing, engines +from sqlalchemy import MetaData, Integer, String +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +import sqlalchemy as tsa +from sqlalchemy.test import TestBase, testing, engines class DDLEventTest(TestBase): @@ -37,7 +39,7 @@ class DDLEventTest(TestBase): assert bind is self.bind self.state = action - def setUp(self): + def setup(self): self.bind = engines.mock_engine() self.metadata = MetaData() self.table = Table('t', self.metadata, Column('id', Integer)) @@ -174,14 +176,14 @@ class DDLEventTest(TestBase): fn = lambda *a: None table.append_ddl_listener('before-create', fn) - self.assertRaises(LookupError, table.append_ddl_listener, 'blah', fn) + assert_raises(LookupError, table.append_ddl_listener, 'blah', fn) metadata.append_ddl_listener('before-create', fn) - self.assertRaises(LookupError, metadata.append_ddl_listener, 'blah', fn) + assert_raises(LookupError, metadata.append_ddl_listener, 'blah', fn) class DDLExecutionTest(TestBase): - def setUp(self): + def setup(self): self.engine = engines.mock_engine() self.metadata = MetaData(self.engine) self.users = Table('users', self.metadata, @@ -303,19 +305,19 @@ class DDLTest(TestBase): ddl = DDL('%(schema)s-%(table)s-%(fullname)s') - self.assertEquals(ddl._expand(sane_alone, bind), '-t-t') - self.assertEquals(ddl._expand(sane_schema, bind), 's-t-s.t') - self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"') - self.assertEquals(ddl._expand(insane_schema, bind), + eq_(ddl._expand(sane_alone, bind), '-t-t') + eq_(ddl._expand(sane_schema, bind), 's-t-s.t') + eq_(ddl._expand(insane_alone, bind), '-"t t"-"t t"') + eq_(ddl._expand(insane_schema, bind), '"s s"-"t t"-"s s"."t t"') # overrides are used piece-meal and verbatim. ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s', context={'schema':'S S', 'table': 'T T', 'bonus': 'b'}) - self.assertEquals(ddl._expand(sane_alone, bind), 'S S-T T-t-b') - self.assertEquals(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b') - self.assertEquals(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b') - self.assertEquals(ddl._expand(insane_schema, bind), + eq_(ddl._expand(sane_alone, bind), 'S S-T T-t-b') + eq_(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b') + eq_(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b') + eq_(ddl._expand(insane_schema, bind), 'S S-T T-"s s"."t t"-b') def test_filter(self): cx = self.mock_engine() @@ -338,5 +340,3 @@ class DDLTest(TestBase): assert repr(DDL('s', on='engine', context={'a':1})) -if __name__ == "__main__": - testenv.main() diff --git a/test/engine/execute.py b/test/engine/test_execute.py similarity index 94% rename from test/engine/execute.py rename to test/engine/test_execute.py index 515c99d309..08bf80fe2f 100644 --- a/test/engine/execute.py +++ b/test/engine/test_execute.py @@ -1,15 +1,17 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import re from sqlalchemy.interfaces import ConnectionProxy -from testlib.sa import MetaData, Table, Column, Integer, String, INT, \ - VARCHAR, func, bindparam -import testlib.sa as tsa -from testlib import TestBase, testing, engines +from sqlalchemy import MetaData, Integer, String, INT, VARCHAR, func, bindparam +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +import sqlalchemy as tsa +from sqlalchemy.test import TestBase, testing, engines users, metadata = None, None class ExecuteTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global users, metadata metadata = MetaData(testing.db) users = Table('users', metadata, @@ -18,9 +20,10 @@ class ExecuteTest(TestBase): ) metadata.create_all() - def tearDown(self): + def teardown(self): testing.db.connect().execute(users.delete()) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite') @@ -82,7 +85,7 @@ class ExecuteTest(TestBase): def test_empty_insert(self): """test that execute() interprets [] as a list with no params""" result = testing.db.execute(users.insert().values(user_name=bindparam('name')), []) - self.assertEquals(result.rowcount, 1) + eq_(result.rowcount, 1) class ProxyConnectionTest(TestBase): @testing.fails_on('firebird', 'Data type unknown') @@ -162,5 +165,3 @@ class ProxyConnectionTest(TestBase): assert_stmts(cursor, cursor_stmts) -if __name__ == "__main__": - testenv.main() diff --git a/test/engine/metadata.py b/test/engine/test_metadata.py similarity index 93% rename from test/engine/metadata.py rename to test/engine/test_metadata.py index c8fc6f7e0f..024d1b854f 100644 --- a/test/engine/metadata.py +++ b/test/engine/test_metadata.py @@ -1,11 +1,12 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import assert_raises, assert_raises_message import pickle from sqlalchemy import MetaData -from testlib.sa import Table, Column, Integer, String, UniqueConstraint, \ - CheckConstraint, ForeignKey -import testlib.sa as tsa -from testlib import TestBase, ComparesTables, testing, engines -from testlib.testing import eq_ +from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +import sqlalchemy as tsa +from sqlalchemy.test import TestBase, ComparesTables, testing, engines +from sqlalchemy.test.testing import eq_ class MetaDataTest(TestBase, ComparesTables): def test_metadata_connect(self): @@ -137,13 +138,13 @@ class MetaDataTest(TestBase, ComparesTables): def test_nonexistent(self): - self.assertRaises(tsa.exc.NoSuchTableError, Table, + assert_raises(tsa.exc.NoSuchTableError, Table, 'fake_table', MetaData(testing.db), autoload=True) class TableOptionsTest(TestBase): - def setUp(self): + def setup(self): self.engine = engines.mock_engine() self.metadata = MetaData(self.engine) @@ -160,5 +161,3 @@ class TableOptionsTest(TestBase): table2.create() assert [str(x) for x in self.engine.mock if 'CREATE VIRTUAL TABLE' in str(x)] -if __name__ == '__main__': - testenv.main() diff --git a/test/engine/parseconnect.py b/test/engine/test_parseconnect.py similarity index 98% rename from test/engine/parseconnect.py rename to test/engine/test_parseconnect.py index c82ca6d58d..6b7ac37b20 100644 --- a/test/engine/parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -1,9 +1,8 @@ -import testenv; testenv.configure_for_tests() import ConfigParser, StringIO import sqlalchemy.engine.url as url from sqlalchemy import create_engine, engine_from_config -import testlib.sa as tsa -from testlib import TestBase +import sqlalchemy as tsa +from sqlalchemy.test import TestBase class ParseConnectTest(TestBase): @@ -229,5 +228,3 @@ class MockCursor(object): pass mock_dbapi = MockDBAPI() -if __name__ == "__main__": - testenv.main() diff --git a/test/engine/pool.py b/test/engine/test_pool.py similarity index 99% rename from test/engine/pool.py rename to test/engine/test_pool.py index b712e24128..43a0fc38b7 100644 --- a/test/engine/pool.py +++ b/test/engine/test_pool.py @@ -1,8 +1,7 @@ -import testenv; testenv.configure_for_tests() import threading, time, gc from sqlalchemy import pool, interfaces -import testlib.sa as tsa -from testlib import TestBase +import sqlalchemy as tsa +from sqlalchemy.test import TestBase mcid = 1 @@ -37,10 +36,11 @@ mock_dbapi = MockDBAPI() class PoolTestBase(TestBase): - def setUp(self): + def setup(self): pool.clear_managers() - def tearDownAll(self): + @classmethod + def teardown_class(cls): pool.clear_managers() class PoolTest(PoolTestBase): @@ -662,5 +662,3 @@ class NullPoolTest(PoolTestBase): -if __name__ == "__main__": - testenv.main() diff --git a/test/engine/reconnect.py b/test/engine/test_reconnect.py similarity index 88% rename from test/engine/reconnect.py rename to test/engine/test_reconnect.py index 4f383d2dde..3a525c2a70 100644 --- a/test/engine/reconnect.py +++ b/test/engine/test_reconnect.py @@ -1,8 +1,10 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import weakref -from testlib.sa import select, MetaData, Table, Column, Integer, String, pool -import testlib.sa as tsa -from testlib import TestBase, testing, engines +from sqlalchemy import select, MetaData, Integer, String, pool +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +import sqlalchemy as tsa +from sqlalchemy.test import TestBase, testing, engines import time import gc @@ -47,7 +49,7 @@ class MockCursor(object): db, dbapi = None, None class MockReconnectTest(TestBase): - def setUp(self): + def setup(self): global db, dbapi dbapi = MockDBAPI() @@ -176,17 +178,17 @@ class MockReconnectTest(TestBase): engine = None class RealReconnectTest(TestBase): - def setUp(self): + def setup(self): global engine engine = engines.reconnecting_engine() - def tearDown(self): + def teardown(self): engine.dispose() def test_reconnect(self): conn = engine.connect() - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed engine.test_shutdown() @@ -202,7 +204,7 @@ class RealReconnectTest(TestBase): assert conn.invalidated assert conn.invalidated - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated # one more time @@ -214,7 +216,7 @@ class RealReconnectTest(TestBase): if not e.connection_invalidated: raise assert conn.invalidated - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated conn.close() @@ -222,7 +224,7 @@ class RealReconnectTest(TestBase): def test_null_pool(self): engine = engines.reconnecting_engine(options=dict(poolclass=pool.NullPool)) conn = engine.connect() - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed engine.test_shutdown() try: @@ -233,12 +235,12 @@ class RealReconnectTest(TestBase): raise assert not conn.closed assert conn.invalidated - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated def test_close(self): conn = engine.connect() - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed engine.test_shutdown() @@ -252,14 +254,14 @@ class RealReconnectTest(TestBase): conn.close() conn = engine.connect() - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) def test_with_transaction(self): conn = engine.connect() trans = conn.begin() - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) assert not conn.closed engine.test_shutdown() @@ -295,7 +297,7 @@ class RealReconnectTest(TestBase): assert not trans.is_active assert conn.invalidated - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) assert not conn.invalidated class RecycleTest(TestBase): @@ -304,19 +306,19 @@ class RecycleTest(TestBase): engine = engines.reconnecting_engine(options={'pool_recycle':1, 'pool_threadlocal':threadlocal}) conn = engine.contextual_connect() - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) conn.close() engine.test_shutdown() time.sleep(2) conn = engine.contextual_connect() - self.assertEquals(conn.execute(select([1])).scalar(), 1) + eq_(conn.execute(select([1])).scalar(), 1) conn.close() meta, table, engine = None, None, None class InvalidateDuringResultTest(TestBase): - def setUp(self): + def setup(self): global meta, table, engine engine = engines.reconnecting_engine() meta = MetaData(engine) @@ -328,7 +330,7 @@ class InvalidateDuringResultTest(TestBase): [{'id':i, 'name':'row %d' % i} for i in range(1, 100)] ) - def tearDown(self): + def teardown(self): meta.drop_all() engine.dispose() @@ -350,5 +352,3 @@ class InvalidateDuringResultTest(TestBase): assert conn.invalidated -if __name__ == '__main__': - testenv.main() diff --git a/test/engine/reflection.py b/test/engine/test_reflection.py similarity index 96% rename from test/engine/reflection.py rename to test/engine/test_reflection.py index d8412237fb..ea80776a6a 100644 --- a/test/engine/reflection.py +++ b/test/engine/test_reflection.py @@ -1,8 +1,11 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import StringIO, unicodedata import sqlalchemy as sa -from testlib.sa import MetaData, Table, Column -from testlib import TestBase, ComparesTables, testing, engines, sa as tsa +from sqlalchemy import MetaData +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +import sqlalchemy as tsa +from sqlalchemy.test import TestBase, ComparesTables, testing, engines metadata, users = None, None @@ -62,7 +65,7 @@ class ReflectionTest(TestBase, ComparesTables): foo = Table('foo', meta2, autoload=True, include_columns=['b', 'f', 'e']) # test that cols come back in original order - self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f']) + eq_([c.name for c in foo.c], ['b', 'e', 'f']) for c in ('b', 'f', 'e'): assert c in foo.c for c in ('a', 'c', 'd'): @@ -73,7 +76,7 @@ class ReflectionTest(TestBase, ComparesTables): foo = Table('foo', meta3, autoload=True) foo = Table('foo', meta3, include_columns=['b', 'f', 'e'], useexisting=True) - self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f']) + eq_([c.name for c in foo.c], ['b', 'e', 'f']) for c in ('b', 'f', 'e'): assert c in foo.c for c in ('a', 'c', 'd'): @@ -103,7 +106,7 @@ class ReflectionTest(TestBase, ComparesTables): dialect_module.ischema_names = {} try: m2 = MetaData(testing.db) - self.assertRaises(tsa.exc.SAWarning, Table, "test", m2, autoload=True) + assert_raises(tsa.exc.SAWarning, Table, "test", m2, autoload=True) @testing.emits_warning('Did not recognize type') def warns(): @@ -282,7 +285,7 @@ class ReflectionTest(TestBase, ComparesTables): a2 = Table('a', m2, include_columns=['z'], autoload=True) b2 = Table('b', m2, autoload=True) - self.assertRaises(tsa.exc.NoReferencedColumnError, a2.join, b2) + assert_raises(tsa.exc.NoReferencedColumnError, a2.join, b2) finally: meta.drop_all() @@ -405,7 +408,7 @@ class ReflectionTest(TestBase, ComparesTables): Column('slot', sa.String(128)), ) - self.assertRaisesMessage(tsa.exc.InvalidRequestError, "Could not find table 'pkgs' with which to generate a foreign key", metadata.create_all) + assert_raises_message(tsa.exc.InvalidRequestError, "Could not find table 'pkgs' with which to generate a foreign key", metadata.create_all) def test_composite_pks(self): """test reflection of a composite primary key""" @@ -608,7 +611,8 @@ class ReflectionTest(TestBase, ComparesTables): m1.drop_all() class CreateDropTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, users metadata = MetaData() users = Table('users', metadata, @@ -656,16 +660,16 @@ class CreateDropTest(TestBase): def test_createdrop(self): metadata.create_all(bind=testing.db) - self.assertEqual( testing.db.has_table('items'), True ) - self.assertEqual( testing.db.has_table('email_addresses'), True ) + eq_( testing.db.has_table('items'), True ) + eq_( testing.db.has_table('email_addresses'), True ) metadata.create_all(bind=testing.db) - self.assertEqual( testing.db.has_table('items'), True ) + eq_( testing.db.has_table('items'), True ) metadata.drop_all(bind=testing.db) - self.assertEqual( testing.db.has_table('items'), False ) - self.assertEqual( testing.db.has_table('email_addresses'), False ) + eq_( testing.db.has_table('items'), False ) + eq_( testing.db.has_table('email_addresses'), False ) metadata.drop_all(bind=testing.db) - self.assertEqual( testing.db.has_table('items'), False ) + eq_( testing.db.has_table('items'), False ) def test_tablenames(self): metadata.create_all(bind=testing.db) @@ -800,7 +804,8 @@ class SchemaTest(TestBase): class HasSequenceTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, users metadata = MetaData() users = Table('users', metadata, @@ -811,10 +816,8 @@ class HasSequenceTest(TestBase): @testing.requires.sequences def test_hassequence(self): metadata.create_all(bind=testing.db) - self.assertEqual(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), True) + eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), True) metadata.drop_all(bind=testing.db) - self.assertEqual(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False) + eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False) -if __name__ == "__main__": - testenv.main() diff --git a/test/engine/transaction.py b/test/engine/test_transaction.py similarity index 96% rename from test/engine/transaction.py rename to test/engine/test_transaction.py index 1fa3856108..7d40adf6d0 100644 --- a/test/engine/transaction.py +++ b/test/engine/test_transaction.py @@ -1,13 +1,15 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import sys, time, threading -from testlib.sa import create_engine, MetaData, Table, Column, INT, VARCHAR, \ - Sequence, select, Integer, String, func, text -from testlib import TestBase, testing +from sqlalchemy import create_engine, MetaData, INT, VARCHAR, Sequence, select, Integer, String, func, text +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.test import TestBase, testing users, metadata = None, None class TransactionTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global users, metadata metadata = MetaData() users = Table('query_users', metadata, @@ -17,9 +19,10 @@ class TransactionTest(TestBase): ) users.create(testing.db) - def tearDown(self): + def teardown(self): testing.db.connect().execute(users.delete()) - def tearDownAll(self): + @classmethod + def teardown_class(cls): users.drop(testing.db) def test_commits(self): @@ -166,7 +169,7 @@ class TransactionTest(TestBase): connection.execute(users.insert(), user_id=3, user_name='user3') transaction.commit() - self.assertEquals( + eq_( connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(3,)] ) @@ -183,7 +186,7 @@ class TransactionTest(TestBase): connection.execute(users.insert(), user_id=3, user_name='user3') transaction.commit() - self.assertEquals( + eq_( connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(2,),(3,)] ) @@ -202,7 +205,7 @@ class TransactionTest(TestBase): connection.execute(users.insert(), user_id=4, user_name='user4') transaction.commit() - self.assertEquals( + eq_( connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(4,)] ) @@ -230,7 +233,7 @@ class TransactionTest(TestBase): transaction.prepare() transaction.rollback() - self.assertEquals( + eq_( connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(2,)] ) @@ -264,7 +267,7 @@ class TransactionTest(TestBase): transaction.commit() - self.assertEquals( + eq_( connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(2,),(5,)] ) @@ -285,19 +288,17 @@ class TransactionTest(TestBase): connection.close() connection2 = testing.db.connect() - self.assertEquals( + eq_( connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [] ) recoverables = connection2.recover_twophase() - self.assertTrue( - transaction.xid in recoverables - ) + assert transaction.xid in recoverables connection2.commit_prepared(transaction.xid, recover=True) - self.assertEquals( + eq_( connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,)] ) @@ -327,16 +328,18 @@ class TransactionTest(TestBase): xa.commit() result = conn.execute(select([users.c.user_name]).order_by(users.c.user_id)) - self.assertEqual(result.fetchall(), [('user1',),('user4',)]) + eq_(result.fetchall(), [('user1',),('user4',)]) conn.close() class AutoRollbackTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata metadata = MetaData() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all(testing.db) def test_rollback_deadlock(self): @@ -368,17 +371,19 @@ class ExplicitAutoCommitTest(TestBase): __only_on__ = 'postgres' - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, foo metadata = MetaData(testing.db) foo = Table('foo', metadata, Column('id', Integer, primary_key=True), Column('data', String(100))) metadata.create_all() testing.db.execute("create function insert_foo(varchar) returns integer as 'insert into foo(data) values ($1);select 1;' language sql") - def tearDown(self): + def teardown(self): foo.delete().execute() - def tearDownAll(self): + @classmethod + def teardown_class(cls): testing.db.execute("drop function insert_foo(varchar)") metadata.drop_all() @@ -437,7 +442,8 @@ class ExplicitAutoCommitTest(TestBase): tlengine = None class TLTransactionTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global users, metadata, tlengine tlengine = create_engine(testing.db.url, strategy='threadlocal') metadata = MetaData() @@ -447,15 +453,16 @@ class TLTransactionTest(TestBase): test_needs_acid=True, ) users.create(tlengine) - def tearDown(self): + def teardown(self): tlengine.execute(users.delete()) - def tearDownAll(self): + @classmethod + def teardown_class(cls): users.drop(tlengine) tlengine.dispose() def test_nested_unsupported(self): - self.assertRaises(NotImplementedError, tlengine.contextual_connect().begin_nested) - self.assertRaises(NotImplementedError, tlengine.begin_nested) + assert_raises(NotImplementedError, tlengine.contextual_connect().begin_nested) + assert_raises(NotImplementedError, tlengine.begin_nested) def test_connection_close(self): """test that when connections are closed for real, transactions are rolled back and disposed.""" @@ -688,14 +695,15 @@ class TLTransactionTest(TestBase): tlengine.prepare() tlengine.rollback() - self.assertEquals( + eq_( tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(), [(1,),(2,)] ) counters = None class ForUpdateTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global counters, metadata metadata = MetaData() counters = Table('forupdate_counters', metadata, @@ -704,9 +712,10 @@ class ForUpdateTest(TestBase): test_needs_acid=True, ) counters.create(testing.db) - def tearDown(self): + def teardown(self): testing.db.connect().execute(counters.delete()) - def tearDownAll(self): + @classmethod + def teardown_class(cls): counters.drop(testing.db) def increment(self, count, errors, update_style=True, delay=0.005): @@ -829,5 +838,3 @@ class ForUpdateTest(TestBase): self.assert_(len(errors) != 0) -if __name__ == "__main__": - testenv.main() diff --git a/test/ext/alltests.py b/test/ext/alltests.py deleted file mode 100644 index 9f5353e04f..0000000000 --- a/test/ext/alltests.py +++ /dev/null @@ -1,36 +0,0 @@ -import testenv; testenv.configure_for_tests() -import doctest, sys - -from testlib import sa_unittest as unittest - - -def suite(): - unittest_modules = ( - 'ext.declarative', - 'ext.orderinglist', - 'ext.associationproxy', - 'ext.serializer', - 'ext.compiler', - ) - - if sys.version_info < (2, 4): - doctest_modules = () - else: - doctest_modules = ( - ('sqlalchemy.ext.orderinglist', {'optionflags': doctest.ELLIPSIS}), - ('sqlalchemy.ext.sqlsoup', {}) - ) - - alltests = unittest.TestSuite() - for name in unittest_modules: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - for name, opts in doctest_modules: - alltests.addTest(doctest.DocTestSuite(name, **opts)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/ext/associationproxy.py b/test/ext/test_associationproxy.py similarity index 97% rename from test/ext/associationproxy.py rename to test/ext/test_associationproxy.py index 821ed90721..742f98baf8 100644 --- a/test/ext/associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -1,10 +1,10 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import gc from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.collections import collection from sqlalchemy.ext.associationproxy import * -from testlib import * +from sqlalchemy.test import * class DictCollection(dict): @@ -34,7 +34,7 @@ class ObjectCollection(object): return iter(self.values) class _CollectionOperations(TestBase): - def setUp(self): + def setup(self): collection_class = self.collection_class metadata = MetaData(testing.db) @@ -77,7 +77,7 @@ class _CollectionOperations(TestBase): self.session = create_session() self.Parent, self.Child = Parent, Child - def tearDown(self): + def teardown(self): self.metadata.drop_all() def roundtrip(self, obj): @@ -189,7 +189,7 @@ class _CollectionOperations(TestBase): self.assert_(p1.children == after) self.assert_([c.name for c in p1._children] == after) - self.assertRaises(TypeError, set, [p1.children]) + assert_raises(TypeError, set, [p1.children]) p1.children *= 0 after = [] @@ -342,7 +342,7 @@ class CustomDictTest(DictTest): except TypeError: self.assert_(True) - self.assertRaises(TypeError, set, [p1.children]) + assert_raises(TypeError, set, [p1.children]) class SetTest(_CollectionOperations): @@ -458,7 +458,7 @@ class SetTest(_CollectionOperations): except TypeError: self.assert_(True) - self.assertRaises(TypeError, set, [p1.children]) + assert_raises(TypeError, set, [p1.children]) def test_set_comparisons(self): @@ -473,19 +473,19 @@ class SetTest(_CollectionOperations): set(['c','d']), set(['e', 'f', 'g']), set()): - self.assertEqual(p1.children.union(other), + eq_(p1.children.union(other), control.union(other)) - self.assertEqual(p1.children.difference(other), + eq_(p1.children.difference(other), control.difference(other)) - self.assertEqual((p1.children - other), + eq_((p1.children - other), (control - other)) - self.assertEqual(p1.children.intersection(other), + eq_(p1.children.intersection(other), control.intersection(other)) - self.assertEqual(p1.children.symmetric_difference(other), + eq_(p1.children.symmetric_difference(other), control.symmetric_difference(other)) - self.assertEqual(p1.children.issubset(other), + eq_(p1.children.issubset(other), control.issubset(other)) - self.assertEqual(p1.children.issuperset(other), + eq_(p1.children.issuperset(other), control.issuperset(other)) self.assert_((p1.children == other) == (control == other)) @@ -714,7 +714,7 @@ class ScalarTest(TestBase): class LazyLoadTest(TestBase): - def setUp(self): + def setup(self): metadata = MetaData(testing.db) parents_table = Table('Parent', metadata, @@ -748,7 +748,7 @@ class LazyLoadTest(TestBase): self.Parent, self.Child = Parent, Child self.table = parents_table - def tearDown(self): + def teardown(self): self.metadata.drop_all() def roundtrip(self, obj): @@ -823,7 +823,7 @@ class LazyLoadTest(TestBase): class ReconstitutionTest(TestBase): - def setUp(self): + def setup(self): metadata = MetaData(testing.db) parents = Table('parents', metadata, Column('id', Integer, primary_key=True, @@ -852,7 +852,7 @@ class ReconstitutionTest(TestBase): self.metadata = metadata self.Parent = Parent - def tearDown(self): + def teardown(self): self.metadata.drop_all() def test_weak_identity_map(self): @@ -883,5 +883,3 @@ class ReconstitutionTest(TestBase): assert set(p_copy.kids) == set(['c1', 'c2']), p.kids -if __name__ == "__main__": - testenv.main() diff --git a/test/ext/compiler.py b/test/ext/test_compiler.py similarity index 97% rename from test/ext/compiler.py rename to test/ext/test_compiler.py index 370ea62ab0..ce25490998 100644 --- a/test/ext/compiler.py +++ b/test/ext/test_compiler.py @@ -1,9 +1,8 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.sql.expression import ClauseElement, ColumnClause from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import table, column -from testlib import * +from sqlalchemy.test import * class UserDefinedTest(TestBase, AssertsCompiledSQL): @@ -122,5 +121,3 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL): "DROP THINGY", ) -if __name__ == '__main__': - testenv.main() diff --git a/test/ext/declarative.py b/test/ext/test_declarative.py similarity index 96% rename from test/ext/declarative.py rename to test/ext/test_declarative.py index f5130b2153..c49c00cec0 100644 --- a/test/ext/declarative.py +++ b/test/ext/test_declarative.py @@ -1,21 +1,24 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy.ext import declarative as decl from sqlalchemy import exc -from testlib import sa, testing -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index -from testlib.sa.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred -from testlib.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import MetaData, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred +from sqlalchemy.test.testing import eq_ -from orm._base import ComparableEntity, MappedTest +from test.orm._base import ComparableEntity, MappedTest class DeclarativeTestBase(testing.TestBase, testing.AssertsExecutionResults): - def setUp(self): + def setup(self): global Base Base = decl.declarative_base(testing.db) - def tearDown(self): + def teardown(self): clear_mappers() Base.metadata.drop_all() @@ -64,7 +67,7 @@ class DeclarativeTest(DeclarativeTestBase): def go(): class User(Base): id = Column('id', Integer, primary_key=True) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "does not have a __table__", go) + assert_raises_message(sa.exc.InvalidRequestError, "does not have a __table__", go) def test_cant_add_columns(self): t = Table('t', Base.metadata, Column('id', Integer, primary_key=True), Column('data', String)) @@ -73,7 +76,7 @@ class DeclarativeTest(DeclarativeTestBase): __table__ = t foo = Column(Integer, primary_key=True) # can't specify new columns not already in the table - self.assertRaisesMessage(sa.exc.ArgumentError, "Can't add additional column 'foo' when specifying __table__", go) + assert_raises_message(sa.exc.ArgumentError, "Can't add additional column 'foo' when specifying __table__", go) # regular re-mapping works tho class Bar(Base): @@ -144,7 +147,7 @@ class DeclarativeTest(DeclarativeTestBase): sess.add(u1) sess.flush() sess.expunge_all() - self.assertEquals(sess.query(User).filter(User.name == 'ed').one(), + eq_(sess.query(User).filter(User.name == 'ed').one(), User(name='ed', addresses=[Address(email='xyz'), Address(email='def'), Address(email='abc')]) ) @@ -152,7 +155,7 @@ class DeclarativeTest(DeclarativeTestBase): __tablename__ = 'foo' id = Column(Integer, primary_key=True) rel = relation("User", primaryjoin="User.addresses==Foo.id") - self.assertRaisesMessage(exc.InvalidRequestError, "'addresses' is not an instance of ColumnProperty", compile_mappers) + assert_raises_message(exc.InvalidRequestError, "'addresses' is not an instance of ColumnProperty", compile_mappers) def test_string_dependency_resolution_in_backref(self): class User(Base, ComparableEntity): @@ -206,7 +209,7 @@ class DeclarativeTest(DeclarativeTestBase): sess.add(u1) sess.flush() sess.expunge_all() - self.assertEquals(sess.query(User).filter(User.name == 'ed').one(), + eq_(sess.query(User).filter(User.name == 'ed').one(), User(name='ed', addresses=[Address(email='abc'), Address(email='def'), Address(email='xyz')]) ) @@ -224,7 +227,7 @@ class DeclarativeTest(DeclarativeTestBase): # this used to raise an error when accessing User.id but that's no longer the case # since we got rid of _CompileOnAttr. - self.assertRaises(sa.exc.ArgumentError, compile_mappers) + assert_raises(sa.exc.ArgumentError, compile_mappers) def test_nice_dependency_error_works_with_hasattr(self): class User(Base): @@ -235,7 +238,7 @@ class DeclarativeTest(DeclarativeTestBase): # hasattr() on a compile-loaded attribute hasattr(User.addresses, 'property') # the exeption is preserved - self.assertRaisesMessage(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers) + assert_raises_message(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers) def test_custom_base(self): class MyBase(object): @@ -255,7 +258,7 @@ class DeclarativeTest(DeclarativeTestBase): i = Index('my_index', User.name) # compile fails due to the nonexistent Addresses relation - self.assertRaises(sa.exc.InvalidRequestError, compile_mappers) + assert_raises(sa.exc.InvalidRequestError, compile_mappers) # index configured assert i in User.__table__.indexes @@ -440,7 +443,7 @@ class DeclarativeTest(DeclarativeTestBase): id = Column('id', Integer, primary_key=True), name = Column('name', String(50)) assert False - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Mapper Mapper|User|users could not assemble any primary key", define) @@ -1204,7 +1207,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): __mapper_args__ = {'polymorphic_identity':'engineer'} primary_language = Column('primary_language', String(50)) foo_bar = Column(Integer, primary_key=True) - self.assertRaisesMessage(sa.exc.ArgumentError, "place primary key", go) + assert_raises_message(sa.exc.ArgumentError, "place primary key", go) def test_single_no_table_args(self): class Person(Base, ComparableEntity): @@ -1219,7 +1222,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): __mapper_args__ = {'polymorphic_identity':'engineer'} primary_language = Column('primary_language', String(50)) __table_args__ = () - self.assertRaisesMessage(sa.exc.ArgumentError, "place __table_args__", go) + assert_raises_message(sa.exc.ArgumentError, "place __table_args__", go) def test_concrete(self): engineers = Table('engineers', Base.metadata, @@ -1270,10 +1273,11 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): ) -def produce_test(inline, stringbased): +def _produce_test(inline, stringbased): class ExplicitJoinTest(MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global User, Address Base = decl.declarative_base(metadata=metadata) @@ -1300,7 +1304,8 @@ def produce_test(inline, stringbased): else: Address.user = relation(User, primaryjoin=User.id==Address.user_id, backref="addresses") - def insert_data(self): + @classmethod + def insert_data(cls): params = [dict(zip(('id', 'name'), column_values)) for column_values in [(7, 'jack'), (8, 'ed'), @@ -1337,12 +1342,13 @@ def produce_test(inline, stringbased): for inline in (True, False): for stringbased in (True, False): - testclass = produce_test(inline, stringbased) + testclass = _produce_test(inline, stringbased) exec("%s = testclass" % testclass.__name__) del testclass class DeclarativeReflectionTest(testing.TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global reflection_metadata reflection_metadata = MetaData(testing.db) @@ -1364,15 +1370,16 @@ class DeclarativeReflectionTest(testing.TestBase): reflection_metadata.create_all() - def setUp(self): + def setup(self): global Base Base = decl.declarative_base(testing.db) - def tearDown(self): + def teardown(self): for t in reversed(reflection_metadata.sorted_tables): t.delete().execute() - def tearDownAll(self): + @classmethod + def teardown_class(cls): reflection_metadata.drop_all() def test_basic(self): @@ -1436,7 +1443,7 @@ class DeclarativeReflectionTest(testing.TestBase): eq_(a1, Address(email='two')) eq_(a1.user, User(nom='u1')) - self.assertRaises(TypeError, User, name='u3') + assert_raises(TypeError, User, name='u3') def test_supplied_fk(self): meta = MetaData(testing.db) diff --git a/test/ext/orderinglist.py b/test/ext/test_orderinglist.py similarity index 98% rename from test/ext/orderinglist.py rename to test/ext/test_orderinglist.py index c111a02de6..4adc779606 100644 --- a/test/ext/orderinglist.py +++ b/test/ext/test_orderinglist.py @@ -1,9 +1,8 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.orderinglist import * -from testlib.testing import eq_ -from testlib import * +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test import * metadata = None @@ -39,7 +38,7 @@ def alpha_ordering(index, collection): return s class OrderingListTest(TestBase): - def setUp(self): + def setup(self): global metadata, slides_table, bullets_table, Slide, Bullet slides_table, bullets_table = None, None Slide, Bullet = None, None @@ -87,7 +86,7 @@ class OrderingListTest(TestBase): metadata.create_all() - def tearDown(self): + def teardown(self): metadata.drop_all() def test_append_no_reorder(self): @@ -399,5 +398,3 @@ class OrderingListTest(TestBase): self.assert_(alpha[li].position == pos) -if __name__ == "__main__": - testenv.main() diff --git a/test/ext/serializer.py b/test/ext/test_serializer.py similarity index 89% rename from test/ext/serializer.py rename to test/ext/test_serializer.py index 048eccdfd1..b8a8e3fef9 100644 --- a/test/ext/serializer.py +++ b/test/ext/test_serializer.py @@ -1,13 +1,15 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy.ext import serializer from sqlalchemy import exc -from testlib import sa, testing -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, select, desc, func, util -from testlib.sa.orm import relation, sessionmaker, scoped_session, class_mapper, mapper, eagerload, compile_mappers, aliased -from testlib.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import MetaData, Integer, String, ForeignKey, select, desc, func, util +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import relation, sessionmaker, scoped_session, class_mapper, mapper, eagerload, compile_mappers, aliased +from sqlalchemy.test.testing import eq_ -from orm._base import ComparableEntity, MappedTest +from test.orm._base import ComparableEntity, MappedTest class User(ComparableEntity): @@ -21,7 +23,8 @@ class SerializeTest(MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global users, addresses users = Table('users', metadata, Column('id', Integer, primary_key=True), @@ -33,7 +36,8 @@ class SerializeTest(MappedTest): Column('user_id', Integer, ForeignKey('users.id')), ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): global Session Session = scoped_session(sessionmaker()) @@ -44,7 +48,8 @@ class SerializeTest(MappedTest): compile_mappers() - def insert_data(self): + @classmethod + def insert_data(cls): params = [dict(zip(('id', 'name'), column_values)) for column_values in [(7, 'jack'), (8, 'ed'), diff --git a/test/orm/_base.py b/test/orm/_base.py index 9e599a6f16..8d695e912b 100644 --- a/test/orm/_base.py +++ b/test/orm/_base.py @@ -2,9 +2,10 @@ import gc import inspect import sys import types -from testlib import config, sa, testing -from testlib.testing import resolve_artifact_names, adict -from testlib.compat import _function_named +import sqlalchemy as sa +from sqlalchemy.test import config, testing +from sqlalchemy.test.testing import resolve_artifact_names, adict +from sqlalchemy.util import function_named _repr_stack = set() @@ -95,7 +96,8 @@ class ComparableEntity(BasicEntity): class ORMTest(testing.TestBase, testing.AssertsExecutionResults): __requires__ = ('subqueries',) - def tearDownAll(self): + @classmethod + def teardown_class(cls): sa.orm.session.Session.close_all() sa.orm.clear_mappers() # TODO: ensure mapper registry is empty @@ -124,18 +126,18 @@ class MappedTest(ORMTest): classes = None other_artifacts = None - def setUpAll(self): - if self.run_setup_classes == 'each': - assert self.run_setup_mappers != 'once' + @classmethod + def setup_class(cls): + if cls.run_setup_classes == 'each': + assert cls.run_setup_mappers != 'once' - assert self.run_deletes in (None, 'each') - if self.run_inserts == 'once': - assert self.run_deletes is None + assert cls.run_deletes in (None, 'each') + if cls.run_inserts == 'once': + assert cls.run_deletes is None - assert not hasattr(self, 'keep_mappers') - assert not hasattr(self, 'keep_data') + assert not hasattr(cls, 'keep_mappers') + assert not hasattr(cls, 'keep_data') - cls = self.__class__ if cls.tables is None: cls.tables = adict() if cls.classes is None: @@ -143,35 +145,32 @@ class MappedTest(ORMTest): if cls.other_artifacts is None: cls.other_artifacts = adict() - if self.metadata is None: - setattr(type(self), 'metadata', sa.MetaData()) + if cls.metadata is None: + setattr(cls, 'metadata', sa.MetaData()) - if self.metadata.bind is None: - self.metadata.bind = getattr(self, 'engine', config.db) + if cls.metadata.bind is None: + cls.metadata.bind = getattr(cls, 'engine', config.db) - if self.run_define_tables: - self.define_tables(self.metadata) - self.metadata.create_all() - self.tables.update(self.metadata.tables) + if cls.run_define_tables == 'once': + cls.define_tables(cls.metadata) + cls.metadata.create_all() + cls.tables.update(cls.metadata.tables) - if self.run_setup_classes: + if cls.run_setup_classes == 'once': baseline = subclasses(BasicEntity) - self.setup_classes() - self._register_new_class_artifacts(baseline) + cls.setup_classes() + cls._register_new_class_artifacts(baseline) - if self.run_setup_mappers: + if cls.run_setup_mappers == 'once': baseline = subclasses(BasicEntity) - self.setup_mappers() - self._register_new_class_artifacts(baseline) + cls.setup_mappers() + cls._register_new_class_artifacts(baseline) - if self.run_inserts: - self._load_fixtures() - self.insert_data() - - def setUp(self): - if self._sa_first_test: - return + if cls.run_inserts == 'once': + cls._load_fixtures() + cls.insert_data() + def setup(self): if self.run_define_tables == 'each': self.tables.clear() self.metadata.drop_all() @@ -195,7 +194,7 @@ class MappedTest(ORMTest): self._load_fixtures() self.insert_data() - def tearDown(self): + def teardown(self): sa.orm.session.Session.close_all() # some tests create mappers in the test bodies @@ -213,26 +212,32 @@ class MappedTest(ORMTest): print >> sys.stderr, "Error emptying table %s: %r" % ( table, ex) - def tearDownAll(self): - for cls in self.classes.values(): - self.unregister_class(cls) - ORMTest.tearDownAll(self) - self.metadata.drop_all() - self.metadata.bind = None + @classmethod + def teardown_class(cls): + for cl in cls.classes.values(): + cls.unregister_class(cl) + ORMTest.teardown_class() + cls.metadata.drop_all() + cls.metadata.bind = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): raise NotImplementedError() - def setup_classes(self): + @classmethod + def setup_classes(cls): pass - def setup_mappers(self): + @classmethod + def setup_mappers(cls): pass - def fixtures(self): + @classmethod + def fixtures(cls): return {} - def insert_data(self): + @classmethod + def insert_data(cls): pass def sql_count_(self, count, fn): @@ -260,15 +265,16 @@ class MappedTest(ORMTest): if name[0].isupper: delattr(cls, name) del cls.classes[name] - - def _load_fixtures(self): + + @classmethod + def _load_fixtures(cls): headers, rows = {}, {} - for table, data in self.fixtures().iteritems(): + for table, data in cls.fixtures().iteritems(): if isinstance(table, basestring): - table = self.tables[table] + table = cls.tables[table] headers[table] = data[0] rows[table] = data[1:] - for table in self.metadata.sorted_tables: + for table in cls.metadata.sorted_tables: if table not in headers: continue table.bind.execute( diff --git a/test/orm/_fixtures.py b/test/orm/_fixtures.py index f036b92b2a..14709ec433 100644 --- a/test/orm/_fixtures.py +++ b/test/orm/_fixtures.py @@ -1,7 +1,9 @@ -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import attributes -from testlib.testing import fixture -from orm import _base +from sqlalchemy import MetaData, Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import attributes +from sqlalchemy.test.testing import fixture +from test.orm import _base __all__ = () @@ -227,34 +229,21 @@ class FixtureTest(_base.MappedTest): Address=Address, Dingaling=Dingaling) - def setUpAll(self): - assert not hasattr(self, 'refresh_data') - assert not hasattr(self, 'only_tables') - #refresh_data = False - #only_tables = False - - #if type(self) is not FixtureTest: - # setattr(type(self), 'classes', _base.adict(self.classes)) - - #if self.run_setup_classes: - # for cls in self.classes.values(): - # self.register_class(cls) - super(FixtureTest, self).setUpAll() - - #if not self.only_tables and self.keep_data: - # _registry.load() - - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): pass - def setup_classes(self): - for cls in self.fixture_classes.values(): - self.register_class(cls) + @classmethod + def setup_classes(cls): + for cl in cls.fixture_classes.values(): + cls.register_class(cl) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): pass - def insert_data(self): + @classmethod + def insert_data(cls): _load_fixtures() diff --git a/test/orm/alltests.py b/test/orm/alltests.py deleted file mode 100644 index 9458ca5236..0000000000 --- a/test/orm/alltests.py +++ /dev/null @@ -1,60 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -import inheritance.alltests as inheritance -import sharding.alltests as sharding - -def suite(): - modules_to_test = ( - 'orm.attributes', - 'orm.bind', - 'orm.extendedattr', - 'orm.instrumentation', - 'orm.query', - 'orm.lazy_relations', - 'orm.eager_relations', - 'orm.mapper', - 'orm.expire', - 'orm.selectable', - 'orm.collection', - 'orm.generative', - 'orm.lazytest1', - 'orm.assorted_eager', - - 'orm.naturalpks', - 'orm.defaults', - 'orm.unitofwork', - 'orm.session', - 'orm.transaction', - 'orm.scoping', - 'orm.cascade', - 'orm.relationships', - 'orm.association', - 'orm.merge', - 'orm.pickled', - 'orm.utils', - - 'orm.cycles', - - 'orm.compile', - 'orm.manytomany', - 'orm.onetoone', - 'orm.dynamic', - - 'orm.evaluator', - - 'orm.deprecations', - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - alltests.addTest(inheritance.suite()) - alltests.addTest(sharding.suite()) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/orm/inheritance/alltests.py b/test/orm/inheritance/alltests.py deleted file mode 100644 index 41f0521dd6..0000000000 --- a/test/orm/inheritance/alltests.py +++ /dev/null @@ -1,31 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -def suite(): - modules_to_test = ( - 'orm.inheritance.basic', - 'orm.inheritance.query', - 'orm.inheritance.manytomany', - 'orm.inheritance.single', - 'orm.inheritance.concrete', - 'orm.inheritance.polymorph', - 'orm.inheritance.polymorph2', - 'orm.inheritance.poly_linked_list', - 'orm.inheritance.abc_polymorphic', - 'orm.inheritance.abc_inheritance', - 'orm.inheritance.productspec', - 'orm.inheritance.magazine', - 'orm.inheritance.selects', - - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/orm/inheritance/abc_inheritance.py b/test/orm/inheritance/test_abc_inheritance.py similarity index 95% rename from test/orm/inheritance/abc_inheritance.py rename to test/orm/inheritance/test_abc_inheritance.py index ee324e3811..4e55cf70ea 100644 --- a/test/orm/inheritance/abc_inheritance.py +++ b/test/orm/inheritance/test_abc_inheritance.py @@ -1,10 +1,9 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE -from testlib import testing -from orm import _base +from sqlalchemy.test import testing +from test.orm import _base def produce_test(parent, child, direction): @@ -12,9 +11,10 @@ def produce_test(parent, child, direction): relationship between two of the classes, using either one-to-many or many-to-one.""" class ABCTest(_base.MappedTest): - def define_tables(self, meta): + @classmethod + def define_tables(cls, metadata): global ta, tb, tc - ta = ["a", meta] + ta = ["a", metadata] ta.append(Column('id', Integer, primary_key=True)), ta.append(Column('a_data', String(30))) if "a"== parent and direction == MANYTOONE: @@ -23,7 +23,7 @@ def produce_test(parent, child, direction): ta.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) ta = Table(*ta) - tb = ["b", meta] + tb = ["b", metadata] tb.append(Column('id', Integer, ForeignKey("a.id"), primary_key=True, )) tb.append(Column('b_data', String(30))) @@ -34,7 +34,7 @@ def produce_test(parent, child, direction): tb.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) tb = Table(*tb) - tc = ["c", meta] + tc = ["c", metadata] tc.append(Column('id', Integer, ForeignKey("b.id"), primary_key=True, )) tc.append(Column('c_data', String(30))) @@ -45,14 +45,14 @@ def produce_test(parent, child, direction): tc.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo"))) tc = Table(*tc) - def tearDown(self): + def teardown(self): if direction == MANYTOONE: parent_table = {"a":ta, "b":tb, "c": tc}[parent] parent_table.update(values={parent_table.c.child_id:None}).execute() elif direction == ONETOMANY: child_table = {"a":ta, "b":tb, "c": tc}[child] child_table.update(values={child_table.c.parent_id:None}).execute() - super(ABCTest, self).tearDown() + super(ABCTest, self).teardown() def test_roundtrip(self): parent_table = {"a":ta, "b":tb, "c": tc}[parent] @@ -167,5 +167,4 @@ for parent in ["a", "b", "c"]: exec("%s = testclass" % testclass.__name__) del testclass -if __name__ == "__main__": - testenv.main() +del produce_test \ No newline at end of file diff --git a/test/orm/inheritance/abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py similarity index 92% rename from test/orm/inheritance/abc_polymorphic.py rename to test/orm/inheritance/test_abc_polymorphic.py index 6fabbb24c2..8cad8ed781 100644 --- a/test/orm/inheritance/abc_polymorphic.py +++ b/test/orm/inheritance/test_abc_polymorphic.py @@ -1,13 +1,13 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy import util from sqlalchemy.orm import * -from testlib import _function_named -from orm import _base, _fixtures +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures class ABCTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global a, b, c a = Table('a', metadata, Column('id', Integer, primary_key=True), @@ -78,7 +78,7 @@ class ABCTest(_base.MappedTest): C(cdata='c2', bdata='c2', adata='c2'), ] == sess.query(C).all() - test_roundtrip = _function_named( + test_roundtrip = function_named( test_roundtrip, 'test_%s' % fetchtype) return test_roundtrip @@ -86,5 +86,3 @@ class ABCTest(_base.MappedTest): test_none = make_test('none') -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/test_basic.py similarity index 95% rename from test/orm/inheritance/basic.py rename to test/orm/inheritance/test_basic.py index 150874477b..fc4aae17d5 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/test_basic.py @@ -1,17 +1,17 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy import exc as sa_exc, util from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc -#from testlib import * -#from testlib import fixtures -from testlib import _function_named, testing, engines -from orm import _base, _fixtures +from sqlalchemy.test import testing, engines +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures class O2MTest(_base.MappedTest): """deals with inheritance and one-to-many relationships""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, blub foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq', optional=True), @@ -69,7 +69,8 @@ class O2MTest(_base.MappedTest): self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') class FalseDiscriminatorTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('type', Integer, nullable=False)) @@ -87,7 +88,8 @@ class FalseDiscriminatorTest(_base.MappedTest): assert isinstance(sess.query(Foo).one(), Bar) class PolymorphicSynonymTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1, t2 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), @@ -118,8 +120,8 @@ class PolymorphicSynonymTest(_base.MappedTest): sess.add(at2) sess.flush() sess.expunge_all() - self.assertEquals(sess.query(T2).filter(T2.info=='at2').one(), at2) - self.assertEquals(at2.info, "THE INFO IS:at2") + eq_(sess.query(T2).filter(T2.info=='at2').one(), at2) + eq_(at2.info, "THE INFO IS:at2") class CascadeTest(_base.MappedTest): @@ -127,7 +129,8 @@ class CascadeTest(_base.MappedTest): cascading along the path of the instance's mapper, not the base mapper.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1, t2, t3, t4 t1= Table('t1', metadata, Column('id', Integer, primary_key=True), @@ -191,7 +194,8 @@ class CascadeTest(_base.MappedTest): sess.flush() class GetTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, blub foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq', optional=True), @@ -209,7 +213,7 @@ class GetTest(_base.MappedTest): Column('bar_id', Integer, ForeignKey('bar.id')), Column('data', String(20))) - def create_test(polymorphic, name): + def _create_test(polymorphic, name): def test_get(self): class Foo(object): pass @@ -271,16 +275,17 @@ class GetTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 3) - test_get = _function_named(test_get, name) + test_get = function_named(test_get, name) return test_get - test_get_polymorphic = create_test(True, 'test_get_polymorphic') - test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic') + test_get_polymorphic = _create_test(True, 'test_get_polymorphic') + test_get_nonpolymorphic = _create_test(False, 'test_get_nonpolymorphic') class EagerLazyTest(_base.MappedTest): """tests eager load/lazy load of child items off inheritance mappers, tests that LazyLoader constructs the right query condition.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, bar_foo foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq', optional=True), @@ -325,7 +330,8 @@ class EagerLazyTest(_base.MappedTest): class FlushTest(_base.MappedTest): """test dependency sorting among inheriting mappers""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global users, roles, user_roles, admins users = Table('users', metadata, Column('id', Integer, primary_key=True), @@ -413,7 +419,8 @@ class FlushTest(_base.MappedTest): assert user_roles.count().scalar() == 1 class VersioningTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global base, subtable, stuff base = Table('base', metadata, Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ), @@ -522,7 +529,8 @@ class DistinctPKTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global person_table, employee_table, Person, Employee person_table = Table("persons", metadata, @@ -542,7 +550,8 @@ class DistinctPKTest(_base.MappedTest): class Employee(Person): pass - def insert_data(self): + @classmethod + def insert_data(cls): person_insert = person_table.insert() person_insert.execute(id=1, name='alice') person_insert.execute(id=2, name='bob') @@ -593,7 +602,8 @@ class DistinctPKTest(_base.MappedTest): class SyncCompileTest(_base.MappedTest): """test that syncrules compile properly on custom inherit conds""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global _a_table, _b_table, _c_table _a_table = Table('a', metadata, @@ -660,7 +670,8 @@ class SyncCompileTest(_base.MappedTest): class OverrideColKeyTest(_base.MappedTest): """test overriding of column attributes.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global base, subtable base = Table('base', metadata, @@ -686,7 +697,7 @@ class OverrideColKeyTest(_base.MappedTest): # Sub gets a "base_id" property using the "base_id" # column of both tables. - self.assertEquals( + eq_( class_mapper(Sub).get_property('base_id').columns, [base.c.base_id, subtable.c.base_id] ) @@ -710,7 +721,7 @@ class OverrideColKeyTest(_base.MappedTest): 'id':[base.c.base_id, subtable.c.base_id] }) - self.assertEquals( + eq_( class_mapper(Sub).get_property('id').columns, [base.c.base_id, subtable.c.base_id] ) @@ -733,12 +744,12 @@ class OverrideColKeyTest(_base.MappedTest): }) mapper(Sub, subtable, inherits=Base) - self.assertEquals( + eq_( class_mapper(Sub).get_property('id').columns, [base.c.base_id] ) - self.assertEquals( + eq_( class_mapper(Sub).get_property('base_id').columns, [subtable.c.base_id] ) @@ -782,7 +793,7 @@ class OverrideColKeyTest(_base.MappedTest): # it has its own "id" property. Sub's "id" property # gets joined normally with the extra column. - self.assertEquals( + eq_( class_mapper(Sub).get_property('id').columns, [base.c.base_id, subtable.c.base_id] ) @@ -892,7 +903,8 @@ class OptimizedLoadTest(_base.MappedTest): a column in the join condition is not available. """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global base, sub base = Table('base', metadata, Column('id', Integer, primary_key=True), @@ -931,7 +943,8 @@ class OptimizedLoadTest(_base.MappedTest): assert s1.sub == 's1sub' class PKDiscriminatorTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): parents = Table('parents', metadata, Column('id', Integer, primary_key=True), Column('name', String(60))) @@ -974,7 +987,8 @@ class PKDiscriminatorTest(_base.MappedTest): class DeleteOrphanTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global single, parent single = Table('single', metadata, Column('id', Integer, primary_key=True), @@ -1007,9 +1021,7 @@ class DeleteOrphanTest(_base.MappedTest): sess = create_session() s1 = SubClass(data='s1') sess.add(s1) - self.assertRaisesMessage(orm_exc.FlushError, + assert_raises_message(orm_exc.FlushError, "is not attached to any parent 'Parent' instance via that classes' 'related' attribute", sess.flush) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/concrete.py b/test/orm/inheritance/test_concrete.py similarity index 95% rename from test/orm/inheritance/concrete.py rename to test/orm/inheritance/test_concrete.py index 6cdaed7e69..4a884cb86c 100644 --- a/test/orm/inheritance/concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -1,12 +1,13 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc -from testlib import * -from testlib import sa, testing -from orm import _base +from sqlalchemy.test import * +import sqlalchemy as sa +from sqlalchemy.test import testing +from test.orm import _base from sqlalchemy.orm import attributes -from testlib.testing import eq_ +from sqlalchemy.test.testing import eq_ class Employee(object): def __init__(self, name): @@ -42,7 +43,8 @@ class Company(object): class ConcreteTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global managers_table, engineers_table, hackers_table, companies, employees_table companies = Table('companies', metadata, @@ -103,7 +105,7 @@ class ConcreteTest(_base.MappedTest): manager = session.query(Manager).one() session.expire(manager, ['manager_data']) - self.assertEquals(manager.manager_data, "knows how to manage things") + eq_(manager.manager_data, "knows how to manage things") def test_multi_level_no_base(self): pjoin = polymorphic_union({ @@ -144,8 +146,8 @@ class ConcreteTest(_base.MappedTest): assert 'name' not in attributes.instance_state(hacker).expired_attributes assert 'nickname' not in attributes.instance_state(hacker).expired_attributes def go(): - self.assertEquals(jerry.name, "Jerry") - self.assertEquals(hacker.nickname, "Badass") + eq_(jerry.name, "Jerry") + eq_(hacker.nickname, "Badass") self.assert_sql_count(testing.db, go, 0) session.expunge_all() @@ -194,8 +196,8 @@ class ConcreteTest(_base.MappedTest): session.flush() def go(): - self.assertEquals(jerry.name, "Jerry") - self.assertEquals(hacker.nickname, "Badass") + eq_(jerry.name, "Jerry") + eq_(hacker.nickname, "Badass") self.assert_sql_count(testing.db, go, 0) session.expunge_all() @@ -315,7 +317,8 @@ class ConcreteTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 1) class PropertyInheritanceTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('a_table', metadata, Column('id', Integer, primary_key=True), Column('some_c_id', Integer, ForeignKey('c_table.id')), @@ -332,7 +335,8 @@ class PropertyInheritanceTest(_base.MappedTest): ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_base.ComparableEntity): pass @@ -352,7 +356,7 @@ class PropertyInheritanceTest(_base.MappedTest): b = B() c = C() - self.assertRaises(AttributeError, setattr, b, 'some_c', c) + assert_raises(AttributeError, setattr, b, 'some_c', c) clear_mappers() mapper(A, a_table, properties={ @@ -361,7 +365,7 @@ class PropertyInheritanceTest(_base.MappedTest): mapper(B, b_table,inherits=A, concrete=True) mapper(C, c_table) b = B() - self.assertRaises(AttributeError, setattr, b, 'a_id', 3) + assert_raises(AttributeError, setattr, b, 'a_id', 3) clear_mappers() mapper(A, a_table, properties={ @@ -392,8 +396,8 @@ class PropertyInheritanceTest(_base.MappedTest): b1 = B(some_c=c1, bname='b1') b2 = B(some_c=c1, bname='b2') - self.assertRaises(AttributeError, setattr, b1, 'aname', 'foo') - self.assertRaises(AttributeError, getattr, A, 'bname') + assert_raises(AttributeError, setattr, b1, 'aname', 'foo') + assert_raises(AttributeError, getattr, A, 'bname') assert c2.many_a == [a2] assert c1.many_a == [a1] @@ -463,7 +467,8 @@ class PropertyInheritanceTest(_base.MappedTest): class ColKeysTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global offices_table, refugees_table refugees_table = Table('refugee', metadata, Column('refugee_fid', Integer, primary_key=True), @@ -473,7 +478,8 @@ class ColKeysTest(_base.MappedTest): Column('office_fid', Integer, primary_key=True), Column('office_name', Unicode(30), key='name')) - def insert_data(self): + @classmethod + def insert_data(cls): refugees_table.insert().execute( dict(refugee_fid=1, name=u"refugee1"), dict(refugee_fid=2, name=u"refugee2") @@ -511,5 +517,3 @@ class ColKeysTest(_base.MappedTest): eq_(sess.query(Office).get(1).name, "office1") eq_(sess.query(Office).get(2).name, "office2") -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/magazine.py b/test/orm/inheritance/test_magazine.py similarity index 97% rename from test/orm/inheritance/magazine.py rename to test/orm/inheritance/test_magazine.py index 34374c887e..0673012511 100644 --- a/test/orm/inheritance/magazine.py +++ b/test/orm/inheritance/test_magazine.py @@ -1,9 +1,9 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -from testlib import testing, _function_named -from orm import _base +from sqlalchemy.test import testing +from sqlalchemy.util import function_named +from test.orm import _base class BaseObject(object): def __init__(self, *args, **kwargs): @@ -70,7 +70,8 @@ class ClassifiedPage(MagazinePage): class MagazineTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global publication_table, issue_table, location_table, location_name_table, magazine_table, \ page_table, magazine_page_table, classified_page_table, page_size_table @@ -208,7 +209,7 @@ def generate_round_trip_test(use_unions=False, use_joins=False): print [page, page2, page3] assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]), repr(p.issues[0].locations[0].magazine.pages) - test_roundtrip = _function_named( + test_roundtrip = function_named( test_roundtrip, "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions")) setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip) @@ -216,5 +217,3 @@ for (use_union, use_join) in [(True, False), (False, True), (False, False)]: generate_round_trip_test(use_union, use_join) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/manytomany.py b/test/orm/inheritance/test_manytomany.py similarity index 96% rename from test/orm/inheritance/manytomany.py rename to test/orm/inheritance/test_manytomany.py index 5dbf69ba56..f7e676bbbc 100644 --- a/test/orm/inheritance/manytomany.py +++ b/test/orm/inheritance/test_manytomany.py @@ -1,14 +1,15 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ from sqlalchemy import * from sqlalchemy.orm import * -from testlib import testing -from orm import _base +from sqlalchemy.test import testing +from test.orm import _base class InheritTest(_base.MappedTest): """deals with inheritance and many-to-many relationships""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global principals global users global groups @@ -67,7 +68,8 @@ class InheritTest(_base.MappedTest): class InheritTest2(_base.MappedTest): """deals with inheritance and many-to-many relationships""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, foo_bar foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_id_seq', optional=True), @@ -140,7 +142,8 @@ class InheritTest2(_base.MappedTest): class InheritTest3(_base.MappedTest): """deals with inheritance and many-to-many relationships""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, blub, bar_foo, blub_bar, blub_foo # the 'data' columns are to appease SQLite which cant handle a blank INSERT @@ -196,7 +199,7 @@ class InheritTest3(_base.MappedTest): l = sess.query(Bar).all() print repr(l[0]) + repr(l[0].foos) found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos])) - self.assertEqual(found, compare) + eq_(found, compare) @testing.fails_on('maxdb', 'FIXME: unknown') def testadvanced(self): @@ -244,5 +247,3 @@ class InheritTest3(_base.MappedTest): self.assert_(repr(x) == compare) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/poly_linked_list.py b/test/orm/inheritance/test_poly_linked_list.py similarity index 93% rename from test/orm/inheritance/poly_linked_list.py rename to test/orm/inheritance/test_poly_linked_list.py index 2cf0519494..67b543f31c 100644 --- a/test/orm/inheritance/poly_linked_list.py +++ b/test/orm/inheritance/test_poly_linked_list.py @@ -1,15 +1,15 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -from orm import _base -from testlib import testing +from test.orm import _base +from sqlalchemy.test import testing class PolymorphicCircularTest(_base.MappedTest): run_setup_mappers = 'once' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global Table1, Table1B, Table2, Table3, Data table1 = Table('table1', metadata, Column('id', Integer, primary_key=True), @@ -115,26 +115,26 @@ class PolymorphicCircularTest(_base.MappedTest): @testing.fails_on('maxdb', 'FIXME: unknown') def testone(self): - self.do_testlist([Table1, Table2, Table1, Table2]) + self._testlist([Table1, Table2, Table1, Table2]) @testing.fails_on('maxdb', 'FIXME: unknown') def testtwo(self): - self.do_testlist([Table3]) + self._testlist([Table3]) @testing.fails_on('maxdb', 'FIXME: unknown') def testthree(self): - self.do_testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1]) + self._testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1]) @testing.fails_on('maxdb', 'FIXME: unknown') def testfour(self): - self.do_testlist([ + self._testlist([ Table2('t2', [Data('data1'), Data('data2')]), Table1('t1', []), Table3('t3', [Data('data3')]), Table1B('t1b', [Data('data4'), Data('data5')]) ]) - def do_testlist(self, classes): + def _testlist(self, classes): sess = create_session( ) # create objects in a linked list @@ -195,5 +195,3 @@ class PolymorphicCircularTest(_base.MappedTest): # everything should match ! assert original == forwards == backwards -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/test_polymorph.py similarity index 90% rename from test/orm/inheritance/polymorph.py rename to test/orm/inheritance/test_polymorph.py index 81f6c82a1e..cd3b2d89e3 100644 --- a/test/orm/inheritance/polymorph.py +++ b/test/orm/inheritance/test_polymorph.py @@ -1,11 +1,12 @@ """tests basic polymorphic mapper loading/saving, minimal relations""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc -from testlib import _function_named, Column, testing -from orm import _fixtures, _base +from sqlalchemy.test import Column, testing +from sqlalchemy.util import function_named +from test.orm import _fixtures, _base class Person(_fixtures.Base): pass @@ -19,7 +20,8 @@ class Company(_fixtures.Base): pass class PolymorphTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global companies, people, engineers, managers, boss companies = Table('companies', metadata, @@ -83,7 +85,7 @@ class InsertOrderTest(PolymorphTest): session.add(c) session.flush() session.expunge_all() - self.assertEquals(session.query(Company).get(c.company_id), c) + eq_(session.query(Company).get(c.company_id), c) class RelationToSubclassTest(PolymorphTest): def test_basic(self): @@ -115,13 +117,13 @@ class RelationToSubclassTest(PolymorphTest): sess.flush() sess.expunge_all() - self.assertEquals(sess.query(Company).filter_by(company_id=c.company_id).one(), c) + eq_(sess.query(Company).filter_by(company_id=c.company_id).one(), c) assert c.managers[0].company is c class RoundTripTest(PolymorphTest): pass -def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic): +def _generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic): """generates a round trip test. include_base - whether or not to include the base 'person' type in the union. @@ -205,15 +207,15 @@ def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with session.flush() session.expunge_all() - self.assertEquals(session.query(Person).get(dilbert.person_id), dilbert) + eq_(session.query(Person).get(dilbert.person_id), dilbert) session.expunge_all() - self.assertEquals(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert) + eq_(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert) session.expunge_all() def go(): cc = session.query(Company).get(c.company_id) - self.assertEquals(cc.employees, employees) + eq_(cc.employees, employees) if not lazy_relation: if with_polymorphic != 'none': @@ -229,14 +231,14 @@ def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with # test selecting from the query, using the base mapped table (people) as the selection criterion. # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" - self.assertEquals( + eq_( session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first(), dilbert ) assert session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first().person_id - self.assertEquals( + eq_( session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first(), dilbert ) @@ -268,22 +270,22 @@ def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with # test standalone orphans daboss = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'}) session.add(daboss) - self.assertRaises(orm_exc.FlushError, session.flush) + assert_raises(orm_exc.FlushError, session.flush) c = session.query(Company).first() daboss.company = c manager_list = [e for e in c.employees if isinstance(e, Manager)] session.flush() session.expunge_all() - self.assertEquals(session.query(Manager).order_by(Manager.person_id).all(), manager_list) + eq_(session.query(Manager).order_by(Manager.person_id).all(), manager_list) c = session.query(Company).first() session.delete(c) session.flush() - self.assertEquals(people.count().scalar(), 0) + eq_(people.count().scalar(), 0) - test_roundtrip = _function_named( + test_roundtrip = function_named( test_roundtrip, "test_%s%s%s_%s" % ( (lazy_relation and "lazy" or "eager"), (include_base and "_inclbase" or ""), @@ -296,9 +298,7 @@ for lazy_relation in [True, False]: for with_polymorphic in ['unions', 'joins', 'auto', 'none']: if with_polymorphic == 'unions': for include_base in [True, False]: - generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic) + _generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic) else: - generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic) + _generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/polymorph2.py b/test/orm/inheritance/test_polymorph2.py similarity index 96% rename from test/orm/inheritance/polymorph2.py rename to test/orm/inheritance/test_polymorph2.py index aec162b75c..51b6d4970a 100644 --- a/test/orm/inheritance/polymorph2.py +++ b/test/orm/inheritance/test_polymorph2.py @@ -2,14 +2,15 @@ inheritance setups for which we maintain compatibility. """ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ from sqlalchemy import * from sqlalchemy import util from sqlalchemy.orm import * -from testlib import _function_named, TestBase, AssertsExecutionResults, testing -from orm import _base, _fixtures -from testlib.testing import eq_ +from sqlalchemy.test import TestBase, AssertsExecutionResults, testing +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures +from sqlalchemy.test.testing import eq_ class AttrSettable(object): def __init__(self, **kwargs): @@ -20,7 +21,8 @@ class AttrSettable(object): class RelationTest1(_base.MappedTest): """test self-referential relationships on polymorphic mappers""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, managers people = Table('people', metadata, @@ -35,11 +37,11 @@ class RelationTest1(_base.MappedTest): Column('manager_name', String(50)) ) - def tearDown(self): + def teardown(self): people.update(values={people.c.manager_id:None}).execute() - super(RelationTest1, self).tearDown() + super(RelationTest1, self).teardown() - def testparentrefsdescendant(self): + def test_parent_refs_descendant(self): class Person(AttrSettable): pass class Manager(Person): @@ -55,7 +57,7 @@ class RelationTest1(_base.MappedTest): mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id) - self.assertEquals(class_mapper(Person).get_property('manager').synchronize_pairs, [(managers.c.person_id,people.c.manager_id)]) + eq_(class_mapper(Person).get_property('manager').synchronize_pairs, [(managers.c.person_id,people.c.manager_id)]) session = create_session() p = Person(name='some person') @@ -70,7 +72,7 @@ class RelationTest1(_base.MappedTest): print p, m, p.manager assert p.manager is m - def testdescendantrefsparent(self): + def test_descendant_refs_parent(self): class Person(AttrSettable): pass class Manager(Person): @@ -99,7 +101,8 @@ class RelationTest1(_base.MappedTest): class RelationTest2(_base.MappedTest): """test self-referential relationships on polymorphic mappers""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), @@ -194,7 +197,8 @@ class RelationTest2(_base.MappedTest): class RelationTest3(_base.MappedTest): """test self-referential relationships on polymorphic mappers""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), @@ -212,7 +216,7 @@ class RelationTest3(_base.MappedTest): Column('data', String(30)) ) -def generate_test(jointype="join1", usedata=False): +def _generate_test(jointype="join1", usedata=False): def do_test(self): class Person(AttrSettable): pass @@ -287,19 +291,20 @@ def generate_test(jointype="join1", usedata=False): assert p.data.data == 'ps data' assert m.data.data == 'ms data' - do_test = _function_named( - do_test, 'test_relationonbaseclass_%s_%s' % ( + do_test = function_named( + do_test, 'test_relation_on_base_class_%s_%s' % ( jointype, data and "nodata" or "data")) return do_test for jointype in ["join1", "join2", "join3", "join4"]: for data in (True, False): - func = generate_test(jointype, data) + func = _generate_test(jointype, data) setattr(RelationTest3, func.__name__, func) - +del func class RelationTest4(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, managers, cars people = Table('people', metadata, Column('person_id', Integer, primary_key=True), @@ -411,7 +416,8 @@ class RelationTest4(_base.MappedTest): assert c.car_id==car1.car_id class RelationTest5(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, managers, cars people = Table('people', metadata, Column('person_id', Integer, primary_key=True), @@ -472,7 +478,8 @@ class RelationTest5(_base.MappedTest): class RelationTest6(_base.MappedTest): """test self-referential relationships on a single joined-table inheritance mapper""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, managers, data people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), @@ -514,7 +521,8 @@ class RelationTest6(_base.MappedTest): assert m.colleague is m2 class RelationTest7(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, managers, cars, offroad_cars cars = Table('cars', metadata, Column('car_id', Integer, primary_key=True), @@ -613,7 +621,8 @@ class RelationTest7(_base.MappedTest): assert p.car_id == p.car.car_id class RelationTest8(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global taggable, users taggable = Table('taggable', metadata, Column('id', Integer, primary_key=True), @@ -658,7 +667,8 @@ class RelationTest8(_base.MappedTest): ) class GenerativeTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): # cars---owned by--- people (abstract) --- has a --- status # | ^ ^ | # | | | | @@ -693,9 +703,10 @@ class GenerativeTest(TestBase, AssertsExecutionResults): metadata.create_all() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() - def tearDown(self): + def teardown(self): clear_mappers() for t in reversed(metadata.sorted_tables): t.delete().execute() @@ -784,7 +795,8 @@ class GenerativeTest(TestBase, AssertsExecutionResults): assert str(list(r)) == "[Engineer E4, field X, status Status dead]" class MultiLevelTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global table_Employee, table_Engineer, table_Manager table_Employee = Table( 'Employee', metadata, Column( 'name', type_= String(100), ), @@ -861,7 +873,8 @@ class MultiLevelTest(_base.MappedTest): assert session.query( Manager).all() == [c] class ManyToManyPolyTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global base_item_table, item_table, base_item_collection_table, collection_table base_item_table = Table( 'base_item', metadata, @@ -911,7 +924,8 @@ class ManyToManyPolyTest(_base.MappedTest): class_mapper(BaseItem) class CustomPKTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1, t2 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), @@ -994,7 +1008,8 @@ class CustomPKTest(_base.MappedTest): sess.flush() class InheritingEagerTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, employees, tags, peopleTags people = Table('people', metadata, @@ -1055,7 +1070,8 @@ class InheritingEagerTest(_base.MappedTest): assert len(instance.tags) == 2 class MissingPolymorphicOnTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global tablea, tableb, tablec, tabled tablea = Table('tablea', metadata, Column('id', Integer, primary_key=True), @@ -1101,7 +1117,5 @@ class MissingPolymorphicOnTest(_base.MappedTest): sess.add(d) sess.flush() sess.expunge_all() - self.assertEquals(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')]) + eq_(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')]) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/productspec.py b/test/orm/inheritance/test_productspec.py similarity index 98% rename from test/orm/inheritance/productspec.py rename to test/orm/inheritance/test_productspec.py index b6a8c51468..b2bcb85d54 100644 --- a/test/orm/inheritance/productspec.py +++ b/test/orm/inheritance/test_productspec.py @@ -1,16 +1,16 @@ -import testenv; testenv.configure_for_tests() from datetime import datetime from sqlalchemy import * from sqlalchemy.orm import * -from testlib import testing -from orm import _base +from sqlalchemy.test import testing +from test.orm import _base class InheritTest(_base.MappedTest): """tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global products_table, specification_table, documents_table global Product, Detail, Assembly, SpecLine, Document, RasterDocument @@ -316,5 +316,3 @@ class InheritTest(_base.MappedTest): print new assert orig == new == ' specification=[>] documents=[, ]' -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/test_query.py similarity index 79% rename from test/orm/inheritance/query.py rename to test/orm/inheritance/test_query.py index 58d2054558..5b57e8f457 100644 --- a/test/orm/inheritance/query.py +++ b/test/orm/inheritance/test_query.py @@ -1,13 +1,13 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy import exc as sa_exc from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.engine import default -from testlib import AssertsCompiledSQL, testing -from orm import _base, _fixtures -from testlib.testing import eq_ +from sqlalchemy.test import AssertsCompiledSQL, testing +from test.orm import _base, _fixtures +from sqlalchemy.test.testing import eq_ class Company(_fixtures.Base): pass @@ -27,13 +27,14 @@ class Machine(_fixtures.Base): class Paperwork(_fixtures.Base): pass -def make_test(select_type): +def _produce_test(select_type): class PolymorphicQueryTest(_base.MappedTest, AssertsCompiledSQL): run_inserts = 'once' run_setup_mappers = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global companies, people, engineers, managers, boss, paperwork, machines companies = Table('companies', metadata, @@ -128,7 +129,8 @@ def make_test(select_type): mapper(Paperwork, paperwork) - def insert_data(self): + @classmethod + def insert_data(cls): global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3, c1, c2 c1 = Company(name="MegaCorp, Inc.") @@ -178,14 +180,14 @@ def make_test(select_type): sess = create_session() def go(): - self.assertEquals(sess.query(Person).all(), all_employees) + eq_(sess.query(Person).all(), all_employees) self.assert_sql_count(testing.db, go, {'':14, 'Polymorphic':9}.get(select_type, 10)) def test_primary_eager_aliasing(self): sess = create_session() def go(): - self.assertEquals(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) + eq_(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4)) sess = create_session() @@ -194,7 +196,7 @@ def make_test(select_type): assert sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().subquery().count().scalar() == 2 def go(): - self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) + eq_(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3]) self.assert_sql_count(testing.db, go, 3) @@ -203,9 +205,9 @@ def make_test(select_type): # for all mappers, ensure the primary key has been calculated as just the "person_id" # column - self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) - self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) - self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore")) + eq_(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) + eq_(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java")) + eq_(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore")) def test_multi_join(self): sess = create_session() @@ -216,8 +218,8 @@ def make_test(select_type): q = sess.query(Company, Person, c, e).join((Person, Company.employees)).join((e, c.employees)).\ filter(Person.name=='dilbert').filter(e.name=='wally') - self.assertEquals(q.count(), 1) - self.assertEquals(q.all(), [ + eq_(q.count(), 1) + eq_(q.all(), [ ( Company(company_id=1,name=u'MegaCorp, Inc.'), Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), @@ -228,99 +230,99 @@ def make_test(select_type): def test_filter_on_subclass(self): sess = create_session() - self.assertEquals(sess.query(Engineer).all()[0], Engineer(name="dilbert")) + eq_(sess.query(Engineer).all()[0], Engineer(name="dilbert")) - self.assertEquals(sess.query(Engineer).first(), Engineer(name="dilbert")) + eq_(sess.query(Engineer).first(), Engineer(name="dilbert")) - self.assertEquals(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert")) + eq_(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert")) - self.assertEquals(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert")) + eq_(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert")) - self.assertEquals(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) + eq_(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) - self.assertEquals(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) + eq_(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss")) def test_join_from_polymorphic(self): sess = create_session() for aliased in (True, False): - self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) - self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) - self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1]) + eq_(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1]) - self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) def test_join_from_with_polymorphic(self): sess = create_session() for aliased in (True, False): sess.expunge_all() - self.assertEquals(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + eq_(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) sess.expunge_all() - self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + eq_(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) sess.expunge_all() - self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + eq_(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) def test_join_to_polymorphic(self): sess = create_session() - self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2) + eq_(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2) - self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2) + eq_(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2) def test_polymorphic_any(self): sess = create_session() - self.assertEquals( + eq_( sess.query(Company).\ filter(Company.employees.any(Person.name=='vlad')).all(), [c2] ) # test that the aliasing on "Person" does not bleed into the # EXISTS clause generated by any() - self.assertEquals( + eq_( sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\ filter(Company.employees.any(Person.name=='wally')).all(), [c1] ) - self.assertEquals( + eq_( sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\ filter(Company.employees.any(Person.name=='vlad')).all(), [] ) - self.assertEquals( + eq_( sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(), c2 ) calias = aliased(Company) - self.assertEquals( + eq_( sess.query(calias).filter(calias.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(), c2 ) - self.assertEquals( + eq_( sess.query(Company).filter(Company.employees.of_type(Boss).any(Boss.golf_swing=='fore')).one(), c1 ) - self.assertEquals( + eq_( sess.query(Company).filter(Company.employees.of_type(Boss).any(Manager.manager_name=='pointy')).one(), c1 ) if select_type != '': - self.assertEquals( + eq_( sess.query(Person).filter(Engineer.machines.any(Machine.name=="Commodore 64")).all(), [e2, e3] ) - self.assertEquals( + eq_( sess.query(Person).filter(Person.paperwork.any(Paperwork.description=="review #2")).all(), [m1] ) - self.assertEquals( + eq_( sess.query(Company).filter(Company.employees.of_type(Engineer).any(and_(Engineer.primary_language=='cobol'))).one(), c2 ) @@ -328,48 +330,48 @@ def make_test(select_type): def test_join_from_columns_or_subclass(self): sess = create_session() - self.assertEquals( + eq_( sess.query(Manager.name).order_by(Manager.name).all(), [(u'dogbert',), (u'pointy haired boss',)] ) - self.assertEquals( + eq_( sess.query(Manager.name).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(), [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)] ) - self.assertEquals( + eq_( sess.query(Person.name).join((Paperwork, Person.paperwork)).order_by(Person.name).all(), [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)] ) - self.assertEquals( + eq_( sess.query(Person.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Person.name).all(), [(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)] ) - self.assertEquals( + eq_( sess.query(Manager).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(), [m1, b1] ) - self.assertEquals( + eq_( sess.query(Manager.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(), [(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)] ) - self.assertEquals( + eq_( sess.query(Manager.person_id).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(), [(4,), (4,), (3,)] ) - self.assertEquals( + eq_( sess.query(Manager.name, Paperwork.description).join((Paperwork, Manager.person_id==Paperwork.person_id)).all(), [(u'pointy haired boss', u'review #1'), (u'dogbert', u'review #2'), (u'dogbert', u'review #3')] ) malias = aliased(Manager) - self.assertEquals( + eq_( sess.query(malias.name).join((paperwork, malias.person_id==paperwork.c.person_id)).all(), [(u'pointy haired boss',), (u'dogbert',), (u'dogbert',)] ) @@ -391,9 +393,9 @@ def make_test(select_type): sess = create_session() - self.assertRaises(sa_exc.InvalidRequestError, sess.query(Person).with_polymorphic, Paperwork) - self.assertRaises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Boss) - self.assertRaises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Person) + assert_raises(sa_exc.InvalidRequestError, sess.query(Person).with_polymorphic, Paperwork) + assert_raises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Boss) + assert_raises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Person) # compare to entities without related collections to prevent additional lazy SQL from firing on # loaded entities @@ -404,32 +406,32 @@ def make_test(select_type): Manager(name="dogbert", manager_name="dogbert", status="regular manager"), Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") ] - self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) + eq_(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) def go(): - self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1]) + eq_(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1]) self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) + eq_(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - self.assertEquals(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations) + eq_(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations) self.assert_sql_count(testing.db, go, 3) sess.expunge_all() def go(): - self.assertEquals(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations) + eq_(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations) self.assert_sql_count(testing.db, go, 3) sess.expunge_all() def go(): # limit the polymorphic join down to just "Person", overriding select_table - self.assertEquals(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations) + eq_(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations) self.assert_sql_count(testing.db, go, 6) def test_relation_to_polymorphic(self): @@ -449,14 +451,14 @@ def make_test(select_type): def go(): # test load Companies with lazy load to 'employees' - self.assertEquals(sess.query(Company).all(), assert_result) + eq_(sess.query(Company).all(), assert_result) self.assert_sql_count(testing.db, go, {'':9, 'Polymorphic':4}.get(select_type, 5)) sess = create_session() def go(): # currently, it doesn't matter if we say Company.employees, or Company.employees.of_type(Engineer). eagerloader doesn't # pick up on the "of_type()" as of yet. - self.assertEquals(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result) + eq_(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result) # in the case of select_type='', the eagerload doesn't take in this case; # it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines" @@ -466,78 +468,78 @@ def make_test(select_type): sess = create_session() def go(): # test load People with eagerload to engineers + machines - self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(), + eq_(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(), [Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])] ) self.assert_sql_count(testing.db, go, 1) def test_join_to_subclass(self): sess = create_session() - self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) if select_type == '': - self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1]) ealias = aliased(Engineer) - self.assertEquals(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3]) - self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) - self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2]) - self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + eq_(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3]) + eq_(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + eq_(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2]) + eq_(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) else: - self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) - self.assertEquals(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) - self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2]) - self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + eq_(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3]) + eq_(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + eq_(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2]) + eq_(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) # non-polymorphic - self.assertEquals(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3]) - self.assertEquals(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) + eq_(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3]) + eq_(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3]) # here's the new way - self.assertEquals(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1]) - self.assertEquals(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) + eq_(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1]) + eq_(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1]) def test_join_through_polymorphic(self): sess = create_session() for aliased in (True, False): - self.assertEquals( + eq_( sess.query(Company).\ join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [c1] ) - self.assertEquals( + eq_( sess.query(Company).\ join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#%')).all(), [c1, c2] ) - self.assertEquals( + eq_( sess.query(Company).\ join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#2%')).all(), [c1] ) - self.assertEquals( + eq_( sess.query(Company).\ join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#%')).all(), [c1, c2] ) - self.assertEquals( + eq_( sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [c1] ) - self.assertEquals( + eq_( sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\ join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#%')).all(), [c1, c2] @@ -549,14 +551,14 @@ def make_test(select_type): # ORMJoin using regular table foreign key connections. Engineer # is expressed as "(select * people join engineers) as anon_1" # so the join is contained. - self.assertEquals( + eq_( sess.query(Company).join(Engineer).filter(Engineer.engineer_name=='vlad').one(), c2 ) # same, using explicit join condition. Query.join() must adapt the on clause # here to match the subquery wrapped around "people join engineers". - self.assertEquals( + eq_( sess.query(Company).join((Engineer, Company.company_id==Engineer.company_id)).filter(Engineer.engineer_name=='vlad').one(), c2 ) @@ -565,17 +567,17 @@ def make_test(select_type): def test_filter_on_baseclass(self): sess = create_session() - self.assertEquals(sess.query(Person).all(), all_employees) + eq_(sess.query(Person).all(), all_employees) - self.assertEquals(sess.query(Person).first(), all_employees[0]) + eq_(sess.query(Person).first(), all_employees[0]) - self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2) + eq_(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2) def test_from_alias(self): sess = create_session() palias = aliased(Person) - self.assertEquals( + eq_( sess.query(palias).filter(palias.name.in_(['dilbert', 'wally'])).all(), [e1, e2] ) @@ -586,7 +588,7 @@ def make_test(select_type): c1_employees = [e1, e2, b1, m1] palias = aliased(Person) - self.assertEquals( + eq_( sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), [ @@ -596,7 +598,7 @@ def make_test(select_type): ] ) - self.assertEquals( + eq_( sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ filter(Person.person_id>palias.person_id).from_self().order_by(Person.person_id, palias.person_id).all(), [ @@ -614,30 +616,30 @@ def make_test(select_type): # the subquery and usually results in recursion overflow errors within the adaption. subq = sess.query(engineers.c.person_id).filter(Engineer.primary_language=='java').statement.as_scalar() - self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1) + eq_(sess.query(Person).filter(Person.person_id==subq).one(), e1) def test_mixed_entities(self): sess = create_session() - self.assertEquals( + eq_( sess.query(Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), [(u'Elbonia, Inc.', Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'))] ) - self.assertEquals( + eq_( sess.query(Person, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'), u'Elbonia, Inc.')] ) - self.assertEquals( + eq_( sess.query(Manager.name).all(), [('pointy haired boss', ), ('dogbert',)] ) - self.assertEquals( + eq_( sess.query(Manager.name + " foo").all(), [('pointy haired boss foo', ), ('dogbert foo',)] ) @@ -647,12 +649,12 @@ def make_test(select_type): assert row.primary_language == 'java' - self.assertEquals( + eq_( sess.query(Engineer.name, Engineer.primary_language).all(), [(u'dilbert', u'java'), (u'wally', u'c++'), (u'vlad', u'cobol')] ) - self.assertEquals( + eq_( sess.query(Boss.name, Boss.golf_swing).all(), [(u'pointy haired boss', u'fore')] ) @@ -670,18 +672,18 @@ def make_test(select_type): # [] # ) - self.assertEquals( + eq_( sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(), [(u'vlad',u'Elbonia, Inc.')] ) - self.assertEquals( + eq_( sess.query(Engineer.primary_language).filter(Person.type=='engineer').all(), [(u'java',), (u'c++',), (u'cobol',)] ) if select_type != '': - self.assertEquals( + eq_( sess.query(Engineer, Company.name).join(Company.employees).filter(Person.type=='engineer').all(), [ (Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'MegaCorp, Inc.'), @@ -690,20 +692,20 @@ def make_test(select_type): ] ) - self.assertEquals( + eq_( sess.query(Engineer.primary_language, Company.name).join(Company.employees).filter(Person.type=='engineer').order_by(desc(Engineer.primary_language)).all(), [(u'java', u'MegaCorp, Inc.'), (u'cobol', u'Elbonia, Inc.'), (u'c++', u'MegaCorp, Inc.')] ) palias = aliased(Person) - self.assertEquals( + eq_( sess.query(Person, Company.name, palias).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'), u'Elbonia, Inc.', Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'))] ) - self.assertEquals( + eq_( sess.query(palias, Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), [(Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'Elbonia, Inc.', @@ -711,13 +713,13 @@ def make_test(select_type): ] ) - self.assertEquals( + eq_( sess.query(Person.name, Company.name, palias.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(), [(u'vlad', u'Elbonia, Inc.', u'dilbert')] ) palias = aliased(Person) - self.assertEquals( + eq_( sess.query(Person.type, Person.name, palias.type, palias.name).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\ filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), [(u'manager', u'dogbert', u'engineer', u'dilbert'), @@ -725,7 +727,7 @@ def make_test(select_type): (u'manager', u'dogbert', u'boss', u'pointy haired boss')] ) - self.assertEquals( + eq_( sess.query(Person.name, Paperwork.description).filter(Person.person_id==Paperwork.person_id).order_by(Person.name, Paperwork.description).all(), [(u'dilbert', u'tps report #1'), (u'dilbert', u'tps report #2'), (u'dogbert', u'review #2'), (u'dogbert', u'review #3'), @@ -737,17 +739,17 @@ def make_test(select_type): ) if select_type != '': - self.assertEquals( + eq_( sess.query(func.count(Person.person_id)).filter(Engineer.primary_language=='java').all(), [(1, )] ) - self.assertEquals( + eq_( sess.query(Company.name, func.count(Person.person_id)).filter(Company.company_id==Person.company_id).group_by(Company.name).order_by(Company.name).all(), [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)] ) - self.assertEquals( + eq_( sess.query(Company.name, func.count(Person.person_id)).join(Company.employees).group_by(Company.name).order_by(Company.name).all(), [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)] ) @@ -757,7 +759,7 @@ def make_test(select_type): return PolymorphicQueryTest for select_type in ('', 'Polymorphic', 'Unions', 'AliasedJoins', 'Joins'): - testclass = make_test(select_type) + testclass = _produce_test(select_type) exec("%s = testclass" % testclass.__name__) del testclass @@ -765,7 +767,8 @@ del testclass class SelfReferentialTestJoinedToBase(_base.MappedTest): run_setup_mappers = 'once' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), @@ -778,7 +781,8 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest): Column('reports_to_id', Integer, ForeignKey('people.person_id')) ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') mapper(Engineer, engineers, inherits=Person, inherit_condition=engineers.c.person_id==people.c.person_id, @@ -796,7 +800,7 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest): sess.flush() sess.expunge_all() - self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert')) + eq_(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert')) def test_oftype_aliases_in_exists(self): e1 = Engineer(name='dilbert', primary_language='java') @@ -805,7 +809,7 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest): sess.add_all([e1, e2]) sess.flush() - self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.of_type(Engineer).has(Engineer.name=='dilbert')).first(), e2) + eq_(sess.query(Engineer).filter(Engineer.reports_to.of_type(Engineer).has(Engineer.name=='dilbert')).first(), e2) def test_join(self): p1 = Person(name='dogbert') @@ -816,14 +820,15 @@ class SelfReferentialTestJoinedToBase(_base.MappedTest): sess.flush() sess.expunge_all() - self.assertEquals( + eq_( sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), Engineer(name='dilbert')) class SelfReferentialJ2JTest(_base.MappedTest): run_setup_mappers = 'once' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, managers people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), @@ -840,7 +845,8 @@ class SelfReferentialJ2JTest(_base.MappedTest): Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') mapper(Manager, managers, inherits=Person, polymorphic_identity='manager') @@ -859,7 +865,7 @@ class SelfReferentialJ2JTest(_base.MappedTest): sess.flush() sess.expunge_all() - self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Manager.name=='dogbert')).first(), Engineer(name='dilbert')) + eq_(sess.query(Engineer).filter(Engineer.reports_to.has(Manager.name=='dogbert')).first(), Engineer(name='dilbert')) def test_join(self): m1 = Manager(name='dogbert') @@ -870,7 +876,7 @@ class SelfReferentialJ2JTest(_base.MappedTest): sess.flush() sess.expunge_all() - self.assertEquals( + eq_( sess.query(Engineer).join('reports_to', aliased=True).filter(Manager.name=='dogbert').first(), Engineer(name='dilbert')) @@ -886,17 +892,17 @@ class SelfReferentialJ2JTest(_base.MappedTest): sess.expunge_all() # filter aliasing applied to Engineer doesn't whack Manager - self.assertEquals( + eq_( sess.query(Manager).join(Manager.engineers).filter(Manager.name=='dogbert').all(), [m1] ) - self.assertEquals( + eq_( sess.query(Manager).join(Manager.engineers).filter(Engineer.name=='dilbert').all(), [m2] ) - self.assertEquals( + eq_( sess.query(Manager, Engineer).join(Manager.engineers).order_by(Manager.name.desc()).all(), [ (m2, e2), @@ -919,12 +925,12 @@ class SelfReferentialJ2JTest(_base.MappedTest): sess.flush() sess.expunge_all() - self.assertEquals( + eq_( sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==None).all(), [] ) - self.assertEquals( + eq_( sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==m1).all(), [m1] ) @@ -936,7 +942,8 @@ class M2MFilterTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, organizations, engineers_to_org organizations = Table('organizations', metadata, @@ -958,7 +965,8 @@ class M2MFilterTest(_base.MappedTest): Column('primary_language', String(50)), ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): global Organization class Organization(_fixtures.Base): pass @@ -970,7 +978,8 @@ class M2MFilterTest(_base.MappedTest): mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer') - def insert_data(self): + @classmethod + def insert_data(cls): e1 = Engineer(name='e1') e2 = Engineer(name='e2') e3 = Engineer(name='e3') @@ -989,20 +998,21 @@ class M2MFilterTest(_base.MappedTest): e1 = sess.query(Person).filter(Engineer.name=='e1').one() # this works - self.assertEquals(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')]) + eq_(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')]) # this had a bug - self.assertEquals(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')]) + eq_(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')]) def test_any(self): sess = create_session() - self.assertEquals(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')]) - self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')]) + eq_(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')]) + eq_(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')]) class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL): run_setup_mappers = 'once' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global Parent, Child1, Child2 Base = declarative_base(metadata=metadata) @@ -1101,5 +1111,3 @@ class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL): assert q.first() is c1 -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/inheritance/selects.py b/test/orm/inheritance/test_selects.py similarity index 88% rename from test/orm/inheritance/selects.py rename to test/orm/inheritance/test_selects.py index e54a0ad13f..a151af4fa2 100644 --- a/test/orm/inheritance/selects.py +++ b/test/orm/inheritance/test_selects.py @@ -1,14 +1,14 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -from testlib import testing -from orm._fixtures import Base -from orm._base import MappedTest +from sqlalchemy.test import testing +from test.orm._fixtures import Base +from test.orm._base import MappedTest class InheritingSelectablesTest(MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global foo, bar, baz foo = Table('foo', metadata, Column('a', String(30), primary_key=1), @@ -49,5 +49,3 @@ class InheritingSelectablesTest(MappedTest): assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all() assert [Bar(), Bar()] == s.query(Bar).all() -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/inheritance/single.py b/test/orm/inheritance/test_single.py similarity index 85% rename from test/orm/inheritance/single.py rename to test/orm/inheritance/test_single.py index 7aee250318..7058268857 100644 --- a/test/orm/inheritance/single.py +++ b/test/orm/inheritance/test_single.py @@ -1,14 +1,15 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ from sqlalchemy import * from sqlalchemy.orm import * -from testlib import testing -from orm import _fixtures -from orm._base import MappedTest, ComparableEntity +from sqlalchemy.test import testing +from test.orm import _fixtures +from test.orm._base import MappedTest, ComparableEntity class SingleInheritanceTest(MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('employees', metadata, Column('employee_id', Integer, primary_key=True), Column('name', String(50)), @@ -22,7 +23,8 @@ class SingleInheritanceTest(MappedTest): Column('name', String(50)), ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Employee(ComparableEntity): pass class Manager(Employee): @@ -32,8 +34,9 @@ class SingleInheritanceTest(MappedTest): class JuniorEngineer(Engineer): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Employee, employees, polymorphic_on=employees.c.type) mapper(Manager, inherits=Employee, polymorphic_identity='manager') mapper(Engineer, inherits=Employee, polymorphic_identity='engineer') @@ -57,7 +60,7 @@ class SingleInheritanceTest(MappedTest): m1 = session.query(Manager).one() session.expire(m1, ['manager_data']) - self.assertEquals(m1.manager_data, "knows how to manage things") + eq_(m1.manager_data, "knows how to manage things") row = session.query(Engineer.name, Engineer.employee_id).filter(Engineer.name=='Kurt').first() assert row.name == 'Kurt' @@ -75,32 +78,32 @@ class SingleInheritanceTest(MappedTest): session.flush() ealias = aliased(Engineer) - self.assertEquals( + eq_( session.query(Manager, ealias).all(), [(m1, e1), (m1, e2)] ) - self.assertEquals( + eq_( session.query(Manager.name).all(), [("Tom",)] ) - self.assertEquals( + eq_( session.query(Manager.name, ealias.name).all(), [("Tom", "Kurt"), ("Tom", "Ed")] ) - self.assertEquals( + eq_( session.query(func.upper(Manager.name), func.upper(ealias.name)).all(), [("TOM", "KURT"), ("TOM", "ED")] ) - self.assertEquals( + eq_( session.query(Manager).add_entity(ealias).all(), [(m1, e1), (m1, e2)] ) - self.assertEquals( + eq_( session.query(Manager.name).add_column(ealias.name).all(), [("Tom", "Kurt"), ("Tom", "Ed")] ) @@ -121,7 +124,7 @@ class SingleInheritanceTest(MappedTest): sess.add_all([m1, m2, e1, e2]) sess.flush() - self.assertEquals( + eq_( sess.query(Manager).select_from(employees.select().limit(10)).all(), [m1, m2] ) @@ -136,12 +139,12 @@ class SingleInheritanceTest(MappedTest): sess.add_all([m1, m2, e1, e2]) sess.flush() - self.assertEquals(sess.query(Manager).count(), 2) - self.assertEquals(sess.query(Engineer).count(), 2) - self.assertEquals(sess.query(Employee).count(), 4) + eq_(sess.query(Manager).count(), 2) + eq_(sess.query(Engineer).count(), 2) + eq_(sess.query(Employee).count(), 4) - self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2) - self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3) + eq_(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2) + eq_(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3) @testing.resolve_artifact_names def test_type_filtering(self): @@ -180,7 +183,8 @@ class SingleInheritanceTest(MappedTest): class RelationToSingleTest(MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('employees', metadata, Column('employee_id', Integer, primary_key=True), Column('name', String(50)), @@ -195,7 +199,8 @@ class RelationToSingleTest(MappedTest): Column('name', String(50)), ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Company(ComparableEntity): pass @@ -229,14 +234,14 @@ class RelationToSingleTest(MappedTest): sess.add_all([c1, c2, m1, m2, e1, e2]) sess.commit() sess.expunge_all() - self.assertEquals( + eq_( sess.query(Company).filter(Company.employees.of_type(JuniorEngineer).any()).all(), [ Company(name='c1'), ] ) - self.assertEquals( + eq_( sess.query(Company).join(Company.employees.of_type(JuniorEngineer)).all(), [ Company(name='c1'), @@ -267,11 +272,11 @@ class RelationToSingleTest(MappedTest): sess.add_all([c1, c2, m1, m2, e1, e2]) sess.commit() - self.assertEquals(c1.engineers, [e2]) - self.assertEquals(c2.engineers, [e1]) + eq_(c1.engineers, [e2]) + eq_(c2.engineers, [e1]) sess.expunge_all() - self.assertEquals(sess.query(Company).order_by(Company.name).all(), + eq_(sess.query(Company).order_by(Company.name).all(), [ Company(name='c1', engineers=[JuniorEngineer(name='Ed')]), Company(name='c2', engineers=[Engineer(name='Kurt')]) @@ -280,7 +285,7 @@ class RelationToSingleTest(MappedTest): # eager load join should limit to only "Engineer" sess.expunge_all() - self.assertEquals(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(), + eq_(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(), [ Company(name='c1', engineers=[JuniorEngineer(name='Ed')]), Company(name='c2', engineers=[Engineer(name='Kurt')]) @@ -289,7 +294,7 @@ class RelationToSingleTest(MappedTest): # join() to Company.engineers, Employee as the requested entity sess.expunge_all() - self.assertEquals(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(), + eq_(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(), [ (Company(name='c1'), JuniorEngineer(name='Ed')), (Company(name='c2'), Engineer(name='Kurt')) @@ -299,7 +304,7 @@ class RelationToSingleTest(MappedTest): # join() to Company.engineers, Engineer as the requested entity. # this actually applies the IN criterion twice which is less than ideal. sess.expunge_all() - self.assertEquals(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(), + eq_(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(), [ (Company(name='c1'), JuniorEngineer(name='Ed')), (Company(name='c2'), Engineer(name='Kurt')) @@ -308,7 +313,7 @@ class RelationToSingleTest(MappedTest): # join() to Company.engineers without any Employee/Engineer entity sess.expunge_all() - self.assertEquals(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(), + eq_(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(), [ Company(name='c2') ] @@ -323,7 +328,7 @@ class RelationToSingleTest(MappedTest): @testing.fails_on_everything_except() def go(): sess.expunge_all() - self.assertEquals(sess.query(Company).\ + eq_(sess.query(Company).\ filter(Company.company_id==Engineer.company_id).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(), [ Company(name='c2') @@ -332,7 +337,8 @@ class RelationToSingleTest(MappedTest): go() class SingleOnJoinedTest(MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global persons_table, employees_table persons_table = Table('persons', metadata, @@ -366,31 +372,29 @@ class SingleOnJoinedTest(MappedTest): sess.flush() sess.expunge_all() - self.assertEquals(sess.query(Person).order_by(Person.person_id).all(), [ + eq_(sess.query(Person).order_by(Person.person_id).all(), [ Person(name='p1'), Employee(name='e1', employee_data='ed1'), Manager(name='m1', employee_data='ed2', manager_data='md1') ]) sess.expunge_all() - self.assertEquals(sess.query(Employee).order_by(Person.person_id).all(), [ + eq_(sess.query(Employee).order_by(Person.person_id).all(), [ Employee(name='e1', employee_data='ed1'), Manager(name='m1', employee_data='ed2', manager_data='md1') ]) sess.expunge_all() - self.assertEquals(sess.query(Manager).order_by(Person.person_id).all(), [ + eq_(sess.query(Manager).order_by(Person.person_id).all(), [ Manager(name='m1', employee_data='ed2', manager_data='md1') ]) sess.expunge_all() def go(): - self.assertEquals(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [ + eq_(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [ Person(name='p1'), Employee(name='e1', employee_data='ed1'), Manager(name='m1', employee_data='ed2', manager_data='md1') ]) self.assert_sql_count(testing.db, go, 1) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/sharding/alltests.py b/test/orm/sharding/alltests.py deleted file mode 100644 index 09fa862126..0000000000 --- a/test/orm/sharding/alltests.py +++ /dev/null @@ -1,18 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -def suite(): - modules_to_test = ( - 'orm.sharding.shard', - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/orm/sharding/shard.py b/test/orm/sharding/test_shard.py similarity index 94% rename from test/orm/sharding/shard.py rename to test/orm/sharding/test_shard.py index 10aaee131b..89e23fb759 100644 --- a/test/orm/sharding/shard.py +++ b/test/orm/sharding/test_shard.py @@ -1,17 +1,17 @@ -import testenv; testenv.configure_for_tests() import datetime, os from sqlalchemy import * from sqlalchemy import sql from sqlalchemy.orm import * from sqlalchemy.orm.shard import ShardedSession from sqlalchemy.sql import operators -from testlib import * -from testlib.testing import eq_ +from sqlalchemy.test import * +from sqlalchemy.test.testing import eq_ # TODO: ShardTest can be turned into a base for further subclasses class ShardTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global db1, db2, db3, db4, weather_locations, weather_reports db1 = create_engine('sqlite:///shard1.db') @@ -48,16 +48,18 @@ class ShardTest(TestBase): db1.execute(ids.insert(), nextid=1) - self.setup_session() - self.setup_mappers() + cls.setup_session() + cls.setup_mappers() - def tearDownAll(self): + @classmethod + def teardown_class(cls): for db in (db1, db2, db3, db4): db.connect().invalidate() for i in range(1,5): os.remove("shard%d.db" % i) - def setup_session(self): + @classmethod + def setup_session(cls): global create_session shard_lookup = { @@ -104,7 +106,8 @@ class ShardTest(TestBase): }, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): global WeatherLocation, Report class WeatherLocation(object): @@ -159,5 +162,3 @@ class ShardTest(TestBase): -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/association.py b/test/orm/test_association.py similarity index 89% rename from test/orm/association.py rename to test/orm/test_association.py index d9265ffb10..ee7fb7af94 100644 --- a/test/orm/association.py +++ b/test/orm/test_association.py @@ -1,17 +1,19 @@ -import testenv; testenv.configure_for_tests() -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base -from testlib.testing import eq_ +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base +from sqlalchemy.test.testing import eq_ class AssociationTest(_base.MappedTest): run_setup_classes = 'once' run_setup_mappers = 'once' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('items', metadata, Column('item_id', Integer, primary_key=True), Column('name', String(40))) @@ -23,7 +25,8 @@ class AssociationTest(_base.MappedTest): Column('keyword_id', Integer, primary_key=True), Column('name', String(40))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Item(_base.BasicEntity): def __init__(self, name): self.name = name @@ -45,9 +48,10 @@ class AssociationTest(_base.MappedTest): return "KeywordAssociation itemid=%d keyword=%r data=%s" % ( self.item_id, self.keyword, self.data) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): - items, item_keywords, keywords = self.tables.get_all( + def setup_mappers(cls): + items, item_keywords, keywords = cls.tables.get_all( 'items', 'item_keywords', 'keywords') mapper(Keyword, keywords) @@ -133,13 +137,11 @@ class AssociationTest(_base.MappedTest): item2.keywords.append(KeywordAssociation(Keyword('green'), 'green_assoc')) sess.add_all((item1, item2)) sess.flush() - eq_(self.tables.item_keywords.count().scalar(), 3) + eq_(item_keywords.count().scalar(), 3) sess.delete(item1) sess.delete(item2) sess.flush() - eq_(self.tables.item_keywords.count().scalar(), 0) + eq_(item_keywords.count().scalar(), 0) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/assorted_eager.py b/test/orm/test_assorted_eager.py similarity index 93% rename from test/orm/assorted_eager.py rename to test/orm/test_assorted_eager.py index 8dc95fa5b2..09f0075479 100644 --- a/test/orm/assorted_eager.py +++ b/test/orm/test_assorted_eager.py @@ -3,21 +3,25 @@ Derived from mailing list-reported problems and trac tickets. """ -import testenv; testenv.configure_for_tests() import datetime -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, backref, create_session -from testlib.testing import eq_ -from orm import _base +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base class EagerTest(_base.MappedTest): run_deletes = None run_inserts = "once" + run_setup_mappers = "once" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): # determine a literal value for "false" based on the dialect # FIXME: this DefaultClause setup is bogus. @@ -30,7 +34,7 @@ class EagerTest(_base.MappedTest): false = text('FALSE') else: false = str(False) - self.other_artifacts['false'] = false + cls.other_artifacts['false'] = false Table('owners', metadata , Column('id', Integer, primary_key=True, nullable=False), @@ -55,30 +59,32 @@ class EagerTest(_base.MappedTest): Column('someoption', sa.Boolean, server_default=false, nullable=False)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Owner(_base.BasicEntity): pass class Category(_base.BasicEntity): pass - class Test(_base.BasicEntity): + class Thing(_base.BasicEntity): pass class Option(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Owner, owners) mapper(Category, categories) mapper(Option, options, properties=dict( owner=relation(Owner), - test=relation(Test))) + test=relation(Thing))) - mapper(Test, tests, properties=dict( + mapper(Thing, tests, properties=dict( owner=relation(Owner, backref='tests'), category=relation(Category), owner_option=relation(Option, @@ -87,16 +93,17 @@ class EagerTest(_base.MappedTest): foreign_keys=[options.c.test_id, options.c.owner_id], uselist=False))) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): session = create_session() o = Owner() c = Category(name='Some Category') session.add_all(( - Test(owner=o, category=c), - Test(owner=o, category=c, owner_option=Option(someoption=True)), - Test(owner=o, category=c, owner_option=Option()))) + Thing(owner=o, category=c), + Thing(owner=o, category=c, owner_option=Option(someoption=True)), + Thing(owner=o, category=c, owner_option=Option()))) session.flush() @@ -129,7 +136,7 @@ class EagerTest(_base.MappedTest): @testing.resolve_artifact_names def test_withouteagerload(self): s = create_session() - l = (s.query(Test). + l = (s.query(Thing). select_from(tests.outerjoin(options, sa.and_(tests.c.id == options.c.test_id, tests.c.owner_id == @@ -150,7 +157,7 @@ class EagerTest(_base.MappedTest): """ s = create_session() - q=s.query(Test).options(sa.orm.eagerload('category')) + q=s.query(Thing).options(sa.orm.eagerload('category')) l=(q.select_from(tests.outerjoin(options, sa.and_(tests.c.id == @@ -168,7 +175,7 @@ class EagerTest(_base.MappedTest): def test_dslish(self): """test the same as witheagerload except using generative""" s = create_session() - q = s.query(Test).options(sa.orm.eagerload('category')) + q = s.query(Thing).options(sa.orm.eagerload('category')) l = q.filter ( sa.and_(tests.c.owner_id == 1, sa.or_(options.c.someoption == None, @@ -182,7 +189,7 @@ class EagerTest(_base.MappedTest): @testing.resolve_artifact_names def test_without_outerjoin_literal(self): s = create_session() - q = s.query(Test).options(sa.orm.eagerload('category')) + q = s.query(Thing).options(sa.orm.eagerload('category')) l = (q.filter( (tests.c.owner_id==1) & ('options.someoption is null or options.someoption=%s' % false)). @@ -194,7 +201,7 @@ class EagerTest(_base.MappedTest): @testing.resolve_artifact_names def test_withoutouterjoin(self): s = create_session() - q = s.query(Test).options(sa.orm.eagerload('category')) + q = s.query(Thing).options(sa.orm.eagerload('category')) l = q.filter( (tests.c.owner_id==1) & ((options.c.someoption==None) | (options.c.someoption==False)) @@ -205,7 +212,8 @@ class EagerTest(_base.MappedTest): class EagerTest2(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('left', metadata, Column('id', Integer, ForeignKey('middle.id'), primary_key=True), Column('data', String(50), primary_key=True)) @@ -218,7 +226,8 @@ class EagerTest2(_base.MappedTest): Column('id', Integer, ForeignKey('middle.id'), primary_key=True), Column('data', String(50), primary_key=True)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Left(_base.BasicEntity): def __init__(self, data): self.data = data @@ -231,8 +240,9 @@ class EagerTest2(_base.MappedTest): def __init__(self, data): self.data = data + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): # set up bi-directional eager loads mapper(Left, left) mapper(Right, right) @@ -267,7 +277,8 @@ class EagerTest2(_base.MappedTest): class EagerTest3(_base.MappedTest): """Eager loading combined with nested SELECT statements, functions, and aggregates.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('datas', metadata, Column('id', Integer, primary_key=True, nullable=False), Column('a', Integer, nullable=False)) @@ -283,7 +294,8 @@ class EagerTest3(_base.MappedTest): Column('data_id', Integer, ForeignKey('datas.id')), Column('somedata', Integer, nullable=False )) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Data(_base.BasicEntity): pass @@ -349,7 +361,8 @@ class EagerTest3(_base.MappedTest): class EagerTest4(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('departments', metadata, Column('department_id', Integer, primary_key=True), Column('name', String(50))) @@ -360,7 +373,8 @@ class EagerTest4(_base.MappedTest): Column('department_id', Integer, ForeignKey('departments.department_id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Department(_base.BasicEntity): pass @@ -401,7 +415,8 @@ class EagerTest4(_base.MappedTest): class EagerTest5(_base.MappedTest): """Construction of AliasedClauses for the same eager load property but different parent mappers, due to inheritance.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('base', metadata, Column('uid', String(30), primary_key=True), Column('x', String(30))) @@ -421,7 +436,8 @@ class EagerTest5(_base.MappedTest): Column('uid', String(30), ForeignKey('base.uid')), Column('comment', String(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Base(_base.BasicEntity): def __init__(self, uid, x): self.uid = uid @@ -486,7 +502,8 @@ class EagerTest5(_base.MappedTest): class EagerTest6(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('design_types', metadata, Column('design_type_id', Integer, primary_key=True)) @@ -506,7 +523,8 @@ class EagerTest6(_base.MappedTest): Column('part_id', Integer, ForeignKey('parts.part_id')), Column('design_id', Integer, ForeignKey('design.design_id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Part(_base.BasicEntity): pass @@ -552,7 +570,8 @@ class EagerTest6(_base.MappedTest): class EagerTest7(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('companies', metadata, Column('company_id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -584,7 +603,8 @@ class EagerTest7(_base.MappedTest): Column('code', String(20)), Column('qty', Integer)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Company(_base.ComparableEntity): pass @@ -699,7 +719,8 @@ class EagerTest7(_base.MappedTest): class EagerTest8(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('prj', metadata, Column('id', Integer, primary_key=True), Column('created', sa.DateTime ), @@ -731,8 +752,9 @@ class EagerTest8(_base.MappedTest): Column('name', sa.Unicode(20)), Column('display_name', sa.Unicode(20))) + @classmethod @testing.resolve_artifact_names - def fixtures(self): + def fixtures(cls): return dict( prj=(('id',), (1,)), @@ -746,7 +768,8 @@ class EagerTest8(_base.MappedTest): task=(('title', 'task_type_id', 'status_id', 'prj_id'), (u'task 1', 1, 1, 1))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Task_Type(_base.BasicEntity): pass @@ -788,7 +811,8 @@ class EagerTest9(_base.MappedTest): throughout the query setup/mapper instances process. """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('accounts', metadata, Column('account_id', Integer, primary_key=True), Column('name', String(40))) @@ -805,7 +829,8 @@ class EagerTest9(_base.MappedTest): Column('transaction_id', Integer, ForeignKey('transactions.transaction_id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Account(_base.BasicEntity): pass @@ -815,8 +840,9 @@ class EagerTest9(_base.MappedTest): class Entry(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Account, accounts) mapper(Transaction, transactions) @@ -874,5 +900,3 @@ class EagerTest9(_base.MappedTest): -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/attributes.py b/test/orm/test_attributes.py similarity index 99% rename from test/orm/attributes.py rename to test/orm/test_attributes.py index 7c116fcf78..3b1b42dadc 100644 --- a/test/orm/attributes.py +++ b/test/orm/test_attributes.py @@ -1,12 +1,11 @@ -import testenv; testenv.configure_for_tests() import pickle import sqlalchemy.orm.attributes as attributes from sqlalchemy.orm.collections import collection from sqlalchemy.orm.interfaces import AttributeExtension from sqlalchemy import exc as sa_exc -from testlib import * -from testlib.testing import eq_ -from orm import _base +from sqlalchemy.test import * +from sqlalchemy.test.testing import eq_ +from test.orm import _base import gc # global for pickling tests @@ -15,12 +14,12 @@ MyTest2 = None class AttributesTest(_base.ORMTest): - def setUp(self): + def setup(self): global MyTest, MyTest2 class MyTest(object): pass class MyTest2(object): pass - def tearDown(self): + def teardown(self): global MyTest, MyTest2 MyTest, MyTest2 = None, None @@ -588,7 +587,7 @@ class BackrefTest(_base.ORMTest): self.assert_(p.jack is None) class PendingBackrefTest(_base.ORMTest): - def setUp(self): + def setup(self): global Post, Blog, called, lazy_load class Post(object): @@ -1327,5 +1326,3 @@ class ListenerTest(_base.ORMTest): assert f1.barset.pop().data == "some bar appended" -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/bind.py b/test/orm/test_bind.py similarity index 71% rename from test/orm/bind.py rename to test/orm/test_bind.py index 33d028d22e..9b1c20b605 100644 --- a/test/orm/bind.py +++ b/test/orm/test_bind.py @@ -1,23 +1,29 @@ -import testenv; testenv.configure_for_tests() -from testlib.sa import MetaData, Table, Column, Integer -from testlib.sa.orm import mapper, create_session -from testlib import sa, testing -from orm import _base +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy import MetaData, Integer +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, create_session +import sqlalchemy as sa +from sqlalchemy.test import testing +from test.orm import _base class BindTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('test_table', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', Integer)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Foo(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): meta = MetaData() test_table.tometadata(meta) @@ -44,12 +50,10 @@ class BindTest(_base.MappedTest): def test_session_unbound(self): sess = create_session() sess.add(Foo()) - self.assertRaisesMessage( + assert_raises_message( sa.exc.UnboundExecutionError, ('Could not locate a bind configured on Mapper|Foo|test_table ' 'or this Session'), sess.flush) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/cascade.py b/test/orm/test_cascade.py similarity index 95% rename from test/orm/cascade.py rename to test/orm/test_cascade.py index c827a85ced..d0a7b9ded6 100644 --- a/test/orm/cascade.py +++ b/test/orm/test_cascade.py @@ -1,18 +1,21 @@ -import testenv; testenv.configure_for_tests() -from testlib.sa import Table, Column, Integer, String, ForeignKey, Sequence, exc as sa_exc -from testlib.sa.orm import mapper, relation, create_session, class_mapper, backref -from testlib.sa.orm import attributes, exc as orm_exc -from testlib import testing -from testlib.testing import eq_ -from orm import _base, _fixtures +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref +from sqlalchemy.orm import attributes, exc as orm_exc +from sqlalchemy.test import testing +from sqlalchemy.test.testing import eq_ +from test.orm import _base, _fixtures class O2MCascadeTest(_fixtures.FixtureTest): run_inserts = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Address, addresses) mapper(User, users, properties = dict( addresses = relation(Address, cascade="all, delete-orphan", backref="user"), @@ -188,8 +191,9 @@ class O2MCascadeTest(_fixtures.FixtureTest): class O2OCascadeTest(_fixtures.FixtureTest): run_inserts = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Address, addresses) mapper(User, users, properties = { 'address':relation(Address, backref=backref("user", single_parent=True), uselist=False) @@ -200,7 +204,7 @@ class O2OCascadeTest(_fixtures.FixtureTest): a1 = Address(email_address='some address') u1 = User(name='u1', address=a1) - self.assertRaises(sa_exc.InvalidRequestError, Address, email_address='asd', user=u1) + assert_raises(sa_exc.InvalidRequestError, Address, email_address='asd', user=u1) a2 = Address(email_address='asd') u1.address = a2 @@ -212,8 +216,9 @@ class O2OCascadeTest(_fixtures.FixtureTest): class O2MBackrefTest(_fixtures.FixtureTest): run_inserts = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users, properties = dict( orders = relation( mapper(Order, orders), cascade="all, delete-orphan", backref="user") @@ -316,8 +321,9 @@ class NoSaveCascadeTest(_fixtures.FixtureTest): class O2MCascadeNoOrphanTest(_fixtures.FixtureTest): run_inserts = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users, properties = dict( orders = relation( mapper(Order, orders), cascade="all") @@ -342,7 +348,8 @@ class O2MCascadeNoOrphanTest(_fixtures.FixtureTest): class M2OCascadeTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("extra", metadata, Column("id", Integer, Sequence("extra_id_seq", optional=True), primary_key=True), @@ -359,7 +366,8 @@ class M2OCascadeTest(_base.MappedTest): Column('name', String(40)), Column('pref_id', Integer, ForeignKey('prefs.id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_fixtures.Base): pass class Pref(_fixtures.Base): @@ -367,8 +375,9 @@ class M2OCascadeTest(_base.MappedTest): class Extra(_fixtures.Base): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Extra, extra) mapper(Pref, prefs, properties=dict( extra = relation(Extra, cascade="all, delete") @@ -377,8 +386,9 @@ class M2OCascadeTest(_base.MappedTest): pref = relation(Pref, lazy=False, cascade="all, delete-orphan", single_parent=True ) )) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): u1 = User(name='ed', pref=Pref(data="pref 1", extra=[Extra()])) u2 = User(name='jack', pref=Pref(data="pref 2", extra=[Extra()])) u3 = User(name="foo", pref=Pref(data="pref 3", extra=[Extra()])) @@ -447,7 +457,8 @@ class M2OCascadeTest(_base.MappedTest): [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")]) class M2OCascadeDeleteTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)), @@ -460,7 +471,8 @@ class M2OCascadeDeleteTest(_base.MappedTest): Column('id', Integer, primary_key=True), Column('data', String(50))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T1(_fixtures.Base): pass class T2(_fixtures.Base): @@ -468,8 +480,9 @@ class M2OCascadeDeleteTest(_base.MappedTest): class T3(_fixtures.Base): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(T1, t1, properties={'t2': relation(T2, cascade="all")}) mapper(T2, t2, properties={'t3': relation(T3, cascade="all")}) mapper(T3, t3) @@ -565,7 +578,8 @@ class M2OCascadeDeleteTest(_base.MappedTest): class M2OCascadeDeleteOrphanTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)), @@ -578,7 +592,8 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): Column('id', Integer, primary_key=True), Column('data', String(50))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T1(_fixtures.Base): pass class T2(_fixtures.Base): @@ -586,8 +601,9 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): class T3(_fixtures.Base): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(T1, t1, properties=dict( t2=relation(T2, cascade="all, delete-orphan", single_parent=True))) mapper(T2, t2, properties=dict( @@ -655,7 +671,7 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): y = T2(data='T2a') x = T1(data='T1a', t2=y) - self.assertRaises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y) + assert_raises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y) @testing.resolve_artifact_names def test_single_parent_backref(self): @@ -666,7 +682,7 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): x = T2(data='T2a', t3=y) # cant attach the T3 to another T2 - self.assertRaises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y) + assert_raises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y) # set via backref tho is OK, unsets from previous parent # first @@ -677,7 +693,8 @@ class M2OCascadeDeleteOrphanTest(_base.MappedTest): assert x.t3 is None class M2MCascadeTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('a', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), @@ -703,7 +720,8 @@ class M2MCascadeTest(_base.MappedTest): ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_fixtures.Base): pass class B(_fixtures.Base): @@ -784,7 +802,7 @@ class M2MCascadeTest(_base.MappedTest): b1 =B(data='b1') a1 = A(data='a1', bs=[b1]) - self.assertRaises(sa_exc.InvalidRequestError, + assert_raises(sa_exc.InvalidRequestError, A, data='a2', bs=[b1] ) @@ -804,7 +822,7 @@ class M2MCascadeTest(_base.MappedTest): b1 =B(data='b1') a1 = A(data='a1', bs=[b1]) - self.assertRaises( + assert_raises( sa_exc.InvalidRequestError, A, data='a2', bs=[b1] ) @@ -817,7 +835,8 @@ class M2MCascadeTest(_base.MappedTest): class UnsavedOrphansTest(_base.MappedTest): """Pending entities that are orphans""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('user_id', Integer, Sequence('user_id_seq', optional=True), @@ -831,7 +850,8 @@ class UnsavedOrphansTest(_base.MappedTest): Column('user_id', Integer, ForeignKey('users.user_id')), Column('email_address', String(40))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_fixtures.Base): pass class Address(_fixtures.Base): @@ -900,7 +920,8 @@ class UnsavedOrphansTest(_base.MappedTest): class UnsavedOrphansTest2(_base.MappedTest): """same test as UnsavedOrphans only three levels deep""" - def define_tables(self, meta): + @classmethod + def define_tables(cls, meta): Table('orders', meta, Column('id', Integer, Sequence('order_id_seq'), primary_key=True), @@ -958,7 +979,8 @@ class UnsavedOrphansTest2(_base.MappedTest): class UnsavedOrphansTest3(_base.MappedTest): """test not expunging double parents""" - def define_tables(self, meta): + @classmethod + def define_tables(cls, meta): Table('sales_reps', meta, Column('sales_rep_id', Integer, Sequence('sales_rep_id_seq'), @@ -1062,7 +1084,8 @@ class UnsavedOrphansTest3(_base.MappedTest): class DoubleParentOrphanTest(_base.MappedTest): """test orphan detection for an entity with two parent relations""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('addresses', metadata, Column('address_id', Integer, primary_key=True), Column('street', String(30)), @@ -1133,7 +1156,8 @@ class DoubleParentOrphanTest(_base.MappedTest): assert True class CollectionAssignmentOrphanTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('table_a', metadata, Column('id', Integer, primary_key=True), Column('name', String(30))) @@ -1181,7 +1205,8 @@ class PartialFlushTest(_base.MappedTest): """test cascade behavior as it relates to object lists passed to flush(). """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("base", metadata, Column("id", Integer, primary_key=True), Column("descr", String(50)) @@ -1288,5 +1313,3 @@ class PartialFlushTest(_base.MappedTest): assert c1 not in sess.new assert c2 in sess.new -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/collection.py b/test/orm/test_collection.py similarity index 98% rename from test/orm/collection.py rename to test/orm/test_collection.py index 23f643597a..12ff25c460 100644 --- a/test/orm/collection.py +++ b/test/orm/test_collection.py @@ -1,17 +1,19 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import sys from operator import and_ import sqlalchemy.orm.collections as collections from sqlalchemy.orm.collections import collection -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa import util, exc as sa_exc -from testlib.sa.orm import create_session, mapper, relation, \ - attributes -from orm import _base -from testlib.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy import util, exc as sa_exc +from sqlalchemy.orm import create_session, mapper, relation, attributes +from test.orm import _base +from sqlalchemy.test.testing import eq_ class Canary(sa.orm.interfaces.AttributeExtension): def __init__(self): @@ -45,12 +47,14 @@ class CollectionsTest(_base.ORMTest): def __repr__(self): return str((id(self), self.a, self.b, self.c)) - def setUpAll(self): - attributes.register_class(self.Entity) + @classmethod + def setup_class(cls): + attributes.register_class(cls.Entity) - def tearDownAll(self): - attributes.unregister_class(self.Entity) - _base.ORMTest.tearDownAll(self) + @classmethod + def teardown_class(cls): + attributes.unregister_class(cls.Entity) + super(CollectionsTest, cls).teardown_class() _entity_id = 1 @@ -937,7 +941,7 @@ class CollectionsTest(_base.ORMTest): pass self.assert_(obj.attr is not real_dict) self.assert_('badkey' not in obj.attr) - self.assertEquals(set(collections.collection_adapter(obj.attr)), + eq_(set(collections.collection_adapter(obj.attr)), set([e2])) self.assert_(e3 not in canary.added) else: @@ -945,13 +949,13 @@ class CollectionsTest(_base.ORMTest): obj.attr = real_dict self.assert_(obj.attr is not real_dict) self.assert_('keyignored1' not in obj.attr) - self.assertEquals(set(collections.collection_adapter(obj.attr)), + eq_(set(collections.collection_adapter(obj.attr)), set([e3])) self.assert_(e2 in canary.removed) self.assert_(e3 in canary.added) obj.attr = typecallable() - self.assertEquals(list(collections.collection_adapter(obj.attr)), []) + eq_(list(collections.collection_adapter(obj.attr)), []) e4 = creator() try: @@ -1336,7 +1340,8 @@ class CollectionsTest(_base.ORMTest): class DictHelpersTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('parents', metadata, Column('id', Integer, primary_key=True), Column('label', String(128))) @@ -1348,7 +1353,8 @@ class DictHelpersTest(_base.MappedTest): Column('b', String(128)), Column('c', String(128))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Parent(_base.BasicEntity): def __init__(self, label=None): self.label = label @@ -1378,7 +1384,7 @@ class DictHelpersTest(_base.MappedTest): p = session.query(Parent).get(pid) - self.assertEquals(set(p.children.keys()), set(['foo', 'bar'])) + eq_(set(p.children.keys()), set(['foo', 'bar'])) cid = p.children['foo'].id collections.collection_adapter(p.children).append_with_event( @@ -1519,7 +1525,8 @@ class DictHelpersTest(_base.MappedTest): # remove if so class CustomCollectionsTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('sometable', metadata, Column('col1',Integer, primary_key=True), Column('data', String(30))) @@ -1830,5 +1837,3 @@ class InstrumentationTest(_base.ORMTest): instrumented = collections._instrument_class(Touchy) assert True -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/compile.py b/test/orm/test_compile.py similarity index 97% rename from test/orm/compile.py rename to test/orm/test_compile.py index 7c9bed4ecc..7a5b636157 100644 --- a/test/orm/compile.py +++ b/test/orm/test_compile.py @@ -1,15 +1,14 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy import exc as sa_exc from sqlalchemy.orm import * -from testlib import * -from orm import _base +from sqlalchemy.test import * +from test.orm import _base class CompileTest(_base.ORMTest): """test various mapper compilation scenarios""" - def tearDown(self): + def teardown(self): clear_mappers() def testone(self): @@ -182,5 +181,3 @@ class CompileTest(_base.ORMTest): assert str(e).index("Error creating backref") > -1 -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/cycles.py b/test/orm/test_cycles.py similarity index 94% rename from test/orm/cycles.py rename to test/orm/test_cycles.py index 3e36360852..fe77b36018 100644 --- a/test/orm/cycles.py +++ b/test/orm/test_cycles.py @@ -5,19 +5,21 @@ T1<->T2, with o2m or m2o between them, and a third T3 with o2m/m2o to one/both T1/T2. """ -import testenv; testenv.configure_for_tests() -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, backref, create_session -from testlib.testing import eq_ -from testlib.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf -from orm import _base +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, create_session +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf +from test.orm import _base class SelfReferentialTest(_base.MappedTest): """A self-referential mapper with an additional list of child objects.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), @@ -29,7 +31,8 @@ class SelfReferentialTest(_base.MappedTest): Column('c1id', Integer, ForeignKey('t1.c1')), Column('data', String(20))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class C1(_base.BasicEntity): def __init__(self, data=None): self.data = data @@ -132,20 +135,23 @@ class SelfReferentialTest(_base.MappedTest): class SelfReferentialNoPKTest(_base.MappedTest): """A self-referential relationship that joins on a column other than the primary key column""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('item', metadata, Column('id', Integer, primary_key=True), Column('uuid', String(32), unique=True, nullable=False), Column('parent_uuid', String(32), ForeignKey('item.uuid'), nullable=True)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class TT(_base.BasicEntity): def __init__(self): self.uuid = hex(id(self)) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(TT, item, properties={ 'children': relation( TT, @@ -181,7 +187,8 @@ class SelfReferentialNoPKTest(_base.MappedTest): class InheritTestOne(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("parent", metadata, Column("id", Integer, primary_key=True), Column("parent_data", String(50)), @@ -199,7 +206,8 @@ class InheritTestOne(_base.MappedTest): nullable=False), Column("child2_data", String(50))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Parent(_base.BasicEntity): pass @@ -209,8 +217,9 @@ class InheritTestOne(_base.MappedTest): class Child2(Parent): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Parent, parent) mapper(Child1, child1, inherits=Parent) mapper(Child2, child2, inherits=Parent, properties=dict( @@ -250,7 +259,8 @@ class InheritTestTwo(_base.MappedTest): """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('a', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), @@ -266,7 +276,8 @@ class InheritTestTwo(_base.MappedTest): Column('aid', Integer, ForeignKey('a.id', use_alter=True, name="foo"))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_base.BasicEntity): pass @@ -297,7 +308,8 @@ class InheritTestTwo(_base.MappedTest): class BiDirectionalManyToOneTest(_base.MappedTest): run_define_tables = 'each' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), @@ -313,7 +325,8 @@ class BiDirectionalManyToOneTest(_base.MappedTest): Column('t1id', Integer, ForeignKey('t1.id'), nullable=False), Column('t2id', Integer, ForeignKey('t2.id'), nullable=False)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T1(_base.BasicEntity): pass class T2(_base.BasicEntity): @@ -321,8 +334,9 @@ class BiDirectionalManyToOneTest(_base.MappedTest): class T3(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(T1, t1, properties={ 't2':relation(T2, primaryjoin=t1.c.t2id == t2.c.id)}) mapper(T2, t2, properties={ @@ -385,7 +399,8 @@ class BiDirectionalOneToManyTest(_base.MappedTest): run_define_tables = 'each' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('c1', Integer, primary_key=True, test_needs_autoincrement=True), @@ -397,7 +412,8 @@ class BiDirectionalOneToManyTest(_base.MappedTest): Column('c2', Integer, ForeignKey('t1.c1', use_alter=True, name='t1c1_fk'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class C1(_base.BasicEntity): pass @@ -434,7 +450,8 @@ class BiDirectionalOneToManyTest2(_base.MappedTest): run_define_tables = 'each' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('c1', Integer, primary_key=True), Column('c2', Integer, ForeignKey('t2.c1')), @@ -452,7 +469,8 @@ class BiDirectionalOneToManyTest2(_base.MappedTest): Column('data', String(20)), test_needs_autoincrement=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class C1(_base.BasicEntity): pass @@ -462,8 +480,9 @@ class BiDirectionalOneToManyTest2(_base.MappedTest): class C1Data(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(C2, t2, properties={ 'c1s': relation(C1, primaryjoin=t2.c.c1 == t1.c.c2, @@ -508,7 +527,8 @@ class OneToManyManyToOneTest(_base.MappedTest): """ run_define_tables = 'each' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('ball', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -522,7 +542,8 @@ class OneToManyManyToOneTest(_base.MappedTest): Column('favorite_ball_id', Integer, ForeignKey('ball.id')), Column('data', String(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Person(_base.BasicEntity): pass @@ -709,7 +730,8 @@ class OneToManyManyToOneTest(_base.MappedTest): class SelfReferentialPostUpdateTest(_base.MappedTest): """Post_update on a single self-referential mapper""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('node', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -721,7 +743,8 @@ class SelfReferentialPostUpdateTest(_base.MappedTest): Column('next_sibling_id', Integer, ForeignKey('node.id'), nullable=True)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Node(_base.BasicEntity): def __init__(self, path=''): self.path = path @@ -815,13 +838,15 @@ class SelfReferentialPostUpdateTest(_base.MappedTest): class SelfReferentialPostUpdateTest2(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("a_table", metadata, Column("id", Integer(), primary_key=True), Column("fui", String(128)), Column("b", Integer(), ForeignKey("a_table.id"))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_base.BasicEntity): pass @@ -858,5 +883,3 @@ class SelfReferentialPostUpdateTest2(_base.MappedTest): assert f2.foo is f1 -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/defaults.py b/test/orm/test_defaults.py similarity index 88% rename from test/orm/defaults.py rename to test/orm/test_defaults.py index 8dc1925195..b063780ac7 100644 --- a/test/orm/defaults.py +++ b/test/orm/test_defaults.py @@ -1,16 +1,19 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base -from testlib.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base +from sqlalchemy.test.testing import eq_ class TriggerDefaultsTest(_base.MappedTest): __requires__ = ('row_triggers',) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): dt = Table('dt', metadata, Column('id', Integer, primary_key=True), Column('col1', String(20)), @@ -63,12 +66,14 @@ class TriggerDefaultsTest(_base.MappedTest): sa.DDL("DROP TRIGGER dt_up").execute_at('before-drop', dt) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Default(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Default, dt) @testing.resolve_artifact_names @@ -107,7 +112,8 @@ class TriggerDefaultsTest(_base.MappedTest): eq_(d1.col4, 'up') class ExcludedDefaultsTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): dt = Table('dt', metadata, Column('id', Integer, primary_key=True), Column('col1', String(20), default="hello"), @@ -125,5 +131,3 @@ class ExcludedDefaultsTest(_base.MappedTest): sess.flush() eq_(dt.select().execute().fetchall(), [(1, "hello")]) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/deprecations.py b/test/orm/test_deprecations.py similarity index 96% rename from test/orm/deprecations.py rename to test/orm/test_deprecations.py index 483e8f556b..00d64119ea 100644 --- a/test/orm/deprecations.py +++ b/test/orm/test_deprecations.py @@ -5,11 +5,12 @@ modern (i.e. not deprecated) alternative to them. The tests snippets here can be migrated directly to the wiki, docs, etc. """ -import testenv; testenv.configure_for_tests() -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, func -from testlib.sa.orm import mapper, relation, create_session, sessionmaker -from orm import _base +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, sessionmaker +from test.orm import _base class QueryAlternativesTest(_base.MappedTest): @@ -44,7 +45,8 @@ class QueryAlternativesTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users_table', metadata, Column('id', Integer, primary_key=True), Column('name', String(64))) @@ -56,21 +58,24 @@ class QueryAlternativesTest(_base.MappedTest): Column('purpose', String(16)), Column('bounces', Integer, default=0)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.BasicEntity): pass class Address(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users_table, properties=dict( addresses=relation(Address, backref='user'), )) mapper(Address, addresses_table) - def fixtures(self): + @classmethod + def fixtures(cls): return dict( users_table=( ('id', 'name'), @@ -479,5 +484,3 @@ class QueryAlternativesTest(_base.MappedTest): assert len(users) == 1 and users[0].name == 'ed' -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/dynamic.py b/test/orm/test_dynamic.py similarity index 96% rename from test/orm/dynamic.py rename to test/orm/test_dynamic.py index 3bd94b7c0e..f2089a4351 100644 --- a/test/orm/dynamic.py +++ b/test/orm/test_dynamic.py @@ -1,13 +1,15 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import operator from sqlalchemy.orm import dynamic_loader, backref -from testlib import testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, desc, select, func -from testlib.sa.orm import mapper, relation, create_session, Query, attributes +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, desc, select, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, Query, attributes from sqlalchemy.orm.dynamic import AppenderMixin -from testlib.testing import eq_ -from testlib.compat import _function_named -from orm import _base, _fixtures +from sqlalchemy.test.testing import eq_ +from sqlalchemy.util import function_named +from test.orm import _base, _fixtures class DynamicTest(_fixtures.FixtureTest): @@ -281,7 +283,7 @@ class SessionTest(_fixtures.FixtureTest): sess.flush() from sqlalchemy.orm import attributes - self.assertEquals(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], [])) + eq_(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], [])) sess.expunge_all() @@ -452,7 +454,7 @@ class SessionTest(_fixtures.FixtureTest): sess.close() -def create_backref_test(autoflush, saveuser): +def _create_backref_test(autoflush, saveuser): @testing.resolve_artifact_names def test_backref(self): @@ -487,17 +489,18 @@ def create_backref_test(autoflush, saveuser): sess.flush() self.assert_(list(u.addresses) == []) - test_backref = _function_named( + test_backref = function_named( test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""), (saveuser and "_saveuser" or "_savead"))) setattr(SessionTest, test_backref.__name__, test_backref) for autoflush in (False, True): for saveuser in (False, True): - create_backref_test(autoflush, saveuser) + _create_backref_test(autoflush, saveuser) class DontDereferenceTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(40)), @@ -509,8 +512,9 @@ class DontDereferenceTest(_base.MappedTest): Column('email_address', String(100), nullable=False), Column('user_id', Integer, ForeignKey('users.id'))) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class User(_base.ComparableEntity): pass @@ -555,5 +559,3 @@ class DontDereferenceTest(_base.MappedTest): eq_(query3(), [Address(email_address='joe@joesdomain.example')]) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/eager_relations.py b/test/orm/test_eager_relations.py similarity index 97% rename from test/orm/eager_relations.py rename to test/orm/test_eager_relations.py index 87c2442cc4..384e0472f6 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -1,13 +1,16 @@ """basic tests of eager loaded attributes""" -import testenv; testenv.configure_for_tests() -from testlib import sa, testing +from sqlalchemy.test.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing from sqlalchemy.orm import eagerload, deferred, undefer -from testlib.sa import Table, Column, Integer, String, Date, ForeignKey, and_, select, func -from testlib.sa.orm import mapper, relation, create_session, lazyload, aliased -from testlib.testing import eq_ -from testlib.assertsql import CompiledSQL -from orm import _base, _fixtures +from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, lazyload, aliased +from sqlalchemy.test.testing import eq_ +from sqlalchemy.test.assertsql import CompiledSQL +from test.orm import _base, _fixtures import datetime class EagerTest(_fixtures.FixtureTest): @@ -23,7 +26,7 @@ class EagerTest(_fixtures.FixtureTest): q = sess.query(User) assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all() - self.assertEquals(self.static.user_address_result, q.order_by(User.id).all()) + eq_(self.static.user_address_result, q.order_by(User.id).all()) @testing.resolve_artifact_names def test_late_compile(self): @@ -287,7 +290,7 @@ class EagerTest(_fixtures.FixtureTest): assert sa.orm.class_mapper(Address).get_property('user').lazy is False sess = create_session() - self.assertEquals(self.static.user_address_result, sess.query(User).order_by(User.id).all()) + eq_(self.static.user_address_result, sess.query(User).order_by(User.id).all()) @testing.resolve_artifact_names def test_double(self): @@ -615,13 +618,13 @@ class EagerTest(_fixtures.FixtureTest): def go(): o1 = sess.query(Order).options(lazyload('address')).filter(Order.id==5).one() - self.assertEquals(o1.address, None) + eq_(o1.address, None) self.assert_sql_count(testing.db, go, 2) sess.expunge_all() def go(): o1 = sess.query(Order).filter(Order.id==5).one() - self.assertEquals(o1.address, None) + eq_(o1.address, None) self.assert_sql_count(testing.db, go, 1) @testing.resolve_artifact_names @@ -817,7 +820,8 @@ class AddEntityTest(_fixtures.FixtureTest): self.assert_sql_count(testing.db, go, 1) class OrderBySecondaryTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('m2m', metadata, Column('id', Integer, primary_key=True), Column('aid', Integer, ForeignKey('a.id')), @@ -830,7 +834,8 @@ class OrderBySecondaryTest(_base.MappedTest): Column('id', Integer, primary_key=True), Column('data', String(50))) - def fixtures(self): + @classmethod + def fixtures(cls): return dict( a=(('id', 'data'), (1, 'a1'), @@ -865,7 +870,8 @@ class OrderBySecondaryTest(_base.MappedTest): class SelfReferentialEagerTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('nodes', metadata, Column('id', Integer, sa.Sequence('node_id_seq', optional=True), primary_key=True), @@ -980,7 +986,7 @@ class SelfReferentialEagerTest(_base.MappedTest): sess.expunge_all() def go(): - self.assertEquals( + eq_( Node(data='n1', children=[Node(data='n11'), Node(data='n12')]), sess.query(Node).order_by(Node.id).first(), ) @@ -1079,7 +1085,8 @@ class SelfReferentialEagerTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 3) class MixedSelfReferentialEagerTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('a_table', metadata, Column('id', Integer, primary_key=True) ) @@ -1091,8 +1098,9 @@ class MixedSelfReferentialEagerTest(_base.MappedTest): Column('parent_b2_id', Integer, ForeignKey('b_table.id'))) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class A(_base.ComparableEntity): pass class B(_base.ComparableEntity): @@ -1113,8 +1121,9 @@ class MixedSelfReferentialEagerTest(_base.MappedTest): ) }); + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): a_table.insert().execute(dict(id=1), dict(id=2), dict(id=3)) b_table.insert().execute( dict(id=1, parent_a_id=2, parent_b1_id=None, parent_b2_id=None), @@ -1149,7 +1158,8 @@ class MixedSelfReferentialEagerTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 1) class SelfReferentialM2MEagerTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('widget', metadata, Column('id', Integer, primary_key=True), Column('name', sa.Unicode(40), nullable=False, unique=True), @@ -1189,8 +1199,9 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): run_inserts = 'once' run_deletes = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users, properties={ 'addresses':relation(Address, backref='user'), 'orders':relation(Order, backref='user'), # o2m, m2o @@ -1323,7 +1334,8 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): class CyclicalInheritingEagerTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('c1', Integer, primary_key=True), Column('c2', String(30)), @@ -1361,7 +1373,8 @@ class CyclicalInheritingEagerTest(_base.MappedTest): create_session().query(SubT).all() class SubqueryTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users_table', metadata, Column('id', Integer, primary_key=True), Column('name', String(16)) @@ -1449,7 +1462,8 @@ class CorrelatedSubqueryTest(_base.MappedTest): """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): users = Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(50)) @@ -1460,8 +1474,9 @@ class CorrelatedSubqueryTest(_base.MappedTest): Column('date', Date), Column('user_id', Integer, ForeignKey('users.id'))) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): users.insert().execute( {'id':1, 'name':'user1'}, {'id':2, 'name':'user2'}, @@ -1591,5 +1606,3 @@ class CorrelatedSubqueryTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 1) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/evaluator.py b/test/orm/test_evaluator.py similarity index 86% rename from test/orm/evaluator.py rename to test/orm/test_evaluator.py index 3527c93d77..af6a3f89e3 100644 --- a/test/orm/evaluator.py +++ b/test/orm/test_evaluator.py @@ -1,10 +1,12 @@ """Evluating SQL expressions on ORM objects""" -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, String, Integer, select -from testlib.sa.orm import mapper, create_session -from testlib.testing import eq_ -from orm import _base +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import String, Integer, select +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base from sqlalchemy import and_, or_, not_ from sqlalchemy.orm import evaluator @@ -20,17 +22,20 @@ def eval_eq(clause, testcases=None): return testeval class EvaluateTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(64))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users) @testing.resolve_artifact_names @@ -90,5 +95,3 @@ class EvaluateTest(_base.MappedTest): (User(id=None, name=None), None), ]) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/expire.py b/test/orm/test_expire.py similarity index 95% rename from test/orm/expire.py rename to test/orm/test_expire.py index c11fb69dfe..6593498978 100644 --- a/test/orm/expire.py +++ b/test/orm/test_expire.py @@ -1,11 +1,14 @@ """Attribute/instance expiration, deferral of attributes, etc.""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import gc -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, exc as sa_exc -from testlib.sa.orm import mapper, relation, create_session, attributes, deferred -from orm import _base, _fixtures +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, exc as sa_exc +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, attributes, deferred +from test.orm import _base, _fixtures class ExpireTest(_fixtures.FixtureTest): @@ -56,7 +59,7 @@ class ExpireTest(_fixtures.FixtureTest): u = s.query(User).get(7) s.expunge_all() - self.assertRaisesMessage(sa.exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u) + assert_raises_message(sa.exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u) @testing.resolve_artifact_names def test_get_refreshes(self): @@ -69,7 +72,7 @@ class ExpireTest(_fixtures.FixtureTest): u = s.query(User).get(10) # get() refreshes self.assert_sql_count(testing.db, go, 1) def go(): - self.assertEquals(u.name, 'chuck') # attributes unexpired + eq_(u.name, 'chuck') # attributes unexpired self.assert_sql_count(testing.db, go, 0) def go(): u = s.query(User).get(10) # expire flag reset, so not expired @@ -86,7 +89,7 @@ class ExpireTest(_fixtures.FixtureTest): # add it back s.add(u) # nope, raises ObjectDeletedError - self.assertRaises(sa.orm.exc.ObjectDeletedError, getattr, u, 'name') + assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u, 'name') # do a get()/remove u from session again assert s.query(User).get(10) is None @@ -97,7 +100,7 @@ class ExpireTest(_fixtures.FixtureTest): assert u in s # but now its back, rollback has occured, the _remove_newly_deleted # is reverted - self.assertEquals(u.name, 'chuck') + eq_(u.name, 'chuck') @testing.resolve_artifact_names def test_deferred(self): @@ -122,7 +125,7 @@ class ExpireTest(_fixtures.FixtureTest): s = create_session(autoflush=True, autocommit=False) u = s.query(User).get(8) adlist = u.addresses - self.assertEquals(adlist, [ + eq_(adlist, [ Address(email_address='ed@bettyboop.com'), Address(email_address='ed@lala.com'), Address(email_address='ed@wood.com'), @@ -130,7 +133,7 @@ class ExpireTest(_fixtures.FixtureTest): a1 = u.addresses[2] a1.email_address = 'aaaaa' s.expire(u, ['addresses']) - self.assertEquals(u.addresses, [ + eq_(u.addresses, [ Address(email_address='aaaaa'), Address(email_address='ed@bettyboop.com'), Address(email_address='ed@lala.com'), @@ -146,10 +149,10 @@ class ExpireTest(_fixtures.FixtureTest): mapper(Address, addresses) s = create_session(autoflush=True, autocommit=False) u = s.query(User).get(8) - self.assertRaisesMessage(sa_exc.InvalidRequestError, "properties specified for refresh", s.refresh, u, ['addresses']) + assert_raises_message(sa_exc.InvalidRequestError, "properties specified for refresh", s.refresh, u, ['addresses']) # in contrast to a regular query with no columns - self.assertRaisesMessage(sa_exc.InvalidRequestError, "no columns with which to SELECT", s.query().all) + assert_raises_message(sa_exc.InvalidRequestError, "no columns with which to SELECT", s.query().all) @testing.resolve_artifact_names def test_refresh_cancels_expire(self): @@ -161,7 +164,7 @@ class ExpireTest(_fixtures.FixtureTest): def go(): u = s.query(User).get(7) - self.assertEquals(u.name, 'jack') + eq_(u.name, 'jack') self.assert_sql_count(testing.db, go, 0) @testing.resolve_artifact_names @@ -187,7 +190,7 @@ class ExpireTest(_fixtures.FixtureTest): sess.expire(u, attribute_names=['name']) sess.expunge(u) - self.assertRaises(sa.exc.UnboundExecutionError, getattr, u, 'name') + assert_raises(sa.exc.UnboundExecutionError, getattr, u, 'name') @testing.resolve_artifact_names def test_pending_raises(self): @@ -197,7 +200,7 @@ class ExpireTest(_fixtures.FixtureTest): sess = create_session() u = User(id=15) sess.add(u) - self.assertRaises(sa.exc.InvalidRequestError, sess.expire, u, ['name']) + assert_raises(sa.exc.InvalidRequestError, sess.expire, u, ['name']) @testing.resolve_artifact_names def test_no_instance_key(self): @@ -668,14 +671,15 @@ class ExpireTest(_fixtures.FixtureTest): userlist = sess.query(User).order_by(User.id).all() u = userlist[1] - self.assertEquals(self.static.user_address_result, userlist) + eq_(self.static.user_address_result, userlist) assert len(list(sess)) == 9 class PolymorphicExpireTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global people, engineers, Person, Engineer people = Table('people', metadata, @@ -690,14 +694,16 @@ class PolymorphicExpireTest(_base.MappedTest): Column('status', String(30)), ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Person(_base.ComparableEntity): pass class Engineer(Person): pass + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): people.insert().execute( {'person_id':1, 'name':'person1', 'type':'person'}, {'person_id':2, 'name':'engineer1', 'type':'engineer'}, @@ -745,7 +751,7 @@ class PolymorphicExpireTest(_base.MappedTest): assert e1.status == 'new engineer' assert e2.status == 'old engineer' self.assert_sql_count(testing.db, go, 2) - self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1'])) + eq_(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1'])) class ExpiredPendingTest(_fixtures.FixtureTest): run_define_tables = 'once' @@ -837,7 +843,7 @@ class RefreshTest(_fixtures.FixtureTest): s = create_session() u = s.query(User).get(7) s.expunge_all() - self.assertRaisesMessage(sa.exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u)) + assert_raises_message(sa.exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u)) @testing.resolve_artifact_names def test_refresh_expired(self): @@ -908,5 +914,3 @@ class RefreshTest(_fixtures.FixtureTest): s.refresh(u) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/extendedattr.py b/test/orm/test_extendedattr.py similarity index 87% rename from test/orm/extendedattr.py rename to test/orm/test_extendedattr.py index aec6c181f2..e0c64bf64a 100644 --- a/test/orm/extendedattr.py +++ b/test/orm/test_extendedattr.py @@ -1,4 +1,4 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import pickle from sqlalchemy import util import sqlalchemy.orm.attributes as attributes @@ -6,8 +6,8 @@ from sqlalchemy.orm.collections import collection from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import InstrumentationManager -from testlib import * -from orm import _base +from sqlalchemy.test import * +from test.orm import _base class MyTypesManager(InstrumentationManager): @@ -100,7 +100,8 @@ class MyClass(object): del self._goofy_dict[key] class UserDefinedExtensionTest(_base.ORMTest): - def tearDownAll(self): + @classmethod + def teardown_class(cls): clear_mappers() attributes._install_lookup_strategy(util.symbol('native')) @@ -161,30 +162,30 @@ class UserDefinedExtensionTest(_base.ORMTest): assert Foo in attributes.instrumentation_registry._state_finders f = Foo() attributes.instance_state(f).expire_attributes(None) - self.assertEquals(f.a, "this is a") - self.assertEquals(f.b, 12) + eq_(f.a, "this is a") + eq_(f.b, 12) f.a = "this is some new a" attributes.instance_state(f).expire_attributes(None) - self.assertEquals(f.a, "this is a") - self.assertEquals(f.b, 12) + eq_(f.a, "this is a") + eq_(f.b, 12) attributes.instance_state(f).expire_attributes(None) f.a = "this is another new a" - self.assertEquals(f.a, "this is another new a") - self.assertEquals(f.b, 12) + eq_(f.a, "this is another new a") + eq_(f.b, 12) attributes.instance_state(f).expire_attributes(None) - self.assertEquals(f.a, "this is a") - self.assertEquals(f.b, 12) + eq_(f.a, "this is a") + eq_(f.b, 12) del f.a - self.assertEquals(f.a, None) - self.assertEquals(f.b, 12) + eq_(f.a, None) + eq_(f.b, 12) attributes.instance_state(f).commit_all(attributes.instance_dict(f)) - self.assertEquals(f.a, None) - self.assertEquals(f.b, 12) + eq_(f.a, None) + eq_(f.b, 12) def test_inheritance(self): """tests that attributes are polymorphic""" @@ -265,27 +266,27 @@ class UserDefinedExtensionTest(_base.ORMTest): f1 = Foo() f1.name = 'f1' - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], (), ())) + eq_(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], (), ())) b1 = Bar() b1.name = 'b1' f1.bars.append(b1) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], [])) + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], [])) attributes.instance_state(f1).commit_all(attributes.instance_dict(f1)) attributes.instance_state(b1).commit_all(attributes.instance_dict(b1)) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ())) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ())) + eq_(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ())) + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ())) f1.name = 'f1mod' b2 = Bar() b2.name = 'b2' f1.bars.append(b2) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1'])) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], [])) + eq_(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1'])) + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], [])) f1.bars.remove(b1) - self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1])) + eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1])) def test_null_instrumentation(self): class Foo(MyBaseClass): @@ -311,9 +312,9 @@ class UserDefinedExtensionTest(_base.ORMTest): assert attributes.manager_of_class(None) is None assert attributes.instance_state(k) is not None - self.assertRaises((AttributeError, KeyError), + assert_raises((AttributeError, KeyError), attributes.instance_state, u) - self.assertRaises((AttributeError, KeyError), + assert_raises((AttributeError, KeyError), attributes.instance_state, None) diff --git a/test/orm/generative.py b/test/orm/test_generative.py similarity index 91% rename from test/orm/generative.py rename to test/orm/test_generative.py index 9952367414..0efc1814ed 100644 --- a/test/orm/generative.py +++ b/test/orm/test_generative.py @@ -1,28 +1,34 @@ -import testenv; testenv.configure_for_tests() -from testlib import testing, sa -from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, func +from sqlalchemy.test.testing import eq_ +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, MetaData, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column from sqlalchemy.orm import mapper, relation, create_session -from testlib.testing import eq_ -from orm import _base, _fixtures +from sqlalchemy.test.testing import eq_ +from test.orm import _base, _fixtures class GenerativeQueryTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foo', metadata, Column('id', Integer, sa.Sequence('foo_id_seq'), primary_key=True), Column('bar', Integer), Column('range', Integer)) - def fixtures(self): + @classmethod + def fixtures(cls): rows = tuple([(i, i % 10) for i in range(100)]) foo_data = (('bar', 'range'),) + rows return dict(foo=foo_data) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Foo(_base.BasicEntity): pass @@ -131,7 +137,8 @@ class GenerativeQueryTest(_base.MappedTest): class GenerativeTest2(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('Table1', metadata, Column('id', Integer, primary_key=True)) Table('Table2', metadata, @@ -139,8 +146,9 @@ class GenerativeTest2(_base.MappedTest): primary_key=True), Column('num', Integer, primary_key=True)) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Obj1(_base.BasicEntity): pass class Obj2(_base.BasicEntity): @@ -149,7 +157,8 @@ class GenerativeTest2(_base.MappedTest): mapper(Obj1, Table1) mapper(Obj2, Table2) - def fixtures(self): + @classmethod + def fixtures(cls): return dict( Table1=(('id',), (1,), @@ -182,8 +191,9 @@ class RelationsTest(_fixtures.FixtureTest): run_inserts = 'once' run_deletes = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users, properties={ 'orders':relation(mapper(Order, orders, properties={ 'addresses':relation(mapper(Address, addresses))}))}) @@ -232,7 +242,8 @@ class RelationsTest(_fixtures.FixtureTest): class CaseSensitiveTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('Table1', metadata, Column('ID', Integer, primary_key=True)) Table('Table2', metadata, @@ -240,8 +251,9 @@ class CaseSensitiveTest(_base.MappedTest): primary_key=True), Column('NUM', Integer, primary_key=True)) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Obj1(_base.BasicEntity): pass class Obj2(_base.BasicEntity): @@ -250,7 +262,8 @@ class CaseSensitiveTest(_base.MappedTest): mapper(Obj1, Table1) mapper(Obj2, Table2) - def fixtures(self): + @classmethod + def fixtures(cls): return dict( Table1=(('ID',), (1,), @@ -272,8 +285,6 @@ class CaseSensitiveTest(_base.MappedTest): res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)) assert res.count() == 3 res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)).distinct() - self.assertEqual(res.count(), 1) + eq_(res.count(), 1) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/instrumentation.py b/test/orm/test_instrumentation.py similarity index 94% rename from test/orm/instrumentation.py rename to test/orm/test_instrumentation.py index fd15420d0a..b4c8f8601c 100644 --- a/test/orm/instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -1,11 +1,13 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa -from testlib.sa import MetaData, Table, Column, Integer, ForeignKey, util -from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers -from testlib.testing import eq_, ne_ -from testlib.compat import _function_named -from orm import _base +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy import MetaData, Integer, ForeignKey, util +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers +from sqlalchemy.test.testing import eq_, ne_ +from sqlalchemy.util import function_named +from test.orm import _base def modifies_instrumentation_finders(fn): @@ -16,7 +18,7 @@ def modifies_instrumentation_finders(fn): finally: del attributes.instrumentation_finders[:] attributes.instrumentation_finders.extend(pristine) - return _function_named(decorated, fn.func_name) + return function_named(decorated, fn.func_name) def with_lookup_strategy(strategy): def decorate(fn): @@ -26,7 +28,7 @@ def with_lookup_strategy(strategy): return fn(*args, **kw) finally: attributes._install_lookup_strategy(sa.util.symbol('native')) - return _function_named(wrapped, fn.func_name) + return function_named(wrapped, fn.func_name) return decorate @@ -459,10 +461,10 @@ class MapperInitTest(_base.ORMTest): m = mapper(A, self.fixture()) # B is not mapped in the current implementation - self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, B) + assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, B) # C is not mapped in the current implementation - self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, C) + assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, C) class InstrumentationCollisionTest(_base.ORMTest): def test_none(self): @@ -486,7 +488,7 @@ class InstrumentationCollisionTest(_base.ORMTest): class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) - self.assertRaises(TypeError, attributes.register_class, B) + assert_raises(TypeError, attributes.register_class, B) def test_single_up(self): @@ -497,7 +499,7 @@ class InstrumentationCollisionTest(_base.ORMTest): class B(A): __sa_instrumentation_manager__ = staticmethod(mgr_factory) attributes.register_class(B) - self.assertRaises(TypeError, attributes.register_class, A) + assert_raises(TypeError, attributes.register_class, A) def test_diamond_b1(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -508,7 +510,7 @@ class InstrumentationCollisionTest(_base.ORMTest): __sa_instrumentation_manager__ = mgr_factory class C(object): pass - self.assertRaises(TypeError, attributes.register_class, B1) + assert_raises(TypeError, attributes.register_class, B1) def test_diamond_b2(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -519,7 +521,7 @@ class InstrumentationCollisionTest(_base.ORMTest): __sa_instrumentation_manager__ = mgr_factory class C(object): pass - self.assertRaises(TypeError, attributes.register_class, B2) + assert_raises(TypeError, attributes.register_class, B2) def test_diamond_c_b(self): mgr_factory = lambda cls: attributes.ClassManager(cls) @@ -531,7 +533,7 @@ class InstrumentationCollisionTest(_base.ORMTest): class C(object): pass attributes.register_class(C) - self.assertRaises(TypeError, attributes.register_class, B1) + assert_raises(TypeError, attributes.register_class, B1) class OnLoadTest(_base.ORMTest): @@ -557,7 +559,8 @@ class OnLoadTest(_base.ORMTest): finally: del A - def tearDownAll(self): + @classmethod + def teardown_class(cls): clear_mappers() attributes._install_lookup_strategy(util.symbol('native')) @@ -593,7 +596,7 @@ class NativeInstrumentationTest(_base.ORMTest): sa = attributes.ClassManager.STATE_ATTR ma = attributes.ClassManager.MANAGER_ATTR - fails = lambda method, attr: self.assertRaises( + fails = lambda method, attr: assert_raises( KeyError, getattr(manager, method), attr, property()) fails('install_member', sa) @@ -609,7 +612,7 @@ class NativeInstrumentationTest(_base.ORMTest): class T(object): pass - self.assertRaises(KeyError, mapper, T, t) + assert_raises(KeyError, mapper, T, t) @with_lookup_strategy(sa.util.symbol('native')) def test_mapped_managerattr(self): @@ -618,7 +621,7 @@ class NativeInstrumentationTest(_base.ORMTest): Column(attributes.ClassManager.MANAGER_ATTR, Integer)) class T(object): pass - self.assertRaises(KeyError, mapper, T, t) + assert_raises(KeyError, mapper, T, t) class MiscTest(_base.ORMTest): @@ -761,5 +764,3 @@ class FinderTest(_base.ORMTest): eq_(type(attributes.manager_of_class(A)), attributes.ClassManager) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/lazy_relations.py b/test/orm/test_lazy_relations.py similarity index 96% rename from test/orm/lazy_relations.py rename to test/orm/test_lazy_relations.py index b5c3b3669e..819f29911e 100644 --- a/test/orm/lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -1,14 +1,17 @@ """basic tests of lazy loaded attributes""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import assert_raises, assert_raises_message import datetime from sqlalchemy import exc as sa_exc from sqlalchemy.orm import attributes -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from testlib.testing import eq_ -from orm import _base, _fixtures +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base, _fixtures class LazyTest(_fixtures.FixtureTest): @@ -35,7 +38,7 @@ class LazyTest(_fixtures.FixtureTest): q = sess.query(User) u = q.filter(users.c.id == 7).first() sess.expunge(u) - self.assertRaises(sa_exc.InvalidRequestError, getattr, u, 'addresses') + assert_raises(sa_exc.InvalidRequestError, getattr, u, 'addresses') @testing.resolve_artifact_names def test_orderby(self): @@ -363,6 +366,7 @@ class M2OGetTest(_fixtures.FixtureTest): class CorrelatedTest(_base.MappedTest): + @classmethod def define_tables(self, meta): Table('user_t', meta, Column('id', Integer, primary_key=True), @@ -373,8 +377,9 @@ class CorrelatedTest(_base.MappedTest): Column('date', sa.Date), Column('user_id', Integer, ForeignKey('user_t.id'))) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): user_t.insert().execute( {'id':1, 'name':'user1'}, {'id':2, 'name':'user2'}, @@ -412,5 +417,3 @@ class CorrelatedTest(_base.MappedTest): ]) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/lazytest1.py b/test/orm/test_lazytest1.py similarity index 88% rename from test/orm/lazytest1.py rename to test/orm/test_lazytest1.py index 5ebb8feeba..f76cb32035 100644 --- a/test/orm/lazytest1.py +++ b/test/orm/test_lazytest1.py @@ -1,12 +1,15 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base class LazyTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('infos', metadata, Column('pk', Integer, primary_key=True), Column('info', String(128))) @@ -25,8 +28,9 @@ class LazyTest(_base.MappedTest): Column('start', Integer), Column('finish', Integer)) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): infos.insert().execute( {'pk':1, 'info':'pk_1_info'}, {'pk':2, 'info':'pk_2_info'}, @@ -86,5 +90,3 @@ class LazyTest(_base.MappedTest): assert len(info.rels[0].datas) == 3 -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/manytomany.py b/test/orm/test_manytomany.py similarity index 93% rename from test/orm/manytomany.py rename to test/orm/test_manytomany.py index 23af3bd1f8..dcd547f80c 100644 --- a/test/orm/manytomany.py +++ b/test/orm/test_manytomany.py @@ -1,12 +1,16 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base class M2MTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('place', metadata, Column('place_id', Integer, sa.Sequence('pid_seq', optional=True), primary_key=True), @@ -40,7 +44,8 @@ class M2MTest(_base.MappedTest): Column('pl1_id', Integer, ForeignKey('place.place_id')), Column('pl2_id', Integer, ForeignKey('place.place_id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Place(_base.BasicEntity): def __init__(self, name=None): self.name = name @@ -70,7 +75,7 @@ class M2MTest(_base.MappedTest): mapper(Transition, transition, properties={ 'places':relation(Place, secondary=place_input, backref='transitions') }) - self.assertRaisesMessage(sa.exc.ArgumentError, "Error creating backref", + assert_raises_message(sa.exc.ArgumentError, "Error creating backref", sa.orm.compile_mappers) @testing.resolve_artifact_names @@ -187,7 +192,8 @@ class M2MTest(_base.MappedTest): class M2MTest2(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('student', metadata, Column('name', String(20), primary_key=True)) @@ -200,7 +206,8 @@ class M2MTest2(_base.MappedTest): Column('course_id', String(20), ForeignKey('course.name'), primary_key=True)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Student(_base.BasicEntity): def __init__(self, name=''): self.name = name @@ -248,7 +255,7 @@ class M2MTest2(_base.MappedTest): s1.courses.append(c1) s1.courses.append(c1) sess.add(s1) - self.assertRaises(sa.exc.DBAPIError, sess.flush) + assert_raises(sa.exc.DBAPIError, sess.flush) @testing.resolve_artifact_names def test_delete(self): @@ -274,7 +281,8 @@ class M2MTest2(_base.MappedTest): assert enroll.count().scalar() == 0 class M2MTest3(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('c', metadata, Column('c1', Integer, primary_key = True), Column('c2', String(20))) @@ -320,5 +328,3 @@ class M2MTest3(_base.MappedTest): # how about some data/inserts/queries/assertions for this one -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/mapper.py b/test/orm/test_mapper.py similarity index 97% rename from test/orm/mapper.py rename to test/orm/test_mapper.py index 13e02a38a0..025b96424d 100644 --- a/test/orm/mapper.py +++ b/test/orm/test_mapper.py @@ -1,14 +1,16 @@ """General mapper operations with an emphasis on selecting/loading.""" -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, func -from testlib.sa.engine import default -from testlib.sa.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased -from testlib.sa.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property -from testlib.testing import eq_, AssertsCompiledSQL -import pickleable -from orm import _base, _fixtures +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing, pickleable +from sqlalchemy import MetaData, Integer, String, ForeignKey, func +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.engine import default +from sqlalchemy.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased +from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property +from sqlalchemy.test.testing import eq_, AssertsCompiledSQL +from test.orm import _base, _fixtures class MapperTest(_fixtures.FixtureTest): @@ -22,7 +24,7 @@ class MapperTest(_fixtures.FixtureTest): properties={ 'addresses':relation(Address, backref='email_address') }) - self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers) + assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers) @testing.resolve_artifact_names def test_update_attr_keys(self): @@ -74,14 +76,14 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_prop_accessor(self): mapper(User, users) - self.assertRaises(NotImplementedError, + assert_raises(NotImplementedError, getattr, sa.orm.class_mapper(User), 'properties') @testing.resolve_artifact_names def test_bad_cascade(self): mapper(Address, addresses) - self.assertRaises(sa.exc.ArgumentError, + assert_raises(sa.exc.ArgumentError, relation, Address, cascade="fake, all, delete-orphan") @testing.resolve_artifact_names @@ -93,7 +95,7 @@ class MapperTest(_fixtures.FixtureTest): }) hasattr(Address.user, 'property') - self.assertRaisesMessage(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers) + assert_raises_message(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers) @testing.resolve_artifact_names def test_column_prefix(self): @@ -111,7 +113,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_no_pks_1(self): s = sa.select([users.c.name]).alias('foo') - self.assertRaises(sa.exc.ArgumentError, mapper, User, s) + assert_raises(sa.exc.ArgumentError, mapper, User, s) @testing.emits_warning( 'mapper Mapper|User|Select object creating an alias for ' @@ -119,7 +121,7 @@ class MapperTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_no_pks_2(self): s = sa.select([users.c.name]) - self.assertRaises(sa.exc.ArgumentError, mapper, User, s) + assert_raises(sa.exc.ArgumentError, mapper, User, s) @testing.resolve_artifact_names def test_recompile_on_other_mapper(self): @@ -167,9 +169,9 @@ class MapperTest(_fixtures.FixtureTest): create_session).extension) sess = create_session() - self.assertRaises(TypeError, Foo, 'one', _sa_session=sess) + assert_raises(TypeError, Foo, 'one', _sa_session=sess) eq_(len(list(sess)), 0) - self.assertRaises(TypeError, Foo, 'one') + assert_raises(TypeError, Foo, 'one') Foo('one', 'two', _sa_session=sess) eq_(len(list(sess)), 1) @@ -197,7 +199,7 @@ class MapperTest(_fixtures.FixtureTest): raise Exception("this exception should be stated as a warning") sess.expunge = bad_expunge - self.assertRaises(sa.exc.SAWarning, Foo, _sa_session=sess) + assert_raises(sa.exc.SAWarning, Foo, _sa_session=sess) @testing.resolve_artifact_names def test_constructor_exc_2(self): @@ -211,8 +213,8 @@ class MapperTest(_fixtures.FixtureTest): mapper(Foo, users) mapper(Bar, addresses) - self.assertRaises(TypeError, Foo, x=5) - self.assertRaises(TypeError, Bar, x=5) + assert_raises(TypeError, Foo, x=5) + assert_raises(TypeError, Bar, x=5) @testing.resolve_artifact_names def test_props(self): @@ -499,7 +501,7 @@ class MapperTest(_fixtures.FixtureTest): # excluding the discriminator column is currently not allowed class Foo(Person): pass - self.assertRaises(sa.exc.InvalidRequestError, mapper, Foo, inherits=Person, polymorphic_identity='foo', exclude_properties=('type',) ) + assert_raises(sa.exc.InvalidRequestError, mapper, Foo, inherits=Person, polymorphic_identity='foo', exclude_properties=('type',) ) @testing.resolve_artifact_names def test_mapping_to_join(self): @@ -643,7 +645,7 @@ class MapperTest(_fixtures.FixtureTest): properties=dict( name=relation(mapper(Address, addresses)))) - self.assertRaises(sa.exc.ArgumentError, go) + assert_raises(sa.exc.ArgumentError, go) @testing.resolve_artifact_names def test_override_2(self): @@ -739,7 +741,7 @@ class MapperTest(_fixtures.FixtureTest): mapper(User, users, properties={ 'not_name':synonym('_name', map_column=True)}) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, ("Can't compile synonym '_name': no column on table " "'users' named 'not_name'"), @@ -834,7 +836,7 @@ class MapperTest(_fixtures.FixtureTest): eq_(User.uc_name.method1(), "method1") eq_(User.uc_name.method2('x'), "method2") - self.assertRaisesMessage( + assert_raises_message( AttributeError, "Neither 'extendedproperty' object nor 'UCComparator' object has an attribute 'nonexistent'", getattr, User.uc_name, 'nonexistent') @@ -879,7 +881,7 @@ class MapperTest(_fixtures.FixtureTest): 'name':sa.orm.column_property(users.c.name, comparator_factory=MyComparator) }) - self.assertRaisesMessage( + assert_raises_message( AttributeError, "Neither 'InstrumentedAttribute' object nor 'MyComparator' object has an attribute 'nonexistent'", getattr, User.name, "nonexistent") @@ -966,7 +968,7 @@ class MapperTest(_fixtures.FixtureTest): 'addresses':relation(Address) }) - self.assertRaises(sa.orm.exc.UnmappedClassError, sa.orm.compile_mappers) + assert_raises(sa.orm.exc.UnmappedClassError, sa.orm.compile_mappers) @testing.resolve_artifact_names def test_oldstyle_mixin(self): @@ -1148,8 +1150,9 @@ class OptionsTest(_fixtures.FixtureTest): class DeepOptionsTest(_fixtures.FixtureTest): + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Keyword, keywords) mapper(Item, items, properties=dict( @@ -1204,7 +1207,7 @@ class DeepOptionsTest(_fixtures.FixtureTest): def test_deep_options_4(self): sess = create_session() - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, r"Can't find entity Mapper\|Order\|orders in Query. " r"Current list: \['Mapper\|User\|users'\]", @@ -1232,7 +1235,7 @@ class ValidatorTest(_fixtures.FixtureTest): sess = create_session() u1 = User(name='ed') eq_(u1.name, 'ed modified') - self.assertRaises(AssertionError, setattr, u1, "name", "fred") + assert_raises(AssertionError, setattr, u1, "name", "fred") eq_(u1.name, 'ed modified') sess.add(u1) sess.flush() @@ -1252,7 +1255,7 @@ class ValidatorTest(_fixtures.FixtureTest): mapper(Address, addresses) sess = create_session() u1 = User(name='edward') - self.assertRaises(AssertionError, u1.addresses.append, Address(email_address='noemail')) + assert_raises(AssertionError, u1.addresses.append, Address(email_address='noemail')) u1.addresses.append(Address(id=15, email_address='foo@bar.com')) sess.add(u1) sess.flush() @@ -1629,7 +1632,8 @@ class DeferredTest(_fixtures.FixtureTest): eq_(item.description, 'item 4') class DeferredPopulationTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("thing", metadata, Column("id", Integer, primary_key=True), Column("name", String(20))) @@ -1639,16 +1643,18 @@ class DeferredPopulationTest(_base.MappedTest): Column("thing_id", Integer, ForeignKey("thing.id")), Column("name", String(20))) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Human(_base.BasicEntity): pass class Thing(_base.BasicEntity): pass mapper(Human, human, properties={"thing": relation(Thing)}) mapper(Thing, thing, properties={"name": deferred(thing.c.name)}) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): thing.insert().execute([ {"id": 1, "name": "Chair"}, ]) @@ -1714,7 +1720,8 @@ class DeferredPopulationTest(_base.MappedTest): class CompositeTypesTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('graphs', metadata, Column('id', Integer, primary_key=True), Column('version_id', Integer, primary_key=True, nullable=True), @@ -2246,7 +2253,8 @@ class MapperExtensionTest(_fixtures.FixtureTest): class RequirementsTest(_base.MappedTest): """Tests the contract for user classes.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('ht1', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -2280,9 +2288,9 @@ class RequirementsTest(_base.MappedTest): class OldStyle: pass - self.assertRaises(sa.exc.ArgumentError, mapper, OldStyle, ht1) + assert_raises(sa.exc.ArgumentError, mapper, OldStyle, ht1) - self.assertRaises(sa.exc.ArgumentError, mapper, 123) + assert_raises(sa.exc.ArgumentError, mapper, 123) class NoWeakrefSupport(str): pass @@ -2388,7 +2396,8 @@ class RequirementsTest(_base.MappedTest): class MagicNamesTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('cartographers', metadata, Column('id', Integer, primary_key=True), Column('name', String(50)), @@ -2401,7 +2410,8 @@ class MagicNamesTest(_base.MappedTest): Column('state', String(2)), Column('data', sa.Text)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Cartographer(_base.BasicEntity): pass @@ -2441,7 +2451,7 @@ class MagicNamesTest(_base.MappedTest): class T(object): pass - self.assertRaisesMessage( + assert_raises_message( KeyError, ('%r: requested attribute name conflicts with ' 'instrumentation attribute of the same name.' % reserved), @@ -2454,7 +2464,7 @@ class MagicNamesTest(_base.MappedTest): class M(object): pass - self.assertRaisesMessage( + assert_raises_message( KeyError, ('requested attribute name conflicts with ' 'instrumentation attribute of the same name'), @@ -2463,5 +2473,3 @@ class MagicNamesTest(_base.MappedTest): -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/merge.py b/test/orm/test_merge.py similarity index 98% rename from test/orm/merge.py rename to test/orm/test_merge.py index fd553f2bf7..70097cbee2 100644 --- a/test/orm/merge.py +++ b/test/orm/test_merge.py @@ -1,9 +1,10 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa.util import OrderedSet -from testlib.sa.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property -from testlib.testing import eq_, ne_ -from orm import _base, _fixtures +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy.util import OrderedSet +from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property +from sqlalchemy.test.testing import eq_, ne_ +from test.orm import _base, _fixtures class MergeTest(_fixtures.FixtureTest): @@ -447,7 +448,7 @@ class MergeTest(_fixtures.FixtureTest): sess = create_session() u = User() - self.assertRaisesMessage(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True) + assert_raises_message(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True) @testing.resolve_artifact_names @@ -732,5 +733,3 @@ class MergeTest(_fixtures.FixtureTest): -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/naturalpks.py b/test/orm/test_naturalpks.py similarity index 84% rename from test/orm/naturalpks.py rename to test/orm/test_naturalpks.py index 8efce660c3..1376c402e7 100644 --- a/test/orm/naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -2,16 +2,20 @@ Primary key changing capabilities and passive/non-passive cascading updates. """ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from testlib.testing import eq_ -from orm import _base +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base class NaturalPKTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): users = Table('users', metadata, Column('username', String(50), primary_key=True), Column('fullname', String(100)), @@ -32,7 +36,8 @@ class NaturalPKTest(_base.MappedTest): Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True), test_needs_fk=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass class Address(_base.ComparableEntity): @@ -62,7 +67,7 @@ class NaturalPKTest(_base.MappedTest): sess.expunge_all() u1 = sess.query(User).get('ed') - self.assertEquals(User(username='ed', fullname='jack'), u1) + eq_(User(username='ed', fullname='jack'), u1) @testing.resolve_artifact_names def test_load_after_expire(self): @@ -81,7 +86,7 @@ class NaturalPKTest(_base.MappedTest): # in this case so theres no way to look it up. criterion- # based session invalidation could solve this [ticket:911] sess.expire(u1) - self.assertRaises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username') + assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username') sess.expunge_all() assert sess.query(User).get('jack') is None @@ -132,7 +137,7 @@ class NaturalPKTest(_base.MappedTest): assert u1.addresses[0].username == 'ed' sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) u1 = sess.query(User).get('ed') u1.username = 'jack' @@ -152,7 +157,7 @@ class NaturalPKTest(_base.MappedTest): sess.expunge_all() assert sess.query(Address).get('jack1').username is None u1 = sess.query(User).get('fred') - self.assertEquals(User(username='fred', fullname='jack'), u1) + eq_(User(username='fred', fullname='jack'), u1) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') @@ -195,7 +200,7 @@ class NaturalPKTest(_base.MappedTest): assert a1.username == a2.username == 'ed' sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') def test_onetoone_passive(self): @@ -236,7 +241,7 @@ class NaturalPKTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 0) sess.expunge_all() - self.assertEquals([Address(username='ed')], sess.query(Address).all()) + eq_([Address(username='ed')], sess.query(Address).all()) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') def test_bidirectional_passive(self): @@ -265,7 +270,7 @@ class NaturalPKTest(_base.MappedTest): u1.username = 'ed' (ad1, ad2) = sess.query(Address).all() - self.assertEquals([Address(username='jack'), Address(username='jack')], [ad1, ad2]) + eq_([Address(username='jack'), Address(username='jack')], [ad1, ad2]) def go(): sess.flush() if passive_updates: @@ -273,9 +278,9 @@ class NaturalPKTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 1) else: self.assert_sql_count(testing.db, go, 3) - self.assertEquals([Address(username='ed'), Address(username='ed')], [ad1, ad2]) + eq_([Address(username='ed'), Address(username='ed')], [ad1, ad2]) sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) u1 = sess.query(User).get('ed') assert len(u1.addresses) == 2 # load addresses @@ -289,7 +294,7 @@ class NaturalPKTest(_base.MappedTest): else: self.assert_sql_count(testing.db, go, 3) sess.expunge_all() - self.assertEquals([Address(username='fred'), Address(username='fred')], sess.query(Address).all()) + eq_([Address(username='fred'), Address(username='fred')], sess.query(Address).all()) @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE') @@ -323,10 +328,10 @@ class NaturalPKTest(_base.MappedTest): r = sess.query(Item).all() # ComparableEntity can't handle a comparison with the backrefs # involved.... - self.assertEquals(Item(itemname='item1'), r[0]) - self.assertEquals(['jack'], [u.username for u in r[0].users]) - self.assertEquals(Item(itemname='item2'), r[1]) - self.assertEquals(['jack', 'fred'], [u.username for u in r[1].users]) + eq_(Item(itemname='item1'), r[0]) + eq_(['jack'], [u.username for u in r[0].users]) + eq_(Item(itemname='item2'), r[1]) + eq_(['jack', 'fred'], [u.username for u in r[1].users]) u2.username='ed' def go(): @@ -338,29 +343,31 @@ class NaturalPKTest(_base.MappedTest): sess.expunge_all() r = sess.query(Item).all() - self.assertEquals(Item(itemname='item1'), r[0]) - self.assertEquals(['jack'], [u.username for u in r[0].users]) - self.assertEquals(Item(itemname='item2'), r[1]) - self.assertEquals(['ed', 'jack'], sorted([u.username for u in r[1].users])) + eq_(Item(itemname='item1'), r[0]) + eq_(['jack'], [u.username for u in r[0].users]) + eq_(Item(itemname='item2'), r[1]) + eq_(['ed', 'jack'], sorted([u.username for u in r[1].users])) sess.expunge_all() u2 = sess.query(User).get(u2.username) u2.username='wendy' sess.flush() r = sess.query(Item).with_parent(u2).all() - self.assertEquals(Item(itemname='item2'), r[0]) + eq_(Item(itemname='item2'), r[0]) class SelfRefTest(_base.MappedTest): __unsupported_on__ = 'mssql' # mssql doesn't allow ON UPDATE on self-referential keys - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('nodes', metadata, Column('name', String(50), primary_key=True), Column('parent', String(50), ForeignKey('nodes.name', onupdate='cascade'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Node(_base.ComparableEntity): pass @@ -391,7 +398,8 @@ class SelfRefTest(_base.MappedTest): class NonPKCascadeTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('username', String(50), unique=True), @@ -406,7 +414,8 @@ class NonPKCascadeTest(_base.MappedTest): test_needs_fk=True ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass class Address(_base.ComparableEntity): @@ -433,17 +442,17 @@ class NonPKCascadeTest(_base.MappedTest): sess.flush() a1 = u1.addresses[0] - self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)]) + eq_(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)]) assert sess.query(Address).get(a1.id) is u1.addresses[0] u1.username = 'ed' sess.flush() assert u1.addresses[0].username == 'ed' - self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)]) + eq_(sa.select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)]) sess.expunge_all() - self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) + eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all()) u1 = sess.query(User).get(u1.id) u1.username = 'jack' @@ -463,13 +472,11 @@ class NonPKCascadeTest(_base.MappedTest): sess.flush() sess.expunge_all() a1 = sess.query(Address).get(a1.id) - self.assertEquals(a1.username, None) + eq_(a1.username, None) - self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [(None,), (None,)]) + eq_(sa.select([addresses.c.username]).execute().fetchall(), [(None,), (None,)]) u1 = sess.query(User).get(u1.id) - self.assertEquals(User(username='fred', fullname='jack'), u1) + eq_(User(username='fred', fullname='jack'), u1) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/onetoone.py b/test/orm/test_onetoone.py similarity index 81% rename from test/orm/onetoone.py rename to test/orm/test_onetoone.py index be0375e48b..0d66915ea5 100644 --- a/test/orm/onetoone.py +++ b/test/orm/test_onetoone.py @@ -1,12 +1,15 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session -from orm import _base +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session +from test.orm import _base class O2OTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('jack', metadata, Column('id', Integer, primary_key=True), Column('number', String(50)), @@ -19,8 +22,9 @@ class O2OTest(_base.MappedTest): Column('description', String(100)), Column('jack_id', Integer, ForeignKey("jack.id"))) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Jack(_base.BasicEntity): pass class Port(_base.BasicEntity): @@ -70,5 +74,3 @@ class O2OTest(_base.MappedTest): session.delete(j) session.flush() -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/pickled.py b/test/orm/test_pickled.py similarity index 80% rename from test/orm/pickled.py rename to test/orm/test_pickled.py index 878fe931e3..5343cc15b9 100644 --- a/test/orm/pickled.py +++ b/test/orm/test_pickled.py @@ -1,9 +1,12 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import pickle -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, create_session, attributes -from orm import _base, _fixtures +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, attributes +from test.orm import _base, _fixtures User, EmailUser = None, None @@ -28,7 +31,7 @@ class PickleTest(_fixtures.FixtureTest): sess.expunge_all() - self.assertEquals(u1, sess.query(User).get(u2.id)) + eq_(u1, sess.query(User).get(u2.id)) @testing.resolve_artifact_names def test_class_deferred_cols(self): @@ -52,14 +55,14 @@ class PickleTest(_fixtures.FixtureTest): u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() sess2.add(u2) - self.assertEquals(u2.name, 'ed') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + eq_(u2.name, 'ed') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() u2 = sess2.merge(u2, dont_load=True) - self.assertEquals(u2.name, 'ed') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + eq_(u2.name, 'ed') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) @testing.resolve_artifact_names def test_instance_deferred_cols(self): @@ -82,22 +85,22 @@ class PickleTest(_fixtures.FixtureTest): u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() sess2.add(u2) - self.assertEquals(u2.name, 'ed') + eq_(u2.name, 'ed') assert 'addresses' not in u2.__dict__ ad = u2.addresses[0] assert 'email_address' not in ad.__dict__ - self.assertEquals(ad.email_address, 'ed@bar.com') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + eq_(ad.email_address, 'ed@bar.com') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) u2 = pickle.loads(pickle.dumps(u1)) sess2 = create_session() u2 = sess2.merge(u2, dont_load=True) - self.assertEquals(u2.name, 'ed') + eq_(u2.name, 'ed') assert 'addresses' not in u2.__dict__ ad = u2.addresses[0] assert 'email_address' in ad.__dict__ # mapper options dont transmit over merge() right now - self.assertEquals(ad.email_address, 'ed@bar.com') - self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) + eq_(ad.email_address, 'ed@bar.com') + eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')])) @testing.resolve_artifact_names def test_options_with_descriptors(self): @@ -122,7 +125,7 @@ class PickleTest(_fixtures.FixtureTest): sa.orm.eagerload(["addresses", User.addresses]), ]: opt2 = pickle.loads(pickle.dumps(opt)) - self.assertEquals(opt.key, opt2.key) + eq_(opt.key, opt2.key) u1 = sess.query(User).options(opt).first() @@ -130,7 +133,8 @@ class PickleTest(_fixtures.FixtureTest): class PolymorphicDeferredTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(30)), @@ -139,7 +143,8 @@ class PolymorphicDeferredTest(_base.MappedTest): Column('id', Integer, ForeignKey('users.id'), primary_key=True), Column('email_address', String(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): global User, EmailUser class User(_base.BasicEntity): pass @@ -147,10 +152,11 @@ class PolymorphicDeferredTest(_base.MappedTest): class EmailUser(User): pass - def tearDownAll(self): + @classmethod + def teardown_class(cls): global User, EmailUser User, EmailUser = None, None - _base.MappedTest.tearDownAll(self) + super(PolymorphicDeferredTest, cls).teardown_class() @testing.resolve_artifact_names def test_polymorphic_deferred(self): @@ -168,9 +174,9 @@ class PolymorphicDeferredTest(_base.MappedTest): sess2 = create_session() sess2.add(eu2) assert 'email_address' not in eu2.__dict__ - self.assertEquals(eu2.email_address, 'foo@bar.com') + eq_(eu2.email_address, 'foo@bar.com') -class CustomSetupTeardowntest(_fixtures.FixtureTest): +class CustomSetupTeardownTest(_fixtures.FixtureTest): @testing.resolve_artifact_names def test_rebuild_state(self): """not much of a 'test', but illustrate how to @@ -186,5 +192,3 @@ class CustomSetupTeardowntest(_fixtures.FixtureTest): attributes.manager_of_class(User).setup_instance(u2) assert attributes.instance_state(u2) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/query.py b/test/orm/test_query.py similarity index 88% rename from test/orm/query.py rename to test/orm/test_query.py index 33c3e39d71..66c219b10c 100644 --- a/test/orm/query.py +++ b/test/orm/test_query.py @@ -1,4 +1,4 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import operator from sqlalchemy import * from sqlalchemy import exc as sa_exc, util @@ -7,17 +7,18 @@ from sqlalchemy.engine import default from sqlalchemy.orm import * from sqlalchemy.orm import attributes -from testlib.testing import eq_ +from sqlalchemy.test.testing import eq_ -from testlib import sa, testing, AssertsCompiledSQL, Column, engines +import sqlalchemy as sa +from sqlalchemy.test import testing, AssertsCompiledSQL, Column, engines -from orm import _fixtures -from orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \ +from test.orm import _fixtures +from test.orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \ Dingaling, item_keywords, dingalings, User, items,\ orders, Address, users, nodes, \ order_items, Item, Order, Node -from orm import _base +from test.orm import _base from sqlalchemy.orm.util import join, outerjoin, with_parent @@ -27,7 +28,8 @@ class QueryTest(_fixtures.FixtureTest): run_deletes = None - def setup_mappers(self): + @classmethod + def setup_mappers(cls): mapper(User, users, properties={ 'addresses':relation(Address, backref='user', order_by=addresses.c.id), 'orders':relation(Order, backref='user', order_by=orders.c.id), # o2m, m2o @@ -82,8 +84,8 @@ class GetTest(QueryTest): s = create_session() q = s.query(User).join('addresses').filter(Address.user_id==8) - self.assertRaises(sa_exc.InvalidRequestError, q.get, 7) - self.assertRaises(sa_exc.InvalidRequestError, s.query(User).filter(User.id==7).get, 19) + assert_raises(sa_exc.InvalidRequestError, q.get, 7) + assert_raises(sa_exc.InvalidRequestError, s.query(User).filter(User.id==7).get, 19) # order_by()/get() doesn't raise s.query(User).order_by(User.id).get(8) @@ -142,7 +144,7 @@ class GetTest(QueryTest): class LocalFoo(Base): pass mapper(LocalFoo, table) - self.assertEquals(create_session().query(LocalFoo).get(ustring), + eq_(create_session().query(LocalFoo).get(ustring), LocalFoo(id=ustring, data=ustring)) finally: metadata.drop_all() @@ -183,7 +185,7 @@ class GetTest(QueryTest): def test_query_str(self): s = create_session() q = s.query(User).filter(User.id==1) - self.assertEquals( + eq_( str(q).replace('\n',''), 'SELECT users.id AS users_id, users.name AS users_name FROM users WHERE users.id = ?' ) @@ -197,29 +199,29 @@ class InvalidGenerationsTest(QueryTest): s.query(User).offset(2), s.query(User).limit(2).offset(2) ): - self.assertRaises(sa_exc.InvalidRequestError, q.join, "addresses") + assert_raises(sa_exc.InvalidRequestError, q.join, "addresses") - self.assertRaises(sa_exc.InvalidRequestError, q.filter, User.name=='ed') + assert_raises(sa_exc.InvalidRequestError, q.filter, User.name=='ed') - self.assertRaises(sa_exc.InvalidRequestError, q.filter_by, name='ed') + assert_raises(sa_exc.InvalidRequestError, q.filter_by, name='ed') - self.assertRaises(sa_exc.InvalidRequestError, q.order_by, 'foo') + assert_raises(sa_exc.InvalidRequestError, q.order_by, 'foo') - self.assertRaises(sa_exc.InvalidRequestError, q.group_by, 'foo') + assert_raises(sa_exc.InvalidRequestError, q.group_by, 'foo') - self.assertRaises(sa_exc.InvalidRequestError, q.having, 'foo') + assert_raises(sa_exc.InvalidRequestError, q.having, 'foo') def test_no_from(self): s = create_session() q = s.query(User).select_from(users) - self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users) + assert_raises(sa_exc.InvalidRequestError, q.select_from, users) q = s.query(User).join('addresses') - self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users) + assert_raises(sa_exc.InvalidRequestError, q.select_from, users) q = s.query(User).order_by(User.id) - self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users) + assert_raises(sa_exc.InvalidRequestError, q.select_from, users) # this is fine, however q.from_self() @@ -227,43 +229,43 @@ class InvalidGenerationsTest(QueryTest): def test_invalid_select_from(self): s = create_session() q = s.query(User) - self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id==5) - self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id) + assert_raises(sa_exc.ArgumentError, q.select_from, User.id==5) + assert_raises(sa_exc.ArgumentError, q.select_from, User.id) def test_invalid_from_statement(self): s = create_session() q = s.query(User) - self.assertRaises(sa_exc.ArgumentError, q.from_statement, User.id==5) - self.assertRaises(sa_exc.ArgumentError, q.from_statement, users.join(addresses)) + assert_raises(sa_exc.ArgumentError, q.from_statement, User.id==5) + assert_raises(sa_exc.ArgumentError, q.from_statement, users.join(addresses)) def test_invalid_column(self): s = create_session() q = s.query(User) - self.assertRaises(sa_exc.InvalidRequestError, q.add_column, object()) + assert_raises(sa_exc.InvalidRequestError, q.add_column, object()) def test_mapper_zero(self): s = create_session() q = s.query(User, Address) - self.assertRaises(sa_exc.InvalidRequestError, q.get, 5) + assert_raises(sa_exc.InvalidRequestError, q.get, 5) def test_from_statement(self): s = create_session() q = s.query(User).filter(User.id==5) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") q = s.query(User).filter_by(id=5) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") q = s.query(User).limit(5) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") q = s.query(User).group_by(User.name) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") q = s.query(User).order_by(User.name) - self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x") + assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x") class OperatorTest(QueryTest, AssertsCompiledSQL): """test sql.Comparator implementation for MapperProperties""" @@ -431,7 +433,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): "users.id IN (:id_1, :id_2)") def test_in_on_relation_not_supported(self): - self.assertRaises(NotImplementedError, Address.user.in_, [User(id=5)]) + assert_raises(NotImplementedError, Address.user.in_, [User(id=5)]) def test_between(self): self._test(User.id.between('a', 'b'), @@ -705,8 +707,8 @@ class FilterTest(QueryTest): assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all() # m2m - self.assertEquals(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)]) - self.assertEquals(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)]) + eq_(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)]) + eq_(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)]) def test_filter_by(self): sess = create_session() @@ -723,16 +725,16 @@ class FilterTest(QueryTest): sess = create_session() # o2o - self.assertEquals([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all()) - self.assertEquals([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all()) + eq_([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all()) + eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all()) # m2o - self.assertEquals([Order(id=5)], sess.query(Order).filter(Order.address==None).all()) - self.assertEquals([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).order_by(Order.id).filter(Order.address!=None).all()) + eq_([Order(id=5)], sess.query(Order).filter(Order.address==None).all()) + eq_([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).order_by(Order.id).filter(Order.address!=None).all()) # o2m - self.assertEquals([User(id=10)], sess.query(User).filter(User.addresses==None).all()) - self.assertEquals([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).order_by(User.id).all()) + eq_([User(id=10)], sess.query(User).filter(User.addresses==None).all()) + eq_([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).order_by(User.id).all()) class FromSelfTest(QueryTest, AssertsCompiledSQL): @@ -818,7 +820,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): def test_multiple_entities(self): sess = create_session() - self.assertEquals( + eq_( sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().all(), [ (User(id=8), Address(id=2)), @@ -826,7 +828,7 @@ class FromSelfTest(QueryTest, AssertsCompiledSQL): ] ) - self.assertEquals( + eq_( sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().options(eagerload('addresses')).first(), # order_by(User.id, Address.id).first(), @@ -842,11 +844,11 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): ed = s.query(User).filter(User.name=='ed') jack = s.query(User).filter(User.name=='jack') - self.assertEquals(fred.union(ed).order_by(User.name).all(), + eq_(fred.union(ed).order_by(User.name).all(), [User(name='ed'), User(name='fred')] ) - self.assertEquals(fred.union(ed, jack).order_by(User.name).all(), + eq_(fred.union(ed, jack).order_by(User.name).all(), [User(name='ed'), User(name='fred'), User(name='jack')] ) @@ -857,11 +859,11 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): fred = s.query(User).filter(User.name=='fred') ed = s.query(User).filter(User.name=='ed') jack = s.query(User).filter(User.name=='jack') - self.assertEquals(fred.intersect(ed, jack).all(), + eq_(fred.intersect(ed, jack).all(), [] ) - self.assertEquals(fred.union(ed).intersect(ed.union(jack)).all(), + eq_(fred.union(ed).intersect(ed.union(jack)).all(), [User(name='ed')] ) @@ -873,7 +875,7 @@ class SetOpsTest(QueryTest, AssertsCompiledSQL): jack = s.query(User).filter(User.name=='jack') def go(): - self.assertEquals( + eq_( fred.union(ed).order_by(User.name).options(eagerload(User.addresses)).all(), [ User(name='ed', addresses=[Address(), Address(), Address()]), @@ -888,8 +890,8 @@ class AggregateTest(QueryTest): def test_sum(self): sess = create_session() orders = sess.query(Order).filter(Order.id.in_([2, 3, 4])) - self.assertEquals(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,)) - self.assertEquals(orders.value(func.sum(Order.user_id * Order.address_id)), 79) + eq_(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,)) + eq_(orders.value(func.sum(Order.user_id * Order.address_id)), 79) def test_apply(self): sess = create_session() @@ -987,13 +989,13 @@ class YieldTest(QueryTest): q = iter(sess.query(User).yield_per(1).from_statement("select * from users")) ret = [] - self.assertEquals(len(sess.identity_map), 0) + eq_(len(sess.identity_map), 0) ret.append(q.next()) ret.append(q.next()) - self.assertEquals(len(sess.identity_map), 2) + eq_(len(sess.identity_map), 2) ret.append(q.next()) ret.append(q.next()) - self.assertEquals(len(sess.identity_map), 4) + eq_(len(sess.identity_map), 4) try: q.next() assert False @@ -1019,7 +1021,7 @@ class TextTest(QueryTest): def test_as_column(self): s = create_session() - self.assertRaises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name")) + assert_raises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name")) eq_(s.query(User.id, "name").order_by(User.id).all(), [(7, u'jack'), (8, u'ed'), (9, u'fred'), (10, u'chuck')]) @@ -1091,24 +1093,24 @@ class JoinTest(QueryTest): sess = create_session() for oalias,ialias in [(True, True), (False, False), (True, False), (False, True)]: - self.assertEquals( + eq_( sess.query(User).join('orders', aliased=oalias).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description == 'item 4').all(), [User(name='jack')] ) # use middle criterion - self.assertEquals( + eq_( sess.query(User).join('orders', aliased=oalias).filter(Order.user_id==9).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description=='item 4').all(), [] ) orderalias = aliased(Order) itemalias = aliased(Item) - self.assertEquals( + eq_( sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(itemalias.description == 'item 4').all(), [User(name='jack')] ) - self.assertEquals( + eq_( sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(orderalias.user_id==9).filter(itemalias.description=='item 4').all(), [] ) @@ -1119,28 +1121,28 @@ class JoinTest(QueryTest): sess = create_session() - self.assertEquals( + eq_( sess.query(User).join(Address.user).filter(Address.email_address=='ed@wood.com').all(), [User(id=8,name=u'ed')] ) # its actually not so controversial if you view it in terms # of multiple entities. - self.assertEquals( + eq_( sess.query(User, Address).join(Address.user).filter(Address.email_address=='ed@wood.com').all(), [(User(id=8,name=u'ed'), Address(email_address='ed@wood.com'))] ) # this was the controversial part. now, raise an error if the feature is abused. # before the error raise was added, this would silently work..... - self.assertRaises( + assert_raises( sa_exc.InvalidRequestError, sess.query(User).join, (Address, Address.user), ) # but this one would silently fail adalias = aliased(Address) - self.assertRaises( + assert_raises( sa_exc.InvalidRequestError, sess.query(User).join, (adalias, Address.user), ) @@ -1153,7 +1155,7 @@ class JoinTest(QueryTest): oalias2 = aliased(Order) result = sess.query(ualias).join((oalias1, ualias.orders), (oalias2, ualias.orders)).\ filter(or_(oalias1.user_id==9, oalias2.user_id==7)).all() - self.assertEquals(result, [User(id=7,name=u'jack'), User(id=9,name=u'fred')]) + eq_(result, [User(id=7,name=u'jack'), User(id=9,name=u'fred')]) def test_orderby_arg_bug(self): sess = create_session() @@ -1163,17 +1165,17 @@ class JoinTest(QueryTest): def test_no_onclause(self): sess = create_session() - self.assertEquals( + eq_( sess.query(User).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(), [User(name='jack')] ) - self.assertEquals( + eq_( sess.query(User.name).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(), [('jack',)] ) - self.assertEquals( + eq_( sess.query(User).join(Order, (Item, Order.items)).filter(Item.description == 'item 4').all(), [User(name='jack')] ) @@ -1181,7 +1183,7 @@ class JoinTest(QueryTest): def test_clause_onclause(self): sess = create_session() - self.assertEquals( + eq_( sess.query(User).join( (Order, User.id==Order.user_id), (order_items, Order.id==order_items.c.order_id), @@ -1190,7 +1192,7 @@ class JoinTest(QueryTest): [User(name='jack')] ) - self.assertEquals( + eq_( sess.query(User.name).join( (Order, User.id==Order.user_id), (order_items, Order.id==order_items.c.order_id), @@ -1200,7 +1202,7 @@ class JoinTest(QueryTest): ) ualias = aliased(User) - self.assertEquals( + eq_( sess.query(ualias.name).join( (Order, ualias.id==Order.user_id), (order_items, Order.id==order_items.c.order_id), @@ -1212,7 +1214,7 @@ class JoinTest(QueryTest): # explicit onclause with from_self(), means # the onclause must be aliased against the query's custom # FROM object - self.assertEquals( + eq_( sess.query(User).order_by(User.id).offset(2).from_self().join( (Order, User.id==Order.user_id) ).all(), @@ -1220,7 +1222,7 @@ class JoinTest(QueryTest): ) # same with an explicit select_from() - self.assertEquals( + eq_( sess.query(User).select_from(select([users]).order_by(User.id).offset(2).alias()).join( (Order, User.id==Order.user_id) ).all(), @@ -1244,34 +1246,34 @@ class JoinTest(QueryTest): AdAlias = aliased(Address) q = q.add_entity(AdAlias).select_from(outerjoin(User, AdAlias)) l = q.order_by(User.id, AdAlias.id).all() - self.assertEquals(l, expected) + eq_(l, expected) sess.expunge_all() q = sess.query(User).add_entity(AdAlias) l = q.select_from(outerjoin(User, AdAlias)).filter(AdAlias.email_address=='ed@bettyboop.com').all() - self.assertEquals(l, [(user8, address3)]) + eq_(l, [(user8, address3)]) l = q.select_from(outerjoin(User, AdAlias, 'addresses')).filter(AdAlias.email_address=='ed@bettyboop.com').all() - self.assertEquals(l, [(user8, address3)]) + eq_(l, [(user8, address3)]) l = q.select_from(outerjoin(User, AdAlias, User.id==AdAlias.user_id)).filter(AdAlias.email_address=='ed@bettyboop.com').all() - self.assertEquals(l, [(user8, address3)]) + eq_(l, [(user8, address3)]) # this is the first test where we are joining "backwards" - from AdAlias to User even though # the query is against User q = sess.query(User, AdAlias) l = q.join(AdAlias.user).filter(User.name=='ed') - self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) + eq_(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) q = sess.query(User, AdAlias).select_from(join(AdAlias, User, AdAlias.user)).filter(User.name=='ed') - self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) + eq_(l.all(), [(user8, address2),(user8, address3),(user8, address4),]) def test_implicit_joins_from_aliases(self): sess = create_session() OrderAlias = aliased(Order) - self.assertEquals( + eq_( sess.query(OrderAlias).join('items').filter_by(description='item 3').\ order_by(OrderAlias.id).all(), [ @@ -1281,7 +1283,7 @@ class JoinTest(QueryTest): ] ) - self.assertEquals( + eq_( sess.query(User, OrderAlias, Item.description).join(('orders', OrderAlias), 'items').filter_by(description='item 3').\ order_by(User.id, OrderAlias.id).all(), [ @@ -1314,12 +1316,12 @@ class JoinTest(QueryTest): q = sess.query(Order) q = q.add_entity(Item).select_from(join(Order, Item, 'items')).order_by(Order.id, Item.id) l = q.all() - self.assertEquals(l, expected) + eq_(l, expected) IAlias = aliased(Item) q = sess.query(Order, IAlias).select_from(join(Order, IAlias, 'items')).filter(IAlias.description=='item 3') l = q.all() - self.assertEquals(l, + eq_(l, [ (order1, item3), (order2, item3), @@ -1385,7 +1387,7 @@ class JoinTest(QueryTest): sess = create_session() ualias = aliased(User) - self.assertEquals( + eq_( sess.query(User, ualias).filter(User.id > ualias.id).order_by(desc(ualias.id), User.name).all(), [ (User(id=10,name=u'chuck'), User(id=9,name=u'fred')), @@ -1401,14 +1403,15 @@ class JoinTest(QueryTest): sess = create_session() - self.assertEquals( + eq_( sess.query(User.name).join((addresses, User.id==addresses.c.user_id)).order_by(User.id).all(), [(u'jack',), (u'ed',), (u'ed',), (u'ed',), (u'fred',)] ) class MultiplePathTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1, t2, t1t2_1, t1t2_2 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), @@ -1440,7 +1443,7 @@ class MultiplePathTest(_base.MappedTest): mapper(T2, t2) q = create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint() - self.assertRaisesMessage(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.", + assert_raises_message(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.", q.join, 't2s_2' ) @@ -1449,7 +1452,8 @@ class MultiplePathTest(_base.MappedTest): class SynonymTest(QueryTest): - def setup_mappers(self): + @classmethod + def setup_mappers(cls): mapper(User, users, properties={ 'name_syn':synonym('name'), 'addresses':relation(Address), @@ -1547,7 +1551,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): adalias = addresses.alias() q = sess.query(User).select_from(users.outerjoin(adalias)).options(contains_eager(User.addresses, alias=adalias)) def go(): - self.assertEquals(self.static.user_address_result, q.order_by(User.id).all()) + eq_(self.static.user_address_result, q.order_by(User.id).all()) self.assert_sql_count(testing.db, go, 1) sess.expunge_all() @@ -1675,35 +1679,35 @@ class MixedEntitiesTest(QueryTest): sel = users.select(User.id.in_([7, 8])).alias() q = sess.query(User) q2 = q.select_from(sel).values(User.name) - self.assertEquals(list(q2), [(u'jack',), (u'ed',)]) + eq_(list(q2), [(u'jack',), (u'ed',)]) q = sess.query(User) q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String)) - self.assertEquals(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')]) + eq_(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')]) q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address) - self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) + eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice(1, 3).values(User.name, Address.email_address) - self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')]) + eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')]) adalias = aliased(Address) q2 = q.join(('addresses', adalias)).filter(User.name.like('%e%')).values(User.name, adalias.email_address) - self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) + eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')]) q2 = q.values(func.count(User.name)) assert q2.next() == (4,) q2 = q.select_from(sel).filter(User.id==8).values(User.name, sel.c.name, User.name) - self.assertEquals(list(q2), [(u'ed', u'ed', u'ed')]) + eq_(list(q2), [(u'ed', u'ed', u'ed')]) # using User.xxx is alised against "sel", so this query returns nothing q2 = q.select_from(sel).filter(User.id==8).filter(User.id>sel.c.id).values(User.name, sel.c.name, User.name) - self.assertEquals(list(q2), []) + eq_(list(q2), []) # whereas this uses users.c.xxx, is not aliased and creates a new join q2 = q.select_from(sel).filter(users.c.id==8).filter(users.c.id>sel.c.id).values(users.c.name, sel.c.name, User.name) - self.assertEquals(list(q2), [(u'ed', u'jack', u'jack')]) + eq_(list(q2), [(u'ed', u'jack', u'jack')]) @testing.fails_on('mssql', 'FIXME: unknown') def test_values_specific_order_by(self): @@ -1715,7 +1719,7 @@ class MixedEntitiesTest(QueryTest): q = sess.query(User) u2 = aliased(User) q2 = q.select_from(sel).filter(u2.id>1).order_by([User.id, sel.c.id, u2.id]).values(User.name, sel.c.name, u2.name) - self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')]) + eq_(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')]) @testing.fails_on('mssql', 'FIXME: unknown') def test_values_with_boolean_selects(self): @@ -1724,7 +1728,7 @@ class MixedEntitiesTest(QueryTest): q = sess.query(User) q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%'))) - self.assertEquals(list(q2), [(True, 1), (False, 3)]) + eq_(list(q2), [(True, 1), (False, 3)]) def test_correlated_subquery(self): """test that a subquery constructed from ORM attributes doesn't leak out @@ -1739,7 +1743,7 @@ class MixedEntitiesTest(QueryTest): label('count') # we don't want Address to be outside of the subquery here - self.assertEquals( + eq_( list(sess.query(User, subq)[0:3]), [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)] ) @@ -1751,7 +1755,7 @@ class MixedEntitiesTest(QueryTest): label('count') # we don't want Address to be outside of the subquery here - self.assertEquals( + eq_( list(sess.query(User, subq)[0:3]), [(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)] ) @@ -1759,71 +1763,71 @@ class MixedEntitiesTest(QueryTest): def test_tuple_labeling(self): sess = create_session() for row in sess.query(User, Address).join(User.addresses).all(): - self.assertEquals(set(row.keys()), set(['User', 'Address'])) - self.assertEquals(row.User, row[0]) - self.assertEquals(row.Address, row[1]) + eq_(set(row.keys()), set(['User', 'Address'])) + eq_(row.User, row[0]) + eq_(row.Address, row[1]) for row in sess.query(User.name, User.id.label('foobar')): - self.assertEquals(set(row.keys()), set(['name', 'foobar'])) - self.assertEquals(row.name, row[0]) - self.assertEquals(row.foobar, row[1]) + eq_(set(row.keys()), set(['name', 'foobar'])) + eq_(row.name, row[0]) + eq_(row.foobar, row[1]) for row in sess.query(User).values(User.name, User.id.label('foobar')): - self.assertEquals(set(row.keys()), set(['name', 'foobar'])) - self.assertEquals(row.name, row[0]) - self.assertEquals(row.foobar, row[1]) + eq_(set(row.keys()), set(['name', 'foobar'])) + eq_(row.name, row[0]) + eq_(row.foobar, row[1]) oalias = aliased(Order) for row in sess.query(User, oalias).join(User.orders).all(): - self.assertEquals(set(row.keys()), set(['User'])) - self.assertEquals(row.User, row[0]) + eq_(set(row.keys()), set(['User'])) + eq_(row.User, row[0]) oalias = aliased(Order, name='orders') for row in sess.query(User, oalias).join(User.orders).all(): - self.assertEquals(set(row.keys()), set(['User', 'orders'])) - self.assertEquals(row.User, row[0]) - self.assertEquals(row.orders, row[1]) + eq_(set(row.keys()), set(['User', 'orders'])) + eq_(row.User, row[0]) + eq_(row.orders, row[1]) def test_column_queries(self): sess = create_session() - self.assertEquals(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)]) + eq_(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)]) sel = users.select(User.id.in_([7, 8])).alias() q = sess.query(User.name) q2 = q.select_from(sel).all() - self.assertEquals(list(q2), [(u'jack',), (u'ed',)]) + eq_(list(q2), [(u'jack',), (u'ed',)]) - self.assertEquals(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [ + eq_(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [ (u'jack', u'jack@bean.com'), (u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com') ]) - self.assertEquals(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(), + eq_(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(), [(u'jack', 1), (u'ed', 3), (u'fred', 1), (u'chuck', 0)] ) - self.assertEquals(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), + eq_(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)] ) - self.assertEquals(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), + eq_(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), [(1, User(name='jack',id=7)), (3, User(name='ed',id=8)), (1, User(name='fred',id=9)), (0, User(name='chuck',id=10))] ) adalias = aliased(Address) - self.assertEquals(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(), + eq_(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(), [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)] ) - self.assertEquals(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(), + eq_(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(), [(1, User(name=u'jack',id=7)), (3, User(name=u'ed',id=8)), (1, User(name=u'fred',id=9)), (0, User(name=u'chuck',id=10))] ) # select from aliasing + explicit aliasing - self.assertEquals( + eq_( sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).order_by(User.id, adalias.id).all(), [ (User(name=u'jack',id=7), u'jack@bean.com'), @@ -1836,7 +1840,7 @@ class MixedEntitiesTest(QueryTest): ) # anon + select from aliasing - self.assertEquals( + eq_( sess.query(User).join(User.addresses, aliased=True).filter(Address.email_address.like('%ed%')).from_self().all(), [ User(name=u'ed',id=8), @@ -1849,7 +1853,7 @@ class MixedEntitiesTest(QueryTest): sess.query(User, adalias.email_address).outerjoin((User.addresses, adalias)).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10), sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10), ]: - self.assertEquals( + eq_( q.all(), [(User(addresses=[Address(user_id=7,email_address=u'jack@bean.com',id=1)],name=u'jack',id=7), u'jack@bean.com'), @@ -1875,7 +1879,7 @@ class MixedEntitiesTest(QueryTest): def go(): results = sess.query(User).limit(1).options(eagerload('addresses')).add_column(User.name).all() - self.assertEquals(results, [(User(name='jack'), 'jack')]) + eq_(results, [(User(name='jack'), 'jack')]) self.assert_sql_count(testing.db, go, 1) def test_self_referential(self): @@ -1898,7 +1902,7 @@ class MixedEntitiesTest(QueryTest): ]: - self.assertEquals( + eq_( q.all(), [ (Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)), @@ -1924,25 +1928,25 @@ class MixedEntitiesTest(QueryTest): sess = create_session() selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id]) - self.assertEquals(list(sess.query(User, Address).instances(selectquery.execute())), expected) + eq_(list(sess.query(User, Address).instances(selectquery.execute())), expected) sess.expunge_all() for address_entity in (Address, aliased(Address)): q = sess.query(User).add_entity(address_entity).outerjoin(('addresses', address_entity)).order_by(User.id, address_entity.id) - self.assertEquals(q.all(), expected) + eq_(q.all(), expected) sess.expunge_all() q = sess.query(User).add_entity(address_entity) q = q.join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com') - self.assertEquals(q.all(), [(user8, address3)]) + eq_(q.all(), [(user8, address3)]) sess.expunge_all() q = sess.query(User, address_entity).join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com') - self.assertEquals(q.all(), [(user8, address3)]) + eq_(q.all(), [(user8, address3)]) sess.expunge_all() q = sess.query(User, address_entity).join(('addresses', address_entity)).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com') - self.assertEquals(list(util.OrderedSet(q.all())), [(user8, address3)]) + eq_(list(util.OrderedSet(q.all())), [(user8, address3)]) sess.expunge_all() def test_aliased_multi_mappers(self): @@ -1979,7 +1983,7 @@ class MixedEntitiesTest(QueryTest): assert sess.query(User).add_column(add_col).all() == expected sess.expunge_all() - self.assertRaises(sa_exc.InvalidRequestError, sess.query(User).add_column, object()) + assert_raises(sa_exc.InvalidRequestError, sess.query(User).add_column, object()) def test_add_multi_columns(self): """test that add_column accepts a FROM clause.""" @@ -2004,13 +2008,13 @@ class MixedEntitiesTest(QueryTest): q = sess.query(User) q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(Address.id).label('count')) - self.assertEquals(q.all(), expected) + eq_(q.all(), expected) sess.expunge_all() adalias = aliased(Address) q = sess.query(User) q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin(('addresses', adalias)).add_column(func.count(adalias.id).label('count')) - self.assertEquals(q.all(), expected) + eq_(q.all(), expected) sess.expunge_all() s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id) @@ -2069,8 +2073,9 @@ class ImmediateTest(_fixtures.FixtureTest): run_inserts = 'once' run_deletes = None + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Address, addresses) mapper(User, users, properties=dict( @@ -2080,25 +2085,25 @@ class ImmediateTest(_fixtures.FixtureTest): def test_one(self): sess = create_session() - self.assertRaises(sa.orm.exc.NoResultFound, + assert_raises(sa.orm.exc.NoResultFound, sess.query(User).filter(User.id == 99).one) eq_(sess.query(User).filter(User.id == 7).one().id, 7) - self.assertRaises(sa.orm.exc.MultipleResultsFound, + assert_raises(sa.orm.exc.MultipleResultsFound, sess.query(User).one) - self.assertRaises( + assert_raises( sa.orm.exc.NoResultFound, sess.query(User.id, User.name).filter(User.id == 99).one) eq_(sess.query(User.id, User.name).filter(User.id == 7).one(), (7, 'jack')) - self.assertRaises(sa.orm.exc.MultipleResultsFound, + assert_raises(sa.orm.exc.MultipleResultsFound, sess.query(User.id, User.name).one) - self.assertRaises(sa.orm.exc.NoResultFound, + assert_raises(sa.orm.exc.NoResultFound, (sess.query(User, Address). join(User.addresses). filter(Address.id == 99)).one) @@ -2108,7 +2113,7 @@ class ImmediateTest(_fixtures.FixtureTest): filter(Address.id == 4)).one(), (User(id=8), Address(id=4))) - self.assertRaises(sa.orm.exc.MultipleResultsFound, + assert_raises(sa.orm.exc.MultipleResultsFound, sess.query(User, Address).join(User.addresses).one) @testing.future @@ -2133,7 +2138,7 @@ class ImmediateTest(_fixtures.FixtureTest): eq_(sess.query(User.id, User.name).filter_by(id=7).value(User.id), 7) eq_(sess.query(User).filter_by(id=0).value(User.id), None) - sess.bind = sa.testing.db + sess.bind = testing.db eq_(sess.query().value(sa.literal_column('1').label('x')), 1) @@ -2149,19 +2154,19 @@ class SelectFromTest(QueryTest): sel = users.select(users.c.id.in_([7, 8])).alias() sess = create_session() - self.assertEquals(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)]) + eq_(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)]) - self.assertEquals(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)]) + eq_(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)]) - self.assertEquals(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [ + eq_(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [ User(name='jack',id=7), User(name='ed',id=8) ]) - self.assertEquals(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [ + eq_(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [ User(name='ed',id=8), User(name='jack',id=7) ]) - self.assertEquals(sess.query(User).select_from(sel).options(eagerload('addresses')).first(), + eq_(sess.query(User).select_from(sel).options(eagerload('addresses')).first(), User(name='jack', addresses=[Address(id=1)]) ) @@ -2173,7 +2178,7 @@ class SelectFromTest(QueryTest): sel = users.select(users.c.id.in_([7, 8])) sess = create_session() - self.assertEquals(sess.query(User).select_from(sel).all(), + eq_(sess.query(User).select_from(sel).all(), [ User(name='jack',id=7), User(name='ed',id=8) ] @@ -2185,7 +2190,7 @@ class SelectFromTest(QueryTest): sel = users.select(users.c.id.in_([7, 8])) sess = create_session() - self.assertEquals(sess.query(User).select_from(sel).all(), + eq_(sess.query(User).select_from(sel).all(), [ User(name='jack',id=7), User(name='ed',id=8) ] @@ -2200,7 +2205,7 @@ class SelectFromTest(QueryTest): sel = users.select(users.c.id.in_([7, 8])) sess = create_session() - self.assertEquals(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(), + eq_(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(), [ (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)), (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)), @@ -2210,7 +2215,7 @@ class SelectFromTest(QueryTest): ) adalias = aliased(Address) - self.assertEquals(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(), + eq_(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(), [ (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)), (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)), @@ -2238,16 +2243,16 @@ class SelectFromTest(QueryTest): # TODO: remove sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all() - self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ User(name=u'jack',id=7) ]) - self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ User(name=u'jack',id=7) ]) def go(): - self.assertEquals(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ + eq_(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [ User(name=u'jack',orders=[ Order(description=u'order 1',items=[ Item(description=u'item 1',keywords=[Keyword(name=u'red'), Keyword(name=u'big'), Keyword(name=u'round')]), @@ -2265,11 +2270,11 @@ class SelectFromTest(QueryTest): sess.expunge_all() sel2 = orders.select(orders.c.id.in_([1,2,3])) - self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').order_by(Order.id).all(), [ + eq_(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').order_by(Order.id).all(), [ Order(description=u'order 1',id=1), Order(description=u'order 2',id=2), ]) - self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [ + eq_(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [ Order(description=u'order 1',id=1), Order(description=u'order 2',id=2), ]) @@ -2285,7 +2290,7 @@ class SelectFromTest(QueryTest): sess = create_session() def go(): - self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id).all(), + eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id).all(), [ User(id=7, addresses=[Address(id=1)]), User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]) @@ -2295,14 +2300,14 @@ class SelectFromTest(QueryTest): sess.expunge_all() def go(): - self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).order_by(User.id).all(), + eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).order_by(User.id).all(), [User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])] ) self.assert_sql_count(testing.db, go, 1) sess.expunge_all() def go(): - self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])) + eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])) self.assert_sql_count(testing.db, go, 1) class CustomJoinTest(QueryTest): @@ -2329,14 +2334,16 @@ class SelfReferentialTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global nodes nodes = Table('nodes', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('parent_id', Integer, ForeignKey('nodes.id')), Column('data', String(30))) - def insert_data(self): + @classmethod + def insert_data(cls): global Node class Node(Base): @@ -2399,7 +2406,7 @@ class SelfReferentialTest(_base.MappedTest): filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).first() assert node.data == 'n122' - self.assertEquals( + eq_( list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\ filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)), [('n122', 'n12', 'n1')]) @@ -2410,13 +2417,13 @@ class SelfReferentialTest(_base.MappedTest): n1 = aliased(Node) # using 'n1.parent' implicitly joins to unaliased Node - self.assertEquals( + eq_( sess.query(n1).join(n1.parent).filter(Node.data=='n1').all(), [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)] ) # explicit (new syntax) - self.assertEquals( + eq_( sess.query(n1).join((Node, n1.parent)).filter(Node.data=='n1').all(), [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)] ) @@ -2426,7 +2433,7 @@ class SelfReferentialTest(_base.MappedTest): parent = aliased(Node) grandparent = aliased(Node) - self.assertEquals( + eq_( sess.query(Node, parent, grandparent).\ join((Node.parent, parent), (parent.parent, grandparent)).\ filter(Node.data=='n122').filter(parent.data=='n12').\ @@ -2434,7 +2441,7 @@ class SelfReferentialTest(_base.MappedTest): (Node(data='n122'), Node(data='n12'), Node(data='n1')) ) - self.assertEquals( + eq_( sess.query(Node, parent, grandparent).\ join((Node.parent, parent), (parent.parent, grandparent)).\ filter(Node.data=='n122').filter(parent.data=='n12').\ @@ -2443,7 +2450,7 @@ class SelfReferentialTest(_base.MappedTest): ) # same, change order around - self.assertEquals( + eq_( sess.query(parent, grandparent, Node).\ join((Node.parent, parent), (parent.parent, grandparent)).\ filter(Node.data=='n122').filter(parent.data=='n12').\ @@ -2451,7 +2458,7 @@ class SelfReferentialTest(_base.MappedTest): (Node(data='n12'), Node(data='n1'), Node(data='n122')) ) - self.assertEquals( + eq_( sess.query(Node, parent, grandparent).\ join((Node.parent, parent), (parent.parent, grandparent)).\ filter(Node.data=='n122').filter(parent.data=='n12').\ @@ -2460,7 +2467,7 @@ class SelfReferentialTest(_base.MappedTest): (Node(data='n122'), Node(data='n12'), Node(data='n1')) ) - self.assertEquals( + eq_( sess.query(Node, parent, grandparent).\ join((Node.parent, parent), (parent.parent, grandparent)).\ filter(Node.data=='n122').filter(parent.data=='n12').\ @@ -2472,40 +2479,41 @@ class SelfReferentialTest(_base.MappedTest): def test_any(self): sess = create_session() - self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), []) - self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')]) - self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) + eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), []) + eq_(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')]) + eq_(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),]) def test_has(self): sess = create_session() - self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) - self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), []) - self.assertEquals(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')]) + eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), []) + eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')]) def test_contains(self): sess = create_session() n122 = sess.query(Node).filter(Node.data=='n122').one() - self.assertEquals(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')]) + eq_(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')]) n13 = sess.query(Node).filter(Node.data=='n13').one() - self.assertEquals(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')]) + eq_(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')]) def test_eq_ne(self): sess = create_session() n12 = sess.query(Node).filter(Node.data=='n12').one() - self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) + eq_(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')]) - self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')]) + eq_(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')]) class SelfReferentialM2MTest(_base.MappedTest): run_setup_mappers = 'once' run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global nodes, node_to_nodes nodes = Table('nodes', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -2516,7 +2524,8 @@ class SelfReferentialM2MTest(_base.MappedTest): Column('right_node_id', Integer, ForeignKey('nodes.id'),primary_key=True), ) - def insert_data(self): + @classmethod + def insert_data(cls): global Node class Node(Base): @@ -2550,20 +2559,20 @@ class SelfReferentialM2MTest(_base.MappedTest): def test_any(self): sess = create_session() - self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')]) + eq_(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')]) def test_contains(self): sess = create_session() n4 = sess.query(Node).filter_by(data='n4').one() - self.assertEquals(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')]) - self.assertEquals(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')]) + eq_(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')]) + eq_(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')]) def test_explicit_join(self): sess = create_session() n1 = aliased(Node) - self.assertEquals( + eq_( sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data.in_(['n3', 'n7'])).order_by(Node.id).all(), [Node(data='n1'), Node(data='n2')] ) @@ -2575,7 +2584,7 @@ class ExternalColumnsTest(QueryTest): def test_external_columns_bad(self): - self.assertRaisesMessage(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={ + assert_raises_message(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={ 'concat': (users.c.id * 2), }) clear_mappers() @@ -2596,7 +2605,7 @@ class ExternalColumnsTest(QueryTest): sess.query(Address).options(eagerload('user')).all() - self.assertEquals(sess.query(User).all(), + eq_(sess.query(User).all(), [ User(id=7, concat=14, count=1), User(id=8, concat=16, count=3), @@ -2612,22 +2621,22 @@ class ExternalColumnsTest(QueryTest): Address(id=4, user=User(id=8, concat=16, count=3)), Address(id=5, user=User(id=9, concat=18, count=1)) ] - self.assertEquals(sess.query(Address).all(), address_result) + eq_(sess.query(Address).all(), address_result) # run the eager version twice to test caching of aliased clauses for x in range(2): sess.expunge_all() def go(): - self.assertEquals(sess.query(Address).options(eagerload('user')).all(), address_result) + eq_(sess.query(Address).options(eagerload('user')).all(), address_result) self.assert_sql_count(testing.db, go, 1) ualias = aliased(User) - self.assertEquals( + eq_( sess.query(Address, ualias).join(('user', ualias)).all(), [(address, address.user) for address in address_result] ) - self.assertEquals( + eq_( sess.query(Address, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(), [ (Address(id=1), 1), @@ -2638,7 +2647,7 @@ class ExternalColumnsTest(QueryTest): ] ) - self.assertEquals(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(), + eq_(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(), [ (Address(id=1), 14, 1), (Address(id=2), 16, 3), @@ -2649,7 +2658,7 @@ class ExternalColumnsTest(QueryTest): ) ua = aliased(User) - self.assertEquals(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(), + eq_(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(), [ (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1), (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3), @@ -2659,11 +2668,11 @@ class ExternalColumnsTest(QueryTest): ] ) - self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), + eq_(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] ) - self.assertEquals(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)), + eq_(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)), [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)] ) @@ -2686,17 +2695,18 @@ class ExternalColumnsTest(QueryTest): sess = create_session() def go(): o1 = sess.query(Order).options(eagerload_all('address.user')).get(1) - self.assertEquals(o1.address.user.count, 1) + eq_(o1.address.user.count, 1) self.assert_sql_count(testing.db, go, 1) sess = create_session() def go(): o1 = sess.query(Order).options(eagerload_all('address.user')).first() - self.assertEquals(o1.address.user.count, 1) + eq_(o1.address.user.count, 1) self.assert_sql_count(testing.db, go, 1) class TestOverlyEagerEquivalentCols(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global base, sub1, sub2 base = Table('base', metadata, Column('id', Integer, primary_key=True), @@ -2747,14 +2757,15 @@ class TestOverlyEagerEquivalentCols(_base.MappedTest): q = sess.query(Base).outerjoin('sub2', aliased=True) assert sub1.c.id not in q._filter_aliases.equivalents - self.assertEquals( + eq_( sess.query(Base).join('sub1').outerjoin('sub2', aliased=True).\ filter(Sub1.id==1).one(), b1 ) class UpdateDeleteTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(32)), @@ -2765,15 +2776,17 @@ class UpdateDeleteTest(_base.MappedTest): Column('user_id', None, ForeignKey('users.id')), Column('title', String(32))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass class Document(_base.ComparableEntity): pass + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): users.insert().execute([ dict(id=1, name='john', age=25), dict(id=2, name='jack', age=47), @@ -2789,8 +2802,9 @@ class UpdateDeleteTest(_base.MappedTest): dict(id=3, user_id=2, title='baz'), ]) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users) mapper(Document, documents, properties={ 'user': relation(User, lazy=False, backref=backref('documents', lazy=True)) @@ -2964,17 +2978,17 @@ class UpdateDeleteTest(_base.MappedTest): sess = create_session(bind=testing.db, autocommit=False) rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age + 0}) - self.assertEquals(rowcount, 2) + eq_(rowcount, 2) rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10}) - self.assertEquals(rowcount, 2) + eq_(rowcount, 2) @testing.resolve_artifact_names def test_delete_returns_rowcount(self): sess = create_session(bind=testing.db, autocommit=False) rowcount = sess.query(User).filter(User.age > 26).delete(synchronize_session=False) - self.assertEquals(rowcount, 3) + eq_(rowcount, 3) @testing.resolve_artifact_names def test_update_with_eager_relations(self): @@ -3008,5 +3022,3 @@ class UpdateDeleteTest(_base.MappedTest): eq_(sess.query(Document.title).all(), zip(['baz'])) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/relationships.py b/test/orm/test_relationships.py similarity index 94% rename from test/orm/relationships.py rename to test/orm/test_relationships.py index a0a8900b2c..1bc074c314 100644 --- a/test/orm/relationships.py +++ b/test/orm/test_relationships.py @@ -1,10 +1,13 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import assert_raises, assert_raises_message import datetime -from testlib import sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, and_ -from testlib.sa.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers -from testlib.testing import eq_, startswith_ -from orm import _base, _fixtures +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import Integer, String, ForeignKey, MetaData, and_ +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers +from sqlalchemy.test.testing import eq_, startswith_ +from test.orm import _base, _fixtures class RelationTest(_base.MappedTest): @@ -26,7 +29,8 @@ class RelationTest(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("tbl_a", metadata, Column("id", Integer, primary_key=True), Column("name", String(128))) @@ -43,7 +47,8 @@ class RelationTest(_base.MappedTest): Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")), Column("name", String(128))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_base.Entity): pass class B(_base.Entity): @@ -53,8 +58,9 @@ class RelationTest(_base.MappedTest): class D(_base.Entity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(A, tbl_a, properties=dict( c_rows=relation(C, cascade="all, delete-orphan", backref="a_row"))) mapper(B, tbl_b) @@ -63,8 +69,9 @@ class RelationTest(_base.MappedTest): mapper(D, tbl_d, properties=dict( b_row=relation(B))) + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): session = create_session() a = A(name='a1') b = B(name='b1') @@ -102,7 +109,8 @@ class RelationTest2(_base.MappedTest): key where one column in the foreign key is 'joined to itself'. """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('company_t', metadata, Column('company_id', Integer, primary_key=True), Column('name', sa.Unicode(30))) @@ -218,7 +226,8 @@ class RelationTest2(_base.MappedTest): assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5' class RelationTest3(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("jobs", metadata, Column("jobno", sa.Unicode(15), primary_key=True), Column("created", sa.DateTime, nullable=False, @@ -257,8 +266,9 @@ class RelationTest3(_base.MappedTest): ["jobno", "pagename"], ["pages.jobno", "pages.pagename"])) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): class Job(_base.Entity): def create_page(self, pagename): return Page(job=self, pagename=pagename) @@ -360,7 +370,8 @@ class RelationTest3(_base.MappedTest): class RelationTest4(_base.MappedTest): """Syncrules on foreign keys that are also primary""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("tableA", metadata, Column("id",Integer,primary_key=True), Column("foo",Integer,), @@ -369,7 +380,8 @@ class RelationTest4(_base.MappedTest): Column("id",Integer,ForeignKey("tableA.id"),primary_key=True), test_needs_fk=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class A(_base.Entity): pass @@ -537,7 +549,8 @@ class RelationTest4(_base.MappedTest): class RelationTest5(_base.MappedTest): """Test a map to a select that relates to a map to the table.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('items', metadata, Column('item_policy_num', String(10), primary_key=True, key='policyNum'), @@ -605,7 +618,8 @@ class RelationTest6(_base.MappedTest): """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('tags', metadata, Column("id", Integer, primary_key=True), Column("data", String(50)), ) @@ -659,7 +673,8 @@ class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest): """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): subscriber_table = Table('subscriber', metadata, Column('id', Integer, primary_key=True), Column('dummy', String(10)) # to appease older sqlite version @@ -671,8 +686,9 @@ class AmbiguousJoinInterpretedAsSelfRef(_base.MappedTest): Column('type', String(1), primary_key=True), ) + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): subscriber_and_address = subscriber.join(address, and_(address.c.subscriber_id==subscriber.c.id, address.c.type.in_(['A', 'B', 'C']))) @@ -762,7 +778,7 @@ class ManualBackrefTest(_fixtures.FixtureTest): 'user':relation(User, back_populates='addresses') }) - self.assertRaises(sa.exc.InvalidRequestError, compile_mappers) + assert_raises(sa.exc.InvalidRequestError, compile_mappers) @testing.resolve_artifact_names def test_invalid_target(self): @@ -775,7 +791,7 @@ class ManualBackrefTest(_fixtures.FixtureTest): 'dingaling':relation(Dingaling) }) - self.assertRaisesMessage(sa.exc.ArgumentError, + assert_raises_message(sa.exc.ArgumentError, r"reverse_property 'dingaling' on relation User.addresses references " "relation Address.dingaling, which does not reference mapper Mapper\|User\|users", compile_mappers) @@ -794,7 +810,7 @@ class JoinConditionErrorTest(testing.TestBase): c1id = Column('c1id', Integer, ForeignKey('c1.id')) c2 = relation(C1, primaryjoin=C1.id) - self.assertRaises(sa.exc.ArgumentError, compile_mappers) + assert_raises(sa.exc.ArgumentError, compile_mappers) def test_clauseelement_pj_false(self): from sqlalchemy.ext.declarative import declarative_base @@ -808,7 +824,7 @@ class JoinConditionErrorTest(testing.TestBase): c1id = Column('c1id', Integer, ForeignKey('c1.id')) c2 = relation(C1, primaryjoin="x"=="y") - self.assertRaises(sa.exc.ArgumentError, compile_mappers) + assert_raises(sa.exc.ArgumentError, compile_mappers) def test_fk_error_raised(self): @@ -834,7 +850,7 @@ class JoinConditionErrorTest(testing.TestBase): mapper(C1, t1, properties={'c2':relation(C2)}) mapper(C2, t3) - self.assertRaises(sa.exc.NoReferencedColumnError, compile_mappers) + assert_raises(sa.exc.NoReferencedColumnError, compile_mappers) def test_join_error_raised(self): m = MetaData() @@ -858,15 +874,16 @@ class JoinConditionErrorTest(testing.TestBase): mapper(C1, t1, properties={'c2':relation(C2)}) mapper(C2, t3) - self.assertRaises(sa.exc.ArgumentError, compile_mappers) + assert_raises(sa.exc.ArgumentError, compile_mappers) - def tearDown(self): + def teardown(self): clear_mappers() class TypeMatchTest(_base.MappedTest): """test errors raised when trying to add items whose type is not handled by a relation""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("a", metadata, Column('aid', Integer, primary_key=True), Column('data', String(30))) @@ -924,7 +941,7 @@ class TypeMatchTest(_base.MappedTest): sess.add(a1) sess.add(b1) sess.add(c1) - self.assertRaisesMessage(sa.orm.exc.FlushError, + assert_raises_message(sa.orm.exc.FlushError, "Attempting to flush an item", sess.flush) @testing.resolve_artifact_names @@ -945,7 +962,7 @@ class TypeMatchTest(_base.MappedTest): sess.add(a1) sess.add(b1) sess.add(c1) - self.assertRaisesMessage(sa.orm.exc.FlushError, + assert_raises_message(sa.orm.exc.FlushError, "Attempting to flush an item", sess.flush) @testing.resolve_artifact_names @@ -962,7 +979,7 @@ class TypeMatchTest(_base.MappedTest): sess = create_session() sess.add(b1) sess.add(d1) - self.assertRaisesMessage(sa.orm.exc.FlushError, + assert_raises_message(sa.orm.exc.FlushError, "Attempting to flush an item", sess.flush) @testing.resolve_artifact_names @@ -977,12 +994,13 @@ class TypeMatchTest(_base.MappedTest): d1 = D() d1.a = b1 sess = create_session() - self.assertRaisesMessage(AssertionError, + assert_raises_message(AssertionError, "doesn't handle objects of type", sess.add, d1) class TypedAssociationTable(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): class MySpecialType(sa.types.TypeDecorator): impl = String def process_bind_param(self, value, dialect): @@ -1033,7 +1051,8 @@ class TypedAssociationTable(_base.MappedTest): class ViewOnlyOverlappingNames(_base.MappedTest): """'viewonly' mappings with overlapping PK column names.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("t1", metadata, Column('id', Integer, primary_key=True), Column('data', String(40))) @@ -1092,7 +1111,8 @@ class ViewOnlyOverlappingNames(_base.MappedTest): class ViewOnlyUniqueNames(_base.MappedTest): """'viewonly' mappings with unique PK column names.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("t1", metadata, Column('t1id', Integer, primary_key=True), Column('data', String(40))) @@ -1182,7 +1202,8 @@ class ViewOnlyLocalRemoteM2M(testing.TestBase): class ViewOnlyNonEquijoin(_base.MappedTest): """'viewonly' mappings based on non-equijoins.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foos', metadata, Column('id', Integer, primary_key=True)) Table('bars', metadata, @@ -1223,7 +1244,8 @@ class ViewOnlyNonEquijoin(_base.MappedTest): class ViewOnlyRepeatedRemoteColumn(_base.MappedTest): """'viewonly' mappings that contain the same 'remote' column twice""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foos', metadata, Column('id', Integer, primary_key=True), Column('bid1', Integer,ForeignKey('bars.id')), @@ -1270,7 +1292,8 @@ class ViewOnlyRepeatedRemoteColumn(_base.MappedTest): class ViewOnlyRepeatedLocalColumn(_base.MappedTest): """'viewonly' mappings that contain the same 'local' column twice""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foos', metadata, Column('id', Integer, primary_key=True), Column('data', String(50))) @@ -1317,7 +1340,8 @@ class ViewOnlyRepeatedLocalColumn(_base.MappedTest): class ViewOnlyComplexJoin(_base.MappedTest): """'viewonly' mappings with a complex join condition.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50))) @@ -1332,7 +1356,8 @@ class ViewOnlyComplexJoin(_base.MappedTest): Column('t2id', Integer, ForeignKey('t2.id')), Column('t3id', Integer, ForeignKey('t3.id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T1(_base.ComparableEntity): pass class T2(_base.ComparableEntity): @@ -1379,14 +1404,15 @@ class ViewOnlyComplexJoin(_base.MappedTest): 't1':relation(T1), 't3s':relation(T3, secondary=t2tot3)}) mapper(T3, t3) - self.assertRaisesMessage(sa.exc.ArgumentError, + assert_raises_message(sa.exc.ArgumentError, "Specify remote_side argument", sa.orm.compile_mappers) class ExplicitLocalRemoteTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', String(50), primary_key=True), Column('data', String(50))) @@ -1395,8 +1421,9 @@ class ExplicitLocalRemoteTest(_base.MappedTest): Column('data', String(50)), Column('t1id', String(50))) + @classmethod @testing.resolve_artifact_names - def setup_classes(self): + def setup_classes(cls): class T1(_base.ComparableEntity): pass class T2(_base.ComparableEntity): @@ -1508,7 +1535,7 @@ class ExplicitLocalRemoteTest(_base.MappedTest): foreign_keys=[t2.c.t1id], remote_side=[t2.c.t1id])}) mapper(T2, t2) - self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers) + assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers) @testing.resolve_artifact_names def test_escalation_2(self): @@ -1517,18 +1544,20 @@ class ExplicitLocalRemoteTest(_base.MappedTest): primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id), _local_remote_pairs=[(t1.c.id, t2.c.t1id)])}) mapper(T2, t2) - self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers) + assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers) class InvalidRemoteSideTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)), Column('t_id', Integer, ForeignKey('t1.id')) ) + @classmethod @testing.resolve_artifact_names - def setup_classes(self): + def setup_classes(cls): class T1(_base.ComparableEntity): pass @@ -1538,7 +1567,7 @@ class InvalidRemoteSideTest(_base.MappedTest): 't1s':relation(T1, backref='parent') }) - self.assertRaisesMessage(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " + assert_raises_message(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " "both of the same direction . Did you " "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) @@ -1548,7 +1577,7 @@ class InvalidRemoteSideTest(_base.MappedTest): 't1s':relation(T1, backref=backref('parent', remote_side=t1.c.id), remote_side=t1.c.id) }) - self.assertRaisesMessage(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " + assert_raises_message(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are " "both of the same direction . Did you " "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) @@ -1560,7 +1589,7 @@ class InvalidRemoteSideTest(_base.MappedTest): }) # can't be sure of ordering here - self.assertRaisesMessage(sa.exc.ArgumentError, + assert_raises_message(sa.exc.ArgumentError, "both of the same direction . Did you " "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) @@ -1572,14 +1601,15 @@ class InvalidRemoteSideTest(_base.MappedTest): }) # can't be sure of ordering here - self.assertRaisesMessage(sa.exc.ArgumentError, + assert_raises_message(sa.exc.ArgumentError, "both of the same direction . Did you " "mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers) class InvalidRelationEscalationTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foos', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer)) @@ -1587,7 +1617,8 @@ class InvalidRelationEscalationTest(_base.MappedTest): Column('id', Integer, primary_key=True), Column('fid', Integer)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Foo(_base.Entity): pass class Bar(_base.Entity): @@ -1599,7 +1630,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): 'bars':relation(Bar)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine join condition between parent/child " "tables on relation", sa.orm.compile_mappers) @@ -1610,7 +1641,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): 'foos':relation(Foo)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine join condition between parent/child " "tables on relation", sa.orm.compile_mappers) @@ -1622,7 +1653,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): primaryjoin=foos.c.id>bars.c.fid)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1635,7 +1666,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): foreign_keys=bars.c.fid)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not locate any equated, locally mapped column pairs " "for primaryjoin condition", sa.orm.compile_mappers) @@ -1648,7 +1679,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): foreign_keys=[foos.c.id, bars.c.fid])}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Do the columns in 'foreign_keys' represent only the " "'foreign' columns in this join condition ?", @@ -1665,7 +1696,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): )}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "could not determine any local/remote column pairs", sa.orm.compile_mappers) @@ -1681,7 +1712,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): )}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "could not determine any local/remote column pairs", sa.orm.compile_mappers) @@ -1694,7 +1725,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): primaryjoin=foos.c.id>foos.c.fid)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1707,7 +1738,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): foreign_keys=[foos.c.fid])}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not locate any equated, locally mapped column pairs " "for primaryjoin condition", sa.orm.compile_mappers) @@ -1720,7 +1751,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): viewonly=True)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1733,7 +1764,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): viewonly=True)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Specify the 'foreign_keys' argument to indicate which columns " "on the relation are foreign.", sa.orm.compile_mappers) @@ -1756,7 +1787,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): primaryjoin=foos.c.id==bars.c.fid)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1767,7 +1798,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): 'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid)}) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1779,7 +1810,7 @@ class InvalidRelationEscalationTest(_base.MappedTest): primaryjoin=foos.c.id==foos.c.fid, foreign_keys=[bars.c.id])}) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1787,7 +1818,8 @@ class InvalidRelationEscalationTest(_base.MappedTest): class InvalidRelationEscalationTestM2M(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('foos', metadata, Column('id', Integer, primary_key=True)) Table('foobars', metadata, @@ -1795,8 +1827,9 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): Table('bars', metadata, Column('id', Integer, primary_key=True)) + @classmethod @testing.resolve_artifact_names - def setup_classes(self): + def setup_classes(cls): class Foo(_base.Entity): pass class Bar(_base.Entity): @@ -1808,7 +1841,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): 'bars': relation(Bar, secondary=foobars)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine join condition between parent/child tables " "on relation", sa.orm.compile_mappers) @@ -1821,7 +1854,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): primaryjoin=foos.c.id > foobars.c.fid)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine join condition between parent/child tables " "on relation", @@ -1836,7 +1869,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): secondaryjoin=foobars.c.bid<=bars.c.id)}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", sa.orm.compile_mappers) @@ -1851,7 +1884,7 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): foreign_keys=[foobars.c.fid])}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not determine relation direction for secondaryjoin " "condition", sa.orm.compile_mappers) @@ -1866,11 +1899,9 @@ class InvalidRelationEscalationTestM2M(_base.MappedTest): foreign_keys=[foobars.c.fid, foobars.c.bid])}) mapper(Bar, bars) - self.assertRaisesMessage( + assert_raises_message( sa.exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for " "secondaryjoin condition", sa.orm.compile_mappers) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/scoping.py b/test/orm/test_scoping.py similarity index 83% rename from test/orm/scoping.py rename to test/orm/test_scoping.py index bdfc5a9d58..2117e8dccb 100644 --- a/test/orm/scoping.py +++ b/test/orm/test_scoping.py @@ -1,10 +1,13 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa, testing +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing from sqlalchemy.orm import scoped_session -from testlib.sa import Table, Column, Integer, String, ForeignKey -from testlib.sa.orm import mapper, relation, query -from testlib.testing import eq_ -from orm import _base +from sqlalchemy import Integer, String, ForeignKey +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, query +from sqlalchemy.test.testing import eq_ +from test.orm import _base class _ScopedTest(_base.MappedTest): @@ -15,18 +18,21 @@ class _ScopedTest(_base.MappedTest): _artifact_registries = ( _base.MappedTest._artifact_registries + ('scoping',)) - def setUpAll(self): - type(self).scoping = _base.adict() - _base.MappedTest.setUpAll(self) + @classmethod + def setup_class(cls): + cls.scoping = _base.adict() + super(_ScopedTest, cls).setup_class() - def tearDownAll(self): - self.scoping.clear() - _base.MappedTest.tearDownAll(self) + @classmethod + def teardown_class(cls): + cls.scoping.clear() + super(_ScopedTest, cls).teardown_class() class ScopedSessionTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('table1', metadata, Column('id', Integer, primary_key=True), Column('data', String(30))) @@ -73,7 +79,8 @@ class ScopedSessionTest(_base.MappedTest): class ScopedMapperTest(_ScopedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('table1', metadata, Column('id', Integer, primary_key=True), Column('data', String(30))) @@ -81,24 +88,27 @@ class ScopedMapperTest(_ScopedTest): Column('id', Integer, primary_key=True), Column('someid', None, ForeignKey('table1.id'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class SomeObject(_base.ComparableEntity): pass class SomeOtherObject(_base.ComparableEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): Session = scoped_session(sa.orm.create_session) Session.mapper(SomeObject, table1, properties={ 'options':relation(SomeOtherObject) }) Session.mapper(SomeOtherObject, table2) - self.scoping['Session'] = Session + cls.scoping['Session'] = Session + @classmethod @testing.resolve_artifact_names - def insert_data(self): + def insert_data(cls): s = SomeObject() s.id = 1 s.data = 'hello' @@ -145,7 +155,7 @@ class ScopedMapperTest(_ScopedTest): scope.mapper(B, table2) A(foo='bar') - self.assertRaises(TypeError, B, foo='bar') + assert_raises(TypeError, B, foo='bar') scope = scoped_session(sa.orm.sessionmaker()) @@ -158,7 +168,7 @@ class ScopedMapperTest(_ScopedTest): scope.mapper(C, table1) scope.mapper(D, table2) - self.assertRaises(TypeError, C, foo='bar') + assert_raises(TypeError, C, foo='bar') D(foo='bar') @testing.resolve_artifact_names @@ -170,7 +180,7 @@ class ScopedMapperTest(_ScopedTest): Session.mapper(ValidatedOtherObject, table2, validate=True) v1 = ValidatedOtherObject(someid=12) - self.assertRaises(sa.exc.ArgumentError, ValidatedOtherObject, + assert_raises(sa.exc.ArgumentError, ValidatedOtherObject, someid=12, bogus=345) @testing.resolve_artifact_names @@ -186,7 +196,8 @@ class ScopedMapperTest(_ScopedTest): class ScopedMapperTest2(_ScopedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('table1', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)), @@ -196,14 +207,16 @@ class ScopedMapperTest2(_ScopedTest): Column('someid', None, ForeignKey('table1.id')), Column('somedata', String(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class BaseClass(_base.ComparableEntity): pass class SubClass(BaseClass): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): Session = scoped_session(sa.orm.sessionmaker()) Session.mapper(BaseClass, table1, @@ -213,7 +226,7 @@ class ScopedMapperTest2(_ScopedTest): polymorphic_identity='sub', inherits=BaseClass) - self.scoping['Session'] = Session + cls.scoping['Session'] = Session @testing.resolve_artifact_names def test_inheritance(self): @@ -234,5 +247,3 @@ class ScopedMapperTest2(_ScopedTest): SubClass.query.all()) -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/selectable.py b/test/orm/test_selectable.py similarity index 71% rename from test/orm/selectable.py rename to test/orm/test_selectable.py index 74c41c8523..0a20253607 100644 --- a/test/orm/selectable.py +++ b/test/orm/test_selectable.py @@ -1,22 +1,27 @@ """Generic mapping to Select statements""" -import testenv; testenv.configure_for_tests() -from testlib import sa, testing -from testlib.sa import Table, Column, String, Integer, select -from testlib.sa.orm import mapper, create_session -from testlib.testing import eq_ -from orm import _base +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import String, Integer, select +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, create_session +from sqlalchemy.test.testing import eq_ +from test.orm import _base # TODO: more tests mapping to selects class SelectableNoFromsTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('common', metadata, Column('id', Integer, primary_key=True), Column('data', Integer), Column('extra', String(45))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Subset(_base.ComparableEntity): pass @@ -24,7 +29,7 @@ class SelectableNoFromsTest(_base.MappedTest): def test_no_tables(self): selectable = select(["x", "y", "z"]) - self.assertRaisesMessage(sa.exc.InvalidRequestError, + assert_raises_message(sa.exc.InvalidRequestError, "Could not find any Table objects", mapper, Subset, selectable) @@ -48,5 +53,3 @@ class SelectableNoFromsTest(_base.MappedTest): Subset(data=1)) -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/session.py b/test/orm/test_session.py similarity index 93% rename from test/orm/session.py rename to test/orm/test_session.py index 6cbd62a50e..3020d66e9d 100644 --- a/test/orm/session.py +++ b/test/orm/test_session.py @@ -1,14 +1,17 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import gc import inspect import pickle from sqlalchemy.orm import create_session, sessionmaker, attributes -from testlib import engines, sa, testing, config -from testlib.sa import Table, Column, Integer, String, Sequence -from testlib.sa.orm import mapper, relation, backref, eagerload -from testlib.testing import eq_ -from engine import _base as engine_base -from orm import _base, _fixtures +import sqlalchemy as sa +from sqlalchemy.test import engines, testing, config +from sqlalchemy import Integer, String, Sequence +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, backref, eagerload +from sqlalchemy.test.testing import eq_ +from test.engine import _base as engine_base +from test.orm import _base, _fixtures class SessionTest(_fixtures.FixtureTest): @@ -508,7 +511,7 @@ class SessionTest(_fixtures.FixtureTest): sess.commit() - self.assertEquals(set(sess.query(User).all()), set([u2])) + eq_(set(sess.query(User).all()), set([u2])) sess.begin() sess.begin_nested() @@ -518,7 +521,7 @@ class SessionTest(_fixtures.FixtureTest): sess.commit() # commit the nested transaction sess.rollback() - self.assertEquals(set(sess.query(User).all()), set([u2])) + eq_(set(sess.query(User).all()), set([u2])) sess.close() @@ -541,7 +544,7 @@ class SessionTest(_fixtures.FixtureTest): sess.close() - self.assertEquals(len(sess.query(User).all()), 1) + eq_(len(sess.query(User).all()), 1) t1 = sess.begin() t2 = sess.begin_nested() @@ -572,7 +575,7 @@ class SessionTest(_fixtures.FixtureTest): sess.close() - self.assertEquals(len(sess.query(User).all()), 1) + eq_(len(sess.query(User).all()), 1) @testing.resolve_artifact_names def test_error_on_using_inactive_session(self): @@ -587,7 +590,7 @@ class SessionTest(_fixtures.FixtureTest): sess.flush() sess.rollback() - self.assertRaisesMessage(sa.exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True) + assert_raises_message(sa.exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True) sess.close() @testing.resolve_artifact_names @@ -612,7 +615,7 @@ class SessionTest(_fixtures.FixtureTest): sess.flush() assert transaction._connection_for_bind(testing.db) is transaction._connection_for_bind(c) is c - self.assertRaisesMessage(sa.exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect()) + assert_raises_message(sa.exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect()) transaction.rollback() assert len(sess.query(User).all()) == 0 @@ -667,8 +670,8 @@ class SessionTest(_fixtures.FixtureTest): user = User(name='u1') - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is not persisted", s.update, user) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is not persisted", s.delete, user) + assert_raises_message(sa.exc.InvalidRequestError, "is not persisted", s.update, user) + assert_raises_message(sa.exc.InvalidRequestError, "is not persisted", s.delete, user) s.add(user) s.flush() @@ -694,13 +697,13 @@ class SessionTest(_fixtures.FixtureTest): assert user in s assert user not in s.dirty - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is already persistent", s.save, user) + assert_raises_message(sa.exc.InvalidRequestError, "is already persistent", s.save, user) s2 = create_session() - self.assertRaisesMessage(sa.exc.InvalidRequestError, "is already attached to session", s2.delete, user) + assert_raises_message(sa.exc.InvalidRequestError, "is already attached to session", s2.delete, user) u2 = s2.query(User).get(user.id) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "another instance with key", s.delete, u2) + assert_raises_message(sa.exc.InvalidRequestError, "another instance with key", s.delete, u2) s.expire(user) s.expunge(user) @@ -1029,7 +1032,7 @@ class SessionTest(_fixtures.FixtureTest): u = User(name='u1') sess.add(u) sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='another u1'), User(name='u1') @@ -1037,7 +1040,7 @@ class SessionTest(_fixtures.FixtureTest): ) sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='another u1'), User(name='u1') @@ -1046,7 +1049,7 @@ class SessionTest(_fixtures.FixtureTest): u.name='u2' sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='another u1'), User(name='another u2'), @@ -1056,7 +1059,7 @@ class SessionTest(_fixtures.FixtureTest): sess.delete(u) sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='another u1'), ] @@ -1075,7 +1078,7 @@ class SessionTest(_fixtures.FixtureTest): u = User(name='u1') sess.add(u) sess.flush() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='u1') ] @@ -1084,7 +1087,7 @@ class SessionTest(_fixtures.FixtureTest): sess.add(User(name='u2')) sess.flush() sess.expunge_all() - self.assertEquals(sess.query(User).order_by(User.name).all(), + eq_(sess.query(User).order_by(User.name).all(), [ User(name='u1 modified'), User(name='u2') @@ -1102,7 +1105,7 @@ class SessionTest(_fixtures.FixtureTest): sess = create_session(extension=MyExt()) sess.add(User(name='foo')) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "already flushing", sess.flush) + assert_raises_message(sa.exc.InvalidRequestError, "already flushing", sess.flush) @testing.resolve_artifact_names def test_pickled_update(self): @@ -1113,7 +1116,7 @@ class SessionTest(_fixtures.FixtureTest): u1 = User(name='u1') sess1.add(u1) - self.assertRaisesMessage(sa.exc.InvalidRequestError, "already attached to session", sess2.add, u1) + assert_raises_message(sa.exc.InvalidRequestError, "already attached to session", sess2.add, u1) u2 = pickle.loads(pickle.dumps(u1)) @@ -1139,7 +1142,7 @@ class SessionTest(_fixtures.FixtureTest): assert u2 is not None and u2 is not u1 assert u2 in sess - self.assertRaises(Exception, lambda: sess.add(u1)) + assert_raises(Exception, lambda: sess.add(u1)) sess.expunge(u2) assert u2 not in sess @@ -1181,24 +1184,26 @@ class DisposedStates(_base.MappedTest): run_inserts = 'once' run_deletes = None - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): global t1 t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('data', String(50)) ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): global T class T(object): def __init__(self, data): self.data = data mapper(T, t1) - def tearDown(self): + def teardown(self): from sqlalchemy.orm.session import _sessions _sessions.clear() - super(DisposedStates, self).tearDown() + super(DisposedStates, self).teardown() def _set_imap_in_disposal(self, sess, *objs): """remove selected objects from the given session, as though they @@ -1291,7 +1296,7 @@ class SessionInterface(testing.TestBase): def x_raises_(obj, method, *args, **kw): watchdog.add(method) callable_ = getattr(obj, method) - self.assertRaises(sa.orm.exc.UnmappedInstanceError, + assert_raises(sa.orm.exc.UnmappedInstanceError, callable_, *args, **kw) def raises_(method, *args, **kw): @@ -1343,7 +1348,7 @@ class SessionInterface(testing.TestBase): def raises_(method, *args, **kw): watchdog.add(method) callable_ = getattr(create_session(), method) - self.assertRaises(sa.orm.exc.UnmappedClassError, + assert_raises(sa.orm.exc.UnmappedClassError, callable_, *args, **kw) raises_('connection', mapper=user_arg) @@ -1395,32 +1400,27 @@ class SessionInterface(testing.TestBase): class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest): - def create_engine(self): + @classmethod + def create_engine(cls): return engines.testing_engine(options=dict(strategy='threadlocal')) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String(20)), test_needs_acid=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users) - def setUpAll(self): - engine_base.AltEngineTest.setUpAll(self) - _base.MappedTest.setUpAll(self) - - - def tearDownAll(self): - _base.MappedTest.tearDownAll(self) - engine_base.AltEngineTest.tearDownAll(self) - @testing.exclude('mysql', '<', (5, 0, 3), 'FIXME: unknown') @testing.resolve_artifact_names def test_session_nesting(self): @@ -1432,5 +1432,3 @@ class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest): self.engine.commit() -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/transaction.py b/test/orm/test_transaction.py similarity index 86% rename from test/orm/transaction.py rename to test/orm/test_transaction.py index 0fcd55df32..5aa541cdad 100644 --- a/test/orm/transaction.py +++ b/test/orm/test_transaction.py @@ -1,13 +1,13 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy.orm import attributes from sqlalchemy import exc as sa_exc from sqlalchemy.orm import * -from testlib import testing -from orm import _base -from orm._fixtures import FixtureTest, User, Address, users, addresses +from sqlalchemy.test import testing +from test.orm import _base +from test.orm._fixtures import FixtureTest, User, Address, users, addresses import gc @@ -16,7 +16,8 @@ class TransactionTest(FixtureTest): run_inserts = None session = sessionmaker() - def setup_mappers(self): + @classmethod + def setup_mappers(cls): mapper(User, users, properties={ 'addresses':relation(Address, backref='user', cascade="all, delete-orphan"), @@ -32,7 +33,7 @@ class FixtureDataTest(TransactionTest): u1 = sess.query(User).get(7) u1.name = 'ed' sess.rollback() - self.assertEquals(u1.name, 'jack') + eq_(u1.name, 'jack') def test_commit_persistent(self): sess = self.session() @@ -40,7 +41,7 @@ class FixtureDataTest(TransactionTest): u1.name = 'ed' sess.flush() sess.commit() - self.assertEquals(u1.name, 'ed') + eq_(u1.name, 'ed') def test_concurrent_commit_persistent(self): s1 = self.session() @@ -157,7 +158,7 @@ class AutoExpireTest(TransactionTest): u1.addresses.remove(a1) s.flush() - self.assertEquals(s.query(Address).filter(Address.email_address=='foo').all(), []) + eq_(s.query(Address).filter(Address.email_address=='foo').all(), []) s.rollback() assert a1 not in s.deleted assert u1.addresses == [a1] @@ -168,7 +169,7 @@ class AutoExpireTest(TransactionTest): sess.add(u1) sess.flush() sess.commit() - self.assertEquals(u1.name, 'newuser') + eq_(u1.name, 'newuser') def test_concurrent_commit_pending(self): @@ -212,8 +213,8 @@ class RollbackRecoverTest(TransactionTest): u1.name = 'edward' a1.email_address = 'foober' s.add(u2) - self.assertRaises(sa_exc.FlushError, s.commit) - self.assertRaises(sa_exc.InvalidRequestError, s.commit) + assert_raises(sa_exc.FlushError, s.commit) + assert_raises(sa_exc.InvalidRequestError, s.commit) s.rollback() assert u2 not in s assert a2 not in s @@ -224,7 +225,7 @@ class RollbackRecoverTest(TransactionTest): u1.name = 'edward' a1.email_address = 'foober' s.commit() - self.assertEquals( + eq_( s.query(User).all(), [User(id=1, name='edward', addresses=[Address(email_address='foober')])] ) @@ -244,8 +245,8 @@ class RollbackRecoverTest(TransactionTest): a1.email_address = 'foober' s.begin_nested() s.add(u2) - self.assertRaises(sa_exc.FlushError, s.commit) - self.assertRaises(sa_exc.InvalidRequestError, s.commit) + assert_raises(sa_exc.FlushError, s.commit) + assert_raises(sa_exc.InvalidRequestError, s.commit) s.rollback() assert u2 not in s assert a2 not in s @@ -271,15 +272,15 @@ class SavepointTest(TransactionTest): u1.name = 'edward' u2.name = 'jackward' s.add_all([u3, u4]) - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) s.rollback() assert u1.name == 'ed' assert u2.name == 'jack' - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) s.commit() assert u1.name == 'ed' assert u2.name == 'jack' - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)]) @testing.requires.savepoints def test_savepoint_delete(self): @@ -287,11 +288,11 @@ class SavepointTest(TransactionTest): u1 = User(name='ed') s.add(u1) s.commit() - self.assertEquals(s.query(User).filter_by(name='ed').count(), 1) + eq_(s.query(User).filter_by(name='ed').count(), 1) s.begin_nested() s.delete(u1) s.commit() - self.assertEquals(s.query(User).filter_by(name='ed').count(), 0) + eq_(s.query(User).filter_by(name='ed').count(), 0) s.commit() @testing.requires.savepoints @@ -307,16 +308,16 @@ class SavepointTest(TransactionTest): u1.name = 'edward' u2.name = 'jackward' s.add_all([u3, u4]) - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) s.commit() def go(): assert u1.name == 'edward' assert u2.name == 'jackward' - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) self.assert_sql_count(testing.db, go, 1) s.commit() - self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) + eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)]) @testing.requires.savepoints def test_savepoint_rollback_collections(self): @@ -330,20 +331,20 @@ class SavepointTest(TransactionTest): s.begin_nested() u2 = User(name='jack', addresses=[Address(email_address='bat')]) s.add(u2) - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), User(name='jack', addresses=[Address(email_address='bat')]) ] ) s.rollback() - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), ] ) s.commit() - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), ] @@ -361,21 +362,21 @@ class SavepointTest(TransactionTest): s.begin_nested() u2 = User(name='jack', addresses=[Address(email_address='bat')]) s.add(u2) - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), User(name='jack', addresses=[Address(email_address='bat')]) ] ) s.commit() - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), User(name='jack', addresses=[Address(email_address='bat')]) ] ) s.commit() - self.assertEquals(s.query(User).order_by(User.id).all(), + eq_(s.query(User).order_by(User.id).all(), [ User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]), User(name='jack', addresses=[Address(email_address='bat')]) @@ -476,7 +477,7 @@ class AccountingFlagsTest(TransactionTest): class AutoCommitTest(TransactionTest): def test_begin_nested_requires_trans(self): sess = create_session(autocommit=True) - self.assertRaises(sa_exc.InvalidRequestError, sess.begin_nested) + assert_raises(sa_exc.InvalidRequestError, sess.begin_nested) def test_begin_preflush(self): sess = create_session(autocommit=True) @@ -495,5 +496,3 @@ class AutoCommitTest(TransactionTest): -if __name__ == '__main__': - testenv.main() diff --git a/test/orm/unitofwork.py b/test/orm/test_unitofwork.py similarity index 95% rename from test/orm/unitofwork.py rename to test/orm/test_unitofwork.py index c5e3afd014..f95346902b 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -1,19 +1,21 @@ # coding: utf-8 """Tests unitofwork operations.""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import datetime import operator from sqlalchemy.orm import mapper as orm_mapper -from testlib import engines, sa, testing -from testlib.sa import Table, Column, Integer, String, ForeignKey, literal_column -from testlib.sa.orm import mapper, relation, create_session, column_property -from testlib.testing import eq_, ne_ -from orm import _base, _fixtures -from engine import _base as engine_base -import pickleable -from testlib.assertsql import AllOf, CompiledSQL +import sqlalchemy as sa +from sqlalchemy.test import engines, testing, pickleable +from sqlalchemy import Integer, String, ForeignKey, literal_column +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +from sqlalchemy.orm import mapper, relation, create_session, column_property +from sqlalchemy.test.testing import eq_, ne_ +from test.orm import _base, _fixtures +from test.engine import _base as engine_base +from sqlalchemy.test.assertsql import AllOf, CompiledSQL import gc class UnitOfWorkTest(object): @@ -22,7 +24,8 @@ class UnitOfWorkTest(object): class HistoryTest(_fixtures.FixtureTest): run_inserts = None - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass class Address(_base.ComparableEntity): @@ -51,14 +54,16 @@ class HistoryTest(_fixtures.FixtureTest): class VersioningTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('version_table', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('version_id', Integer, nullable=False), Column('value', String(40), nullable=False)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Foo(_base.ComparableEntity): pass @@ -86,7 +91,7 @@ class VersioningTest(_base.MappedTest): # Only dialects with a sane rowcount can detect the # ConcurrentModificationError if testing.db.dialect.supports_sane_rowcount: - self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.commit) + assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit) s1.rollback() else: s1.commit() @@ -102,7 +107,7 @@ class VersioningTest(_base.MappedTest): s1.delete(f2) if testing.db.dialect.supports_sane_multi_rowcount: - self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.commit) + assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit) else: s1.commit() @@ -124,7 +129,7 @@ class VersioningTest(_base.MappedTest): s2.commit() # load, version is wrong - self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id) + assert_raises(sa.orm.exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id) # reload it s1.query(Foo).populate_existing().get(f1s1.id) @@ -153,7 +158,8 @@ class VersioningTest(_base.MappedTest): class UnicodeTest(_base.MappedTest): __requires__ = ('unicode_connections',) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('uni_t1', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -163,7 +169,8 @@ class UnicodeTest(_base.MappedTest): test_needs_autoincrement=True), Column('txt', sa.Unicode(50), ForeignKey('uni_t1'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Test(_base.BasicEntity): pass class Test2(_base.BasicEntity): @@ -205,10 +212,12 @@ class UnicodeTest(_base.MappedTest): class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest): __requires__ = ('unicode_connections', 'unicode_ddl',) - def create_engine(self): + @classmethod + def create_engine(cls): return engines.utf8_engine() - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): t1 = Table('unitable1', metadata, Column(u'méil', Integer, primary_key=True, key='a', test_needs_autoincrement=True), Column(u'\u6e2c\u8a66', Integer, key='b'), @@ -223,16 +232,16 @@ class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest): test_needs_fk=True, test_needs_autoincrement=True) - self.tables['t1'] = t1 - self.tables['t2'] = t2 + cls.tables['t1'] = t1 + cls.tables['t2'] = t2 - def setUpAll(self): - engine_base.AltEngineTest.setUpAll(self) - _base.MappedTest.setUpAll(self) + @classmethod + def setup_class(cls): + super(UnicodeSchemaTest, cls).setup_class() - def tearDownAll(self): - _base.MappedTest.tearDownAll(self) - engine_base.AltEngineTest.tearDownAll(self) + @classmethod + def teardown_class(cls): + super(UnicodeSchemaTest, cls).teardown_class() @testing.fails_on('mssql', 'pyodbc returns a non unicode encoding of the results description.') @testing.resolve_artifact_names @@ -298,19 +307,22 @@ class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest): class MutableTypesTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('mutable_t', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', sa.PickleType), Column('val', sa.Unicode(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Foo(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Foo, mutable_t) @testing.resolve_artifact_names @@ -433,20 +445,23 @@ class MutableTypesTest(_base.MappedTest): self.sql_count_(0, session.commit) -class PickledDicts(_base.MappedTest): +class PickledDictsTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('mutable_t', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', sa.PickleType(comparator=operator.eq))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Foo(_base.BasicEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(Foo, mutable_t) @testing.resolve_artifact_names @@ -519,7 +534,8 @@ class PickledDicts(_base.MappedTest): class PKTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('multipk1', metadata, Column('multi_id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -537,7 +553,8 @@ class PKTest(_base.MappedTest): Column('date_assigned', sa.Date, key='assigned', primary_key=True), Column('data', String(30))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Entry(_base.BasicEntity): pass @@ -587,7 +604,8 @@ class PKTest(_base.MappedTest): class ForeignPKTest(_base.MappedTest): """Detection of the relationship direction on PK joins.""" - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("people", metadata, Column('person', String(10), primary_key=True), Column('firstname', String(10)), @@ -598,7 +616,8 @@ class ForeignPKTest(_base.MappedTest): primary_key=True), Column('site', String(10))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Person(_base.BasicEntity): pass class PersonSite(_base.BasicEntity): @@ -629,19 +648,22 @@ class ForeignPKTest(_base.MappedTest): class ClauseAttributesTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('users_t', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30)), Column('counter', Integer, default=1)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class User(_base.ComparableEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): mapper(User, users_t) @testing.resolve_artifact_names @@ -697,7 +719,8 @@ class ClauseAttributesTest(_base.MappedTest): class PassiveDeletesTest(_base.MappedTest): __requires__ = ('foreign_keys',) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('mytable', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), @@ -712,7 +735,8 @@ class PassiveDeletesTest(_base.MappedTest): ondelete="CASCADE"), test_needs_fk=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class MyClass(_base.BasicEntity): pass class MyOtherClass(_base.BasicEntity): @@ -773,7 +797,8 @@ class PassiveDeletesTest(_base.MappedTest): class ExtraPassiveDeletesTest(_base.MappedTest): __requires__ = ('foreign_keys',) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('mytable', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(30)), @@ -788,7 +813,8 @@ class ExtraPassiveDeletesTest(_base.MappedTest): ['mytable.id']), test_needs_fk=True) - def setup_classes(self): + @classmethod + def setup_classes(cls): class MyClass(_base.BasicEntity): pass class MyOtherClass(_base.BasicEntity): @@ -829,7 +855,7 @@ class ExtraPassiveDeletesTest(_base.MappedTest): assert myothertable.count().scalar() == 4 mc = session.query(MyClass).get(mc.id) session.delete(mc) - self.assertRaises(sa.exc.DBAPIError, session.flush) + assert_raises(sa.exc.DBAPIError, session.flush) @testing.resolve_artifact_names def test_extra_passive_2(self): @@ -851,7 +877,7 @@ class ExtraPassiveDeletesTest(_base.MappedTest): mc = session.query(MyClass).get(mc.id) session.delete(mc) mc.children[0].data = 'some new data' - self.assertRaises(sa.exc.DBAPIError, session.flush) + assert_raises(sa.exc.DBAPIError, session.flush) class DefaultTest(_base.MappedTest): @@ -864,7 +890,8 @@ class DefaultTest(_base.MappedTest): """ - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): use_string_defaults = testing.against('postgres', 'oracle', 'sqlite', 'mssql') if use_string_defaults: @@ -876,8 +903,8 @@ class DefaultTest(_base.MappedTest): hohoval = 9 althohoval = 15 - self.other_artifacts['hohoval'] = hohoval - self.other_artifacts['althohoval'] = althohoval + cls.other_artifacts['hohoval'] = hohoval + cls.other_artifacts['althohoval'] = althohoval dt = Table('default_t', metadata, Column('id', Integer, primary_key=True, @@ -906,7 +933,8 @@ class DefaultTest(_base.MappedTest): st.append_column( Column('hoho', hohotype, ForeignKey('default_t.hoho'))) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Hoho(_base.ComparableEntity): pass class Secondary(_base.ComparableEntity): @@ -1037,7 +1065,8 @@ class DefaultTest(_base.MappedTest): Secondary(data='s2')])) class ColumnPropertyTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('data', metadata, Column('id', Integer, primary_key=True), Column('a', String(50)), @@ -1049,7 +1078,8 @@ class ColumnPropertyTest(_base.MappedTest): Column('c', String(50)), ) - def setup_mappers(self): + @classmethod + def setup_mappers(cls): class Data(_base.BasicEntity): pass @@ -1079,7 +1109,7 @@ class ColumnPropertyTest(_base.MappedTest): sd1 = SubData(a="hello", b="there", c="hi") sess.add(sd1) sess.flush() - self.assertEquals(sd1.aplusb, "hello there") + eq_(sd1.aplusb, "hello there") @testing.resolve_artifact_names def _test(self): @@ -1089,16 +1119,16 @@ class ColumnPropertyTest(_base.MappedTest): sess.add(d1) sess.flush() - self.assertEquals(d1.aplusb, "hello there") + eq_(d1.aplusb, "hello there") d1.b = "bye" sess.flush() - self.assertEquals(d1.aplusb, "hello bye") + eq_(d1.aplusb, "hello bye") d1.b = 'foobar' d1.aplusb = 'im setting this explicitly' sess.flush() - self.assertEquals(d1.aplusb, "im setting this explicitly") + eq_(d1.aplusb, "im setting this explicitly") class OneToManyTest(_fixtures.FixtureTest): run_inserts = None @@ -1596,7 +1626,7 @@ class SaveTest(_fixtures.FixtureTest): u1 = User(name='user1') u2 = User(name='user2') session.add_all((u1, u2)) - self.assertRaises(AssertionError, session.flush) + assert_raises(AssertionError, session.flush) class ManyToOneTest(_fixtures.FixtureTest): @@ -2029,7 +2059,8 @@ class SaveTest2(_fixtures.FixtureTest): ) class SaveTest3(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('items', metadata, Column('item_id', Integer, primary_key=True, test_needs_autoincrement=True), @@ -2045,7 +2076,8 @@ class SaveTest3(_base.MappedTest): Column('keyword_id', Integer, ForeignKey("keywords")), Column('foo', sa.Boolean, default=True)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class Keyword(_base.BasicEntity): pass class Item(_base.BasicEntity): @@ -2076,7 +2108,8 @@ class SaveTest3(_base.MappedTest): assert assoc.count().scalar() == 0 class BooleanColTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('t1_t', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30)), @@ -2118,7 +2151,8 @@ class BooleanColTest(_base.MappedTest): class RowSwitchTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): # parent Table('t5', metadata, Column('id', Integer, primary_key=True), @@ -2140,7 +2174,8 @@ class RowSwitchTest(_base.MappedTest): Column('t5id', Integer, ForeignKey('t5.id'),nullable=False), Column('t7id', Integer, ForeignKey('t7.id'),nullable=False)) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T5(_base.ComparableEntity): pass @@ -2240,7 +2275,8 @@ class RowSwitchTest(_base.MappedTest): assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some other t6', 2)] class InheritingRowSwitchTest(_base.MappedTest): - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table('parent', metadata, Column('id', Integer, primary_key=True), Column('pdata', String(30)) @@ -2251,7 +2287,8 @@ class InheritingRowSwitchTest(_base.MappedTest): Column('cdata', String(30)) ) - def setup_classes(self): + @classmethod + def setup_classes(cls): class P(_base.ComparableEntity): pass @@ -2290,7 +2327,8 @@ class TransactionTest(_base.MappedTest): # be specified. it'll raise immediately post-INSERT, instead of at # COMMIT. either way, this test should pass. - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): t1 = Table('t1', metadata, Column('id', Integer, primary_key=True)) @@ -2299,15 +2337,17 @@ class TransactionTest(_base.MappedTest): Column('t1_id', Integer, ForeignKey('t1.id', deferrable=True, initially='deferred') )) - def setup_classes(self): + @classmethod + def setup_classes(cls): class T1(_base.ComparableEntity): pass class T2(_base.ComparableEntity): pass + @classmethod @testing.resolve_artifact_names - def setup_mappers(self): + def setup_mappers(cls): orm_mapper(T1, t1) orm_mapper(T2, t2) @@ -2332,5 +2372,3 @@ class TransactionTest(_base.MappedTest): if testing.against('postgres'): t1.bind.engine.dispose() -if __name__ == "__main__": - testenv.main() diff --git a/test/orm/utils.py b/test/orm/test_utils.py similarity index 95% rename from test/orm/utils.py rename to test/orm/test_utils.py index 813121a446..06533a243b 100644 --- a/test/orm/utils.py +++ b/test/orm/test_utils.py @@ -1,4 +1,4 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import assert_raises, assert_raises_message from sqlalchemy.orm import interfaces, util from sqlalchemy import Column from sqlalchemy import Integer @@ -8,10 +8,10 @@ from sqlalchemy.orm import aliased from sqlalchemy.orm import mapper, create_session -from testlib import TestBase, testing +from sqlalchemy.test import TestBase, testing -from orm import _fixtures -from testlib.testing import eq_ +from test.orm import _fixtures +from sqlalchemy.test.testing import eq_ class ExtensionCarrierTest(TestBase): @@ -22,7 +22,7 @@ class ExtensionCarrierTest(TestBase): assert carrier.translate_row() is interfaces.EXT_CONTINUE assert 'translate_row' not in carrier - self.assertRaises(AttributeError, lambda: carrier.snickysnack) + assert_raises(AttributeError, lambda: carrier.snickysnack) class Partial(object): def __init__(self, marker): @@ -74,7 +74,7 @@ class AliasedClassTest(TestBase): table = self.point_map(Point) alias = aliased(Point) - self.assertRaises(TypeError, alias) + assert_raises(TypeError, alias) def test_instancemethods(self): class Point(object): @@ -236,6 +236,4 @@ class IdentityKeyTest(_fixtures.FixtureTest): key = util.identity_key(User, row=row) eq_(key, (User, (1,))) -if __name__ == '__main__': - testenv.main() diff --git a/test/profiling/alltests.py b/test/profiling/alltests.py deleted file mode 100644 index 19401098c1..0000000000 --- a/test/profiling/alltests.py +++ /dev/null @@ -1,26 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - - -def suite(): - modules_to_test = ( - 'profiling.memusage', - 'profiling.compiler', - 'profiling.pool', - 'profiling.zoomark', - 'profiling.zoomark_orm', - ) - alltests = unittest.TestSuite() - if testenv.testlib.config.coverage_enabled: - return alltests - - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/sql/_base.py b/test/sql/_base.py index c1a107eeb3..48879ae7e3 100644 --- a/test/sql/_base.py +++ b/test/sql/_base.py @@ -1,4 +1,4 @@ -from engine import _base as engine_base +from test.engine import _base as engine_base TablesTest = engine_base.TablesTest diff --git a/test/sql/alltests.py b/test/sql/alltests.py deleted file mode 100644 index f01b0e6202..0000000000 --- a/test/sql/alltests.py +++ /dev/null @@ -1,38 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - - -def suite(): - modules_to_test = ( - 'sql.testtypes', - 'sql.columns', - 'sql.constraints', - - 'sql.generative', - - # SQL syntax - 'sql.select', - 'sql.selectable', - 'sql.case_statement', - 'sql.labels', - 'sql.unicode', - - # assorted round-trip tests - 'sql.functions', - 'sql.query', - 'sql.quote', - 'sql.rowcount', - - # defaults, sequences (postgres/oracle) - 'sql.defaults', - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/sql/case_statement.py b/test/sql/test_case_statement.py similarity index 94% rename from test/sql/case_statement.py rename to test/sql/test_case_statement.py index 1d53837495..3f3abe7e19 100644 --- a/test/sql/case_statement.py +++ b/test/sql/test_case_statement.py @@ -1,14 +1,15 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import assert_raises, assert_raises_message import sys from sqlalchemy import * -from testlib import * +from sqlalchemy.test import * from sqlalchemy import util, exc from sqlalchemy.sql import table, column class CaseTest(TestBase, AssertsCompiledSQL): - def setUpAll(self): + @classmethod + def setup_class(cls): metadata = MetaData(testing.db) global info_table info_table = Table('infos', metadata, @@ -24,7 +25,8 @@ class CaseTest(TestBase, AssertsCompiledSQL): {'pk':4, 'info':'pk_4_data'}, {'pk':5, 'info':'pk_5_data'}, {'pk':6, 'info':'pk_6_data'}) - def tearDownAll(self): + @classmethod + def teardown_class(cls): info_table.drop() @testing.fails_on('firebird', 'FIXME: unknown') @@ -93,7 +95,7 @@ class CaseTest(TestBase, AssertsCompiledSQL): def test_literal_interpretation(self): t = table('test', column('col1')) - self.assertRaises(exc.ArgumentError, case, [("x", "y")]) + assert_raises(exc.ArgumentError, case, [("x", "y")]) self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END") self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END") @@ -133,5 +135,3 @@ class CaseTest(TestBase, AssertsCompiledSQL): ('other', 3), ] -if __name__ == "__main__": - testenv.main() diff --git a/test/sql/columns.py b/test/sql/test_columns.py similarity index 79% rename from test/sql/columns.py rename to test/sql/test_columns.py index 661be891ae..e9dabe1421 100644 --- a/test/sql/columns.py +++ b/test/sql/test_columns.py @@ -1,7 +1,7 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy import exc, sql -from testlib import * +from sqlalchemy.test import * from sqlalchemy import Table, Column # don't use testlib's wrappers @@ -37,7 +37,7 @@ class ColumnDefinitionTest(TestBase): def test_incomplete(self): c = self.columns() - self.assertRaises(exc.ArgumentError, Table, 't', MetaData(), *c) + assert_raises(exc.ArgumentError, Table, 't', MetaData(), *c) def test_incomplete_key(self): c = Column(Integer) @@ -52,9 +52,7 @@ class ColumnDefinitionTest(TestBase): def test_bogus(self): - self.assertRaises(exc.ArgumentError, Column, 'foo', name='bar') - self.assertRaises(exc.ArgumentError, Column, 'foo', Integer, + assert_raises(exc.ArgumentError, Column, 'foo', name='bar') + assert_raises(exc.ArgumentError, Column, 'foo', Integer, type_=Integer()) -if __name__ == "__main__": - testenv.main() diff --git a/test/sql/constraints.py b/test/sql/test_constraints.py similarity index 94% rename from test/sql/constraints.py rename to test/sql/test_constraints.py index d019aa0378..8abeb35338 100644 --- a/test/sql/constraints.py +++ b/test/sql/test_constraints.py @@ -1,16 +1,16 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy import exc -from testlib import * -from testlib import config, engines +from sqlalchemy.test import * +from sqlalchemy.test import config, engines class ConstraintTest(TestBase, AssertsExecutionResults): - def setUp(self): + def setup(self): global metadata metadata = MetaData(testing.db) - def tearDown(self): + def teardown(self): metadata.drop_all() def test_constraint(self): @@ -33,7 +33,7 @@ class ConstraintTest(TestBase, AssertsExecutionResults): def test_double_fk_usage_raises(self): f = ForeignKey('b.id') - self.assertRaises(exc.InvalidRequestError, Table, "a", metadata, + assert_raises(exc.InvalidRequestError, Table, "a", metadata, Column('x', Integer, f), Column('y', Integer, f) ) @@ -219,14 +219,14 @@ class ConstraintTest(TestBase, AssertsExecutionResults): t1 = Table("sometable", MetaData(), Column("foo", Integer)) schemagen.visit_index(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)) - self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)") + eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)") schemagen.buffer.truncate(0) schemagen.visit_index(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo)) - self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)") + eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)") schemadrop = dialect.schemadropper(dialect, None) schemadrop.execute = lambda: None - self.assertRaises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)) + assert_raises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo)) class ConstraintCompilationTest(TestBase, AssertsExecutionResults): @@ -245,7 +245,7 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults): def clear(self): del self.statements[:] - def setUp(self): + def setup(self): self.sql = self.accum() opts = config.db_opts.copy() opts['strategy'] = 'mock' @@ -333,5 +333,3 @@ class ConstraintCompilationTest(TestBase, AssertsExecutionResults): assert 'INITIALLY DEFERRED' in self.sql, self.sql -if __name__ == "__main__": - testenv.main() diff --git a/test/sql/defaults.py b/test/sql/test_defaults.py similarity index 96% rename from test/sql/defaults.py rename to test/sql/test_defaults.py index bea6dc04be..9641574665 100644 --- a/test/sql/defaults.py +++ b/test/sql/test_defaults.py @@ -1,16 +1,19 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import datetime from sqlalchemy import Sequence, Column, func from sqlalchemy.sql import select, text -from testlib import sa, testing -from testlib.sa import MetaData, Table, Integer, String, ForeignKey, Boolean -from testlib.testing import eq_ -from sql import _base +import sqlalchemy as sa +from sqlalchemy.test import testing +from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean +from sqlalchemy.test.schema import Table +from sqlalchemy.test.testing import eq_ +from test.sql import _base class DefaultTest(testing.TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global t, f, f2, ts, currenttime, metadata, default_generator db = testing.db @@ -117,10 +120,11 @@ class DefaultTest(testing.TestBase): server_default='ddl')) t.create() - def tearDownAll(self): + @classmethod + def teardown_class(cls): t.drop() - def tearDown(self): + def teardown(self): default_generator['x'] = 50 t.delete().execute() @@ -139,7 +143,7 @@ class DefaultTest(testing.TestBase): fn4 = FN4() for fn in fn1, fn2, fn3, fn4: - self.assertRaisesMessage(sa.exc.ArgumentError, + assert_raises_message(sa.exc.ArgumentError, ex_msg, sa.ColumnDefault, fn) @@ -387,7 +391,8 @@ class DefaultTest(testing.TestBase): class PKDefaultTest(_base.TablesTest): __requires__ = ('subqueries',) - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): t2 = Table('t2', metadata, Column('nextid', Integer)) @@ -411,7 +416,8 @@ class PKDefaultTest(_base.TablesTest): class PKIncrementTest(_base.TablesTest): run_define_tables = 'each' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): Table("aitable", metadata, Column('id', Integer, Sequence('ai_id_seq', optional=True), primary_key=True), @@ -484,8 +490,8 @@ class EmptyInsertTest(testing.TestBase): try: result = t1.insert().execute() - self.assertEquals(1, select([func.count(text('*'))], from_obj=t1).scalar()) - self.assertEquals(True, t1.select().scalar()) + eq_(1, select([func.count(text('*'))], from_obj=t1).scalar()) + eq_(True, t1.select().scalar()) finally: metadata.drop_all() @@ -493,7 +499,8 @@ class AutoIncrementTest(_base.TablesTest): __requires__ = ('identity',) run_define_tables = 'each' - def define_tables(self, metadata): + @classmethod + def define_tables(cls, metadata): """Each test manipulates self.metadata individually.""" @testing.exclude('sqlite', '<', (3, 4), 'no database support') @@ -542,7 +549,8 @@ class AutoIncrementTest(_base.TablesTest): class SequenceTest(testing.TestBase): __requires__ = ('sequences',) - def setUpAll(self): + @classmethod + def setup_class(cls): global cartitems, sometable, metadata metadata = MetaData(testing.db) cartitems = Table("cartitems", metadata, @@ -626,9 +634,8 @@ class SequenceTest(testing.TestBase): x = cartitems.c.cart_id.sequence.execute() self.assert_(1 <= x <= 4) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() -if __name__ == "__main__": - testenv.main() diff --git a/test/sql/functions.py b/test/sql/test_functions.py similarity index 97% rename from test/sql/functions.py rename to test/sql/test_functions.py index 17d8a35e97..e9bf49ce30 100644 --- a/test/sql/functions.py +++ b/test/sql/test_functions.py @@ -1,15 +1,15 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_ import datetime from sqlalchemy import * from sqlalchemy.sql import table, column from sqlalchemy import databases, sql, util from sqlalchemy.sql.compiler import BIND_TEMPLATES from sqlalchemy.engine import default -from testlib.engines import all_dialects +from sqlalchemy.test.engines import all_dialects from sqlalchemy import types as sqltypes -from testlib import * +from sqlalchemy.test import * from sqlalchemy.sql.functions import GenericFunction -from testlib.testing import eq_ +from sqlalchemy.test.testing import eq_ from decimal import Decimal as _python_Decimal from sqlalchemy.databases import * @@ -237,7 +237,7 @@ class ExecuteTest(TestBase): t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi") res = exec_sorted(select([t2.c.value, t2.c.stuff])) - self.assertEquals(res, [(-14, 'hi'), (3, None), (7, None)]) + eq_(res, [(-14, 'hi'), (3, None), (7, None)]) t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff") assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")] @@ -315,5 +315,3 @@ def exec_sorted(statement, *args, **kw): return sorted([tuple(row) for row in statement.execute(*args, **kw).fetchall()]) -if __name__ == '__main__': - testenv.main() diff --git a/test/sql/generative.py b/test/sql/test_generative.py similarity index 99% rename from test/sql/generative.py rename to test/sql/test_generative.py index 3947a450fe..ca427ca5f5 100644 --- a/test/sql/generative.py +++ b/test/sql/test_generative.py @@ -1,8 +1,7 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.sql import table, column, ClauseElement from sqlalchemy.sql.expression import _clone, _from_objects -from testlib import * +from sqlalchemy.test import * from sqlalchemy.sql.visitors import * from sqlalchemy import util from sqlalchemy.sql import util as sql_util @@ -12,7 +11,8 @@ class TraversalTest(TestBase, AssertsExecutionResults): """test ClauseVisitor's traversal, particularly its ability to copy and modify a ClauseElement in place.""" - def setUpAll(self): + @classmethod + def setup_class(cls): global A, B # establish two ficticious ClauseElements. @@ -162,7 +162,8 @@ class TraversalTest(TestBase, AssertsExecutionResults): class ClauseTest(TestBase, AssertsCompiledSQL): """test copy-in-place behavior of various ClauseElements.""" - def setUpAll(self): + @classmethod + def setup_class(cls): global t1, t2 t1 = table("table1", column("col1"), @@ -361,7 +362,8 @@ class ClauseTest(TestBase, AssertsCompiledSQL): self.assert_compile(s2, "SELECT 1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1 WHERE table1.col1 = anon_1.col1") class ClauseAdapterTest(TestBase, AssertsCompiledSQL): - def setUpAll(self): + @classmethod + def setup_class(cls): global t1, t2 t1 = table("table1", column("col1"), @@ -630,7 +632,8 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL): ) class SpliceJoinsTest(TestBase, AssertsCompiledSQL): - def setUpAll(self): + @classmethod + def setup_class(cls): global table1, table2, table3, table4 def _table(name): return table(name, column("col1"), column("col2"),column("col3")) @@ -691,7 +694,8 @@ class SpliceJoinsTest(TestBase, AssertsCompiledSQL): class SelectTest(TestBase, AssertsCompiledSQL): """tests the generative capability of Select""" - def setUpAll(self): + @classmethod + def setup_class(cls): global t1, t2 t1 = table("table1", column("col1"), @@ -772,7 +776,8 @@ class InsertTest(TestBase, AssertsCompiledSQL): # fixme: consolidate converage from elsewhere here and expand - def setUpAll(self): + @classmethod + def setup_class(cls): global t1, t2 t1 = table("table1", column("col1"), @@ -811,5 +816,3 @@ class InsertTest(TestBase, AssertsCompiledSQL): "table1 (col1, col2, col3) " "VALUES (:col1, :col2, :col3)") -if __name__ == '__main__': - testenv.main() diff --git a/test/sql/labels.py b/test/sql/test_labels.py similarity index 95% rename from test/sql/labels.py rename to test/sql/test_labels.py index 94ee20342e..b946b0ae98 100644 --- a/test/sql/labels.py +++ b/test/sql/test_labels.py @@ -1,7 +1,7 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy import exc as exceptions -from testlib import * +from sqlalchemy.test import * from sqlalchemy.engine import default IDENT_LENGTH = 29 @@ -16,7 +16,8 @@ class LabelTypeTest(TestBase): assert isinstance(select([t.c.col2]).as_scalar().label('lala').type, Float) class LongLabelsTest(TestBase, AssertsCompiledSQL): - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, table1, table2, maxlen metadata = MetaData(testing.db) table1 = Table("some_large_named_table", metadata, @@ -34,20 +35,21 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL): maxlen = testing.db.dialect.max_identifier_length testing.db.dialect.max_identifier_length = IDENT_LENGTH - def tearDown(self): + def teardown(self): table1.delete().execute() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() testing.db.dialect.max_identifier_length = maxlen def test_too_long_name_disallowed(self): m = MetaData(testing.db) t1 = Table("this_name_is_too_long_for_what_were_doing_in_this_test", m, Column('foo', Integer)) - self.assertRaises(exceptions.IdentifierError, m.create_all) - self.assertRaises(exceptions.IdentifierError, m.drop_all) - self.assertRaises(exceptions.IdentifierError, t1.create) - self.assertRaises(exceptions.IdentifierError, t1.drop) + assert_raises(exceptions.IdentifierError, m.create_all) + assert_raises(exceptions.IdentifierError, m.drop_all) + assert_raises(exceptions.IdentifierError, t1.create) + assert_raises(exceptions.IdentifierError, t1.drop) def test_result(self): table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"}) @@ -191,5 +193,3 @@ class LongLabelsTest(TestBase, AssertsCompiledSQL): "FROM some_large_named_table WHERE some_large_named_table.this_is_the_primarykey_column = :_1) AS _1", dialect=compile_dialect) -if __name__ == '__main__': - testenv.main() diff --git a/test/sql/query.py b/test/sql/test_query.py similarity index 94% rename from test/sql/query.py rename to test/sql/test_query.py index b428d8991c..c9305b615f 100644 --- a/test/sql/query.py +++ b/test/sql/test_query.py @@ -1,14 +1,14 @@ -import testenv; testenv.configure_for_tests() import datetime from sqlalchemy import * from sqlalchemy import exc, sql from sqlalchemy.engine import default -from testlib import * -from testlib.testing import eq_ +from sqlalchemy.test import * +from sqlalchemy.test.testing import eq_ class QueryTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global users, users2, addresses, metadata metadata = MetaData(testing.db) users = Table('query_users', metadata, @@ -31,7 +31,8 @@ class QueryTest(TestBase): users.delete().execute() users2.delete().execute() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def test_insert(self): @@ -174,19 +175,19 @@ class QueryTest(TestBase): ) concat = ("test: " + users.c.user_name).label('thedata') - self.assertEquals( + eq_( select([concat]).order_by(concat).execute().fetchall(), [("test: ed",), ("test: fred",), ("test: jack",)] ) concat = ("test: " + users.c.user_name).label('thedata') - self.assertEquals( + eq_( select([concat]).order_by(desc(concat)).execute().fetchall(), [("test: jack",), ("test: fred",), ("test: ed",)] ) concat = ("test: " + users.c.user_name).label('thedata') - self.assertEquals( + eq_( select([concat]).order_by(concat + "x").execute().fetchall(), [("test: ed",), ("test: fred",), ("test: jack",)] ) @@ -211,11 +212,11 @@ class QueryTest(TestBase): def test_or_and_as_columns(self): true, false = literal(True), literal(False) - self.assertEquals(testing.db.execute(select([and_(true, false)])).scalar(), False) - self.assertEquals(testing.db.execute(select([and_(true, true)])).scalar(), True) - self.assertEquals(testing.db.execute(select([or_(true, false)])).scalar(), True) - self.assertEquals(testing.db.execute(select([or_(false, false)])).scalar(), False) - self.assertEquals(testing.db.execute(select([not_(or_(false, false))])).scalar(), True) + eq_(testing.db.execute(select([and_(true, false)])).scalar(), False) + eq_(testing.db.execute(select([and_(true, true)])).scalar(), True) + eq_(testing.db.execute(select([or_(true, false)])).scalar(), True) + eq_(testing.db.execute(select([or_(false, false)])).scalar(), False) + eq_(testing.db.execute(select([not_(or_(false, false))])).scalar(), True) row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).fetchone() assert row.x == False @@ -272,13 +273,13 @@ class QueryTest(TestBase): {'user_id':4, 'user_name':'OnE'}, ) - self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('one')).execute().fetchall(), [(1, ), (3, ), (4, )]) + eq_(select([users.c.user_id]).where(users.c.user_name.ilike('one')).execute().fetchall(), [(1, ), (3, ), (4, )]) - self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )]) + eq_(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )]) if testing.against('postgres'): - self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )]) - self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), []) + eq_(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )]) + eq_(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), []) def test_compiled_execute(self): @@ -388,7 +389,7 @@ class QueryTest(TestBase): def a_eq(executable, wanted): got = list(executable.execute()) - self.assertEquals(got, wanted) + eq_(got, wanted) for labels in False, True: a_eq(users.select(order_by=[users.c.user_id], @@ -511,23 +512,23 @@ class QueryTest(TestBase): def test_keys(self): users.insert().execute(user_id=1, user_name='foo') r = users.select().execute().fetchone() - self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name']) + eq_([x.lower() for x in r.keys()], ['user_id', 'user_name']) def test_items(self): users.insert().execute(user_id=1, user_name='foo') r = users.select().execute().fetchone() - self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')]) + eq_([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')]) def test_len(self): users.insert().execute(user_id=1, user_name='foo') r = users.select().execute().fetchone() - self.assertEqual(len(r), 2) + eq_(len(r), 2) r.close() r = testing.db.execute('select user_name, user_id from query_users').fetchone() - self.assertEqual(len(r), 2) + eq_(len(r), 2) r.close() r = testing.db.execute('select user_name from query_users').fetchone() - self.assertEqual(len(r), 1) + eq_(len(r), 1) r.close() def test_cant_execute_join(self): @@ -542,19 +543,19 @@ class QueryTest(TestBase): # should return values in column definition order users.insert().execute(user_id=1, user_name='foo') r = users.select(users.c.user_id==1).execute().fetchone() - self.assertEqual(r[0], 1) - self.assertEqual(r[1], 'foo') - self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name']) - self.assertEqual(r.values(), [1, 'foo']) + eq_(r[0], 1) + eq_(r[1], 'foo') + eq_([x.lower() for x in r.keys()], ['user_id', 'user_name']) + eq_(r.values(), [1, 'foo']) def test_column_order_with_text_query(self): # should return values in query order users.insert().execute(user_id=1, user_name='foo') r = testing.db.execute('select user_name, user_id from query_users').fetchone() - self.assertEqual(r[0], 'foo') - self.assertEqual(r[1], 1) - self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id']) - self.assertEqual(r.values(), ['foo', 1]) + eq_(r[0], 'foo') + eq_(r[1], 1) + eq_([x.lower() for x in r.keys()], ['user_name', 'user_id']) + eq_(r.values(), ['foo', 1]) @testing.crashes('oracle', 'FIXME: unknown, varify not fails_on()') @testing.crashes('firebird', 'An identifier must begin with a letter') @@ -657,9 +658,10 @@ class PercentSchemaNamesTest(TestBase): operation the same way we do for text() and column labels. """ + @classmethod @testing.crashes('mysql', 'mysqldb calls name % (params)') @testing.crashes('postgres', 'postgres calls name % (params)') - def setUpAll(self): + def setup_class(cls): global percent_table, metadata metadata = MetaData(testing.db) percent_table = Table('percent%table', metadata, @@ -669,9 +671,10 @@ class PercentSchemaNamesTest(TestBase): ) metadata.create_all() + @classmethod @testing.crashes('mysql', 'mysqldb calls name % (params)') @testing.crashes('postgres', 'postgres calls name % (params)') - def tearDownAll(self): + def teardown_class(cls): metadata.drop_all() @testing.crashes('mysql', 'mysqldb calls name % (params)') @@ -734,7 +737,8 @@ class PercentSchemaNamesTest(TestBase): class LimitTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global users, addresses, metadata metadata = MetaData(testing.db) users = Table('query_users', metadata, @@ -746,9 +750,7 @@ class LimitTest(TestBase): Column('user_id', Integer, ForeignKey('query_users.user_id')), Column('address', String(30))) metadata.create_all() - self._data() - - def _data(self): + users.insert().execute(user_id=1, user_name='john') addresses.insert().execute(address_id=1, user_id=1, address='addr1') users.insert().execute(user_id=2, user_name='jack') @@ -764,7 +766,8 @@ class LimitTest(TestBase): users.insert().execute(user_id=7, user_name='fido') addresses.insert().execute(address_id=7, user_id=7, address='addr5') - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def test_select_limit(self): @@ -805,7 +808,8 @@ class LimitTest(TestBase): class CompoundTest(TestBase): """test compound statements like UNION, INTERSECT, particularly their ability to nest on different databases.""" - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, t1, t2, t3 metadata = MetaData(testing.db) t1 = Table('t1', metadata, @@ -842,7 +846,8 @@ class CompoundTest(TestBase): dict(col2="t3col2r3", col3="ccc", col4="bbb"), ]) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def _fetchall_sorted(self, executed): @@ -861,10 +866,10 @@ class CompoundTest(TestBase): wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] found1 = self._fetchall_sorted(u.execute()) - self.assertEquals(found1, wanted) + eq_(found1, wanted) found2 = self._fetchall_sorted(u.alias('bar').select().execute()) - self.assertEquals(found2, wanted) + eq_(found2, wanted) def test_union_ordered(self): (s1, s2) = ( @@ -877,7 +882,7 @@ class CompoundTest(TestBase): wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - self.assertEquals(u.execute().fetchall(), wanted) + eq_(u.execute().fetchall(), wanted) @testing.fails_on('maxdb', 'FIXME: unknown') @testing.requires.subqueries @@ -892,7 +897,7 @@ class CompoundTest(TestBase): wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] - self.assertEquals(u.alias('bar').select().execute().fetchall(), wanted) + eq_(u.alias('bar').select().execute().fetchall(), wanted) @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') @testing.fails_on('mysql', 'FIXME: unknown') @@ -908,10 +913,10 @@ class CompoundTest(TestBase): wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)] found1 = self._fetchall_sorted(e.execute()) - self.assertEquals(found1, wanted) + eq_(found1, wanted) found2 = self._fetchall_sorted(e.alias('foo').select().execute()) - self.assertEquals(found2, wanted) + eq_(found2, wanted) @testing.crashes('firebird', 'Does not support intersect') @testing.crashes('sybase', 'FIXME: unknown, verify not fails_on') @@ -925,10 +930,10 @@ class CompoundTest(TestBase): wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] found1 = self._fetchall_sorted(i.execute()) - self.assertEquals(found1, wanted) + eq_(found1, wanted) found2 = self._fetchall_sorted(i.alias('bar').select().execute()) - self.assertEquals(found2, wanted) + eq_(found2, wanted) @testing.crashes('firebird', 'Does not support except') @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') @@ -945,7 +950,7 @@ class CompoundTest(TestBase): ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] found = self._fetchall_sorted(e.alias('bar').select().execute()) - self.assertEquals(found, wanted) + eq_(found, wanted) @testing.crashes('firebird', 'Does not support except') @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') @@ -962,10 +967,10 @@ class CompoundTest(TestBase): ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] found1 = self._fetchall_sorted(e.execute()) - self.assertEquals(found1, wanted) + eq_(found1, wanted) found2 = self._fetchall_sorted(e.alias('bar').select().execute()) - self.assertEquals(found2, wanted) + eq_(found2, wanted) @testing.crashes('firebird', 'Does not support except') @testing.crashes('oracle', 'FIXME: unknown, verify not fails_on') @@ -981,8 +986,8 @@ class CompoundTest(TestBase): select([t3.c.col3], t3.c.col3 == 'ccc'), #ccc ) ) - self.assertEquals(e.execute().fetchall(), [('ccc',)]) - self.assertEquals(e.alias('foo').select().execute().fetchall(), + eq_(e.execute().fetchall(), [('ccc',)]) + eq_(e.alias('foo').select().execute().fetchall(), [('ccc',)]) @testing.crashes('firebird', 'Does not support intersect') @@ -999,7 +1004,7 @@ class CompoundTest(TestBase): wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] found = self._fetchall_sorted(u.execute()) - self.assertEquals(found, wanted) + eq_(found, wanted) @testing.crashes('firebird', 'Does not support intersect') @testing.fails_on('mysql', 'FIXME: unknown') @@ -1015,7 +1020,7 @@ class CompoundTest(TestBase): wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')] found = self._fetchall_sorted(ua.select().execute()) - self.assertEquals(found, wanted) + eq_(found, wanted) class JoinTest(TestBase): @@ -1028,7 +1033,8 @@ class JoinTest(TestBase): database seems to be sensitive to this. """ - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata global t1, t2, t3 @@ -1057,7 +1063,8 @@ class JoinTest(TestBase): {'t2_id': 21, 't1_id': 11, 'name': 't2 #21'}) t3.insert().execute({'t3_id': 30, 't2_id': 20, 'name': 't3 #30'}) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def assertRows(self, statement, expected): @@ -1066,7 +1073,7 @@ class JoinTest(TestBase): found = sorted([tuple(row) for row in statement.execute().fetchall()]) - self.assertEquals(found, sorted(expected)) + eq_(found, sorted(expected)) def test_join_x1(self): """Joins t1->t2.""" @@ -1289,7 +1296,8 @@ class JoinTest(TestBase): class OperatorTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global metadata, flds metadata = MetaData(testing.db) flds = Table('flds', metadata, @@ -1304,18 +1312,14 @@ class OperatorTest(TestBase): dict(intcol=13, strcol='bar') ]) - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() @testing.fails_on('maxdb', 'FIXME: unknown') def test_modulo(self): - self.assertEquals( + eq_( select([flds.c.intcol % 3], order_by=flds.c.idcol).execute().fetchall(), [(2,),(1,)] ) - - - -if __name__ == "__main__": - testenv.main() diff --git a/test/sql/quote.py b/test/sql/test_quote.py similarity index 97% rename from test/sql/quote.py rename to test/sql/test_quote.py index 106189afe0..64e097b85f 100644 --- a/test/sql/quote.py +++ b/test/sql/test_quote.py @@ -1,12 +1,12 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy import sql from sqlalchemy.sql import compiler -from testlib import * +from sqlalchemy.test import * class QuoteTest(TestBase, AssertsCompiledSQL): - def setUpAll(self): + @classmethod + def setup_class(cls): # TODO: figure out which databases/which identifiers allow special # characters to be used, such as: spaces, quote characters, # punctuation characters, set up tests for those as well. @@ -24,11 +24,12 @@ class QuoteTest(TestBase, AssertsCompiledSQL): table1.create() table2.create() - def tearDown(self): + def teardown(self): table1.delete().execute() table2.delete().execute() - def tearDownAll(self): + @classmethod + def teardown_class(cls): table1.drop() table2.drop() @@ -207,5 +208,3 @@ class PreparerTest(TestBase): a_eq(unformat('`foo`.bar'), ['foo', 'bar']) a_eq(unformat('`foo`.`b``a``r`.`baz`'), ['foo', 'b`a`r', 'baz']) -if __name__ == "__main__": - testenv.main() diff --git a/test/sql/rowcount.py b/test/sql/test_rowcount.py similarity index 91% rename from test/sql/rowcount.py rename to test/sql/test_rowcount.py index 3c9caad754..82301a4a5c 100644 --- a/test/sql/rowcount.py +++ b/test/sql/test_rowcount.py @@ -1,11 +1,11 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * -from testlib import * +from sqlalchemy.test import * class FoundRowsTest(TestBase, AssertsExecutionResults): """tests rowcount functionality""" - def setUpAll(self): + @classmethod + def setup_class(cls): metadata = MetaData(testing.db) global employees_table @@ -17,7 +17,7 @@ class FoundRowsTest(TestBase, AssertsExecutionResults): ) employees_table.create() - def setUp(self): + def setup(self): global data data = [ ('Angela', 'A'), ('Andrew', 'A'), @@ -31,10 +31,11 @@ class FoundRowsTest(TestBase, AssertsExecutionResults): i = employees_table.insert() i.execute(*[{'name':n, 'department':d} for n, d in data]) - def tearDown(self): + def teardown(self): employees_table.delete().execute() - def tearDownAll(self): + @classmethod + def teardown_class(cls): employees_table.drop() def testbasic(self): @@ -67,5 +68,3 @@ class FoundRowsTest(TestBase, AssertsExecutionResults): if testing.db.dialect.supports_sane_rowcount: assert r.rowcount == 3 -if __name__ == '__main__': - testenv.main() diff --git a/test/sql/select.py b/test/sql/test_select.py similarity index 98% rename from test/sql/select.py rename to test/sql/test_select.py index 2ec5b8da51..1d9e531de3 100644 --- a/test/sql/select.py +++ b/test/sql/test_select.py @@ -1,4 +1,4 @@ -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import datetime, re, operator from sqlalchemy import * from sqlalchemy import exc, sql, util @@ -6,7 +6,7 @@ from sqlalchemy.sql import table, column, label, compiler from sqlalchemy.sql.expression import ClauseList from sqlalchemy.engine import default from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql -from testlib import * +from sqlalchemy.test import * table1 = table('mytable', column('myid', Integer), @@ -185,7 +185,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A t2 = table('t2', column('c'), column('d')) s = select([t.c.a]).where(t.c.a==t2.c.d).as_scalar() s2 =select([t, t2, s]) - self.assertRaises(exc.InvalidRequestError, str, s2) + assert_raises(exc.InvalidRequestError, str, s2) # intentional again s = s.correlate(t, t2) @@ -1149,10 +1149,10 @@ UNION SELECT mytable.myid FROM mytable" # check that conflicts with "unique" params are caught s = select([table1], or_(table1.c.myid==7, table1.c.myid==bindparam('myid_1'))) - self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s) + assert_raises_message(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s) s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('myid_1'))) - self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s) + assert_raises_message(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s) def test_binds_no_hash_collision(self): """test that construct_params doesn't corrupt dict due to hash collisions""" @@ -1287,20 +1287,20 @@ UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)") ) def check_results(dialect, expected_results, literal): - self.assertEqual(len(expected_results), 5, 'Incorrect number of expected results') - self.assertEqual(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0]) - self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1]) - self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2]) - self.assertEqual(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3])) - self.assertEqual(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4])) + eq_(len(expected_results), 5, 'Incorrect number of expected results') + eq_(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0]) + eq_(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1]) + eq_(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2]) + eq_(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3])) + eq_(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4])) # fixme: shoving all of this dialect-specific stuff in one test # is now officialy completely ridiculous AND non-obviously omits # coverage on other dialects. sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect) if isinstance(dialect, type(mysql.dialect())): - self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest") + eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest") else: - self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest") + eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest") # first test with Postgres engine check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s') @@ -1548,5 +1548,3 @@ class SchemaTest(TestBase, AssertsCompiledSQL): self.assert_compile(table4.insert(values=(2, 5, 'test')), "INSERT INTO remote_owner.remotetable (rem_id, datatype_id, value) VALUES "\ "(:rem_id, :datatype_id, :value)") -if __name__ == "__main__": - testenv.main() diff --git a/test/sql/selectable.py b/test/sql/test_selectable.py old mode 100755 new mode 100644 similarity index 98% rename from test/sql/selectable.py rename to test/sql/test_selectable.py index e9ed5f5653..a172eb4523 --- a/test/sql/selectable.py +++ b/test/sql/test_selectable.py @@ -1,8 +1,8 @@ """Test various algorithmic properties of selectables.""" -import testenv; testenv.configure_for_tests() +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * -from testlib import * +from sqlalchemy.test import * from sqlalchemy.sql import util as sql_util, visitors from sqlalchemy import exc from sqlalchemy.sql import table, column @@ -228,7 +228,7 @@ class SelectableTest(TestBase, AssertsExecutionResults): s = select([t2, t3], use_labels=True) - self.assertRaises(exc.NoReferencedTableError, s.join, t1) + assert_raises(exc.NoReferencedTableError, s.join, t1) class PrimaryKeyTest(TestBase, AssertsExecutionResults): def test_join_pk_collapse_implicit(self): @@ -304,12 +304,12 @@ class PrimaryKeyTest(TestBase, AssertsExecutionResults): Column('id', Integer, ForeignKey( 'Employee.id', ), primary_key=True), ) - self.assertEquals( + eq_( util.column_set(employee.join(engineer, employee.c.id==engineer.c.id).primary_key), util.column_set([employee.c.id]) ) - self.assertEquals( + eq_( util.column_set(employee.join(engineer, engineer.c.id==employee.c.id).primary_key), util.column_set([employee.c.id]) ) @@ -329,7 +329,7 @@ class ReduceTest(TestBase, AssertsExecutionResults): Column('t3data', String(30))) - self.assertEquals( + eq_( util.column_set(sql_util.reduce_columns([t1.c.t1id, t1.c.t1data, t2.c.t2id, t2.c.t2data, t3.c.t3id, t3.c.t3data])), util.column_set([t1.c.t1id, t1.c.t1data, t2.c.t2data, t3.c.t3data]) ) @@ -349,7 +349,7 @@ class ReduceTest(TestBase, AssertsExecutionResults): s = select([engineers, managers]).where(engineers.c.engineer_name==managers.c.manager_name) - self.assertEquals(util.column_set(sql_util.reduce_columns(list(s.c), s)), + eq_(util.column_set(sql_util.reduce_columns(list(s.c), s)), util.column_set([s.c.engineer_id, s.c.engineer_name, s.c.manager_id]) ) @@ -374,7 +374,7 @@ class ReduceTest(TestBase, AssertsExecutionResults): ) pjoin = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin') - self.assertEquals( + eq_( util.column_set(sql_util.reduce_columns([pjoin.c.people_person_id, pjoin.c.engineers_person_id, pjoin.c.managers_person_id])), util.column_set([pjoin.c.people_person_id]) ) @@ -398,7 +398,7 @@ class ReduceTest(TestBase, AssertsExecutionResults): 'Item':base_item_table.join(item_table), }, None, 'item_join') - self.assertEquals( + eq_( util.column_set(sql_util.reduce_columns([item_join.c.id, item_join.c.dummy, item_join.c.child_name])), util.column_set([item_join.c.id, item_join.c.dummy, item_join.c.child_name]) ) @@ -423,7 +423,7 @@ class ReduceTest(TestBase, AssertsExecutionResults): 'c': page_table.join(magazine_page_table).join(classified_page_table), }, None, 'page_join') - self.assertEquals( + eq_( util.column_set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])), util.column_set([pjoin.c.id]) ) @@ -522,5 +522,3 @@ class AnnotationsTest(TestBase): assert b4.left is bin.left # since column is immutable assert b4.right is not bin.right is not b2.right is not b3.right -if __name__ == "__main__": - testenv.main() diff --git a/test/sql/testtypes.py b/test/sql/test_types.py similarity index 94% rename from test/sql/testtypes.py rename to test/sql/test_types.py index e5cffe3282..13b6d0954e 100644 --- a/test/sql/testtypes.py +++ b/test/sql/test_types.py @@ -1,13 +1,13 @@ +from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import decimal -import testenv; testenv.configure_for_tests() -import datetime, os, pickleable, re +import datetime, os, re from sqlalchemy import * from sqlalchemy import exc, types, util from sqlalchemy.sql import operators -from testlib.testing import eq_ +from sqlalchemy.test.testing import eq_ import sqlalchemy.engine.url as url from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird -from testlib import * +from sqlalchemy.test import * class AdaptTest(TestBase): @@ -121,13 +121,14 @@ class UserDefinedTest(TestBase): l ): for col in row[1:5]: - self.assertEquals(col, assertstr) - self.assertEquals(row[5], assertint) - self.assertEquals(row[6], assertint2) + eq_(col, assertstr) + eq_(row[5], assertint) + eq_(row[6], assertint2) for col in row[3], row[4]: assert isinstance(col, unicode) - def setUpAll(self): + @classmethod + def setup_class(cls): global users, metadata class MyType(types.TypeEngine): @@ -226,7 +227,8 @@ class UserDefinedTest(TestBase): metadata.create_all() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() class ColumnsTest(TestBase, AssertsExecutionResults): @@ -263,14 +265,15 @@ class ColumnsTest(TestBase, AssertsExecutionResults): ) for aCol in testTable.c: - self.assertEquals( + eq_( expectedResults[aCol.name], db.dialect.schemagenerator(db.dialect, db, None, None).\ get_column_specification(aCol)) class UnicodeTest(TestBase, AssertsExecutionResults): """tests the Unicode type. also tests the TypeDecorator with instances in the types package.""" - def setUpAll(self): + @classmethod + def setup_class(cls): global unicode_table metadata = MetaData(testing.db) unicode_table = Table('unicode_table', metadata, @@ -280,10 +283,11 @@ class UnicodeTest(TestBase, AssertsExecutionResults): Column('plain_varchar', String(250)) ) unicode_table.create() - def tearDownAll(self): + @classmethod + def teardown_class(cls): unicode_table.drop() - def tearDown(self): + def teardown(self): unicode_table.delete().execute() def test_round_trip(self): @@ -387,7 +391,8 @@ class BinaryTest(TestBase, AssertsExecutionResults): ('mysql', '<', (4, 1, 1)), # screwy varbinary types ) - def setUpAll(self): + @classmethod + def setup_class(cls): global binary_table, MyPickleType class MyPickleType(types.TypeDecorator): @@ -416,10 +421,11 @@ class BinaryTest(TestBase, AssertsExecutionResults): ) binary_table.create() - def tearDown(self): + def teardown(self): binary_table.delete().execute() - def tearDownAll(self): + @classmethod + def teardown_class(cls): binary_table.drop() @testing.fails_on('mssql', 'MSSQl BINARY type right pads the fixed length with \x00') @@ -439,21 +445,22 @@ class BinaryTest(TestBase, AssertsExecutionResults): text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db) ): l = stmt.execute().fetchall() - self.assertEquals(list(stream1), list(l[0]['data'])) - self.assertEquals(list(stream1[0:100]), list(l[0]['data_slice'])) - self.assertEquals(list(stream2), list(l[1]['data'])) - self.assertEquals(testobj1, l[0]['pickled']) - self.assertEquals(testobj2, l[1]['pickled']) - self.assertEquals(testobj3.moredata, l[0]['mypickle'].moredata) - self.assertEquals(l[0]['mypickle'].stuff, 'this is the right stuff') + eq_(list(stream1), list(l[0]['data'])) + eq_(list(stream1[0:100]), list(l[0]['data_slice'])) + eq_(list(stream2), list(l[1]['data'])) + eq_(testobj1, l[0]['pickled']) + eq_(testobj2, l[1]['pickled']) + eq_(testobj3.moredata, l[0]['mypickle'].moredata) + eq_(l[0]['mypickle'].stuff, 'this is the right stuff') def load_stream(self, name, len=12579): - f = os.path.join(os.path.dirname(testenv.__file__), name) + f = os.path.join(os.path.dirname(__file__), "..", name) # put a number less than the typical MySQL default BLOB size return file(f).read(len) class ExpressionTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global test_table, meta class MyCustomType(types.TypeEngine): @@ -481,7 +488,8 @@ class ExpressionTest(TestBase, AssertsExecutionResults): test_table.insert().execute({'id':1, 'data':'somedata', 'atimestamp':datetime.date(2007, 10, 15), 'avalue':25}) - def tearDownAll(self): + @classmethod + def teardown_class(cls): meta.drop_all() def test_control(self): @@ -523,7 +531,8 @@ class ExpressionTest(TestBase, AssertsExecutionResults): assert testing.db.execute(select([expr])).scalar() == -15 class DateTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global users_with_date, insert_data db = testing.db @@ -605,7 +614,8 @@ class DateTest(TestBase, AssertsExecutionResults): for idict in insert_dicts: users_with_date.insert().execute(**idict) - def tearDownAll(self): + @classmethod + def teardown_class(cls): users_with_date.drop() def testdate(self): @@ -649,8 +659,8 @@ class DateTest(TestBase, AssertsExecutionResults): # test mismatched date/datetime t.insert().execute(adate=d2, adatetime=d2) - self.assertEquals(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)]) - self.assertEquals(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)]) + eq_(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)]) + eq_(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)]) finally: t.drop(checkfirst=True) @@ -674,7 +684,8 @@ def _missing_decimal(): return True class NumericTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global numeric_table, metadata metadata = MetaData(testing.db) numeric_table = Table('numeric_table', metadata, @@ -686,10 +697,11 @@ class NumericTest(TestBase, AssertsExecutionResults): ) metadata.create_all() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() - def tearDown(self): + def teardown(self): numeric_table.delete().execute() @testing.fails_if(_missing_decimal) @@ -723,7 +735,7 @@ class NumericTest(TestBase, AssertsExecutionResults): assert isinstance(row['fcasdec'], decimal.Decimal) def test_length_deprecation(self): - self.assertRaises(exc.SADeprecationWarning, Numeric, length=8) + assert_raises(exc.SADeprecationWarning, Numeric, length=8) @testing.uses_deprecated(".*is deprecated for Numeric") def go(): @@ -751,7 +763,8 @@ class NumericTest(TestBase, AssertsExecutionResults): class IntervalTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global interval_table, metadata metadata = MetaData(testing.db) interval_table = Table("intervaltable", metadata, @@ -760,10 +773,11 @@ class IntervalTest(TestBase, AssertsExecutionResults): ) metadata.create_all() - def tearDown(self): + def teardown(self): interval_table.delete().execute() - def tearDownAll(self): + @classmethod + def teardown_class(cls): metadata.drop_all() def test_roundtrip(self): @@ -776,14 +790,16 @@ class IntervalTest(TestBase, AssertsExecutionResults): assert interval_table.select().execute().fetchone()['interval'] is None class BooleanTest(TestBase, AssertsExecutionResults): - def setUpAll(self): + @classmethod + def setup_class(cls): global bool_table metadata = MetaData(testing.db) bool_table = Table('booltest', metadata, Column('id', Integer, primary_key=True), Column('value', Boolean)) bool_table.create() - def tearDownAll(self): + @classmethod + def teardown_class(cls): bool_table.drop() def testbasic(self): bool_table.insert().execute(id=1, value=True) @@ -802,11 +818,11 @@ class PickleTest(TestBase): def test_noeq_deprecation(self): p1 = PickleType() - self.assertRaises(DeprecationWarning, + assert_raises(DeprecationWarning, p1.compare_values, pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2) ) - self.assertRaises(DeprecationWarning, + assert_raises(DeprecationWarning, p1.compare_values, pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2) ) @@ -833,7 +849,7 @@ class PickleTest(TestBase): ): assert p1.compare_values(p1.copy_value(obj), obj) - self.assertRaises(NotImplementedError, p1.compare_values, pickleable.BrokenComparable('foo'),pickleable.BrokenComparable('foo')) + assert_raises(NotImplementedError, p1.compare_values, pickleable.BrokenComparable('foo'),pickleable.BrokenComparable('foo')) def test_nonmutable_comparison(self): p1 = PickleType() @@ -846,11 +862,13 @@ class PickleTest(TestBase): assert p1.compare_values(p1.copy_value(obj), obj) class CallableTest(TestBase): - def setUpAll(self): + @classmethod + def setup_class(cls): global meta meta = MetaData(testing.db) - def tearDownAll(self): + @classmethod + def teardown_class(cls): meta.drop_all() def test_callable_as_arg(self): @@ -871,5 +889,3 @@ class CallableTest(TestBase): assert isinstance(thang_table.c.name.type, Unicode) thang_table.create() -if __name__ == "__main__": - testenv.main() diff --git a/test/sql/unicode.py b/test/sql/test_unicode.py similarity index 95% rename from test/sql/unicode.py rename to test/sql/test_unicode.py index c5002aaffb..d759132678 100644 --- a/test/sql/unicode.py +++ b/test/sql/test_unicode.py @@ -1,16 +1,16 @@ # coding: utf-8 """verrrrry basic unicode column name testing""" -import testenv; testenv.configure_for_tests() from sqlalchemy import * -from testlib import * -from testlib.engines import utf8_engine +from sqlalchemy.test import * +from sqlalchemy.test.engines import utf8_engine from sqlalchemy.sql import column class UnicodeSchemaTest(TestBase): __requires__ = ('unicode_ddl',) - def setUpAll(self): + @classmethod + def setup_class(cls): global unicode_bind, metadata, t1, t2, t3 unicode_bind = utf8_engine() @@ -56,13 +56,14 @@ class UnicodeSchemaTest(TestBase): ) metadata.create_all() - def tearDown(self): + def teardown(self): if metadata.tables: t3.delete().execute() t2.delete().execute() t1.delete().execute() - def tearDownAll(self): + @classmethod + def teardown_class(cls): global unicode_bind metadata.drop_all() del unicode_bind @@ -135,5 +136,3 @@ class EscapesDefaultsTest(testing.TestBase): t1.drop() -if __name__ == '__main__': - testenv.main() diff --git a/test/testenv.py b/test/testenv.py deleted file mode 100644 index 808a3c5f0e..0000000000 --- a/test/testenv.py +++ /dev/null @@ -1,36 +0,0 @@ -"""First import for all test cases, sets sys.path and loads configuration.""" - -import sys, os, logging, warnings - -if sys.version_info < (2, 4): - warnings.filterwarnings('ignore', category=FutureWarning) - - -from testlib.testing import main -import testlib.config - - -_setup = False - -def configure_for_tests(): - """import testenv; testenv.configure_for_tests()""" - - global _setup - if not _setup: - sys.path.insert(0, os.path.join(os.getcwd(), 'lib')) - logging.basicConfig() - - testlib.config.configure() - _setup = True - -def simple_setup(): - """import testenv; testenv.simple_setup()""" - - global _setup - if not _setup: - sys.path.insert(0, os.path.join(os.getcwd(), 'lib')) - logging.basicConfig() - - testlib.config.configure_defaults() - _setup = True - diff --git a/test/testlib/__init__.py b/test/testlib/__init__.py deleted file mode 100644 index 5b8075ddb4..0000000000 --- a/test/testlib/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Enhance unittest and instrument SQLAlchemy classes for testing. - -Load after sqlalchemy imports to use instrumented stand-ins like Table. -""" - -import sys -import testlib.config -from testlib.schema import Table, Column -import testlib.testing as testing -from testlib.testing import \ - AssertsCompiledSQL, \ - AssertsExecutionResults, \ - ComparesTables, \ - TestBase, \ - rowset -from testlib.orm import mapper -import testlib.profiling as profiling -import testlib.engines as engines -import testlib.requires as requires -from testlib.compat import _function_named - - -__all__ = ('testing', - 'mapper', - 'Table', 'Column', - 'rowset', - 'TestBase', 'AssertsExecutionResults', - 'AssertsCompiledSQL', 'ComparesTables', - 'profiling', 'engines', - '_function_named') - - -testing.requires = requires - -sys.modules['testlib.sa'] = sa = testing.CompositeModule( - 'testlib.sa', 'sqlalchemy', 'testlib.schema', orm=testing.CompositeModule( - 'testlib.sa.orm', 'sqlalchemy.orm', 'testlib.orm')) -sys.modules['testlib.sa.orm'] = sa.orm diff --git a/test/testlib/compat.py b/test/testlib/compat.py deleted file mode 100644 index 73eb2d651f..0000000000 --- a/test/testlib/compat.py +++ /dev/null @@ -1,19 +0,0 @@ -import types -import __builtin__ - -__all__ = '_function_named', 'callable' - - -def _function_named(fn, newname): - try: - fn.__name__ = newname - except: - fn = types.FunctionType(fn.func_code, fn.func_globals, newname, - fn.func_defaults, fn.func_closure) - return fn - -try: - callable = __builtin__.callable -except NameError: - def callable(fn): return hasattr(fn, '__call__') - diff --git a/test/testlib/config.py b/test/testlib/config.py deleted file mode 100644 index cef4c6e1dc..0000000000 --- a/test/testlib/config.py +++ /dev/null @@ -1,344 +0,0 @@ -import optparse, os, sys, re, ConfigParser, StringIO, time, warnings -logging, require = None, None - - -__all__ = 'parser', 'configure', 'options', - -db = None -db_label, db_url, db_opts = None, None, {} - -options = None -file_config = None -coverage_enabled = False - -base_config = """ -[db] -sqlite=sqlite:///:memory: -sqlite_file=sqlite:///querytest.db -postgres=postgres://scott:tiger@127.0.0.1:5432/test -mysql=mysql://scott:tiger@127.0.0.1:3306/test -oracle=oracle://scott:tiger@127.0.0.1:1521 -oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 -mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test -firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb -maxdb=maxdb://MONA:RED@/maxdb1 -""" - -parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]") - -def configure(): - global options, config - global getopts_options, file_config - - file_config = ConfigParser.ConfigParser() - file_config.readfp(StringIO.StringIO(base_config)) - file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')]) - - # Opt parsing can fire immediate actions, like logging and coverage - (options, args) = parser.parse_args() - sys.argv[1:] = args - - # Lazy setup of other options (post coverage) - for fn in post_configure: - fn(options, file_config) - - return options, file_config - -def configure_defaults(): - global options, config - global getopts_options, file_config - global db - - file_config = ConfigParser.ConfigParser() - file_config.readfp(StringIO.StringIO(base_config)) - file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')]) - (options, args) = parser.parse_args([]) - - # make error messages raised by decorators that depend on a default - # database clearer. - class _engine_bomb(object): - def __getattr__(self, key): - raise RuntimeError('No default engine available, testlib ' - 'was configured with defaults only.') - - db = _engine_bomb() - import testlib.testing - testlib.testing.db = db - - return options, file_config - -def _log(option, opt_str, value, parser): - global logging - if not logging: - import logging - logging.basicConfig() - - if opt_str.endswith('-info'): - logging.getLogger(value).setLevel(logging.INFO) - elif opt_str.endswith('-debug'): - logging.getLogger(value).setLevel(logging.DEBUG) - -def _start_cumulative_coverage(option, opt_str, value, parser): - _start_coverage(option, opt_str, value, parser, erase=False) - -def _start_coverage(option, opt_str, value, parser, erase=True): - import sys, atexit, coverage - true_out = sys.stdout - - global coverage_enabled - coverage_enabled = True - - def _iter_covered_files(mod, recursive=True): - - if recursive: - ff = os.walk - else: - ff = os.listdir - - for rec in ff(os.path.dirname(mod.__file__)): - for x in rec[2]: - if x.endswith('.py'): - yield os.path.join(rec[0], x) - - def _stop(): - coverage.stop() - true_out.write("\nPreparing coverage report...\n") - - from sqlalchemy import sql, orm, engine, \ - ext, databases, log - - import sqlalchemy - - for modset in [ - _iter_covered_files(sqlalchemy, recursive=False), - _iter_covered_files(databases), - _iter_covered_files(engine), - _iter_covered_files(ext), - _iter_covered_files(orm), - ]: - coverage.report(list(modset), - show_missing=False, ignore_errors=False, - file=true_out) - atexit.register(_stop) - if erase: - coverage.erase() - coverage.start() - -def _list_dbs(*args): - print "Available --db options (use --dburi to override)" - for macro in sorted(file_config.options('db')): - print "%20s\t%s" % (macro, file_config.get('db', macro)) - sys.exit(0) - -def _server_side_cursors(options, opt_str, value, parser): - db_opts['server_side_cursors'] = True - -def _engine_strategy(options, opt_str, value, parser): - if value: - db_opts['strategy'] = value - -opt = parser.add_option -opt("--verbose", action="store_true", dest="verbose", - help="enable stdout echoing/printing") -opt("--quiet", action="store_true", dest="quiet", help="suppress output") -opt("--log-info", action="callback", type="string", callback=_log, - help="turn on info logging for (multiple OK)") -opt("--log-debug", action="callback", type="string", callback=_log, - help="turn on debug logging for (multiple OK)") -opt("--require", action="append", dest="require", default=[], - help="require a particular driver or module version (multiple OK)") -opt("--db", action="store", dest="db", default="sqlite", - help="Use prefab database uri") -opt('--dbs', action='callback', callback=_list_dbs, - help="List available prefab dbs") -opt("--dburi", action="store", dest="dburi", - help="Database uri (overrides --db)") -opt("--dropfirst", action="store_true", dest="dropfirst", - help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)") -opt("--mockpool", action="store_true", dest="mockpool", - help="Use mock pool (asserts only one connection used)") -opt("--enginestrategy", action="callback", type="string", - callback=_engine_strategy, - help="Engine strategy (plain or threadlocal, defaults to plain)") -opt("--reversetop", action="store_true", dest="reversetop", default=False, - help="Reverse the collection ordering for topological sorts (helps " - "reveal dependency issues)") -opt("--unhashable", action="store_true", dest="unhashable", default=False, - help="Disallow SQLAlchemy from performing a hash() on mapped test objects.") -opt("--noncomparable", action="store_true", dest="noncomparable", default=False, - help="Disallow SQLAlchemy from performing == on mapped test objects.") -opt("--truthless", action="store_true", dest="truthless", default=False, - help="Disallow SQLAlchemy from truth-evaluating mapped test objects.") -opt("--serverside", action="callback", callback=_server_side_cursors, - help="Turn on server side cursors for PG") -opt("--mysql-engine", action="store", dest="mysql_engine", default=None, - help="Use the specified MySQL storage engine for all tables, default is " - "a db-default/InnoDB combo.") -opt("--table-option", action="append", dest="tableopts", default=[], - help="Add a dialect-specific table option, key=value") -opt("--coverage", action="callback", callback=_start_coverage, - help="Dump a full coverage report after running tests") -opt("--cumulative-coverage", action="callback", callback=_start_cumulative_coverage, - help="Like --coverage, but accumlate coverage into the current DB") -opt("--profile", action="append", dest="profile_targets", default=[], - help="Enable a named profile target (multiple OK.)") -opt("--profile-sort", action="store", dest="profile_sort", default=None, - help="Sort profile stats with this comma-separated sort order") -opt("--profile-limit", type="int", action="store", dest="profile_limit", - default=None, - help="Limit function count in profile stats") - -class _ordered_map(object): - def __init__(self): - self._keys = list() - self._data = dict() - - def __setitem__(self, key, value): - if key not in self._keys: - self._keys.append(key) - self._data[key] = value - - def __iter__(self): - for key in self._keys: - yield self._data[key] - -# at one point in refactoring, modules were injecting into the config -# process. this could probably just become a list now. -post_configure = _ordered_map() - -def _engine_uri(options, file_config): - global db_label, db_url - db_label = 'sqlite' - if options.dburi: - db_url = options.dburi - db_label = db_url[:db_url.index(':')] - elif options.db: - db_label = options.db - db_url = None - - if db_url is None: - if db_label not in file_config.options('db'): - raise RuntimeError( - "Unknown engine. Specify --dbs for known engines.") - db_url = file_config.get('db', db_label) -post_configure['engine_uri'] = _engine_uri - -def _require(options, file_config): - if not(options.require or - (file_config.has_section('require') and - file_config.items('require'))): - return - - try: - import pkg_resources - except ImportError: - raise RuntimeError("setuptools is required for version requirements") - - cmdline = [] - for requirement in options.require: - pkg_resources.require(requirement) - cmdline.append(re.split('\s*(=)', requirement, 1)[0]) - - if file_config.has_section('require'): - for label, requirement in file_config.items('require'): - if not label == db_label or label.startswith('%s.' % db_label): - continue - seen = [c for c in cmdline if requirement.startswith(c)] - if seen: - continue - pkg_resources.require(requirement) -post_configure['require'] = _require - -def _engine_pool(options, file_config): - if options.mockpool: - from sqlalchemy import pool - db_opts['poolclass'] = pool.AssertionPool -post_configure['engine_pool'] = _engine_pool - -def _create_testing_engine(options, file_config): - from testlib import engines, testing - global db - db = engines.testing_engine(db_url, db_opts) - testing.db = db -post_configure['create_engine'] = _create_testing_engine - -def _prep_testing_database(options, file_config): - from testlib import engines - from sqlalchemy import schema - - try: - # also create alt schemas etc. here? - if options.dropfirst: - e = engines.utf8_engine() - existing = e.table_names() - if existing: - if not options.quiet: - print "Dropping existing tables in database: " + db_url - try: - print "Tables: %s" % ', '.join(existing) - except: - pass - print "Abort within 5 seconds..." - time.sleep(5) - md = schema.MetaData(e, reflect=True) - md.drop_all() - e.dispose() - except (KeyboardInterrupt, SystemExit): - raise - except Exception, e: - if not options.quiet: - warnings.warn(RuntimeWarning( - "Error checking for existing tables in testing " - "database: %s" % e)) -post_configure['prep_db'] = _prep_testing_database - -def _set_table_options(options, file_config): - import testlib.schema - - table_options = testlib.schema.table_options - for spec in options.tableopts: - key, value = spec.split('=') - table_options[key] = value - - if options.mysql_engine: - table_options['mysql_engine'] = options.mysql_engine -post_configure['table_options'] = _set_table_options - -def _reverse_topological(options, file_config): - if options.reversetop: - from sqlalchemy.orm import unitofwork - from sqlalchemy import topological - class RevQueueDepSort(topological.QueueDependencySorter): - def __init__(self, tuples, allitems): - self.tuples = list(tuples) - self.allitems = list(allitems) - self.tuples.reverse() - self.allitems.reverse() - topological.QueueDependencySorter = RevQueueDepSort - unitofwork.DependencySorter = RevQueueDepSort -post_configure['topological'] = _reverse_topological - -def _set_profile_targets(options, file_config): - from testlib import profiling - - profile_config = profiling.profile_config - - for target in options.profile_targets: - profile_config['targets'].add(target) - - if options.profile_sort: - profile_config['sort'] = options.profile_sort.split(',') - - if options.profile_limit: - profile_config['limit'] = options.profile_limit - - if options.quiet: - profile_config['report'] = False - - # magic "all" target - if 'all' in profiling.all_targets: - targets = profile_config['targets'] - if 'all' in targets and len(targets) != 1: - targets.clear() - targets.add('all') -post_configure['profile_targets'] = _set_profile_targets diff --git a/test/testlib/coverage.py b/test/testlib/coverage.py deleted file mode 100644 index fc0f2c2360..0000000000 --- a/test/testlib/coverage.py +++ /dev/null @@ -1,1098 +0,0 @@ -#!/usr/bin/python -# -# Perforce Defect Tracking Integration Project -# -# -# COVERAGE.PY -- COVERAGE TESTING -# -# Gareth Rees, Ravenbrook Limited, 2001-12-04 -# Ned Batchelder, 2004-12-12 -# http://nedbatchelder.com/code/modules/coverage.html -# -# -# 1. INTRODUCTION -# -# This module provides coverage testing for Python code. -# -# The intended readership is all Python developers. -# -# This document is not confidential. -# -# See [GDR 2001-12-04a] for the command-line interface, programmatic -# interface and limitations. See [GDR 2001-12-04b] for requirements and -# design. - -r"""\ -Usage: - -coverage.py -x [-p] MODULE.py [ARG1 ARG2 ...] - Execute module, passing the given command-line arguments, collecting - coverage data. With the -p option, write to a temporary file containing - the machine name and process ID. - -coverage.py -e - Erase collected coverage data. - -coverage.py -c - Collect data from multiple coverage files (as created by -p option above) - and store it into a single file representing the union of the coverage. - -coverage.py -r [-m] [-o dir1,dir2,...] FILE1 FILE2 ... - Report on the statement coverage for the given files. With the -m - option, show line numbers of the statements that weren't executed. - -coverage.py -a [-d dir] [-o dir1,dir2,...] FILE1 FILE2 ... - Make annotated copies of the given files, marking statements that - are executed with > and statements that are missed with !. With - the -d option, make the copies in that directory. Without the -d - option, make each copy in the same directory as the original. - --o dir,dir2,... - Omit reporting or annotating files when their filename path starts with - a directory listed in the omit list. - e.g. python coverage.py -i -r -o c:\python23,lib\enthought\traits - -Coverage data is saved in the file .coverage by default. Set the -COVERAGE_FILE environment variable to save it somewhere else.""" - -__version__ = "2.75.20070722" # see detailed history at the end of this file. - -import compiler -import compiler.visitor -import glob -import os -import re -import string -import symbol -import sys -import threading -import token -import types -from socket import gethostname - - -# 2. IMPLEMENTATION -# -# This uses the "singleton" pattern. -# -# The word "morf" means a module object (from which the source file can -# be deduced by suitable manipulation of the __file__ attribute) or a -# filename. -# -# When we generate a coverage report we have to canonicalize every -# filename in the coverage dictionary just in case it refers to the -# module we are reporting on. It seems a shame to throw away this -# information so the data in the coverage dictionary is transferred to -# the 'cexecuted' dictionary under the canonical filenames. -# -# The coverage dictionary is called "c" and the trace function "t". The -# reason for these short names is that Python looks up variables by name -# at runtime and so execution time depends on the length of variables! -# In the bottleneck of this application it's appropriate to abbreviate -# names to increase speed. - -class StatementFindingAstVisitor(compiler.visitor.ASTVisitor): - """ A visitor for a parsed Abstract Syntax Tree which finds executable - statements. - """ - def __init__(self, statements, excluded, suite_spots): - compiler.visitor.ASTVisitor.__init__(self) - self.statements = statements - self.excluded = excluded - self.suite_spots = suite_spots - self.excluding_suite = 0 - - def doRecursive(self, node): - for n in node.getChildNodes(): - self.dispatch(n) - - visitStmt = visitModule = doRecursive - - def doCode(self, node): - if hasattr(node, 'decorators') and node.decorators: - self.dispatch(node.decorators) - self.recordAndDispatch(node.code) - else: - self.doSuite(node, node.code) - - visitFunction = visitClass = doCode - - def getFirstLine(self, node): - # Find the first line in the tree node. - lineno = node.lineno - for n in node.getChildNodes(): - f = self.getFirstLine(n) - if lineno and f: - lineno = min(lineno, f) - else: - lineno = lineno or f - return lineno - - def getLastLine(self, node): - # Find the first line in the tree node. - lineno = node.lineno - for n in node.getChildNodes(): - lineno = max(lineno, self.getLastLine(n)) - return lineno - - def doStatement(self, node): - self.recordLine(self.getFirstLine(node)) - - visitAssert = visitAssign = visitAssTuple = visitPrint = \ - visitPrintnl = visitRaise = visitSubscript = visitDecorators = \ - doStatement - - def visitPass(self, node): - # Pass statements have weird interactions with docstrings. If this - # pass statement is part of one of those pairs, claim that the statement - # is on the later of the two lines. - l = node.lineno - if l: - lines = self.suite_spots.get(l, [l,l]) - self.statements[lines[1]] = 1 - - def visitDiscard(self, node): - # Discard nodes are statements that execute an expression, but then - # discard the results. This includes function calls, so we can't - # ignore them all. But if the expression is a constant, the statement - # won't be "executed", so don't count it now. - if node.expr.__class__.__name__ != 'Const': - self.doStatement(node) - - def recordNodeLine(self, node): - # Stmt nodes often have None, but shouldn't claim the first line of - # their children (because the first child might be an ignorable line - # like "global a"). - if node.__class__.__name__ != 'Stmt': - return self.recordLine(self.getFirstLine(node)) - else: - return 0 - - def recordLine(self, lineno): - # Returns a bool, whether the line is included or excluded. - if lineno: - # Multi-line tests introducing suites have to get charged to their - # keyword. - if lineno in self.suite_spots: - lineno = self.suite_spots[lineno][0] - # If we're inside an excluded suite, record that this line was - # excluded. - if self.excluding_suite: - self.excluded[lineno] = 1 - return 0 - # If this line is excluded, or suite_spots maps this line to - # another line that is exlcuded, then we're excluded. - elif self.excluded.has_key(lineno) or \ - self.suite_spots.has_key(lineno) and \ - self.excluded.has_key(self.suite_spots[lineno][1]): - return 0 - # Otherwise, this is an executable line. - else: - self.statements[lineno] = 1 - return 1 - return 0 - - default = recordNodeLine - - def recordAndDispatch(self, node): - self.recordNodeLine(node) - self.dispatch(node) - - def doSuite(self, intro, body, exclude=0): - exsuite = self.excluding_suite - if exclude or (intro and not self.recordNodeLine(intro)): - self.excluding_suite = 1 - self.recordAndDispatch(body) - self.excluding_suite = exsuite - - def doPlainWordSuite(self, prevsuite, suite): - # Finding the exclude lines for else's is tricky, because they aren't - # present in the compiler parse tree. Look at the previous suite, - # and find its last line. If any line between there and the else's - # first line are excluded, then we exclude the else. - lastprev = self.getLastLine(prevsuite) - firstelse = self.getFirstLine(suite) - for l in range(lastprev+1, firstelse): - if self.suite_spots.has_key(l): - self.doSuite(None, suite, exclude=self.excluded.has_key(l)) - break - else: - self.doSuite(None, suite) - - def doElse(self, prevsuite, node): - if node.else_: - self.doPlainWordSuite(prevsuite, node.else_) - - def visitFor(self, node): - self.doSuite(node, node.body) - self.doElse(node.body, node) - - visitWhile = visitFor - - def visitIf(self, node): - # The first test has to be handled separately from the rest. - # The first test is credited to the line with the "if", but the others - # are credited to the line with the test for the elif. - self.doSuite(node, node.tests[0][1]) - for t, n in node.tests[1:]: - self.doSuite(t, n) - self.doElse(node.tests[-1][1], node) - - def visitTryExcept(self, node): - self.doSuite(node, node.body) - for i in range(len(node.handlers)): - a, b, h = node.handlers[i] - if not a: - # It's a plain "except:". Find the previous suite. - if i > 0: - prev = node.handlers[i-1][2] - else: - prev = node.body - self.doPlainWordSuite(prev, h) - else: - self.doSuite(a, h) - self.doElse(node.handlers[-1][2], node) - - def visitTryFinally(self, node): - self.doSuite(node, node.body) - self.doPlainWordSuite(node.body, node.final) - - def visitGlobal(self, node): - # "global" statements don't execute like others (they don't call the - # trace function), so don't record their line numbers. - pass - -the_coverage = None - -class CoverageException(Exception): pass - -class coverage: - # Name of the cache file (unless environment variable is set). - cache_default = ".coverage" - - # Environment variable naming the cache file. - cache_env = "COVERAGE_FILE" - - # A dictionary with an entry for (Python source file name, line number - # in that file) if that line has been executed. - c = {} - - # A map from canonical Python source file name to a dictionary in - # which there's an entry for each line number that has been - # executed. - cexecuted = {} - - # Cache of results of calling the analysis2() method, so that you can - # specify both -r and -a without doing double work. - analysis_cache = {} - - # Cache of results of calling the canonical_filename() method, to - # avoid duplicating work. - canonical_filename_cache = {} - - def __init__(self): - global the_coverage - if the_coverage: - raise CoverageException, "Only one coverage object allowed." - self.usecache = 1 - self.cache = None - self.parallel_mode = False - self.exclude_re = '' - self.nesting = 0 - self.cstack = [] - self.xstack = [] - self.relative_dir = os.path.normcase(os.path.abspath(os.curdir)+os.sep) - self.exclude('# *pragma[: ]*[nN][oO] *[cC][oO][vV][eE][rR]') - - # t(f, x, y). This method is passed to sys.settrace as a trace function. - # See [van Rossum 2001-07-20b, 9.2] for an explanation of sys.settrace and - # the arguments and return value of the trace function. - # See [van Rossum 2001-07-20a, 3.2] for a description of frame and code - # objects. - - def t(self, f, w, unused): #pragma: no cover - if w == 'line': - #print "Executing %s @ %d" % (f.f_code.co_filename, f.f_lineno) - self.c[(f.f_code.co_filename, f.f_lineno)] = 1 - for c in self.cstack: - c[(f.f_code.co_filename, f.f_lineno)] = 1 - return self.t - - def help(self, error=None): #pragma: no cover - if error: - print error - print - print __doc__ - sys.exit(1) - - def command_line(self, argv, help_fn=None): - import getopt - help_fn = help_fn or self.help - settings = {} - optmap = { - '-a': 'annotate', - '-c': 'collect', - '-d:': 'directory=', - '-e': 'erase', - '-h': 'help', - '-i': 'ignore-errors', - '-m': 'show-missing', - '-p': 'parallel-mode', - '-r': 'report', - '-x': 'execute', - '-o:': 'omit=', - } - short_opts = string.join(map(lambda o: o[1:], optmap.keys()), '') - long_opts = optmap.values() - options, args = getopt.getopt(argv, short_opts, long_opts) - for o, a in options: - if optmap.has_key(o): - settings[optmap[o]] = 1 - elif optmap.has_key(o + ':'): - settings[optmap[o + ':']] = a - elif o[2:] in long_opts: - settings[o[2:]] = 1 - elif o[2:] + '=' in long_opts: - settings[o[2:]+'='] = a - else: #pragma: no cover - pass # Can't get here, because getopt won't return anything unknown. - - if settings.get('help'): - help_fn() - - for i in ['erase', 'execute']: - for j in ['annotate', 'report', 'collect']: - if settings.get(i) and settings.get(j): - help_fn("You can't specify the '%s' and '%s' " - "options at the same time." % (i, j)) - - args_needed = (settings.get('execute') - or settings.get('annotate') - or settings.get('report')) - action = (settings.get('erase') - or settings.get('collect') - or args_needed) - if not action: - help_fn("You must specify at least one of -e, -x, -c, -r, or -a.") - if not args_needed and args: - help_fn("Unexpected arguments: %s" % " ".join(args)) - - self.parallel_mode = settings.get('parallel-mode') - self.get_ready() - - if settings.get('erase'): - self.erase() - if settings.get('execute'): - if not args: - help_fn("Nothing to do.") - sys.argv = args - self.start() - import __main__ - sys.path[0] = os.path.dirname(sys.argv[0]) - execfile(sys.argv[0], __main__.__dict__) - if settings.get('collect'): - self.collect() - if not args: - args = self.cexecuted.keys() - - ignore_errors = settings.get('ignore-errors') - show_missing = settings.get('show-missing') - directory = settings.get('directory=') - - omit = settings.get('omit=') - if omit is not None: - omit = omit.split(',') - else: - omit = [] - - if settings.get('report'): - self.report(args, show_missing, ignore_errors, omit_prefixes=omit) - if settings.get('annotate'): - self.annotate(args, directory, ignore_errors, omit_prefixes=omit) - - def use_cache(self, usecache, cache_file=None): - self.usecache = usecache - if cache_file and not self.cache: - self.cache_default = cache_file - - def get_ready(self, parallel_mode=False): - if self.usecache and not self.cache: - self.cache = os.environ.get(self.cache_env, self.cache_default) - if self.parallel_mode: - self.cache += "." + gethostname() + "." + str(os.getpid()) - self.restore() - self.analysis_cache = {} - - def start(self, parallel_mode=False): - self.get_ready() - if self.nesting == 0: #pragma: no cover - sys.settrace(self.t) - if hasattr(threading, 'settrace'): - threading.settrace(self.t) - self.nesting += 1 - - def stop(self): - self.nesting -= 1 - if self.nesting == 0: #pragma: no cover - sys.settrace(None) - if hasattr(threading, 'settrace'): - threading.settrace(None) - - def erase(self): - self.get_ready() - self.c = {} - self.analysis_cache = {} - self.cexecuted = {} - if self.cache and os.path.exists(self.cache): - os.remove(self.cache) - - def exclude(self, re): - if self.exclude_re: - self.exclude_re += "|" - self.exclude_re += "(" + re + ")" - - def begin_recursive(self): - self.cstack.append(self.c) - self.xstack.append(self.exclude_re) - - def end_recursive(self): - self.c = self.cstack.pop() - self.exclude_re = self.xstack.pop() - - # save(). Save coverage data to the coverage cache. - - def save(self): - if self.usecache and self.cache: - self.canonicalize_filenames() - cache = open(self.cache, 'wb') - import marshal - marshal.dump(self.cexecuted, cache) - cache.close() - - # restore(). Restore coverage data from the coverage cache (if it exists). - - def restore(self): - self.c = {} - self.cexecuted = {} - assert self.usecache - if os.path.exists(self.cache): - self.cexecuted = self.restore_file(self.cache) - - def restore_file(self, file_name): - try: - cache = open(file_name, 'rb') - import marshal - cexecuted = marshal.load(cache) - cache.close() - if isinstance(cexecuted, types.DictType): - return cexecuted - else: - return {} - except: - return {} - - # collect(). Collect data in multiple files produced by parallel mode - - def collect(self): - cache_dir, local = os.path.split(self.cache) - for f in os.listdir(cache_dir or '.'): - if not f.startswith(local): - continue - - full_path = os.path.join(cache_dir, f) - cexecuted = self.restore_file(full_path) - self.merge_data(cexecuted) - - def merge_data(self, new_data): - for file_name, file_data in new_data.items(): - if self.cexecuted.has_key(file_name): - self.merge_file_data(self.cexecuted[file_name], file_data) - else: - self.cexecuted[file_name] = file_data - - def merge_file_data(self, cache_data, new_data): - for line_number in new_data.keys(): - if not cache_data.has_key(line_number): - cache_data[line_number] = new_data[line_number] - - # canonical_filename(filename). Return a canonical filename for the - # file (that is, an absolute path with no redundant components and - # normalized case). See [GDR 2001-12-04b, 3.3]. - - def canonical_filename(self, filename): - if not self.canonical_filename_cache.has_key(filename): - f = filename - if os.path.isabs(f) and not os.path.exists(f): - f = os.path.basename(f) - if not os.path.isabs(f): - for path in [os.curdir] + sys.path: - g = os.path.join(path, f) - if os.path.exists(g): - f = g - break - cf = os.path.normcase(os.path.abspath(f)) - self.canonical_filename_cache[filename] = cf - return self.canonical_filename_cache[filename] - - # canonicalize_filenames(). Copy results from "c" to "cexecuted", - # canonicalizing filenames on the way. Clear the "c" map. - - def canonicalize_filenames(self): - for filename, lineno in self.c.keys(): - if filename == '': - # Can't do anything useful with exec'd strings, so skip them. - continue - f = self.canonical_filename(filename) - if not self.cexecuted.has_key(f): - self.cexecuted[f] = {} - self.cexecuted[f][lineno] = 1 - self.c = {} - - # morf_filename(morf). Return the filename for a module or file. - - def morf_filename(self, morf): - if isinstance(morf, types.ModuleType): - if not hasattr(morf, '__file__'): - raise CoverageException, "Module has no __file__ attribute." - f = morf.__file__ - else: - f = morf - return self.canonical_filename(f) - - # analyze_morf(morf). Analyze the module or filename passed as - # the argument. If the source code can't be found, raise an error. - # Otherwise, return a tuple of (1) the canonical filename of the - # source code for the module, (2) a list of lines of statements - # in the source code, (3) a list of lines of excluded statements, - # and (4), a map of line numbers to multi-line line number ranges, for - # statements that cross lines. - - def analyze_morf(self, morf): - if self.analysis_cache.has_key(morf): - return self.analysis_cache[morf] - filename = self.morf_filename(morf) - ext = os.path.splitext(filename)[1] - if ext == '.pyc': - if not os.path.exists(filename[0:-1]): - raise CoverageException, ("No source for compiled code '%s'." - % filename) - filename = filename[0:-1] - elif ext != '.py': - raise CoverageException, "File '%s' not Python source." % filename - source = open(filename, 'r') - lines, excluded_lines, line_map = self.find_executable_statements( - source.read(), exclude=self.exclude_re - ) - source.close() - result = filename, lines, excluded_lines, line_map - self.analysis_cache[morf] = result - return result - - def first_line_of_tree(self, tree): - while True: - if len(tree) == 3 and type(tree[2]) == type(1): - return tree[2] - tree = tree[1] - - def last_line_of_tree(self, tree): - while True: - if len(tree) == 3 and type(tree[2]) == type(1): - return tree[2] - tree = tree[-1] - - def find_docstring_pass_pair(self, tree, spots): - for i in range(1, len(tree)): - if self.is_string_constant(tree[i]) and self.is_pass_stmt(tree[i+1]): - first_line = self.first_line_of_tree(tree[i]) - last_line = self.last_line_of_tree(tree[i+1]) - self.record_multiline(spots, first_line, last_line) - - def is_string_constant(self, tree): - try: - return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.expr_stmt - except: - return False - - def is_pass_stmt(self, tree): - try: - return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.pass_stmt - except: - return False - - def record_multiline(self, spots, i, j): - for l in range(i, j+1): - spots[l] = (i, j) - - def get_suite_spots(self, tree, spots): - """ Analyze a parse tree to find suite introducers which span a number - of lines. - """ - for i in range(1, len(tree)): - if type(tree[i]) == type(()): - if tree[i][0] == symbol.suite: - # Found a suite, look back for the colon and keyword. - lineno_colon = lineno_word = None - for j in range(i-1, 0, -1): - if tree[j][0] == token.COLON: - # Colons are never executed themselves: we want the - # line number of the last token before the colon. - lineno_colon = self.last_line_of_tree(tree[j-1]) - elif tree[j][0] == token.NAME: - if tree[j][1] == 'elif': - # Find the line number of the first non-terminal - # after the keyword. - t = tree[j+1] - while t and token.ISNONTERMINAL(t[0]): - t = t[1] - if t: - lineno_word = t[2] - else: - lineno_word = tree[j][2] - break - elif tree[j][0] == symbol.except_clause: - # "except" clauses look like: - # ('except_clause', ('NAME', 'except', lineno), ...) - if tree[j][1][0] == token.NAME: - lineno_word = tree[j][1][2] - break - if lineno_colon and lineno_word: - # Found colon and keyword, mark all the lines - # between the two with the two line numbers. - self.record_multiline(spots, lineno_word, lineno_colon) - - # "pass" statements are tricky: different versions of Python - # treat them differently, especially in the common case of a - # function with a doc string and a single pass statement. - self.find_docstring_pass_pair(tree[i], spots) - - elif tree[i][0] == symbol.simple_stmt: - first_line = self.first_line_of_tree(tree[i]) - last_line = self.last_line_of_tree(tree[i]) - if first_line != last_line: - self.record_multiline(spots, first_line, last_line) - self.get_suite_spots(tree[i], spots) - - def find_executable_statements(self, text, exclude=None): - # Find lines which match an exclusion pattern. - excluded = {} - suite_spots = {} - if exclude: - reExclude = re.compile(exclude) - lines = text.split('\n') - for i in range(len(lines)): - if reExclude.search(lines[i]): - excluded[i+1] = 1 - - # Parse the code and analyze the parse tree to find out which statements - # are multiline, and where suites begin and end. - import parser - tree = parser.suite(text+'\n\n').totuple(1) - self.get_suite_spots(tree, suite_spots) - #print "Suite spots:", suite_spots - - # Use the compiler module to parse the text and find the executable - # statements. We add newlines to be impervious to final partial lines. - statements = {} - ast = compiler.parse(text+'\n\n') - visitor = StatementFindingAstVisitor(statements, excluded, suite_spots) - compiler.walk(ast, visitor, walker=visitor) - - lines = statements.keys() - lines.sort() - excluded_lines = excluded.keys() - excluded_lines.sort() - return lines, excluded_lines, suite_spots - - # format_lines(statements, lines). Format a list of line numbers - # for printing by coalescing groups of lines as long as the lines - # represent consecutive statements. This will coalesce even if - # there are gaps between statements, so if statements = - # [1,2,3,4,5,10,11,12,13,14] and lines = [1,2,5,10,11,13,14] then - # format_lines will return "1-2, 5-11, 13-14". - - def format_lines(self, statements, lines): - pairs = [] - i = 0 - j = 0 - start = None - pairs = [] - while i < len(statements) and j < len(lines): - if statements[i] == lines[j]: - if start == None: - start = lines[j] - end = lines[j] - j = j + 1 - elif start: - pairs.append((start, end)) - start = None - i = i + 1 - if start: - pairs.append((start, end)) - def stringify(pair): - start, end = pair - if start == end: - return "%d" % start - else: - return "%d-%d" % (start, end) - ret = string.join(map(stringify, pairs), ", ") - return ret - - # Backward compatibility with version 1. - def analysis(self, morf): - f, s, _, m, mf = self.analysis2(morf) - return f, s, m, mf - - def analysis2(self, morf): - filename, statements, excluded, line_map = self.analyze_morf(morf) - self.canonicalize_filenames() - if not self.cexecuted.has_key(filename): - self.cexecuted[filename] = {} - missing = [] - for line in statements: - lines = line_map.get(line, [line, line]) - for l in range(lines[0], lines[1]+1): - if self.cexecuted[filename].has_key(l): - break - else: - missing.append(line) - return (filename, statements, excluded, missing, - self.format_lines(statements, missing)) - - def relative_filename(self, filename): - """ Convert filename to relative filename from self.relative_dir. - """ - return filename.replace(self.relative_dir, "") - - def morf_name(self, morf): - """ Return the name of morf as used in report. - """ - if isinstance(morf, types.ModuleType): - return morf.__name__ - else: - return self.relative_filename(os.path.splitext(morf)[0]) - - def filter_by_prefix(self, morfs, omit_prefixes): - """ Return list of morfs where the morf name does not begin - with any one of the omit_prefixes. - """ - filtered_morfs = [] - for morf in morfs: - for prefix in omit_prefixes: - if self.morf_name(morf).startswith(prefix): - break - else: - filtered_morfs.append(morf) - - return filtered_morfs - - def morf_name_compare(self, x, y): - return cmp(self.morf_name(x), self.morf_name(y)) - - def report(self, morfs, show_missing=1, ignore_errors=0, file=None, omit_prefixes=[]): - if not isinstance(morfs, types.ListType): - morfs = [morfs] - # On windows, the shell doesn't expand wildcards. Do it here. - globbed = [] - for morf in morfs: - if isinstance(morf, basestring): - globbed.extend(glob.glob(morf)) - else: - globbed.append(morf) - morfs = globbed - - morfs = self.filter_by_prefix(morfs, omit_prefixes) - morfs.sort(self.morf_name_compare) - - max_name = max([5,] + map(len, map(self.morf_name, morfs))) - fmt_name = "%%- %ds " % max_name - fmt_err = fmt_name + "%s: %s" - header = fmt_name % "Name" + " Stmts Exec Cover" - fmt_coverage = fmt_name + "% 6d % 6d % 5d%%" - if show_missing: - header = header + " Missing" - fmt_coverage = fmt_coverage + " %s" - if not file: - file = sys.stdout - print >>file, header - print >>file, "-" * len(header) - total_statements = 0 - total_executed = 0 - for morf in morfs: - name = self.morf_name(morf) - try: - _, statements, _, missing, readable = self.analysis2(morf) - n = len(statements) - m = n - len(missing) - if n > 0: - pc = 100.0 * m / n - else: - pc = 100.0 - args = (name, n, m, pc) - if show_missing: - args = args + (readable,) - print >>file, fmt_coverage % args - total_statements = total_statements + n - total_executed = total_executed + m - except KeyboardInterrupt: #pragma: no cover - raise - except: - if not ignore_errors: - typ, msg = sys.exc_info()[0:2] - print >>file, fmt_err % (name, typ, msg) - if len(morfs) > 1: - print >>file, "-" * len(header) - if total_statements > 0: - pc = 100.0 * total_executed / total_statements - else: - pc = 100.0 - args = ("TOTAL", total_statements, total_executed, pc) - if show_missing: - args = args + ("",) - print >>file, fmt_coverage % args - - # annotate(morfs, ignore_errors). - - blank_re = re.compile(r"\s*(#|$)") - else_re = re.compile(r"\s*else\s*:\s*(#|$)") - - def annotate(self, morfs, directory=None, ignore_errors=0, omit_prefixes=[]): - morfs = self.filter_by_prefix(morfs, omit_prefixes) - for morf in morfs: - try: - filename, statements, excluded, missing, _ = self.analysis2(morf) - self.annotate_file(filename, statements, excluded, missing, directory) - except KeyboardInterrupt: - raise - except: - if not ignore_errors: - raise - - def annotate_file(self, filename, statements, excluded, missing, directory=None): - source = open(filename, 'r') - if directory: - dest_file = os.path.join(directory, - os.path.basename(filename) - + ',cover') - else: - dest_file = filename + ',cover' - dest = open(dest_file, 'w') - lineno = 0 - i = 0 - j = 0 - covered = 1 - while 1: - line = source.readline() - if line == '': - break - lineno = lineno + 1 - while i < len(statements) and statements[i] < lineno: - i = i + 1 - while j < len(missing) and missing[j] < lineno: - j = j + 1 - if i < len(statements) and statements[i] == lineno: - covered = j >= len(missing) or missing[j] > lineno - if self.blank_re.match(line): - dest.write(' ') - elif self.else_re.match(line): - # Special logic for lines containing only 'else:'. - # See [GDR 2001-12-04b, 3.2]. - if i >= len(statements) and j >= len(missing): - dest.write('! ') - elif i >= len(statements) or j >= len(missing): - dest.write('> ') - elif statements[i] == missing[j]: - dest.write('! ') - else: - dest.write('> ') - elif lineno in excluded: - dest.write('- ') - elif covered: - dest.write('> ') - else: - dest.write('! ') - dest.write(line) - source.close() - dest.close() - -# Singleton object. -the_coverage = coverage() - -# Module functions call methods in the singleton object. -def use_cache(*args, **kw): - return the_coverage.use_cache(*args, **kw) - -def start(*args, **kw): - return the_coverage.start(*args, **kw) - -def stop(*args, **kw): - return the_coverage.stop(*args, **kw) - -def erase(*args, **kw): - return the_coverage.erase(*args, **kw) - -def begin_recursive(*args, **kw): - return the_coverage.begin_recursive(*args, **kw) - -def end_recursive(*args, **kw): - return the_coverage.end_recursive(*args, **kw) - -def exclude(*args, **kw): - return the_coverage.exclude(*args, **kw) - -def analysis(*args, **kw): - return the_coverage.analysis(*args, **kw) - -def analysis2(*args, **kw): - return the_coverage.analysis2(*args, **kw) - -def report(*args, **kw): - return the_coverage.report(*args, **kw) - -def annotate(*args, **kw): - return the_coverage.annotate(*args, **kw) - -def annotate_file(*args, **kw): - return the_coverage.annotate_file(*args, **kw) - -# Save coverage data when Python exits. (The atexit module wasn't -# introduced until Python 2.0, so use sys.exitfunc when it's not -# available.) -try: - import atexit - atexit.register(the_coverage.save) -except ImportError: - sys.exitfunc = the_coverage.save - -# Command-line interface. -if __name__ == '__main__': - the_coverage.command_line(sys.argv[1:]) - - -# A. REFERENCES -# -# [GDR 2001-12-04a] "Statement coverage for Python"; Gareth Rees; -# Ravenbrook Limited; 2001-12-04; -# . -# -# [GDR 2001-12-04b] "Statement coverage for Python: design and -# analysis"; Gareth Rees; Ravenbrook Limited; 2001-12-04; -# . -# -# [van Rossum 2001-07-20a] "Python Reference Manual (releae 2.1.1)"; -# Guide van Rossum; 2001-07-20; -# . -# -# [van Rossum 2001-07-20b] "Python Library Reference"; Guido van Rossum; -# 2001-07-20; . -# -# -# B. DOCUMENT HISTORY -# -# 2001-12-04 GDR Created. -# -# 2001-12-06 GDR Added command-line interface and source code -# annotation. -# -# 2001-12-09 GDR Moved design and interface to separate documents. -# -# 2001-12-10 GDR Open cache file as binary on Windows. Allow -# simultaneous -e and -x, or -a and -r. -# -# 2001-12-12 GDR Added command-line help. Cache analysis so that it -# only needs to be done once when you specify -a and -r. -# -# 2001-12-13 GDR Improved speed while recording. Portable between -# Python 1.5.2 and 2.1.1. -# -# 2002-01-03 GDR Module-level functions work correctly. -# -# 2002-01-07 GDR Update sys.path when running a file with the -x option, -# so that it matches the value the program would get if it were run on -# its own. -# -# 2004-12-12 NMB Significant code changes. -# - Finding executable statements has been rewritten so that docstrings and -# other quirks of Python execution aren't mistakenly identified as missing -# lines. -# - Lines can be excluded from consideration, even entire suites of lines. -# - The filesystem cache of covered lines can be disabled programmatically. -# - Modernized the code. -# -# 2004-12-14 NMB Minor tweaks. Return 'analysis' to its original behavior -# and add 'analysis2'. Add a global for 'annotate', and factor it, adding -# 'annotate_file'. -# -# 2004-12-31 NMB Allow for keyword arguments in the module global functions. -# Thanks, Allen. -# -# 2005-12-02 NMB Call threading.settrace so that all threads are measured. -# Thanks Martin Fuzzey. Add a file argument to report so that reports can be -# captured to a different destination. -# -# 2005-12-03 NMB coverage.py can now measure itself. -# -# 2005-12-04 NMB Adapted Greg Rogers' patch for using relative filenames, -# and sorting and omitting files to report on. -# -# 2006-07-23 NMB Applied Joseph Tate's patch for function decorators. -# -# 2006-08-21 NMB Applied Sigve Tjora and Mark van der Wal's fixes for argument -# handling. -# -# 2006-08-22 NMB Applied Geoff Bache's parallel mode patch. -# -# 2006-08-23 NMB Refactorings to improve testability. Fixes to command-line -# logic for parallel mode and collect. -# -# 2006-08-25 NMB "#pragma: nocover" is excluded by default. -# -# 2006-09-10 NMB Properly ignore docstrings and other constant expressions that -# appear in the middle of a function, a problem reported by Tim Leslie. -# Minor changes to avoid lint warnings. -# -# 2006-09-17 NMB coverage.erase() shouldn't clobber the exclude regex. -# Change how parallel mode is invoked, and fix erase() so that it erases the -# cache when called programmatically. -# -# 2007-07-21 NMB In reports, ignore code executed from strings, since we can't -# do anything useful with it anyway. -# Better file handling on Linux, thanks Guillaume Chazarain. -# Better shell support on Windows, thanks Noel O'Boyle. -# Python 2.2 support maintained, thanks Catherine Proulx. -# -# 2007-07-22 NMB Python 2.5 now fully supported. The method of dealing with -# multi-line statements is now less sensitive to the exact line that Python -# reports during execution. Pass statements are handled specially so that their -# disappearance during execution won't throw off the measurement. - -# C. COPYRIGHT AND LICENCE -# -# Copyright 2001 Gareth Rees. All rights reserved. -# Copyright 2004-2007 Ned Batchelder. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# 1. Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the -# distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDERS AND CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS -# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. -# -# $Id: coverage.py 67 2007-07-21 19:51:07Z nedbat $ diff --git a/test/testlib/sa_unittest.py b/test/testlib/sa_unittest.py deleted file mode 100644 index 7eb2c07271..0000000000 --- a/test/testlib/sa_unittest.py +++ /dev/null @@ -1,787 +0,0 @@ -#!/usr/bin/env python -''' -unittest.py from Python 2.5. - -SQLAlchemy extends unittest internals to provide setUpAll()/tearDownAll() -so we include a fixed version here to insulate from changes. 2.6 and -3.0's unittest is incompatible with our changes. - -Approaches to removing this dependency are: - -* find a unittest-supported method of grouping UnitTest classes within -a setUpAll()/tearDownAll() pair, such that all tests within a single -UnitTest class are executed within a single execution of setUpAll()/ -tearDownAll(). It may be possible to create nested TestSuite objects -to accomplish this but it's not clear. -* migrate to a different system such as nose. - -Copyright (c) 1999-2003 Steve Purcell -This module is free software, and you may redistribute it and/or modify -it under the same terms as Python itself, so long as this copyright message -and disclaimer are retained in their original form. - -IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, -SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF -THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH -DAMAGE. - -THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -PARTICULAR PURPOSE. THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, -AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE, -SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. -''' - -__author__ = "Steve Purcell" -__email__ = "stephen_purcell at yahoo dot com" -__version__ = "#Revision: 1.63 $"[11:-2] - -import time -import sys -import traceback -import os -import types -from testlib.compat import callable - -############################################################################## -# Exported classes and functions -############################################################################## -__all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner', - 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader'] - -# Expose obsolete functions for backwards compatibility -__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) - - -############################################################################## -# Test framework core -############################################################################## - -# All classes defined herein are 'new-style' classes, allowing use of 'super()' -__metaclass__ = type - -def _strclass(cls): - return "%s.%s" % (cls.__module__, cls.__name__) - -__unittest = 1 - -class TestResult: - """Holder for test result information. - - Test results are automatically managed by the TestCase and TestSuite - classes, and do not need to be explicitly manipulated by writers of tests. - - Each instance holds the total number of tests run, and collections of - failures and errors that occurred among those test runs. The collections - contain tuples of (testcase, exceptioninfo), where exceptioninfo is the - formatted traceback of the error that occurred. - """ - def __init__(self): - self.failures = [] - self.errors = [] - self.testsRun = 0 - self.shouldStop = 0 - - def startTest(self, test): - "Called when the given test is about to be run" - self.testsRun = self.testsRun + 1 - - def stopTest(self, test): - "Called when the given test has been run" - pass - - def addError(self, test, err): - """Called when an error has occurred. 'err' is a tuple of values as - returned by sys.exc_info(). - """ - self.errors.append((test, self._exc_info_to_string(err, test))) - - def addFailure(self, test, err): - """Called when an error has occurred. 'err' is a tuple of values as - returned by sys.exc_info().""" - self.failures.append((test, self._exc_info_to_string(err, test))) - - def addSuccess(self, test): - "Called when a test has completed successfully" - pass - - def wasSuccessful(self): - "Tells whether or not this result was a success" - return len(self.failures) == len(self.errors) == 0 - - def stop(self): - "Indicates that the tests should be aborted" - self.shouldStop = True - - def _exc_info_to_string(self, err, test): - """Converts a sys.exc_info()-style tuple of values into a string.""" - exctype, value, tb = err - # Skip test runner traceback levels - while tb and self._is_relevant_tb_level(tb): - tb = tb.tb_next - if exctype is test.failureException: - # Skip assert*() traceback levels - length = self._count_relevant_tb_levels(tb) - return ''.join(traceback.format_exception(exctype, value, tb, length)) - return ''.join(traceback.format_exception(exctype, value, tb)) - - def _is_relevant_tb_level(self, tb): - return tb.tb_frame.f_globals.has_key('__unittest') - - def _count_relevant_tb_levels(self, tb): - length = 0 - while tb and not self._is_relevant_tb_level(tb): - length += 1 - tb = tb.tb_next - return length - - def __repr__(self): - return "<%s run=%i errors=%i failures=%i>" % \ - (_strclass(self.__class__), self.testsRun, len(self.errors), - len(self.failures)) - -class TestCase: - """A class whose instances are single test cases. - - By default, the test code itself should be placed in a method named - 'runTest'. - - If the fixture may be used for many test cases, create as - many test methods as are needed. When instantiating such a TestCase - subclass, specify in the constructor arguments the name of the test method - that the instance is to execute. - - Test authors should subclass TestCase for their own tests. Construction - and deconstruction of the test's environment ('fixture') can be - implemented by overriding the 'setUp' and 'tearDown' methods respectively. - - If it is necessary to override the __init__ method, the base class - __init__ method must always be called. It is important that subclasses - should not change the signature of their __init__ method, since instances - of the classes are instantiated automatically by parts of the framework - in order to be run. - """ - - # This attribute determines which exception will be raised when - # the instance's assertion methods fail; test methods raising this - # exception will be deemed to have 'failed' rather than 'errored' - - failureException = AssertionError - - def __init__(self, methodName='runTest'): - """Create an instance of the class that will use the named test - method when executed. Raises a ValueError if the instance does - not have a method with the specified name. - """ - try: - self._testMethodName = methodName - testMethod = getattr(self, methodName) - self._testMethodDoc = testMethod.__doc__ - except AttributeError: - raise ValueError, "no such test method in %s: %s" % \ - (self.__class__, methodName) - - def setUp(self): - "Hook method for setting up the test fixture before exercising it." - pass - - def tearDown(self): - "Hook method for deconstructing the test fixture after testing it." - pass - - def countTestCases(self): - return 1 - - def defaultTestResult(self): - return TestResult() - - def shortDescription(self): - """Returns a one-line description of the test, or None if no - description has been provided. - - The default implementation of this method returns the first line of - the specified test method's docstring. - """ - doc = self._testMethodDoc - return doc and doc.split("\n")[0].strip() or None - - def id(self): - return "%s.%s" % (_strclass(self.__class__), self._testMethodName) - - def __str__(self): - return "%s (%s)" % (self._testMethodName, _strclass(self.__class__)) - - def __repr__(self): - return "<%s testMethod=%s>" % \ - (_strclass(self.__class__), self._testMethodName) - - def run(self, result=None): - if result is None: result = self.defaultTestResult() - result.startTest(self) - testMethod = getattr(self, self._testMethodName) - try: - try: - self.setUp() - except KeyboardInterrupt: - raise - except: - result.addError(self, self._exc_info()) - return - - ok = False - try: - testMethod() - ok = True - except self.failureException: - result.addFailure(self, self._exc_info()) - except KeyboardInterrupt: - raise - except: - result.addError(self, self._exc_info()) - - try: - self.tearDown() - except KeyboardInterrupt: - raise - except: - result.addError(self, self._exc_info()) - ok = False - if ok: result.addSuccess(self) - finally: - result.stopTest(self) - - def __call__(self, *args, **kwds): - return self.run(*args, **kwds) - - def debug(self): - """Run the test without collecting errors in a TestResult""" - self.setUp() - getattr(self, self._testMethodName)() - self.tearDown() - - def _exc_info(self): - """Return a version of sys.exc_info() with the traceback frame - minimised; usually the top level of the traceback frame is not - needed. - """ - exctype, excvalue, tb = sys.exc_info() - if sys.platform[:4] == 'java': ## tracebacks look different in Jython - return (exctype, excvalue, tb) - return (exctype, excvalue, tb) - - def fail(self, msg=None): - """Fail immediately, with the given message.""" - raise self.failureException, msg - - def failIf(self, expr, msg=None): - "Fail the test if the expression is true." - if expr: raise self.failureException, msg - - def failUnless(self, expr, msg=None): - """Fail the test unless the expression is true.""" - if not expr: raise self.failureException, msg - - def failUnlessRaises(self, excClass, callableObj, *args, **kwargs): - """Fail unless an exception of class excClass is thrown - by callableObj when invoked with arguments args and keyword - arguments kwargs. If a different type of exception is - thrown, it will not be caught, and the test case will be - deemed to have suffered an error, exactly as for an - unexpected exception. - """ - try: - callableObj(*args, **kwargs) - except excClass: - return - else: - if hasattr(excClass,'__name__'): excName = excClass.__name__ - else: excName = str(excClass) - raise self.failureException, "%s not raised" % excName - - def failUnlessEqual(self, first, second, msg=None): - """Fail if the two objects are unequal as determined by the '==' - operator. - """ - if not first == second: - raise self.failureException, \ - (msg or '%r != %r' % (first, second)) - - def failIfEqual(self, first, second, msg=None): - """Fail if the two objects are equal as determined by the '==' - operator. - """ - if first == second: - raise self.failureException, \ - (msg or '%r == %r' % (first, second)) - - def failUnlessAlmostEqual(self, first, second, places=7, msg=None): - """Fail if the two objects are unequal as determined by their - difference rounded to the given number of decimal places - (default 7) and comparing to zero. - - Note that decimal places (from zero) are usually not the same - as significant digits (measured from the most signficant digit). - """ - if round(second-first, places) != 0: - raise self.failureException, \ - (msg or '%r != %r within %r places' % (first, second, places)) - - def failIfAlmostEqual(self, first, second, places=7, msg=None): - """Fail if the two objects are equal as determined by their - difference rounded to the given number of decimal places - (default 7) and comparing to zero. - - Note that decimal places (from zero) are usually not the same - as significant digits (measured from the most signficant digit). - """ - if round(second-first, places) == 0: - raise self.failureException, \ - (msg or '%r == %r within %r places' % (first, second, places)) - - # Synonyms for assertion methods - - assertEqual = assertEquals = failUnlessEqual - - assertNotEqual = assertNotEquals = failIfEqual - - assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual - - assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual - - assertRaises = failUnlessRaises - - assert_ = assertTrue = failUnless - - assertFalse = failIf - - - -class TestSuite: - """A test suite is a composite test consisting of a number of TestCases. - - For use, create an instance of TestSuite, then add test case instances. - When all tests have been added, the suite can be passed to a test - runner, such as TextTestRunner. It will run the individual test cases - in the order in which they were added, aggregating the results. When - subclassing, do not forget to call the base class constructor. - """ - def __init__(self, tests=()): - self._tests = [] - self.addTests(tests) - - def __repr__(self): - return "<%s tests=%s>" % (_strclass(self.__class__), self._tests) - - __str__ = __repr__ - - def __iter__(self): - return iter(self._tests) - - def countTestCases(self): - cases = 0 - for test in self._tests: - cases += test.countTestCases() - return cases - - def addTest(self, test): - # sanity checks - if not callable(test): - raise TypeError("the test to add must be callable") - if (isinstance(test, (type, types.ClassType)) and - issubclass(test, (TestCase, TestSuite))): - raise TypeError("TestCases and TestSuites must be instantiated " - "before passing them to addTest()") - self._tests.append(test) - - def addTests(self, tests): - if isinstance(tests, basestring): - raise TypeError("tests must be an iterable of tests, not a string") - for test in tests: - self.addTest(test) - - def run(self, result): - for test in self._tests: - if result.shouldStop: - break - test(result) - return result - - def __call__(self, *args, **kwds): - return self.run(*args, **kwds) - - def debug(self): - """Run the tests without collecting errors in a TestResult""" - for test in self._tests: test.debug() - - -class FunctionTestCase(TestCase): - """A test case that wraps a test function. - - This is useful for slipping pre-existing test functions into the - PyUnit framework. Optionally, set-up and tidy-up functions can be - supplied. As with TestCase, the tidy-up ('tearDown') function will - always be called if the set-up ('setUp') function ran successfully. - """ - - def __init__(self, testFunc, setUp=None, tearDown=None, - description=None): - TestCase.__init__(self) - self.__setUpFunc = setUp - self.__tearDownFunc = tearDown - self.__testFunc = testFunc - self.__description = description - - def setUp(self): - if self.__setUpFunc is not None: - self.__setUpFunc() - - def tearDown(self): - if self.__tearDownFunc is not None: - self.__tearDownFunc() - - def runTest(self): - self.__testFunc() - - def id(self): - return self.__testFunc.__name__ - - def __str__(self): - return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__) - - def __repr__(self): - return "<%s testFunc=%s>" % (_strclass(self.__class__), self.__testFunc) - - def shortDescription(self): - if self.__description is not None: return self.__description - doc = self.__testFunc.__doc__ - return doc and doc.split("\n")[0].strip() or None - - - -############################################################################## -# Locating and loading tests -############################################################################## - -class TestLoader: - """This class is responsible for loading tests according to various - criteria and returning them wrapped in a Test - """ - testMethodPrefix = 'test' - suiteClass = TestSuite - - def loadTestsFromTestCase(self, testCaseClass): - """Return a suite of all tests cases contained in testCaseClass""" - if issubclass(testCaseClass, TestSuite): - raise TypeError("Test cases should not be derived from TestSuite. Maybe you meant to derive from TestCase?") - testCaseNames = self.getTestCaseNames(testCaseClass) - if not testCaseNames and hasattr(testCaseClass, 'runTest'): - testCaseNames = ['runTest'] - return self.suiteClass(map(testCaseClass, testCaseNames)) - - def loadTestsFromModule(self, module): - """Return a suite of all tests cases contained in the given module""" - tests = [] - for name in dir(module): - obj = getattr(module, name) - if (isinstance(obj, (type, types.ClassType)) and - issubclass(obj, TestCase)): - tests.append(self.loadTestsFromTestCase(obj)) - return self.suiteClass(tests) - - def loadTestsFromName(self, name, module=None): - """Return a suite of all tests cases given a string specifier. - - The name may resolve either to a module, a test case class, a - test method within a test case class, or a callable object which - returns a TestCase or TestSuite instance. - - The method optionally resolves the names relative to a given module. - """ - parts = name.split('.') - if module is None: - parts_copy = parts[:] - while parts_copy: - try: - module = __import__('.'.join(parts_copy)) - break - except ImportError: - del parts_copy[-1] - if not parts_copy: raise - parts = parts[1:] - obj = module - for part in parts: - parent, obj = obj, getattr(obj, part) - - if type(obj) == types.ModuleType: - return self.loadTestsFromModule(obj) - elif (isinstance(obj, (type, types.ClassType)) and - issubclass(obj, TestCase)): - return self.loadTestsFromTestCase(obj) - elif type(obj) == types.UnboundMethodType: - return parent(obj.__name__) - elif isinstance(obj, TestSuite): - return obj - elif callable(obj): - test = obj() - if not isinstance(test, (TestCase, TestSuite)): - raise ValueError, \ - "calling %s returned %s, not a test" % (obj,test) - return test - else: - raise ValueError, "don't know how to make test from: %s" % obj - - def loadTestsFromNames(self, names, module=None): - """Return a suite of all tests cases found using the given sequence - of string specifiers. See 'loadTestsFromName()'. - """ - suites = [self.loadTestsFromName(name, module) for name in names] - return self.suiteClass(suites) - - def getTestCaseNames(self, testCaseClass): - """Return a sorted sequence of method names found within testCaseClass - """ - - def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix): - return attrname.startswith(prefix) and callable(getattr(testCaseClass, attrname)) - testFnNames = filter(isTestMethod, dir(testCaseClass)) - for baseclass in testCaseClass.__bases__: - for testFnName in self.getTestCaseNames(baseclass): - if testFnName not in testFnNames: # handle overridden methods - testFnNames.append(testFnName) - testFnNames.sort() - return testFnNames - - - -defaultTestLoader = TestLoader() - - -############################################################################## -# Patches for old functions: these functions should be considered obsolete -############################################################################## - -def _makeLoader(prefix, sortUsing, suiteClass=None): - loader = TestLoader() - loader.testMethodPrefix = prefix - if suiteClass: loader.suiteClass = suiteClass - return loader - -def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp): - return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass) - -def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, suiteClass=TestSuite): - return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass) - -def findTestCases(module, prefix='test', sortUsing=cmp, suiteClass=TestSuite): - return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module) - - -############################################################################## -# Text UI -############################################################################## - -class _WritelnDecorator: - """Used to decorate file-like objects with a handy 'writeln' method""" - def __init__(self,stream): - self.stream = stream - - def __getattr__(self, attr): - return getattr(self.stream,attr) - - def writeln(self, arg=None): - if arg: self.write(arg) - self.write('\n') # text-mode streams translate to \r\n if needed - - -class _TextTestResult(TestResult): - """A test result class that can print formatted text results to a stream. - - Used by TextTestRunner. - """ - separator1 = '=' * 70 - separator2 = '-' * 70 - - def __init__(self, stream, descriptions, verbosity): - TestResult.__init__(self) - self.stream = stream - self.showAll = verbosity > 1 - self.dots = verbosity == 1 - self.descriptions = descriptions - - def getDescription(self, test): - if self.descriptions: - return test.shortDescription() or str(test) - else: - return str(test) - - def startTest(self, test): - TestResult.startTest(self, test) - if self.showAll: - self.stream.write(self.getDescription(test)) - self.stream.write(" ... ") - - def addSuccess(self, test): - TestResult.addSuccess(self, test) - if self.showAll: - self.stream.writeln("ok") - elif self.dots: - self.stream.write('.') - - def addError(self, test, err): - TestResult.addError(self, test, err) - if self.showAll: - self.stream.writeln("ERROR") - elif self.dots: - self.stream.write('E') - - def addFailure(self, test, err): - TestResult.addFailure(self, test, err) - if self.showAll: - self.stream.writeln("FAIL") - elif self.dots: - self.stream.write('F') - - def printErrors(self): - if self.dots or self.showAll: - self.stream.writeln() - self.printErrorList('ERROR', self.errors) - self.printErrorList('FAIL', self.failures) - - def printErrorList(self, flavour, errors): - for test, err in errors: - self.stream.writeln(self.separator1) - self.stream.writeln("%s: %s" % (flavour,self.getDescription(test))) - self.stream.writeln(self.separator2) - self.stream.writeln("%s" % err) - - -class TextTestRunner: - """A test runner class that displays results in textual form. - - It prints out the names of tests as they are run, errors as they - occur, and a summary of the results at the end of the test run. - """ - def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1): - self.stream = _WritelnDecorator(stream) - self.descriptions = descriptions - self.verbosity = verbosity - - def _makeResult(self): - return _TextTestResult(self.stream, self.descriptions, self.verbosity) - - def run(self, test): - "Run the given test case or test suite." - result = self._makeResult() - startTime = time.time() - test(result) - stopTime = time.time() - timeTaken = stopTime - startTime - result.printErrors() - self.stream.writeln(result.separator2) - run = result.testsRun - self.stream.writeln("Ran %d test%s in %.3fs" % - (run, run != 1 and "s" or "", timeTaken)) - self.stream.writeln() - if not result.wasSuccessful(): - self.stream.write("FAILED (") - failed, errored = map(len, (result.failures, result.errors)) - if failed: - self.stream.write("failures=%d" % failed) - if errored: - if failed: self.stream.write(", ") - self.stream.write("errors=%d" % errored) - self.stream.writeln(")") - else: - self.stream.writeln("OK") - return result - - - -############################################################################## -# Facilities for running tests from the command line -############################################################################## - -class TestProgram: - """A command-line program that runs a set of tests; this is primarily - for making test modules conveniently executable. - """ - USAGE = """\ -Usage: %(progName)s [options] [test] [...] - -Options: - -h, --help Show this message - -v, --verbose Verbose output - -q, --quiet Minimal output - -Examples: - %(progName)s - run default set of tests - %(progName)s MyTestSuite - run suite 'MyTestSuite' - %(progName)s MyTestCase.testSomething - run MyTestCase.testSomething - %(progName)s MyTestCase - run all 'test*' test methods - in MyTestCase -""" - def __init__(self, module='__main__', defaultTest=None, - argv=None, testRunner=None, testLoader=defaultTestLoader): - if type(module) == type(''): - self.module = __import__(module) - for part in module.split('.')[1:]: - self.module = getattr(self.module, part) - else: - self.module = module - if argv is None: - argv = sys.argv - self.verbosity = 1 - self.defaultTest = defaultTest - self.testRunner = testRunner - self.testLoader = testLoader - self.progName = os.path.basename(argv[0]) - self.parseArgs(argv) - self.runTests() - - def usageExit(self, msg=None): - if msg: print msg - print self.USAGE % self.__dict__ - sys.exit(2) - - def parseArgs(self, argv): - import getopt - try: - options, args = getopt.getopt(argv[1:], 'hHvq', - ['help','verbose','quiet']) - for opt, value in options: - if opt in ('-h','-H','--help'): - self.usageExit() - if opt in ('-q','--quiet'): - self.verbosity = 0 - if opt in ('-v','--verbose'): - self.verbosity = 2 - if len(args) == 0 and self.defaultTest is None: - self.test = self.testLoader.loadTestsFromModule(self.module) - return - if len(args) > 0: - self.testNames = args - else: - self.testNames = (self.defaultTest,) - self.createTests() - except getopt.error, msg: - self.usageExit(msg) - - def createTests(self): - self.test = self.testLoader.loadTestsFromNames(self.testNames, - self.module) - - def runTests(self): - if self.testRunner is None: - self.testRunner = TextTestRunner(verbosity=self.verbosity) - result = self.testRunner.run(self.test) - sys.exit(not result.wasSuccessful()) - -main = TestProgram - - -############################################################################## -# Executing this module from the command line -############################################################################## - -if __name__ == "__main__": - main(module=None) diff --git a/test/zblog/alltests.py b/test/zblog/alltests.py deleted file mode 100644 index 34a188ed76..0000000000 --- a/test/zblog/alltests.py +++ /dev/null @@ -1,18 +0,0 @@ -import testenv; testenv.configure_for_tests() -from testlib import sa_unittest as unittest - -def suite(): - modules_to_test = ( - 'zblog.tests', - ) - alltests = unittest.TestSuite() - for name in modules_to_test: - mod = __import__(name) - for token in name.split('.')[1:]: - mod = getattr(mod, token) - alltests.addTest(unittest.findTestCases(mod, suiteClass=None)) - return alltests - - -if __name__ == '__main__': - testenv.main(suite()) diff --git a/test/zblog/mappers.py b/test/zblog/mappers.py index 0d789f3d0b..5203bd866a 100644 --- a/test/zblog/mappers.py +++ b/test/zblog/mappers.py @@ -1,8 +1,7 @@ """mapper.py - defines mappers for domain objects, mapping operations""" -import zblog.tables as tables -import zblog.user as user -from zblog.blog import * +import tables, user +from blog import * from sqlalchemy import * from sqlalchemy.orm import * import sqlalchemy.util as util diff --git a/test/zblog/tables.py b/test/zblog/tables.py index 4fce48a4c4..36c7aeb8b1 100644 --- a/test/zblog/tables.py +++ b/test/zblog/tables.py @@ -1,7 +1,6 @@ """application table metadata objects are described here.""" from sqlalchemy import * -from testlib import * metadata = MetaData() diff --git a/test/zblog/tests.py b/test/zblog/test_zblog.py similarity index 79% rename from test/zblog/tests.py rename to test/zblog/test_zblog.py index f784c27962..8170766cb2 100644 --- a/test/zblog/tests.py +++ b/test/zblog/test_zblog.py @@ -1,33 +1,39 @@ -import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy.orm import * -from testlib import * -from zblog import mappers, tables -from zblog.user import * -from zblog.blog import * +from sqlalchemy.test import * +import mappers, tables +from user import * +from blog import * class ZBlogTest(TestBase, AssertsExecutionResults): - def create_tables(self): + @classmethod + def create_tables(cls): tables.metadata.drop_all(bind=testing.db) tables.metadata.create_all(bind=testing.db) - def drop_tables(self): + + @classmethod + def drop_tables(cls): tables.metadata.drop_all(bind=testing.db) - def setUpAll(self): - self.create_tables() - def tearDownAll(self): - self.drop_tables() - def tearDown(self): + @classmethod + def setup_class(cls): + cls.create_tables() + @classmethod + def teardown_class(cls): + cls.drop_tables() + def teardown(self): pass - def setUp(self): + def setup(self): pass class SavePostTest(ZBlogTest): - def setUpAll(self): - super(SavePostTest, self).setUpAll() + @classmethod + def setup_class(cls): + super(SavePostTest, cls).setup_class() + mappers.zblog_mappers() global blog_id, user_id s = create_session(bind=testing.db) @@ -41,9 +47,10 @@ class SavePostTest(ZBlogTest): user_id = user.id s.close() - def tearDownAll(self): + @classmethod + def teardown_class(cls): clear_mappers() - super(SavePostTest, self).tearDownAll() + super(SavePostTest, cls).teardown_class() def testattach(self): """test that a transient/pending instance has proper bi-directional behavior. @@ -93,5 +100,3 @@ class SavePostTest(ZBlogTest): s.rollback() -if __name__ == "__main__": - testenv.main()