0.5.5
=======
+- general
+ - unit tests have been migrated from unittest to nose.
+ See README.unittests for information on how to run
+ the tests. [ticket:970]
- orm
- Fixed bug introduced in 0.5.4 whereby Composite types
fail when default-holding columns are flushed.
SQLALCHEMY UNIT TESTS
=====================
-SETUP
------
-SQLite support is required. These instructions assume standard Python 2.4 or
-higher.
-
-The 'test' directory must be on the PYTHONPATH.
+SQLAlchemy unit tests by default run using Python's built-in sqlite3
+module. If running on Python 2.4, pysqlite must be installed.
-cd into the SQLAlchemy distribution directory
+As of 0.5.5, unit tests are run using nose. Documentation and
+downloads for nose are available at:
-In bash:
+http://somethingaboutorange.com/mrl/projects/nose/0.11.1/index.html
- $ export PYTHONPATH=./test/
-On windows:
+SQLAlchemy implements a nose plugin that must be present when tests are run.
+This plugin is available when SQLAlchemy is installed via setuptools.
- C:\sa\> set PYTHONPATH=test\
+SETUP
+-----
- Adjust any other use Unix-style paths in this README as needed.
+All that's required is for SQLAlchemy to be installed via setuptools.
+For example, to create a local install in a source distribution directory:
-The unittest framework will automatically prepend the lib/ directory to
-sys.path. This forces the local version of SQLAlchemy to be used, bypassing
-any setuptools-installed installations (setuptools places .egg files ahead of
-plain directories, even if on PYTHONPATH, unfortunately).
+ $ export PYTHONPATH=.
+ $ python setup.py develop -d .
+The above will create a setuptools "development" distribution in the local
+path, which allows the Nose plugin to be available when nosetests is run.
+The plugin is enabled using the "with-sqlalchemy=True" configuration
+in setup.cfg.
RUNNING ALL TESTS
-----------------
To run all tests:
- $ python test/alltests.py
+ $ nosetests
+
+Assuming all tests pass, this is a very unexciting output. To make it more
+intersesting:
+ $ nosetests -v
RUNNING INDIVIDUAL TESTS
-------------------------
-Any unittest module can be run directly from the module file:
+Any test module can be run directly by specifying its module name:
- python test/orm/mapper.py
+ $ nosetests test.orm.test_mapper
-To run a specific test within the module, specify it as ClassName.methodname:
+To run a specific test within the module, specify it as module:ClassName.methodname:
- python test/orm/mapper.py MapperTest.testget
+ $ nosetests test.orm.test_mapper:MapperTest.test_utils
COMMAND LINE OPTIONS
--------------------
-Help is available via --help
+Help is available via --help:
- $ python test/alltests.py --help
+ $ nosetests --help
- usage: alltests.py [options] [tests...]
-
- Options:
- -h, --help show this help message and exit
- --verbose enable stdout echoing/printing
- --quiet suppress output
- [...]
-
-Command line options can applied to alltests.py or any individual test module.
-Many are available. The most commonly used are '--db' and '--dburi'.
+The --help screen is a combination of common nose options and options which
+the SQLAlchemy nose plugin adds. The most commonly SQLAlchemy-specific
+options used are '--db' and '--dburi'.
DATABASE TARGETS
If you'll be running the tests frequently, database aliases can save a lot of
typing. The --dbs option lists the built-in aliases and their matching URLs:
- $ python test/alltests.py --dbs
+ $ nosetests --dbs
Available --db options (use --dburi to override)
mysql mysql://scott:tiger@127.0.0.1:3306/test
oracle oracle://scott:tiger@127.0.0.1:1521
To run tests against an aliased database:
- $ python test/alltests.py --db=postgres
+ $ nosetests --db=postgres
To customize the URLs with your own users or hostnames, make a simple .ini
file called `test.cfg` at the top level of the SQLAlchemy source distribution
Any log target can be directed to the console with command line options, such
as:
- $ python test/orm/unitofwork.py --log-info=sqlalchemy.orm.mapper \
+ $ nosetests test.orm.unitofwork --log-info=sqlalchemy.orm.mapper \
--log-debug=sqlalchemy.pool --log-info=sqlalchemy.engine
This would log mapper configuration, connection pool checkouts, and SQL
BUILT-IN COVERAGE REPORTING
------------------------------
-Coverage is tracked with coverage.py module, included in the './test/'
-directory. Running the test suite with the --coverage switch will generate a
-local file ".coverage" containing coverage details, and a report will be
-printed to standard output with an overview of the coverage gathered from the
-last unittest run (the file is deleted between runs).
-
-After the suite has been run with --coverage, an annotated version of any
-source file can be generated, marking statements that are executed with > and
-statements that are missed with !, by running the coverage.py utility with the
-"-a" (annotate) option, such as:
-
- $ python ./test/testlib/coverage.py -a ./lib/sqlalchemy/sql.py
+Coverage is tracked using Nose's coverage plugin. See the nose
+documentation for details. Basic usage is:
-This will create a new annotated file ./lib/sqlalchemy/sql.py,cover. Pretty
-cool!
+ $ nosetests test.sql.test_query --with-coverage
BIG COVERAGE TIP !!! There is an issue where existing .pyc files may
store the incorrect filepaths, which will break the coverage system. If
--- /dev/null
+import os
+import subprocess
+import re
+
+def walk():
+ for root, dirs, files in os.walk("./test/"):
+ if root.endswith("/perf"):
+ continue
+
+ for fname in files:
+ if not fname.endswith(".py"):
+ continue
+ if fname == "alltests.py":
+ subprocess.call(["svn", "remove", os.path.join(root, fname)])
+ elif fname.startswith("_") or fname == "__init__.py" or fname == "pickleable.py":
+ convert(os.path.join(root, fname))
+ elif not fname.startswith("test_"):
+ if os.path.exists(os.path.join(root, "test_" + fname)):
+ os.unlink(os.path.join(root, "test_" + fname))
+ subprocess.call(["svn", "rename", os.path.join(root, fname), os.path.join(root, "test_" + fname)])
+ convert(os.path.join(root, "test_" + fname))
+
+
+def convert(fname):
+ lines = list(file(fname))
+ replaced = []
+ flags = {}
+
+ while lines:
+ for reg, handler in handlers:
+ m = reg.match(lines[0])
+ if m:
+ handler(lines, replaced, flags)
+ break
+
+ post_handler(lines, replaced, flags)
+ f = file(fname, 'w')
+ f.write("".join(replaced))
+ f.close()
+
+handlers = []
+
+
+def post_handler(lines, replaced, flags):
+ imports = []
+ if "needs_eq" in flags:
+ imports.append("eq_")
+ if "needs_assert_raises" in flags:
+ imports += ["assert_raises", "assert_raises_message"]
+ if imports:
+ for i, line in enumerate(replaced):
+ if "import" in line:
+ replaced.insert(i, "from sqlalchemy.test.testing import %s\n" % ", ".join(imports))
+ break
+
+def remove_line(lines, replaced, flags):
+ lines.pop(0)
+
+handlers.append((re.compile(r"import testenv; testenv\.configure_for_tests"), remove_line))
+handlers.append((re.compile(r"(.*\s)?import sa_unittest"), remove_line))
+
+
+def import_testlib_sa(lines, replaced, flags):
+ line = lines.pop(0)
+ line = line.replace("import testlib.sa", "import sqlalchemy")
+ replaced.append(line)
+handlers.append((re.compile("import testlib\.sa"), import_testlib_sa))
+
+def from_testlib_sa(lines, replaced, flags):
+ line = lines.pop(0)
+ while True:
+ if line.endswith("\\\n"):
+ line = line[0:-2] + lines.pop(0)
+ else:
+ break
+
+ components = re.compile(r'from testlib\.sa import (.*)').match(line)
+ if components:
+ components = re.split(r"\s*,\s*", components.group(1))
+ line = "from sqlalchemy import %s\n" % (", ".join(c for c in components if c not in ("Table", "Column")))
+ replaced.append(line)
+ if "Table" in components:
+ replaced.append("from sqlalchemy.test.schema import Table\n")
+ if "Column" in components:
+ replaced.append("from sqlalchemy.test.schema import Column\n")
+ return
+
+ line = line.replace("testlib.sa", "sqlalchemy")
+ replaced.append(line)
+handlers.append((re.compile("from testlib\.sa.*import"), from_testlib_sa))
+
+def from_testlib(lines, replaced, flags):
+ line = lines.pop(0)
+
+ components = re.compile(r'from testlib import (.*)').match(line)
+ if components:
+ components = re.split(r"\s*,\s*", components.group(1))
+ if "sa" in components:
+ replaced.append("import sqlalchemy as sa\n")
+ replaced.append("from sqlalchemy.test import %s\n" % (", ".join(c for c in components if c != "sa" and c != "sa as tsa")))
+ return
+ elif "sa as tsa" in components:
+ replaced.append("import sqlalchemy as tsa\n")
+ replaced.append("from sqlalchemy.test import %s\n" % (", ".join(c for c in components if c != "sa" and c != "sa as tsa")))
+ return
+
+ line = line.replace("testlib", "sqlalchemy.test")
+ replaced.append(line)
+handlers.append((re.compile(r"from testlib"), from_testlib))
+
+def from_orm(lines, replaced, flags):
+ line = lines.pop(0)
+ line = line.replace("from orm import", "from test.orm import")
+ line = line.replace("from orm.", "from test.orm.")
+ replaced.append(line)
+handlers.append((re.compile(r'from orm( import|\.)'), from_orm))
+
+def assert_equals(lines, replaced, flags):
+ line = lines.pop(0)
+ line = line.replace("self.assertEquals", "eq_")
+ line = line.replace("self.assertEqual", "eq_")
+ replaced.append(line)
+ flags["needs_eq"] = True
+handlers.append((re.compile(r"\s*self\.assertEqual(s)?"), assert_equals))
+
+def assert_raises(lines, replaced, flags):
+ line = lines.pop(0)
+ line = line.replace("self.assertRaisesMessage", "assert_raises_message")
+ line = line.replace("self.assertRaises", "assert_raises")
+ replaced.append(line)
+ flags["needs_assert_raises"] = True
+handlers.append((re.compile(r"\s*self\.assertRaises(Message)?"), assert_raises))
+
+def setup_all(lines, replaced, flags):
+ line = lines.pop(0)
+ whitespace = re.compile(r"(\s*)def setUpAll\(self\)\:").match(line).group(1)
+ replaced.append("%s@classmethod\n" % whitespace)
+ replaced.append("%sdef setup_class(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setUpAll\(self\)"), setup_all))
+
+def teardown_all(lines, replaced, flags):
+ line = lines.pop(0)
+ whitespace = re.compile(r"(\s*)def tearDownAll\(self\)\:").match(line).group(1)
+ replaced.append("%s@classmethod\n" % whitespace)
+ replaced.append("%sdef teardown_class(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def tearDownAll\(self\)"), teardown_all))
+
+def setup(lines, replaced, flags):
+ line = lines.pop(0)
+ whitespace = re.compile(r"(\s*)def setUp\(self\)\:").match(line).group(1)
+ replaced.append("%sdef setup(self):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setUp\(self\)"), setup))
+
+def teardown(lines, replaced, flags):
+ line = lines.pop(0)
+ whitespace = re.compile(r"(\s*)def tearDown\(self\)\:").match(line).group(1)
+ replaced.append("%sdef teardown(self):\n" % whitespace)
+handlers.append((re.compile(r"\s*def tearDown\(self\)"), teardown))
+
+def define_tables(lines, replaced, flags):
+ line = lines.pop(0)
+ whitespace = re.compile(r"(\s*)def define_tables").match(line).group(1)
+ replaced.append("%s@classmethod\n" % whitespace)
+ replaced.append("%sdef define_tables(cls, metadata):\n" % whitespace)
+handlers.append((re.compile(r"\s*def define_tables\(self, metadata\)"), define_tables))
+
+def setup_mappers(lines, replaced, flags):
+ line = lines.pop(0)
+ whitespace = re.compile(r"(\s*)def setup_mappers").match(line).group(1)
+
+ i = -1
+ while re.match("\s*@testing", replaced[i]):
+ i -= 1
+
+ replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+ replaced.append("%sdef setup_mappers(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setup_mappers\(self\)"), setup_mappers))
+
+def setup_classes(lines, replaced, flags):
+ line = lines.pop(0)
+ whitespace = re.compile(r"(\s*)def setup_classes").match(line).group(1)
+
+ i = -1
+ while re.match("\s*@testing", replaced[i]):
+ i -= 1
+
+ replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+ replaced.append("%sdef setup_classes(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setup_classes\(self\)"), setup_classes))
+
+def insert_data(lines, replaced, flags):
+ line = lines.pop(0)
+ whitespace = re.compile(r"(\s*)def insert_data").match(line).group(1)
+
+ i = -1
+ while re.match("\s*@testing", replaced[i]):
+ i -= 1
+
+ replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+ replaced.append("%sdef insert_data(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def insert_data\(self\)"), insert_data))
+
+def fixtures(lines, replaced, flags):
+ line = lines.pop(0)
+ whitespace = re.compile(r"(\s*)def fixtures").match(line).group(1)
+
+ i = -1
+ while re.match("\s*@testing", replaced[i]):
+ i -= 1
+
+ replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+ replaced.append("%sdef fixtures(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def fixtures\(self\)"), fixtures))
+
+
+def call_main(lines, replaced, flags):
+ replaced.pop(-1)
+ lines.pop(0)
+handlers.append((re.compile(r"\s+testenv\.main\(\)"), call_main))
+
+def default(lines, replaced, flags):
+ replaced.append(lines.pop(0))
+handlers.append((re.compile(r".*"), default))
+
+
+if __name__ == '__main__':
+ convert("test/orm/inheritance/abc_inheritance.py")
+# walk()
--- /dev/null
+"""Testing environment and utilities.
+
+This package contains base classes and routines used by
+the unit tests. Tests are based on Nose and bootstrapped
+by noseplugin.NoseSQLAlchemy.
+
+"""
+
+from sqlalchemy.test import testing, engines, requires, profiling, pickleable, config
+from sqlalchemy.test.schema import Column, Table
+from sqlalchemy.test.testing import \
+ AssertsCompiledSQL, \
+ AssertsExecutionResults, \
+ ComparesTables, \
+ TestBase, \
+ rowset
+
+
+__all__ = ('testing',
+ 'Column', 'Table',
+ 'rowset',
+ 'TestBase', 'AssertsExecutionResults',
+ 'AssertsCompiledSQL', 'ComparesTables',
+ 'engines', 'profiling', 'pickleable')
+
+
--- /dev/null
+import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
+logging = None
+
+__all__ = 'parser', 'configure', 'options',
+
+db = None
+db_label, db_url, db_opts = None, None, {}
+
+options = None
+file_config = None
+
+base_config = """
+[db]
+sqlite=sqlite:///:memory:
+sqlite_file=sqlite:///querytest.db
+postgres=postgres://scott:tiger@127.0.0.1:5432/test
+mysql=mysql://scott:tiger@127.0.0.1:3306/test
+oracle=oracle://scott:tiger@127.0.0.1:1521
+oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
+mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
+firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb
+maxdb=maxdb://MONA:RED@/maxdb1
+"""
+
+def _log(option, opt_str, value, parser):
+ global logging
+ if not logging:
+ import logging
+ logging.basicConfig()
+
+ if opt_str.endswith('-info'):
+ logging.getLogger(value).setLevel(logging.INFO)
+ elif opt_str.endswith('-debug'):
+ logging.getLogger(value).setLevel(logging.DEBUG)
+
+
+def _list_dbs(*args):
+ print "Available --db options (use --dburi to override)"
+ for macro in sorted(file_config.options('db')):
+ print "%20s\t%s" % (macro, file_config.get('db', macro))
+ sys.exit(0)
+
+def _server_side_cursors(options, opt_str, value, parser):
+ db_opts['server_side_cursors'] = True
+
+def _engine_strategy(options, opt_str, value, parser):
+ if value:
+ db_opts['strategy'] = value
+
+class _ordered_map(object):
+ def __init__(self):
+ self._keys = list()
+ self._data = dict()
+
+ def __setitem__(self, key, value):
+ if key not in self._keys:
+ self._keys.append(key)
+ self._data[key] = value
+
+ def __iter__(self):
+ for key in self._keys:
+ yield self._data[key]
+
+# at one point in refactoring, modules were injecting into the config
+# process. this could probably just become a list now.
+post_configure = _ordered_map()
+
+def _engine_uri(options, file_config):
+ global db_label, db_url
+ db_label = 'sqlite'
+ if options.dburi:
+ db_url = options.dburi
+ db_label = db_url[:db_url.index(':')]
+ elif options.db:
+ db_label = options.db
+ db_url = None
+
+ if db_url is None:
+ if db_label not in file_config.options('db'):
+ raise RuntimeError(
+ "Unknown engine. Specify --dbs for known engines.")
+ db_url = file_config.get('db', db_label)
+post_configure['engine_uri'] = _engine_uri
+
+def _require(options, file_config):
+ if not(options.require or
+ (file_config.has_section('require') and
+ file_config.items('require'))):
+ return
+
+ try:
+ import pkg_resources
+ except ImportError:
+ raise RuntimeError("setuptools is required for version requirements")
+
+ cmdline = []
+ for requirement in options.require:
+ pkg_resources.require(requirement)
+ cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
+
+ if file_config.has_section('require'):
+ for label, requirement in file_config.items('require'):
+ if not label == db_label or label.startswith('%s.' % db_label):
+ continue
+ seen = [c for c in cmdline if requirement.startswith(c)]
+ if seen:
+ continue
+ pkg_resources.require(requirement)
+post_configure['require'] = _require
+
+def _engine_pool(options, file_config):
+ if options.mockpool:
+ from sqlalchemy import pool
+ db_opts['poolclass'] = pool.AssertionPool
+post_configure['engine_pool'] = _engine_pool
+
+def _create_testing_engine(options, file_config):
+ from sqlalchemy.test import engines, testing
+ global db
+ db = engines.testing_engine(db_url, db_opts)
+ testing.db = db
+post_configure['create_engine'] = _create_testing_engine
+
+def _prep_testing_database(options, file_config):
+ from sqlalchemy.test import engines
+ from sqlalchemy import schema
+
+ try:
+ # also create alt schemas etc. here?
+ if options.dropfirst:
+ e = engines.utf8_engine()
+ existing = e.table_names()
+ if existing:
+ print "Dropping existing tables in database: " + db_url
+ try:
+ print "Tables: %s" % ', '.join(existing)
+ except:
+ pass
+ print "Abort within 5 seconds..."
+ time.sleep(5)
+ md = schema.MetaData(e, reflect=True)
+ md.drop_all()
+ e.dispose()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except Exception, e:
+ warnings.warn(RuntimeWarning(
+ "Error checking for existing tables in testing "
+ "database: %s" % e))
+post_configure['prep_db'] = _prep_testing_database
+
+def _set_table_options(options, file_config):
+ from sqlalchemy.test import schema
+
+ table_options = schema.table_options
+ for spec in options.tableopts:
+ key, value = spec.split('=')
+ table_options[key] = value
+
+ if options.mysql_engine:
+ table_options['mysql_engine'] = options.mysql_engine
+post_configure['table_options'] = _set_table_options
+
+def _reverse_topological(options, file_config):
+ if options.reversetop:
+ from sqlalchemy.orm import unitofwork
+ from sqlalchemy import topological
+ class RevQueueDepSort(topological.QueueDependencySorter):
+ def __init__(self, tuples, allitems):
+ self.tuples = list(tuples)
+ self.allitems = list(allitems)
+ self.tuples.reverse()
+ self.allitems.reverse()
+ topological.QueueDependencySorter = RevQueueDepSort
+ unitofwork.DependencySorter = RevQueueDepSort
+post_configure['topological'] = _reverse_topological
+
import sys, types, weakref
from collections import deque
-from testlib import config
-from testlib.compat import _function_named, callable
+import config
+from sqlalchemy.util import function_named, callable
class ConnectionKiller(object):
def __init__(self):
fn(*args, **kw)
finally:
testing_reaper.assert_all_closed()
- return _function_named(decorated, fn.__name__)
+ return function_named(decorated, fn.__name__)
def rollback_open_connections(fn):
"""Decorator that rolls back all open connections after fn execution."""
fn(*args, **kw)
finally:
testing_reaper.rollback_all()
- return _function_named(decorated, fn.__name__)
+ return function_named(decorated, fn.__name__)
def close_open_connections(fn):
"""Decorator that closes all connections after fn execution."""
fn(*args, **kw)
finally:
testing_reaper.close_all()
- return _function_named(decorated, fn.__name__)
+ return function_named(decorated, fn.__name__)
def all_dialects():
import sqlalchemy.databases as d
"""Produce an engine configured by --options with optional overrides."""
from sqlalchemy import create_engine
- from testlib.assertsql import asserter
+ from sqlalchemy.test.assertsql import asserter
url = url or config.db_url
options = options or config.db_opts
--- /dev/null
+import logging
+import os
+import re
+import sys
+import time
+import warnings
+import ConfigParser
+import StringIO
+from config import db, db_label, db_url, file_config, base_config, \
+ post_configure, \
+ _list_dbs, _server_side_cursors, _engine_strategy, \
+ _engine_uri, _require, _engine_pool, \
+ _create_testing_engine, _prep_testing_database, \
+ _set_table_options, _reverse_topological, _log
+from sqlalchemy.test import testing, config, requires
+from nose.plugins import Plugin
+from nose.util import tolist
+import nose.case
+
+log = logging.getLogger('nose.plugins.sqlalchemy')
+
+class NoseSQLAlchemy(Plugin):
+ """
+ Handles the setup and extra properties required for testing SQLAlchemy
+ """
+ enabled = True
+ name = 'sqlalchemy'
+ score = 100
+
+ def options(self, parser, env=os.environ):
+ Plugin.options(self, parser, env)
+ opt = parser.add_option
+ #opt("--verbose", action="store_true", dest="verbose",
+ #help="enable stdout echoing/printing")
+ #opt("--quiet", action="store_true", dest="quiet", help="suppress output")
+ opt("--log-info", action="callback", type="string", callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)")
+ opt("--log-debug", action="callback", type="string", callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)")
+ opt("--require", action="append", dest="require", default=[],
+ help="require a particular driver or module version (multiple OK)")
+ opt("--db", action="store", dest="db", default="sqlite",
+ help="Use prefab database uri")
+ opt('--dbs', action='callback', callback=_list_dbs,
+ help="List available prefab dbs")
+ opt("--dburi", action="store", dest="dburi",
+ help="Database uri (overrides --db)")
+ opt("--dropfirst", action="store_true", dest="dropfirst",
+ help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)")
+ opt("--mockpool", action="store_true", dest="mockpool",
+ help="Use mock pool (asserts only one connection used)")
+ opt("--enginestrategy", action="callback", type="string",
+ callback=_engine_strategy,
+ help="Engine strategy (plain or threadlocal, defaults to plain)")
+ opt("--reversetop", action="store_true", dest="reversetop", default=False,
+ help="Reverse the collection ordering for topological sorts (helps "
+ "reveal dependency issues)")
+ opt("--unhashable", action="store_true", dest="unhashable", default=False,
+ help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
+ opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
+ help="Disallow SQLAlchemy from performing == on mapped test objects.")
+ opt("--truthless", action="store_true", dest="truthless", default=False,
+ help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
+ opt("--serverside", action="callback", callback=_server_side_cursors,
+ help="Turn on server side cursors for PG")
+ opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
+ help="Use the specified MySQL storage engine for all tables, default is "
+ "a db-default/InnoDB combo.")
+ opt("--table-option", action="append", dest="tableopts", default=[],
+ help="Add a dialect-specific table option, key=value")
+
+ global file_config
+ file_config = ConfigParser.ConfigParser()
+ file_config.readfp(StringIO.StringIO(base_config))
+ file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
+ config.file_config = file_config
+
+ def configure(self, options, conf):
+ Plugin.configure(self, options, conf)
+
+ import testing, requires
+ testing.db = db
+ testing.requires = requires
+
+ # Lazy setup of other options (post coverage)
+ for fn in post_configure:
+ fn(options, file_config)
+
+ def describeTest(self, test):
+ return ""
+
+ def wantClass(self, cls):
+ """Return true if you want the main test selector to collect
+ tests from this class, false if you don't, and None if you don't
+ care.
+
+ :Parameters:
+ cls : class
+ The class being examined by the selector
+
+ """
+
+ if not issubclass(cls, testing.TestBase):
+ return False
+ else:
+ if (hasattr(cls, '__whitelist__') and
+ testing.db.name in cls.__whitelist__):
+ return True
+ else:
+ return not self.__should_skip_for(cls)
+
+ def __should_skip_for(self, cls):
+ if hasattr(cls, '__requires__'):
+ def test_suite(): return 'ok'
+ for requirement in cls.__requires__:
+ check = getattr(requires, requirement)
+ if check(test_suite)() != 'ok':
+ # The requirement will perform messaging.
+ return True
+ if (hasattr(cls, '__unsupported_on__') and
+ testing.db.name in cls.__unsupported_on__):
+ print "'%s' unsupported on DB implementation '%s'" % (
+ cls.__class__.__name__, testing.db.name)
+ return True
+ if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)):
+ print "'%s' unsupported on DB implementation '%s'" % (
+ cls.__class__.__name__, testing.db.name)
+ return True
+ if (getattr(cls, '__skip_if__', False)):
+ for c in getattr(cls, '__skip_if__'):
+ if c():
+ print "'%s' skipped by %s" % (
+ cls.__class__.__name__, c.__name__)
+ return True
+ for rule in getattr(cls, '__excluded_on__', ()):
+ if testing._is_excluded(*rule):
+ print "'%s' unsupported on DB %s version %s" % (
+ cls.__class__.__name__, testing.db.name,
+ _server_version())
+ return True
+ return False
+
+ #def begin(self):
+ #pass
+
+ def beforeTest(self, test):
+ testing.resetwarnings()
+
+ def afterTest(self, test):
+ testing.resetwarnings()
+
+ #def handleError(self, test, err):
+ #pass
+
+ #def finalize(self, result=None):
+ #pass
import inspect, re
-from testlib import config, testing
-
-sa = None
-orm = None
+import config, testing
+from sqlalchemy import orm
__all__ = 'mapper',
return method
def mapper(type_, *args, **kw):
- global orm
- if orm is None:
- from sqlalchemy import orm
-
forbidden = [
('__hash__', 'unhashable', lambda s: id(s)),
('__eq__', 'noncomparable', lambda s, o: s is o),
-"""since the cPickle module as of py2.4 uses erroneous relative imports, define the various
-picklable classes here so we can test PickleType stuff without issue."""
+"""
+
+some objects used for pickle tests, declared in their own module so that they
+are easily pickleable.
+
+"""
class Foo(object):
-"""Profiling support for unit and performance tests."""
+"""Profiling support for unit and performance tests.
+
+These are special purpose profiling methods which operate
+in a more fine-grained way than nose's profiling plugin.
+
+"""
import os, sys
-from testlib.compat import _function_named
-import testlib.config
+from sqlalchemy.util import function_named
+import config
__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
elapsed, load_stats, result = _profile(
filename, fn, *args, **kw)
- if not testlib.config.options.quiet:
- print "Profiled target '%s', wall time: %.2f seconds" % (
- target, elapsed)
-
report = target_opts.get('report', profile_config['report'])
- if report and testlib.config.options.verbose:
+ if report:
sort_ = target_opts.get('sort', profile_config['sort'])
limit = target_opts.get('limit', profile_config['limit'])
print "Profile report for target '%s' (%s)" % (
#stats.print_callers()
os.unlink(filename)
return result
- return _function_named(profiled, fn.__name__)
+ return function_named(profiled, fn.__name__)
return decorator
def function_call_count(count=None, versions={}, variance=0.05):
stats = stat_loader()
calls = stats.total_calls
- if testlib.config.options.verbose:
- stats.sort_stats('calls', 'cumulative')
- stats.print_stats()
- #stats.print_callers()
+ stats.sort_stats('calls', 'cumulative')
+ stats.print_stats()
+ #stats.print_callers()
deviance = int(count * variance)
if (calls < (count - deviance) or
calls > (count + deviance)):
finally:
if os.path.exists(filename):
os.unlink(filename)
- return _function_named(counted, fn.__name__)
+ return function_named(counted, fn.__name__)
return decorator
def conditional_call_count(discriminator, categories):
rewrapped = function_call_count(*criteria)(fn)
return rewrapped(*args, **kw)
- return _function_named(at_runtime, fn.__name__)
+ return function_named(at_runtime, fn.__name__)
return decorator
"""
-from testlib.testing import \
+from testing import \
_block_unconditionally as no_support, \
_chain_decorators_on, \
exclude, \
-from testlib import testing
+"""Enhanced versions of schema.Table and schema.Column which establish
+desired state for different backends.
+"""
-schema = None
+from sqlalchemy.test import testing
+from sqlalchemy import schema
__all__ = 'Table', 'Column',
def Table(*args, **kw):
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
- global schema
- if schema is None:
- from sqlalchemy import schema
-
test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
if k.startswith('test_')])
def Column(*args, **kw):
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
- global schema
- if schema is None:
- from sqlalchemy import schema
-
test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
if k.startswith('test_')])
"""TestCase and TestSuite artifacts and testing decorators."""
-# monkeypatches unittest.TestLoader.suiteClass at import time
-
import itertools
import operator
import re
import sys
import types
-from testlib import sa_unittest as unittest
import warnings
from cStringIO import StringIO
-import testlib.config as config
-from testlib.compat import _function_named, callable
-
-# Delayed imports
-MetaData = None
-Session = None
-clear_mappers = None
-sa_exc = None
-schema = None
-sqltypes = None
-util = None
+from sqlalchemy.test import config, assertsql
+from sqlalchemy.util import function_named
+from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema
_ops = { '<': operator.lt,
'>': operator.gt,
raise AssertionError(
"Unexpected success for '%s' (condition: %s)" %
(fn_name, description))
- return _function_named(maybe, fn_name)
+ return function_named(maybe, fn_name)
return decorate
else:
raise AssertionError(
"Unexpected success for future test '%s'" % fn_name)
- return _function_named(decorated, fn_name)
+ return function_named(decorated, fn_name)
def fails_on(dbs, reason):
"""Mark a test as expected to fail on the specified database
raise AssertionError(
"Unexpected success for '%s' on DB implementation '%s'" %
(fn_name, config.db.name))
- return _function_named(maybe, fn_name)
+ return function_named(maybe, fn_name)
return decorate
def fails_on_everything_except(*dbs):
raise AssertionError(
"Unexpected success for '%s' on DB implementation '%s'" %
(fn_name, config.db.name))
- return _function_named(maybe, fn_name)
+ return function_named(maybe, fn_name)
return decorate
def crashes(db, reason):
return True
else:
return fn(*args, **kw)
- return _function_named(maybe, fn_name)
+ return function_named(maybe, fn_name)
return decorate
def _block_unconditionally(db, reason):
return True
else:
return fn(*args, **kw)
- return _function_named(maybe, fn_name)
+ return function_named(maybe, fn_name)
return decorate
return True
else:
return fn(*args, **kw)
- return _function_named(maybe, fn_name)
+ return function_named(maybe, fn_name)
return decorate
def _should_carp_about_exclusion(reason):
return True
else:
return fn(*args, **kw)
- return _function_named(maybe, fn_name)
+ return function_named(maybe, fn_name)
return decorate
def emits_warning(*messages):
# - update: jython looks ok, it uses cpython's module
def decorate(fn):
def safe(*args, **kw):
- global sa_exc
- if sa_exc is None:
- import sqlalchemy.exc as sa_exc
-
# todo: should probably be strict about this, too
filters = [dict(action='ignore',
category=sa_exc.SAPendingDeprecationWarning)]
return fn(*args, **kw)
finally:
resetwarnings()
- return _function_named(safe, fn.__name__)
+ return function_named(safe, fn.__name__)
return decorate
def emits_warning_on(db, *warnings):
else:
wrapped = emits_warning(*warnings)(fn)
return wrapped(*args, **kw)
- return _function_named(maybe, fn.__name__)
+ return function_named(maybe, fn.__name__)
return decorate
def uses_deprecated(*messages):
def decorate(fn):
def safe(*args, **kw):
- global sa_exc
- if sa_exc is None:
- import sqlalchemy.exc as sa_exc
-
# todo: should probably be strict about this, too
filters = [dict(action='ignore',
category=sa_exc.SAPendingDeprecationWarning)]
return fn(*args, **kw)
finally:
resetwarnings()
- return _function_named(safe, fn.__name__)
+ return function_named(safe, fn.__name__)
return decorate
def resetwarnings():
"""Reset warning behavior to testing defaults."""
- global sa_exc
- if sa_exc is None:
- import sqlalchemy.exc as sa_exc
-
warnings.filterwarnings('ignore',
category=sa_exc.SAPendingDeprecationWarning)
warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
assert a.startswith(fragment), msg or "%r does not start with %r" % (
a, fragment)
+def assert_raises(except_cls, callable_, *args, **kw):
+ try:
+ callable_(*args, **kw)
+ assert False, "Callable did not raise an exception"
+ except except_cls, e:
+ pass
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+ try:
+ callable_(*args, **kwargs)
+ assert False, "Callable did not raise an exception"
+ except except_cls, e:
+ assert re.search(msg, str(e)), "%r !~ %s" % (msg, e)
+
+def fail(msg):
+ assert False, msg
+
def fixture(table, columns, *rows):
"""Insert data into table after creation."""
def onload(event, schema_item, connection):
for column_values in rows])
table.append_ddl_listener('after-create', onload)
-def _import_by_name(name):
- submodule = name.split('.')[-1]
- return __import__(name, globals(), locals(), [submodule])
-
-class CompositeModule(types.ModuleType):
- """Merged attribute access for multiple modules."""
-
- # break the habit
- __all__ = ()
-
- def __init__(self, name, *modules, **overrides):
- """Construct a new lazy composite of modules.
-
- Modules may be string names or module-like instances. Individual
- attribute overrides may be specified as keyword arguments for
- convenience.
-
- The constructed module will resolve attribute access in reverse order:
- overrides, then each member of reversed(modules). Modules specified
- by name will be loaded lazily when encountered in attribute
- resolution.
-
- """
- types.ModuleType.__init__(self, name)
- self.__modules = list(reversed(modules))
- for key, value in overrides.iteritems():
- setattr(self, key, value)
-
- def __getattr__(self, key):
- for idx, mod in enumerate(self.__modules):
- if isinstance(mod, basestring):
- self.__modules[idx] = mod = _import_by_name(mod)
- if hasattr(mod, key):
- return getattr(mod, key)
- raise AttributeError(key)
-
-
def resolve_artifact_names(fn):
"""Decorator, augment function globals with tables and classes.
fn.func_code, context, fn.func_name, fn.func_defaults,
fn.func_closure)
return rebound(*args, **kwargs)
- return _function_named(resolved, fn.func_name)
+ return function_named(resolved, fn.func_name)
class adict(dict):
"""Dict keys available as attributes. Shadows."""
return tuple([self[key] for key in keys])
-class TestBase(unittest.TestCase):
+class TestBase(object):
# A sequence of database names to always run, regardless of the
# constraints below.
__whitelist__ = ()
# skipped.
__skip_if__ = None
-
_artifact_registries = ()
- _sa_first_test = False
- _sa_last_test = False
-
- def __init__(self, *args, **params):
- unittest.TestCase.__init__(self, *args, **params)
-
- def setUpAll(self):
- pass
-
- def tearDownAll(self):
- pass
-
- def shortDescription(self):
- """overridden to not return docstrings"""
- return None
-
- def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs):
- try:
- callable_(*args, **kwargs)
- assert False, "Callable did not raise an exception"
- except except_cls, e:
- assert re.search(msg, str(e)), "%r !~ %s" % (msg, e)
-
- if not hasattr(unittest.TestCase, 'assertTrue'):
- assertTrue = unittest.TestCase.failUnless
- if not hasattr(unittest.TestCase, 'assertFalse'):
- assertFalse = unittest.TestCase.failIf
-
+ def assert_(self, val, msg=None):
+ assert val, msg
+
class AssertsCompiledSQL(object):
def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None):
if dialect is None:
cc = re.sub(r'\n', '', str(c))
- self.assertEquals(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
+ eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
if checkparams is not None:
- self.assertEquals(c.construct_params(params), checkparams)
+ eq_(c.construct_params(params), checkparams)
class ComparesTables(object):
def assert_tables_equal(self, table, reflected_table):
- global sqltypes, schema
- if sqltypes is None:
- import sqlalchemy.types as sqltypes
- if schema is None:
- import sqlalchemy.schema as schema
base_mro = sqltypes.TypeEngine.__mro__
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
- self.assertEquals(c.name, reflected_c.name)
+ eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
- self.assertEquals(c.primary_key, reflected_c.primary_key)
- self.assertEquals(c.nullable, reflected_c.nullable)
+ eq_(c.primary_key, reflected_c.primary_key)
+ eq_(c.nullable, reflected_c.nullable)
assert len(
set(type(reflected_c.type).__mro__).difference(base_mro).intersection(
set(type(c.type).__mro__).difference(base_mro)
) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
if isinstance(c.type, sqltypes.String):
- self.assertEquals(c.type.length, reflected_c.type.length)
+ eq_(c.type.length, reflected_c.type.length)
- self.assertEquals(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
+ eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
if c.default:
assert isinstance(reflected_c.server_default,
schema.FetchedValue)
numbers of rows that the test suite manipulates.
"""
- global util
- if util is None:
- from sqlalchemy import util
-
class frozendict(dict):
def __hash__(self):
return id(self)
expected = set([frozendict(e) for e in expected])
for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
- self.fail('Unexpected type "%s", expected "%s"' % (
+ fail('Unexpected type "%s", expected "%s"' % (
type(wrong).__name__, cls.__name__))
if len(found) != len(expected):
- self.fail('Unexpected object count "%s", expected "%s"' % (
+ fail('Unexpected object count "%s", expected "%s"' % (
len(found), len(expected)))
NOVALUE = object()
found.remove(found_item)
break
else:
- self.fail(
+ fail(
"Expected %s instance with attributes %s not found." % (
cls.__name__, repr(expected_item)))
return True
def assert_sql_execution(self, db, callable_, *rules):
- from testlib import assertsql
assertsql.asserter.add_rules(rules)
try:
callable_()
assertsql.asserter.clear_rules()
def assert_sql(self, db, callable_, list_, with_sequences=None):
- from testlib import assertsql
-
if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
rules = with_sequences
else:
self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
- from testlib import assertsql
self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
-
-class TTestSuite(unittest.TestSuite):
- """A TestSuite with once per TestCase setUpAll() and tearDownAll()"""
-
- def __init__(self, tests=()):
- if len(tests) > 0 and isinstance(tests[0], TestBase):
- self._initTest = tests[0]
- else:
- self._initTest = None
-
- for t in tests:
- if isinstance(t, TestBase):
- t._sa_first_test = True
- break
- for t in reversed(tests):
- if isinstance(t, TestBase):
- t._sa_last_test = True
- break
- unittest.TestSuite.__init__(self, tests)
-
- def run(self, result):
- init = getattr(self, '_initTest', None)
- if init is not None:
- if (hasattr(init, '__whitelist__') and
- config.db.name in init.__whitelist__):
- pass
- else:
- if self.__should_skip_for(init):
- return True
- try:
- resetwarnings()
- init.setUpAll()
- except:
- # skip tests if global setup fails
- ex = self.__exc_info()
- for test in self._tests:
- result.addError(test, ex)
- return False
- try:
- resetwarnings()
- for test in self._tests:
- if result.shouldStop:
- break
- test(result)
- return result
- finally:
- try:
- resetwarnings()
- if init is not None:
- init.tearDownAll()
- except:
- result.addError(init, self.__exc_info())
- pass
-
- def __should_skip_for(self, cls):
- if hasattr(cls, '__requires__'):
- global requires
- if requires is None:
- from testing import requires
- def test_suite(): return 'ok'
- for requirement in cls.__requires__:
- check = getattr(requires, requirement)
- if check(test_suite)() != 'ok':
- # The requirement will perform messaging.
- return True
- if (hasattr(cls, '__unsupported_on__') and
- config.db.name in cls.__unsupported_on__):
- print "'%s' unsupported on DB implementation '%s'" % (
- cls.__class__.__name__, config.db.name)
- return True
- if (getattr(cls, '__only_on__', None) not in (None,config.db.name)):
- print "'%s' unsupported on DB implementation '%s'" % (
- cls.__class__.__name__, config.db.name)
- return True
- if (getattr(cls, '__skip_if__', False)):
- for c in getattr(cls, '__skip_if__'):
- if c():
- print "'%s' skipped by %s" % (
- cls.__class__.__name__, c.__name__)
- return True
- for rule in getattr(cls, '__excluded_on__', ()):
- if _is_excluded(*rule):
- print "'%s' unsupported on DB %s version %s" % (
- cls.__class__.__name__, config.db.name,
- _server_version())
- return True
- return False
-
-
- def __exc_info(self):
- """Return a version of sys.exc_info() with the traceback frame
- minimised; usually the top level of the traceback frame is not
- needed.
- ripped off out of unittest module since its double __
- """
- exctype, excvalue, tb = sys.exc_info()
- if sys.platform[:4] == 'java': ## tracebacks look different in Jython
- return (exctype, excvalue, tb)
- return (exctype, excvalue, tb)
-
-# monkeypatch
-unittest.TestLoader.suiteClass = TTestSuite
-
-
-class DevNullWriter(object):
- def write(self, msg):
- pass
- def flush(self):
- pass
-
-def runTests(suite):
- verbose = config.options.verbose
- quiet = config.options.quiet
- orig_stdout = sys.stdout
-
- try:
- if not verbose or quiet:
- sys.stdout = DevNullWriter()
- runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2)
- return runner.run(suite)
- finally:
- if not verbose or quiet:
- sys.stdout = orig_stdout
-
-def main(suite=None):
- if not suite:
- if sys.argv[1:]:
- suite =unittest.TestLoader().loadTestsFromNames(
- sys.argv[1:], __import__('__main__'))
- else:
- suite = unittest.TestLoader().loadTestsFromModule(
- __import__('__main__'))
-
- result = runTests(suite)
- sys.exit(not result.wasSuccessful())
[egg_info]
tag_build = dev
tag_svn_revision = true
+
+[nosetests]
+with-sqlalchemy = true
\ No newline at end of file
packages = find_packages('lib'),
package_dir = {'':'lib'},
license = "MIT License",
+
+ entry_points = {
+ 'nose.plugins.0.10': [
+ 'sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy',
+ ]
+ },
+
long_description = """\
SQLAlchemy is:
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
class CompileTest(TestBase, AssertsExecutionResults):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global t1, t2, metadata
metadata = MetaData()
t1 = Table('t1', metadata,
s = select([t1], t1.c.c2==t2.c.c1)
s.compile()
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
import gc
from sqlalchemy.orm import mapper, relation, create_session, clear_mappers, sessionmaker
from sqlalchemy.orm.mapper import _mapper_registry
from sqlalchemy.orm.session import _sessions
import operator
-from testlib import testing
-from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, PickleType
+from sqlalchemy.test import testing
+from sqlalchemy import MetaData, Integer, String, ForeignKey, PickleType
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
import sqlalchemy as sa
from sqlalchemy.sql import column
-from orm import _base
+from test.orm import _base
class A(_base.ComparableEntity):
assert len(_mapper_registry) == 0
class EnsureZeroed(_base.ORMTest):
- def setUp(self):
+ def setup(self):
_sessions.clear()
_mapper_registry.clear()
sess.expunge_all()
alist = sess.query(A).all()
- self.assertEquals(
+ eq_(
[
A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]),
A(col2="a2", bs=[]),
sess.expunge_all()
alist = sess.query(A).order_by(A.col1).all()
- self.assertEquals(
+ eq_(
[
A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]),
A(col2="a2", bs=[]),
sess.expunge_all()
alist = sess.query(A).order_by(A.col1).all()
- self.assertEquals(
+ eq_(
[
A(), A(), B(col3='b1'), B(col3='b2')
],
sess.expunge_all()
alist = sess.query(A).order_by(A.col1).all()
- self.assertEquals(
+ eq_(
[
A(bs=[B(col2='b1')]), A(bs=[B(col2='b2')])
],
cast.compile(dialect=dialect)
go()
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
from sqlalchemy.pool import QueuePool
def close(self):
pass
- def setUp(self):
+ def setup(self):
global pool
pool = QueuePool(creator=self.Connection,
pool_size=3, max_overflow=-1,
c2 = go()
-if __name__ == '__main__':
- testenv.main()
import datetime
import sys
import time
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
ITERATIONS = 1
self.test_baseline_8_drop()
-if __name__ == '__main__':
- testenv.main()
import datetime
import sys
import time
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import *
+from sqlalchemy.test import *
ITERATIONS = 1
self.test_baseline_7_drop()
-if __name__ == '__main__':
- testenv.main()
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-import orm.alltests as orm
-import base.alltests as base
-import sql.alltests as sql
-import engine.alltests as engine
-import dialect.alltests as dialect
-import ext.alltests as ext
-import zblog.alltests as zblog
-import profiling.alltests as profiling
-
-# The profiling tests are sensitive to foibles of CPython VM state, so
-# run them first. Ideally, each should be run in a fresh interpreter.
-
-def suite():
- alltests = unittest.TestSuite()
- for suite in (profiling, base, engine, sql, dialect, orm, ext, zblog):
- alltests.addTest(suite.suite())
- return alltests
-
-
-if __name__ == '__main__':
- testenv.main(suite())
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-def suite():
- modules_to_test = (
- # core utilities
- 'base.dependency',
- 'base.utils',
- 'base.except',
- )
- alltests = unittest.TestSuite()
- for name in modules_to_test:
- mod = __import__(name)
- for token in name.split('.')[1:]:
- mod = getattr(mod, token)
- alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
- return alltests
-
-
-if __name__ == '__main__':
- testenv.main(suite())
-import testenv; testenv.configure_for_tests()
import sqlalchemy.topological as topological
-from testlib import TestBase
+from sqlalchemy.test import TestBase
class DependencySortTest(TestBase):
head = topological.sort_as_tree(tuples, [])
-if __name__ == "__main__":
- testenv.main()
"""Tests exceptions and DB-API exception wrapping."""
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
import exceptions as stdlib_exceptions
from sqlalchemy import exc as sa_exceptions
+from sqlalchemy.test import TestBase
class Error(stdlib_exceptions.StandardError):
pass
-class WrapTest(unittest.TestCase):
+class WrapTest(TestBase):
def test_db_error_normal(self):
try:
raise sa_exceptions.DBAPIError.instance(
self.assert_(True)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
-import copy, threading, unittest
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import copy, threading
from sqlalchemy import util, sql, exc
-from testlib import TestBase
-from testlib.testing import eq_, is_, ne_
+from sqlalchemy.test import TestBase
+from sqlalchemy.test.testing import eq_, is_, ne_
class OrderedDictTest(TestBase):
def test_odict(self):
return True
-class IdentitySetTest(unittest.TestCase):
+class IdentitySetTest(TestBase):
def assert_eq(self, identityset, expected_iterable):
expected = sorted([id(o) for o in expected_iterable])
found = sorted([id(o) for o in identityset])
ids.discard(o1)
ids.add(o1)
ids.remove(o1)
- self.assertRaises(KeyError, ids.remove, o1)
+ assert_raises(KeyError, ids.remove, o1)
eq_(ids.copy(), ids)
except TypeError:
assert True
- self.assertRaises(TypeError, cmp, ids)
- self.assertRaises(TypeError, hash, ids)
+ assert_raises(TypeError, cmp, ids)
+ assert_raises(TypeError, hash, ids)
def test_difference(self):
os1 = util.IdentitySet([1,2,3])
eq_(os1 - os2, util.IdentitySet([1, 2]))
eq_(os2 - os1, util.IdentitySet([4, 5]))
- self.assertRaises(TypeError, lambda: os1 - s2)
- self.assertRaises(TypeError, lambda: os1 - [3, 4, 5])
- self.assertRaises(TypeError, lambda: s1 - os2)
- self.assertRaises(TypeError, lambda: s1 - [3, 4, 5])
+ assert_raises(TypeError, lambda: os1 - s2)
+ assert_raises(TypeError, lambda: os1 - [3, 4, 5])
+ assert_raises(TypeError, lambda: s1 - os2)
+ assert_raises(TypeError, lambda: s1 - [3, 4, 5])
-class OrderedIdentitySetTest(unittest.TestCase):
+class OrderedIdentitySetTest(TestBase):
def assert_eq(self, identityset, expected_iterable):
expected = [id(o) for o in expected_iterable]
eq_(s1.union(s2).intersection(s3), [a, d, f])
-class DictlikeIteritemsTest(unittest.TestCase):
+class DictlikeIteritemsTest(TestBase):
baseline = set([('a', 1), ('b', 2), ('c', 3)])
def _ok(self, instance):
eq_(set(iterator), self.baseline)
def _notok(self, instance):
- self.assertRaises(TypeError,
+ assert_raises(TypeError,
util.dictlike_iteritems,
instance)
def test_update(self):
data, wim = self._fixture()
- self.assertRaises(NotImplementedError, wim.update)
+ assert_raises(NotImplementedError, wim.update)
def test_weak_clear(self):
data, wim = self._fixture()
def test_instance(self):
obj = object()
- self.assertRaises(TypeError, util.as_interface, obj,
+ assert_raises(TypeError, util.as_interface, obj,
cls=self.Something)
- self.assertRaises(TypeError, util.as_interface, obj,
+ assert_raises(TypeError, util.as_interface, obj,
methods=('foo'))
- self.assertRaises(TypeError, util.as_interface, obj,
+ assert_raises(TypeError, util.as_interface, obj,
cls=self.Something, required=('foo'))
obj = self.Something()
for obj in partial, slotted:
eq_(obj, util.as_interface(obj, cls=self.Something))
- self.assertRaises(TypeError, util.as_interface, obj,
+ assert_raises(TypeError, util.as_interface, obj,
methods=('foo'))
eq_(obj, util.as_interface(obj, methods=('bar',)))
eq_(obj, util.as_interface(obj, cls=self.Something,
required=('bar',)))
- self.assertRaises(TypeError, util.as_interface, obj,
+ assert_raises(TypeError, util.as_interface, obj,
cls=self.Something, required=('foo',))
- self.assertRaises(TypeError, util.as_interface, obj,
+ assert_raises(TypeError, util.as_interface, obj,
cls=self.Something, required=self.Something)
def test_dict(self):
obj = {}
- self.assertRaises(TypeError, util.as_interface, obj,
+ assert_raises(TypeError, util.as_interface, obj,
cls=self.Something)
- self.assertRaises(TypeError, util.as_interface, obj,
+ assert_raises(TypeError, util.as_interface, obj,
methods=('foo'))
- self.assertRaises(TypeError, util.as_interface, obj,
+ assert_raises(TypeError, util.as_interface, obj,
cls=self.Something, required=('foo'))
def assertAdapted(obj, *methods):
res = util.as_interface(obj, methods=('foo', 'bar'), required=('foo',))
assertAdapted(res, 'foo', 'bar')
- self.assertRaises(TypeError, util.as_interface, obj, methods=('foo',))
+ assert_raises(TypeError, util.as_interface, obj, methods=('foo',))
- self.assertRaises(TypeError, util.as_interface, obj,
+ assert_raises(TypeError, util.as_interface, obj,
methods=('foo', 'bar', 'baz'), required=('baz',))
obj = {'foo': 123}
- self.assertRaises(TypeError, util.as_interface, obj, cls=self.Something)
+ assert_raises(TypeError, util.as_interface, obj, cls=self.Something)
class TestClassHierarchy(TestBase):
eq_(set(util.class_hierarchy(A)), set((A, B, object)))
-if __name__ == "__main__":
- testenv.main()
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-
-def suite():
- modules_to_test = (
- 'dialect.access',
- 'dialect.firebird',
- 'dialect.informix',
- 'dialect.maxdb',
- 'dialect.mssql',
- 'dialect.mysql',
- 'dialect.oracle',
- 'dialect.postgres',
- 'dialect.sqlite',
- 'dialect.sybase',
- )
- alltests = unittest.TestSuite()
- for name in modules_to_test:
- mod = __import__(name)
- for token in name.split('.')[1:]:
- mod = getattr(mod, token)
- alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
- return alltests
-
-
-if __name__ == '__main__':
- testenv.main(suite())
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy import sql
from sqlalchemy.databases import access
-from testlib import *
+from sqlalchemy.test import *
class CompileTest(TestBase, AssertsCompiledSQL):
'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % subst)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
from sqlalchemy import *
from sqlalchemy.databases import firebird
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.sql import table, column
-from testlib import *
+from sqlalchemy.test import *
class DomainReflectionTest(TestBase, AssertsExecutionResults):
__only_on__ = 'firebird'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
con = testing.db.connect()
try:
con.execute('CREATE DOMAIN int_domain AS INTEGER DEFAULT 42 NOT NULL')
NEW.question = gen_id(gen_testtable_id, 1);
END''')
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
con = testing.db.connect()
con.execute('DROP TABLE testtable')
con.execute('DROP DOMAIN int_domain')
def test_table_is_reflected(self):
metadata = MetaData(testing.db)
table = Table('testtable', metadata, autoload=True)
- self.assertEquals(set(table.columns.keys()),
+ eq_(set(table.columns.keys()),
set(['question', 'answer', 'remark', 'photo', 'd', 't', 'dt']),
"Columns of reflected table didn't equal expected columns")
- self.assertEquals(table.c.question.primary_key, True)
- self.assertEquals(table.c.question.sequence.name, 'gen_testtable_id')
- self.assertEquals(table.c.question.type.__class__, firebird.FBInteger)
- self.assertEquals(table.c.question.server_default.arg.text, "42")
- self.assertEquals(table.c.answer.type.__class__, firebird.FBString)
- self.assertEquals(table.c.answer.server_default.arg.text, "'no answer'")
- self.assertEquals(table.c.remark.type.__class__, firebird.FBText)
- self.assertEquals(table.c.remark.server_default.arg.text, "''")
- self.assertEquals(table.c.photo.type.__class__, firebird.FBBinary)
+ eq_(table.c.question.primary_key, True)
+ eq_(table.c.question.sequence.name, 'gen_testtable_id')
+ eq_(table.c.question.type.__class__, firebird.FBInteger)
+ eq_(table.c.question.server_default.arg.text, "42")
+ eq_(table.c.answer.type.__class__, firebird.FBString)
+ eq_(table.c.answer.server_default.arg.text, "'no answer'")
+ eq_(table.c.remark.type.__class__, firebird.FBText)
+ eq_(table.c.remark.server_default.arg.text, "''")
+ eq_(table.c.photo.type.__class__, firebird.FBBinary)
# The following assume a Dialect 3 database
- self.assertEquals(table.c.d.type.__class__, firebird.FBDate)
- self.assertEquals(table.c.t.type.__class__, firebird.FBTime)
- self.assertEquals(table.c.dt.type.__class__, firebird.FBDateTime)
+ eq_(table.c.d.type.__class__, firebird.FBDate)
+ eq_(table.c.t.type.__class__, firebird.FBTime)
+ eq_(table.c.dt.type.__class__, firebird.FBDateTime)
class CompileTest(TestBase, AssertsCompiledSQL):
table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute()
- self.assertEqual(result.fetchall(), [(1,)])
+ eq_(result.fetchall(), [(1,)])
result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
- self.assertEqual(result2.fetchall(), [(1,True),(2,False)])
+ eq_(result2.fetchall(), [(1,True),(2,False)])
finally:
table.drop()
try:
result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False})
- self.assertEqual(result.fetchall(), [(1,)])
+ eq_(result.fetchall(), [(1,)])
# Multiple inserts only return the last row
result2 = table.insert(firebird_returning=[table]).execute(
[{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
- self.assertEqual(result2.fetchall(), [(3,3,True)])
+ eq_(result2.fetchall(), [(3,3,True)])
result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False})
- self.assertEqual([dict(row) for row in result3], [{'ID':4}])
+ eq_([dict(row) for row in result3], [{'ID':4}])
result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons')
- self.assertEqual([dict(row) for row in result4], [{'PERSONS': 10}])
+ eq_([dict(row) for row in result4], [{'PERSONS': 10}])
finally:
table.drop()
table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
result = table.delete(table.c.persons > 4, firebird_returning=[table.c.id]).execute()
- self.assertEqual(result.fetchall(), [(1,)])
+ eq_(result.fetchall(), [(1,)])
result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
- self.assertEqual(result2.fetchall(), [(2,False),])
+ eq_(result2.fetchall(), [(2,False),])
finally:
table.drop()
assert len(version) == 3, "Got strange version info: %s" % repr(version)
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.databases import informix
-from testlib import *
+from sqlalchemy.test import *
class CompileTest(TestBase, AssertsCompiledSQL):
self.assert_compile(t1.update().values({t1.c.col1 : t1.c.col1 + 1}), 'UPDATE t1 SET col1=(t1.col1 + ?)')
-if __name__ == "__main__":
- testenv.main()
"""MaxDB-specific tests."""
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
import StringIO, sys
from sqlalchemy import *
from sqlalchemy import exc, sql
from decimal import Decimal
from sqlalchemy.databases import maxdb
-from testlib import *
+from sqlalchemy.test import *
# TODO
cols = ['d1','d2','n1','i1']
t.insert().execute(dict(zip(cols,vals)))
roundtrip = list(t.select().execute())
- self.assertEquals(roundtrip, [tuple([1] + vals)])
+ eq_(roundtrip, [tuple([1] + vals)])
t.insert().execute(dict(zip(['id'] + cols,
[2] + list(roundtrip[0][1:]))))
roundtrip2 = list(t.select(order_by=t.c.id).execute())
- self.assertEquals(roundtrip2, [tuple([1] + vals),
+ eq_(roundtrip2, [tuple([1] + vals),
tuple([2] + vals)])
finally:
try:
def test_modulo_operator(self):
st = str(select([sql.column('col') % 5]).compile(testing.db))
- self.assertEquals(st, 'SELECT mod(col, ?) FROM DUAL')
+ eq_(st, 'SELECT mod(col, ?) FROM DUAL')
-if __name__ == "__main__":
- testenv.main()
# -*- encoding: utf-8
-import testenv; testenv.configure_for_tests()
-import datetime, os, pickleable, re
+from sqlalchemy.test.testing import eq_
+import datetime, os, re
from sqlalchemy import *
from sqlalchemy import types, exc
from sqlalchemy.orm import *
from sqlalchemy.sql import table, column
from sqlalchemy.databases import mssql
import sqlalchemy.engine.url as url
-from testlib import *
-from testlib.testing import eq_
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
class CompileTest(TestBase, AssertsCompiledSQL):
__only_on__ = 'mssql'
__dialect__ = mssql.MSSQLDialect()
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, cattable
metadata = MetaData(testing.db)
PrimaryKeyConstraint('id', name='PK_cattable'),
)
- def setUp(self):
+ def setup(self):
metadata.create_all()
- def tearDown(self):
+ def teardown(self):
metadata.drop_all()
def test_compiled(self):
cattable.insert().values(id=9, description='Python').execute()
cats = cattable.select().order_by(cattable.c.id).execute()
- self.assertEqual([(9, 'Python')], list(cats))
+ eq_([(9, 'Python')], list(cats))
result = cattable.insert().values(description='PHP').execute()
- self.assertEqual([10], result.last_inserted_ids())
+ eq_([10], result.last_inserted_ids())
lastcat = cattable.select().order_by(desc(cattable.c.id)).execute()
- self.assertEqual((10, 'PHP'), lastcat.fetchone())
+ eq_((10, 'PHP'), lastcat.fetchone())
def test_executemany(self):
cattable.insert().execute([
])
cats = cattable.select().order_by(cattable.c.id).execute()
- self.assertEqual([(1, 'Java'), (3, 'Perl'), (8, 'Ruby'), (89, 'Python')], list(cats))
+ eq_([(1, 'Java'), (3, 'Perl'), (8, 'Ruby'), (89, 'Python')], list(cats))
cattable.insert().execute([
{'description': 'PHP'},
])
lastcats = cattable.select().order_by(desc(cattable.c.id)).limit(2).execute()
- self.assertEqual([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats))
+ eq_([(91, 'Smalltalk'), (90, 'PHP')], list(lastcats))
class ReflectionTest(TestBase):
class GenerativeQueryTest(TestBase):
__only_on__ = 'mssql'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global foo, metadata
metadata = MetaData(testing.db)
foo = Table('foo', metadata,
sess.add(Foo(bar=i, range=i%10))
sess.flush()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
clear_mappers()
class SchemaTest(TestBase):
- def setUp(self):
+ def setup(self):
t = Table('sometable', MetaData(),
Column('pk_column', Integer),
Column('test_column', String)
__only_on__ = 'mssql'
__skip_if__ = (full_text_search_missing, )
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, cattable, matchtable
metadata = MetaData(testing.db)
])
DDL("WAITFOR DELAY '00:00:05'").execute(bind=engines.testing_engine())
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
connection = testing.db.connect()
connection.execute("DROP FULLTEXT CATALOG Catalog")
def test_simple_match(self):
results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([2, 5], [r.id for r in results])
+ eq_([2, 5], [r.id for r in results])
def test_simple_match_with_apostrophe(self):
results = matchtable.select().where(matchtable.c.title.match('"Matz''s"')).execute().fetchall()
- self.assertEquals([3], [r.id for r in results])
+ eq_([3], [r.id for r in results])
def test_simple_prefix_match(self):
results = matchtable.select().where(matchtable.c.title.match('"nut*"')).execute().fetchall()
- self.assertEquals([5], [r.id for r in results])
+ eq_([5], [r.id for r in results])
def test_simple_inflectional_match(self):
results = matchtable.select().where(matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")')).execute().fetchall()
- self.assertEquals([2], [r.id for r in results])
+ eq_([2], [r.id for r in results])
def test_or_match(self):
results1 = matchtable.select().where(or_(matchtable.c.title.match('nutshell'),
matchtable.c.title.match('ruby'))
).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([3, 5], [r.id for r in results1])
+ eq_([3, 5], [r.id for r in results1])
results2 = matchtable.select().where(matchtable.c.title.match('nutshell OR ruby'),
).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([3, 5], [r.id for r in results2])
+ eq_([3, 5], [r.id for r in results2])
def test_and_match(self):
results1 = matchtable.select().where(and_(matchtable.c.title.match('python'),
matchtable.c.title.match('nutshell'))
).execute().fetchall()
- self.assertEquals([5], [r.id for r in results1])
+ eq_([5], [r.id for r in results1])
results2 = matchtable.select().where(matchtable.c.title.match('python AND nutshell'),
).execute().fetchall()
- self.assertEquals([5], [r.id for r in results2])
+ eq_([5], [r.id for r in results2])
def test_match_across_joins(self):
results = matchtable.select().where(and_(cattable.c.id==matchtable.c.category_id,
or_(cattable.c.description.match('Ruby'),
matchtable.c.title.match('nutshell')))
).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([1, 3, 5], [r.id for r in results])
+ eq_([1, 3, 5], [r.id for r in results])
class ParseConnectTest(TestBase, AssertsCompiledSQL):
u = url.make_url('mssql://mydsn')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
+ eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
def test_pyodbc_connect_old_style_dsn_trusted(self):
u = url.make_url('mssql:///?dsn=mydsn')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
+ eq_([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
def test_pyodbc_connect_dsn_non_trusted(self):
u = url.make_url('mssql://username:password@mydsn')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['dsn=mydsn;UID=username;PWD=password'], {}], connection)
+ eq_([['dsn=mydsn;UID=username;PWD=password'], {}], connection)
def test_pyodbc_connect_dsn_extra(self):
u = url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_english&foo=bar')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection)
+ eq_([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection)
def test_pyodbc_connect(self):
u = url.make_url('mssql://username:password@hostspec/database')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
+ eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_connect_comma_port(self):
u = url.make_url('mssql://username:password@hostspec:12345/database')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection)
+ eq_([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_connect_config_port(self):
u = url.make_url('mssql://username:password@hostspec/database?port=12345')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection)
+ eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection)
def test_pyodbc_extra_connect(self):
u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection)
+ eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection)
def test_pyodbc_odbc_connect(self):
u = url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
+ eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_odbc_connect_with_dsn(self):
u = url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection)
+ eq_([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection)
def test_pyodbc_odbc_connect_ignores_other_values(self):
u = url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
dialect = mssql.MSSQLDialect_pyodbc()
connection = dialect.create_connect_args(u)
- self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
+ eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
class TypesTest(TestBase):
__only_on__ = 'mssql'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global numeric_table, metadata
metadata = MetaData(testing.db)
- def tearDown(self):
+ def teardown(self):
metadata.drop_all()
def test_decimal_notation(self):
numeric_table.insert().execute(numericcol=value)
for value in select([numeric_table.c.numericcol]).execute():
- self.assertTrue(value[0] in test_items, "%s not in test_items" % value[0])
+ assert value[0] in test_items, "%s not in test_items" % value[0]
except Exception, e:
raise e
t.insert().execute(adate=d1, adatetime=d2, atime=t1)
- self.assertEquals(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
+ eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
finally:
t.drop(checkfirst=True)
class BinaryTest(TestBase, AssertsExecutionResults):
"""Test the Binary and VarBinary types"""
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global binary_table, MyPickleType
class MyPickleType(types.TypeDecorator):
)
binary_table.create()
- def tearDown(self):
+ def teardown(self):
binary_table.delete().execute()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
binary_table.drop()
def test_binary(self):
text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db)
):
l = stmt.execute().fetchall()
- self.assertEquals(list(stream1), list(l[0]['data']))
+ eq_(list(stream1), list(l[0]['data']))
paddedstream = list(stream1[0:100])
paddedstream.extend(['\x00'] * (100 - len(paddedstream)))
- self.assertEquals(paddedstream, list(l[0]['data_slice']))
+ eq_(paddedstream, list(l[0]['data_slice']))
- self.assertEquals(list(stream2), list(l[1]['data']))
- self.assertEquals(list(stream2), list(l[1]['data_image']))
- self.assertEquals(testobj1, l[0]['pickled'])
- self.assertEquals(testobj2, l[1]['pickled'])
- self.assertEquals(testobj3.moredata, l[0]['mypickle'].moredata)
- self.assertEquals(l[0]['mypickle'].stuff, 'this is the right stuff')
+ eq_(list(stream2), list(l[1]['data']))
+ eq_(list(stream2), list(l[1]['data_image']))
+ eq_(testobj1, l[0]['pickled'])
+ eq_(testobj2, l[1]['pickled'])
+ eq_(testobj3.moredata, l[0]['mypickle'].moredata)
+ eq_(l[0]['mypickle'].stuff, 'this is the right stuff')
def load_stream(self, name, len=3000):
- f = os.path.join(os.path.dirname(testenv.__file__), name)
+ f = os.path.join(os.path.dirname(__file__), "..", name)
return file(f).read(len)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
import sets
from sqlalchemy import *
from sqlalchemy import sql, exc
from sqlalchemy.databases import mysql
-from testlib.testing import eq_
-from testlib import *
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test import *
class TypesTest(TestBase, AssertsExecutionResults):
[set_table.c.s3],
set_table.c.s3.in_([set(['5']), set(['5', '7'])])).execute())
found = set([frozenset(row[0]) for row in rows])
- self.assertEquals(found,
+ eq_(found,
set([frozenset(['5']), frozenset(['5', '7'])]))
finally:
meta.drop_all()
if got != wanted:
print "Expected %s" % wanted
print "Found %s" % got
- self.assertEqual(got, wanted)
+ eq_(got, wanted)
class SQLTest(TestBase, AssertsCompiledSQL):
kw['prefixes'] = prefixes
return str(select(['q'], **kw).compile(dialect=dialect))
- self.assertEqual(gen(None), 'SELECT q')
- self.assertEqual(gen(True), 'SELECT DISTINCT q')
- self.assertEqual(gen(1), 'SELECT DISTINCT q')
- self.assertEqual(gen('diSTInct'), 'SELECT DISTINCT q')
- self.assertEqual(gen('DISTINCT'), 'SELECT DISTINCT q')
+ eq_(gen(None), 'SELECT q')
+ eq_(gen(True), 'SELECT DISTINCT q')
+ eq_(gen(1), 'SELECT DISTINCT q')
+ eq_(gen('diSTInct'), 'SELECT DISTINCT q')
+ eq_(gen('DISTINCT'), 'SELECT DISTINCT q')
# Standard SQL
- self.assertEqual(gen('all'), 'SELECT ALL q')
- self.assertEqual(gen('distinctrow'), 'SELECT DISTINCTROW q')
+ eq_(gen('all'), 'SELECT ALL q')
+ eq_(gen('distinctrow'), 'SELECT DISTINCTROW q')
# Interaction with MySQL prefix extensions
- self.assertEqual(
+ eq_(
gen(None, ['straight_join']),
'SELECT straight_join q')
- self.assertEqual(
+ eq_(
gen('all', ['HIGH_PRIORITY SQL_SMALL_RESULT']),
'SELECT HIGH_PRIORITY SQL_SMALL_RESULT ALL q')
- self.assertEqual(
+ eq_(
gen(True, ['high_priority', sql.text('sql_cache')]),
'SELECT high_priority sql_cache DISTINCT q')
class RawReflectionTest(TestBase):
- def setUp(self):
+ def setup(self):
self.dialect = mysql.dialect()
self.reflector = mysql.MySQLSchemaReflector(
self.dialect.identifier_preparer)
class MatchTest(TestBase, AssertsCompiledSQL):
__only_on__ = 'mysql'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, cattable, matchtable
metadata = MetaData(testing.db)
'category_id': 1}
])
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def test_expression(self):
order_by(matchtable.c.id).
execute().
fetchall())
- self.assertEquals([2, 5], [r.id for r in results])
+ eq_([2, 5], [r.id for r in results])
def test_simple_match_with_apostrophe(self):
results = (matchtable.select().
where(matchtable.c.title.match('"Matz''s"')).
execute().
fetchall())
- self.assertEquals([3], [r.id for r in results])
+ eq_([3], [r.id for r in results])
def test_or_match(self):
results1 = (matchtable.select().
order_by(matchtable.c.id).
execute().
fetchall())
- self.assertEquals([3, 5], [r.id for r in results1])
+ eq_([3, 5], [r.id for r in results1])
results2 = (matchtable.select().
where(matchtable.c.title.match('nutshell ruby')).
order_by(matchtable.c.id).
execute().
fetchall())
- self.assertEquals([3, 5], [r.id for r in results2])
+ eq_([3, 5], [r.id for r in results2])
def test_and_match(self):
matchtable.c.title.match('nutshell'))).
execute().
fetchall())
- self.assertEquals([5], [r.id for r in results1])
+ eq_([5], [r.id for r in results1])
results2 = (matchtable.select().
where(matchtable.c.title.match('+python +nutshell')).
execute().
fetchall())
- self.assertEquals([5], [r.id for r in results2])
+ eq_([5], [r.id for r in results2])
def test_match_across_joins(self):
results = (matchtable.select().
order_by(matchtable.c.id).
execute().
fetchall())
- self.assertEquals([1, 3, 5], [r.id for r in results])
+ eq_([1, 3, 5], [r.id for r in results])
def colspec(c):
return testing.db.dialect.schemagenerator(testing.db.dialect,
testing.db, None, None).get_column_specification(c)
-if __name__ == "__main__":
- testenv.main()
# coding: utf-8
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
from sqlalchemy import *
from sqlalchemy.sql import table, column
from sqlalchemy.databases import oracle
-from testlib import *
-from testlib.testing import eq_
-from testlib.engines import testing_engine
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.engines import testing_engine
import os
class OutParamTest(TestBase, AssertsExecutionResults):
__only_on__ = 'oracle'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
testing.db.execute("""
create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT number, z_out OUT varchar) IS
retval number;
result = testing.db.execute(text("begin foo(:x_in, :x_out, :y_out, :z_out); end;", bindparams=[bindparam('x_in', Numeric), outparam('x_out', Numeric), outparam('y_out', Numeric), outparam('z_out', String)]), x_in=5)
assert result.out_parameters == {'x_out':10, 'y_out':75, 'z_out':None}, result.out_parameters
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
testing.db.execute("DROP PROCEDURE foo")
try:
parent.insert().execute({'pid':1})
child.insert().execute({'cid':1, 'pid':1})
- self.assertEquals(child.select().execute().fetchall(), [(1, 1)])
+ eq_(child.select().execute().fetchall(), [(1, 1)])
finally:
meta.drop_all()
try:
parent.insert().execute({'pid':1})
child.insert().execute({'cid':1, 'pid':1})
- self.assertEquals(child.select().execute().fetchall(), [(1, 1)])
+ eq_(child.select().execute().fetchall(), [(1, 1)])
finally:
meta.drop_all()
class BufferedColumnTest(TestBase, AssertsCompiledSQL):
__only_on__ = 'oracle'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global binary_table, stream, meta
meta = MetaData(testing.db)
binary_table = Table('binary_table', meta,
for i in range(1, 11):
binary_table.insert().execute(id=i, data=stream)
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
meta.drop_all()
def test_fetch(self):
- self.assertEquals(
+ eq_(
binary_table.select().execute().fetchall() ,
[(i, stream) for i in range(1, 11)],
)
def test_fetch_single_arraysize(self):
eng = testing_engine(options={'arraysize':1})
- self.assertEquals(
+ eq_(
eng.execute(binary_table.select()).fetchall(),
[(i, stream) for i in range(1, 11)],
)
def test_basic(self):
assert testing.db.execute("/*+ this is a comment */ SELECT 1 FROM DUAL").fetchall() == [(1,)]
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import datetime
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy import exc
from sqlalchemy.databases import postgres
from sqlalchemy.engine.strategies import MockEngineStrategy
-from testlib import *
+from sqlalchemy.test import *
from sqlalchemy.sql import table, column
table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
result = table.update(table.c.persons > 4, dict(full=True), postgres_returning=[table.c.id]).execute()
- self.assertEqual(result.fetchall(), [(1,)])
+ eq_(result.fetchall(), [(1,)])
result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
- self.assertEqual(result2.fetchall(), [(1,True),(2,False)])
+ eq_(result2.fetchall(), [(1,True),(2,False)])
finally:
table.drop()
try:
result = table.insert(postgres_returning=[table.c.id]).execute({'persons': 1, 'full': False})
- self.assertEqual(result.fetchall(), [(1,)])
+ eq_(result.fetchall(), [(1,)])
@testing.fails_on('postgres', 'Known limitation of psycopg2')
def test_executemany():
# return value is documented as failing with psycopg2/executemany
result2 = table.insert(postgres_returning=[table]).execute(
[{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
- self.assertEqual(result2.fetchall(), [(2, 2, False), (3,3,True)])
+ eq_(result2.fetchall(), [(2, 2, False), (3,3,True)])
test_executemany()
result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False})
- self.assertEqual([dict(row) for row in result3], [{'double_id':8}])
+ eq_([dict(row) for row in result3], [{'double_id':8}])
result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, true) returning persons')
- self.assertEqual([dict(row) for row in result4], [{'persons': 10}])
+ eq_([dict(row) for row in result4], [{'persons': 10}])
finally:
table.drop()
class InsertTest(TestBase, AssertsExecutionResults):
__only_on__ = 'postgres'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata
metadata = MetaData(testing.db)
- def tearDown(self):
+ def teardown(self):
metadata.drop_all()
metadata.tables.clear()
__only_on__ = 'postgres'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
con = testing.db.connect()
for ddl in ('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42',
'CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0'):
con.execute('CREATE TABLE alt_schema.testtable(question integer, answer alt_schema.testdomain, anything integer)')
con.execute('CREATE TABLE crosschema (question integer, answer alt_schema.testdomain)')
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
con = testing.db.connect()
con.execute('DROP TABLE testtable')
con.execute('DROP TABLE alt_schema.testtable')
def test_table_is_reflected(self):
metadata = MetaData(testing.db)
table = Table('testtable', metadata, autoload=True)
- self.assertEquals(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
- self.assertEquals(table.c.answer.type.__class__, postgres.PGInteger)
+ eq_(set(table.columns.keys()), set(['question', 'answer']), "Columns of reflected table didn't equal expected columns")
+ eq_(table.c.answer.type.__class__, postgres.PGInteger)
def test_domain_is_reflected(self):
metadata = MetaData(testing.db)
table = Table('testtable', metadata, autoload=True)
- self.assertEquals(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value")
- self.assertFalse(table.columns.answer.nullable, "Expected reflected column to not be nullable.")
+ eq_(str(table.columns.answer.server_default.arg), '42', "Reflected default value didn't equal expected value")
+ assert not table.columns.answer.nullable, "Expected reflected column to not be nullable."
def test_table_is_reflected_alt_schema(self):
metadata = MetaData(testing.db)
table = Table('testtable', metadata, autoload=True, schema='alt_schema')
- self.assertEquals(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
- self.assertEquals(table.c.anything.type.__class__, postgres.PGInteger)
+ eq_(set(table.columns.keys()), set(['question', 'answer', 'anything']), "Columns of reflected table didn't equal expected columns")
+ eq_(table.c.anything.type.__class__, postgres.PGInteger)
def test_schema_domain_is_reflected(self):
metadata = MetaData(testing.db)
table = Table('testtable', metadata, autoload=True, schema='alt_schema')
- self.assertEquals(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
- self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
+ eq_(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
+ assert table.columns.answer.nullable, "Expected reflected column to be nullable."
def test_crosschema_domain_is_reflected(self):
metadata = MetaData(testing.db)
table = Table('crosschema', metadata, autoload=True)
- self.assertEquals(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
- self.assertTrue(table.columns.answer.nullable, "Expected reflected column to be nullable.")
+ eq_(str(table.columns.answer.server_default.arg), '0', "Reflected default value didn't equal expected value")
+ assert table.columns.answer.nullable, "Expected reflected column to be nullable."
def test_unknown_types(self):
from sqlalchemy.databases import postgres
postgres.ischema_names = {}
try:
m2 = MetaData(testing.db)
- self.assertRaises(exc.SAWarning, Table, "testtable", m2, autoload=True)
+ assert_raises(exc.SAWarning, Table, "testtable", m2, autoload=True)
@testing.emits_warning('Did not recognize type')
def warns():
t = Table('mytable', MetaData(testing.db),
Column('id', Integer, primary_key=True),
Column('a', String(8)))
- self.assertEquals(
+ eq_(
str(t.select(distinct=t.c.a)),
'SELECT DISTINCT ON (mytable.a) mytable.id, mytable.a \n'
'FROM mytable')
- self.assertEquals(
+ eq_(
str(t.select(distinct=['id','a'])),
'SELECT DISTINCT ON (id, a) mytable.id, mytable.a \n'
'FROM mytable')
- self.assertEquals(
+ eq_(
str(t.select(distinct=[t.c.id, t.c.a])),
'SELECT DISTINCT ON (mytable.id, mytable.a) mytable.id, mytable.a \n'
'FROM mytable')
users.insert().execute(id=3, name='name3')
users.insert().execute(id=4, name='name4')
- self.assertEquals(users.select().where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')])
- self.assertEquals(users.select(use_labels=True).where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')])
+ eq_(users.select().where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')])
+ eq_(users.select(use_labels=True).where(users.c.name=='name2').execute().fetchall(), [(2, 'name2')])
users.delete().where(users.c.id==3).execute()
- self.assertEquals(users.select().where(users.c.name=='name3').execute().fetchall(), [])
+ eq_(users.select().where(users.c.name=='name3').execute().fetchall(), [])
users.update().where(users.c.name=='name4').execute(name='newname')
- self.assertEquals(users.select(use_labels=True).where(users.c.id==4).execute().fetchall(), [(4, 'newname')])
+ eq_(users.select(use_labels=True).where(users.c.id==4).execute().fetchall(), [(4, 'newname')])
finally:
users.drop()
__only_on__ = 'postgres'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global tztable, notztable, metadata
metadata = MetaData(testing.db)
Column("name", String(20)),
)
metadata.create_all()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def test_with_timezone(self):
class ArrayTest(TestBase, AssertsExecutionResults):
__only_on__ = 'postgres'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, arrtable
metadata = MetaData(testing.db)
Column('strarr', postgres.PGArray(String(convert_unicode=True)), nullable=False)
)
metadata.create_all()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def test_reflect_array_column(self):
metadata2 = MetaData(testing.db)
tbl = Table('arrtable', metadata2, autoload=True)
- self.assertTrue(isinstance(tbl.c.intarr.type, postgres.PGArray))
- self.assertTrue(isinstance(tbl.c.strarr.type, postgres.PGArray))
- self.assertTrue(isinstance(tbl.c.intarr.type.item_type, Integer))
- self.assertTrue(isinstance(tbl.c.strarr.type.item_type, String))
+ assert isinstance(tbl.c.intarr.type, postgres.PGArray)
+ assert isinstance(tbl.c.strarr.type, postgres.PGArray)
+ assert isinstance(tbl.c.intarr.type.item_type, Integer)
+ assert isinstance(tbl.c.strarr.type.item_type, String)
def test_insert_array(self):
arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
results = arrtable.select().execute().fetchall()
- self.assertEquals(len(results), 1)
- self.assertEquals(results[0]['intarr'], [1,2,3])
- self.assertEquals(results[0]['strarr'], ['abc','def'])
+ eq_(len(results), 1)
+ eq_(results[0]['intarr'], [1,2,3])
+ eq_(results[0]['strarr'], ['abc','def'])
arrtable.delete().execute()
def test_array_where(self):
arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
arrtable.insert().execute(intarr=[4,5,6], strarr='ABC')
results = arrtable.select().where(arrtable.c.intarr == [1,2,3]).execute().fetchall()
- self.assertEquals(len(results), 1)
- self.assertEquals(results[0]['intarr'], [1,2,3])
+ eq_(len(results), 1)
+ eq_(results[0]['intarr'], [1,2,3])
arrtable.delete().execute()
def test_array_concat(self):
arrtable.insert().execute(intarr=[1,2,3], strarr=['abc', 'def'])
results = select([arrtable.c.intarr + [4,5,6]]).execute().fetchall()
- self.assertEquals(len(results), 1)
- self.assertEquals(results[0][0], [1,2,3,4,5,6])
+ eq_(len(results), 1)
+ eq_(results[0][0], [1,2,3,4,5,6])
arrtable.delete().execute()
def test_array_subtype_resultprocessor(self):
arrtable.insert().execute(intarr=[4,5,6], strarr=[[u'm\xe4\xe4'], [u'm\xf6\xf6']])
arrtable.insert().execute(intarr=[1,2,3], strarr=[u'm\xe4\xe4', u'm\xf6\xf6'])
results = arrtable.select(order_by=[arrtable.c.intarr]).execute().fetchall()
- self.assertEquals(len(results), 2)
- self.assertEquals(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6'])
- self.assertEquals(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']])
+ eq_(len(results), 2)
+ eq_(results[0]['strarr'], [u'm\xe4\xe4', u'm\xf6\xf6'])
+ eq_(results[1]['strarr'], [[u'm\xe4\xe4'], [u'm\xf6\xf6']])
arrtable.delete().execute()
def test_array_mutability(self):
sess.flush()
sess.expunge_all()
foo = sess.query(Foo).get(1)
- self.assertEquals(foo.intarr, [1,2,3])
+ eq_(foo.intarr, [1,2,3])
foo.intarr.append(4)
sess.flush()
sess.expunge_all()
foo = sess.query(Foo).get(1)
- self.assertEquals(foo.intarr, [1,2,3,4])
+ eq_(foo.intarr, [1,2,3,4])
foo.intarr = []
sess.flush()
sess.expunge_all()
- self.assertEquals(foo.intarr, [])
+ eq_(foo.intarr, [])
foo.intarr = None
sess.flush()
sess.expunge_all()
- self.assertEquals(foo.intarr, None)
+ eq_(foo.intarr, None)
# Errors in r4217:
foo = Foo()
connection = engine.connect()
s = select([func.TIMESTAMP("12/25/07").label("ts")])
result = connection.execute(s).fetchone()
- self.assertEqual(result[0], datetime.datetime(2007, 12, 25, 0, 0))
+ eq_(result[0], datetime.datetime(2007, 12, 25, 0, 0))
class ServerSideCursorsTest(TestBase, AssertsExecutionResults):
__only_on__ = 'postgres'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global ss_engine
ss_engine = engines.testing_engine(options={'server_side_cursors':True})
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
ss_engine.dispose()
def test_uses_ss(self):
nextid = ss_engine.execute(Sequence('test_table_id_seq'))
test_table.insert().execute(id=nextid, data='data2')
- self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2')])
+ eq_(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2')])
test_table.update().where(test_table.c.id==2).values(data=test_table.c.data + ' updated').execute()
- self.assertEquals(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2 updated')])
+ eq_(test_table.select().execute().fetchall(), [(1, 'data1'), (2, 'data2 updated')])
test_table.delete().execute()
- self.assertEquals(test_table.count().scalar(), 0)
+ eq_(test_table.count().scalar(), 0)
finally:
test_table.drop(checkfirst=True)
__only_on__ = 'postgres'
__excluded_on__ = (('postgres', '<', (8, 3, 0)),)
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, table
metadata = MetaData(testing.db)
metadata.create_all()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def test_reflection(self):
__only_on__ = 'postgres'
__excluded_on__ = (('postgres', '<', (8, 3, 0)),)
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, cattable, matchtable
metadata = MetaData(testing.db)
{'id': 5, 'title': 'Python in a Nutshell', 'category_id': 1}
])
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def test_expression(self):
def test_simple_match(self):
results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([2, 5], [r.id for r in results])
+ eq_([2, 5], [r.id for r in results])
def test_simple_match_with_apostrophe(self):
results = matchtable.select().where(matchtable.c.title.match("Matz''s")).execute().fetchall()
- self.assertEquals([3], [r.id for r in results])
+ eq_([3], [r.id for r in results])
def test_simple_derivative_match(self):
results = matchtable.select().where(matchtable.c.title.match('nutshells')).execute().fetchall()
- self.assertEquals([5], [r.id for r in results])
+ eq_([5], [r.id for r in results])
def test_or_match(self):
results1 = matchtable.select().where(or_(matchtable.c.title.match('nutshells'),
matchtable.c.title.match('rubies'))
).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([3, 5], [r.id for r in results1])
+ eq_([3, 5], [r.id for r in results1])
results2 = matchtable.select().where(matchtable.c.title.match('nutshells | rubies'),
).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([3, 5], [r.id for r in results2])
+ eq_([3, 5], [r.id for r in results2])
def test_and_match(self):
results1 = matchtable.select().where(and_(matchtable.c.title.match('python'),
matchtable.c.title.match('nutshells'))
).execute().fetchall()
- self.assertEquals([5], [r.id for r in results1])
+ eq_([5], [r.id for r in results1])
results2 = matchtable.select().where(matchtable.c.title.match('python & nutshells'),
).execute().fetchall()
- self.assertEquals([5], [r.id for r in results2])
+ eq_([5], [r.id for r in results2])
def test_match_across_joins(self):
results = matchtable.select().where(and_(cattable.c.id==matchtable.c.category_id,
or_(cattable.c.description.match('Ruby'),
matchtable.c.title.match('nutshells')))
).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([1, 3, 5], [r.id for r in results])
+ eq_([1, 3, 5], [r.id for r in results])
-if __name__ == "__main__":
- testenv.main()
"""SQLite-specific tests."""
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import datetime
from sqlalchemy import *
from sqlalchemy import exc, sql
from sqlalchemy.databases import sqlite
-from testlib import *
+from sqlalchemy.test import *
class TestTypes(TestBase, AssertsExecutionResults):
meta.drop_all()
def test_string_dates_raise(self):
- self.assertRaises(TypeError, testing.db.execute, select([1]).where(bindparam("date", type_=Date)), date=str(datetime.date(2007, 10, 30)))
+ assert_raises(TypeError, testing.db.execute, select([1]).where(bindparam("date", type_=Date)), date=str(datetime.date(2007, 10, 30)))
def test_time_microseconds(self):
dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125) # 125 usec
- self.assertEquals(str(dt), '2008-06-27 12:00:00.000125')
+ eq_(str(dt), '2008-06-27 12:00:00.000125')
sldt = sqlite.SLDateTime()
bp = sldt.bind_processor(None)
- self.assertEquals(bp(dt), '2008-06-27 12:00:00.000125')
+ eq_(bp(dt), '2008-06-27 12:00:00.000125')
rp = sldt.result_processor(None)
- self.assertEquals(rp(bp(dt)), dt)
+ eq_(rp(bp(dt)), dt)
sldt.__legacy_microseconds__ = True
bp = sldt.bind_processor(None)
- self.assertEquals(bp(dt), '2008-06-27 12:00:00.125')
- self.assertEquals(rp(bp(dt)), dt)
+ eq_(bp(dt), '2008-06-27 12:00:00.125')
+ eq_(rp(bp(dt)), dt)
def test_no_convert_unicode(self):
"""test no utf-8 encoding occurs"""
rt = Table('t_defaults', m2, autoload=True)
expected = [c[1] for c in specs]
for i, reflected in enumerate(rt.c):
- self.assertEquals(reflected.server_default.arg.text, expected[i])
+ eq_(reflected.server_default.arg.text, expected[i])
finally:
m.drop_all()
rt = Table('r_defaults', m, autoload=True)
for i, reflected in enumerate(rt.c):
- self.assertEquals(reflected.server_default.arg.text, expected[i])
+ eq_(reflected.server_default.arg.text, expected[i])
finally:
db.execute("DROP TABLE r_defaults")
schema='alt_schema')
meta.create_all(cx)
- self.assertEquals(dialect.table_names(cx, 'alt_schema'),
+ eq_(dialect.table_names(cx, 'alt_schema'),
['created'])
assert len(alt_master.c) > 0
table.insert().execute()
rows = table.select().execute().fetchall()
- self.assertEquals(len(rows), wanted)
+ eq_(len(rows), wanted)
finally:
table.drop()
@testing.exclude('sqlite', '<', (3, 3, 8), 'no database support')
def test_empty_insert_pk2(self):
- self.assertRaises(
+ assert_raises(
exc.DBAPIError,
self._test_empty_insert,
Table('b', MetaData(testing.db),
@testing.exclude('sqlite', '<', (3, 3, 8), 'no database support')
def test_empty_insert_pk3(self):
- self.assertRaises(
+ assert_raises(
exc.DBAPIError,
self._test_empty_insert,
Table('c', MetaData(testing.db),
__only_on__ = 'sqlite'
__skip_if__ = (full_text_search_missing, )
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, cattable, matchtable
metadata = MetaData(testing.db)
{'id': 5, 'title': 'Python in a Nutshell', 'category_id': 1}
])
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def test_expression(self):
def test_simple_match(self):
results = matchtable.select().where(matchtable.c.title.match('python')).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([2, 5], [r.id for r in results])
+ eq_([2, 5], [r.id for r in results])
def test_simple_prefix_match(self):
results = matchtable.select().where(matchtable.c.title.match('nut*')).execute().fetchall()
- self.assertEquals([5], [r.id for r in results])
+ eq_([5], [r.id for r in results])
def test_or_match(self):
results2 = matchtable.select().where(matchtable.c.title.match('nutshell OR ruby'),
).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([3, 5], [r.id for r in results2])
+ eq_([3, 5], [r.id for r in results2])
def test_and_match(self):
results2 = matchtable.select().where(matchtable.c.title.match('python nutshell'),
).execute().fetchall()
- self.assertEquals([5], [r.id for r in results2])
+ eq_([5], [r.id for r in results2])
def test_match_across_joins(self):
results = matchtable.select().where(and_(cattable.c.id==matchtable.c.category_id,
cattable.c.description.match('Ruby'))
).order_by(matchtable.c.id).execute().fetchall()
- self.assertEquals([1, 3], [r.id for r in results])
+ eq_([1, 3], [r.id for r in results])
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy import sql
from sqlalchemy.databases import sybase
-from testlib import *
+from sqlalchemy.test import *
class CompileTest(TestBase, AssertsCompiledSQL):
-if __name__ == "__main__":
- testenv.main()
-from testlib import sa, testing
-from testlib.testing import adict
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy.test.testing import adict
class TablesTest(testing.TestBase):
tables = None
other_artifacts = None
- def setUpAll(self):
- if self.run_setup_bind is None:
- assert self.bind is not None
- assert self.run_deletes in (None, 'each')
- if self.run_inserts == 'once':
- assert self.run_deletes is None
+ @classmethod
+ def setup_class(cls):
+ if cls.run_setup_bind is None:
+ assert cls.bind is not None
+ assert cls.run_deletes in (None, 'each')
+ if cls.run_inserts == 'once':
+ assert cls.run_deletes is None
- cls = self.__class__
if cls.tables is None:
cls.tables = adict()
if cls.other_artifacts is None:
cls.other_artifacts = adict()
- if self.bind is None:
- setattr(type(self), 'bind', self.setup_bind())
-
- if self.metadata is None:
- setattr(type(self), 'metadata', sa.MetaData())
+ if cls.bind is None:
+ setattr(cls, 'bind', cls.setup_bind())
- if self.metadata.bind is None:
- self.metadata.bind = self.bind
+ if cls.metadata is None:
+ setattr(cls, 'metadata', sa.MetaData())
- if self.run_define_tables:
- self.define_tables(self.metadata)
- self.metadata.create_all()
- self.tables.update(self.metadata.tables)
+ if cls.metadata.bind is None:
+ cls.metadata.bind = cls.bind
- if self.run_inserts:
- self._load_fixtures()
- self.insert_data()
+ if cls.run_define_tables == 'once':
+ cls.define_tables(cls.metadata)
+ cls.metadata.create_all()
+ cls.tables.update(cls.metadata.tables)
- def setUp(self):
- if self._sa_first_test:
- return
+ if cls.run_inserts == 'once':
+ cls._load_fixtures()
+ cls.insert_data()
+ def setup(self):
cls = self.__class__
if self.setup_bind == 'each':
self._load_fixtures()
self.insert_data()
- def tearDown(self):
+ def teardown(self):
# no need to run deletes if tables are recreated on setup
if self.run_define_tables != 'each' and self.run_deletes:
for table in reversed(self.metadata.sorted_tables):
if self.run_dispose_bind == 'each':
self.dispose_bind(self.bind)
- def tearDownAll(self):
- self.metadata.drop_all()
+ @classmethod
+ def teardown_class(cls):
+ cls.metadata.drop_all()
- if self.dispose_bind:
- self.dispose_bind(self.bind)
+ if cls.dispose_bind:
+ cls.dispose_bind(cls.bind)
- self.metadata.bind = None
+ cls.metadata.bind = None
- if self.run_setup_bind is not None:
- self.bind = None
+ if cls.run_setup_bind is not None:
+ cls.bind = None
- def setup_bind(self):
+ @classmethod
+ def setup_bind(cls):
return testing.db
- def dispose_bind(self, bind):
+ @classmethod
+ def dispose_bind(cls, bind):
if hasattr(bind, 'dispose'):
bind.dispose()
elif hasattr(bind, 'close'):
bind.close()
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
raise NotImplementedError()
- def fixtures(self):
+ @classmethod
+ def fixtures(cls):
return {}
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
pass
def sql_count_(self, count, fn):
class AltEngineTest(testing.TestBase):
engine = None
- def setUpAll(self):
- type(self).engine = self.create_engine()
- testing.TestBase.setUpAll(self)
-
- def tearDownAll(self):
- testing.TestBase.tearDownAll(self)
- self.engine.dispose()
- type(self).engine = None
-
- def create_engine(self):
+ @classmethod
+ def setup_class(cls):
+ cls.engine = cls.create_engine()
+ super(AltEngineTest, cls).setup_class()
+
+ @classmethod
+ def teardown_class(cls):
+ cls.engine.dispose()
+ cls.engine = None
+ super(AltEngineTest, cls).teardown_class()
+
+ @classmethod
+ def create_engine(cls):
raise NotImplementedError
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-
-def suite():
- modules_to_test = (
- # connectivity, execution
- 'engine.parseconnect',
- 'engine.pool',
- 'engine.bind',
- 'engine.reconnect',
- 'engine.execute',
- 'engine.metadata',
- 'engine.transaction',
-
- # schema/tables
- 'engine.reflection',
- 'engine.ddlevents',
-
- )
- alltests = unittest.TestSuite()
- for name in modules_to_test:
- mod = __import__(name)
- for token in name.split('.')[1:]:
- mod = getattr(mod, token)
- alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
- return alltests
-
-
-if __name__ == '__main__':
- testenv.main(suite())
"""tests the "bind" attribute/argument across schema and SQL,
including the deprecated versions of these arguments"""
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
from sqlalchemy import engine, exc
from sqlalchemy import MetaData, ThreadLocalMetaData
-from testlib.sa import Table, Column, Integer, text
-from testlib import sa, testing
+from sqlalchemy import Integer, text
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as sa
+from sqlalchemy.test import testing
class BindTest(testing.TestBase):
meth()
assert False
except exc.UnboundExecutionError, e:
- self.assertEquals(
+ eq_(
str(e),
"The MetaData "
"is not bound to an Engine or Connection. "
meth()
assert False
except exc.UnboundExecutionError, e:
- self.assertEquals(
+ eq_(
str(e),
"The Table 'test_table' "
"is not bound to an Engine or Connection. "
meth()
assert False
except exc.UnboundExecutionError, e:
- self.assertEquals(
+ eq_(
str(e),
"The Table 'test_table' "
"is not bound to an Engine or Connection. "
metadata.drop_all(bind=testing.db)
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
from sqlalchemy.schema import DDL
from sqlalchemy import create_engine
-from testlib.sa import MetaData, Table, Column, Integer, String
-import testlib.sa as tsa
-from testlib import TestBase, testing, engines
+from sqlalchemy import MetaData, Integer, String
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase, testing, engines
class DDLEventTest(TestBase):
assert bind is self.bind
self.state = action
- def setUp(self):
+ def setup(self):
self.bind = engines.mock_engine()
self.metadata = MetaData()
self.table = Table('t', self.metadata, Column('id', Integer))
fn = lambda *a: None
table.append_ddl_listener('before-create', fn)
- self.assertRaises(LookupError, table.append_ddl_listener, 'blah', fn)
+ assert_raises(LookupError, table.append_ddl_listener, 'blah', fn)
metadata.append_ddl_listener('before-create', fn)
- self.assertRaises(LookupError, metadata.append_ddl_listener, 'blah', fn)
+ assert_raises(LookupError, metadata.append_ddl_listener, 'blah', fn)
class DDLExecutionTest(TestBase):
- def setUp(self):
+ def setup(self):
self.engine = engines.mock_engine()
self.metadata = MetaData(self.engine)
self.users = Table('users', self.metadata,
ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
- self.assertEquals(ddl._expand(sane_alone, bind), '-t-t')
- self.assertEquals(ddl._expand(sane_schema, bind), 's-t-s.t')
- self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
- self.assertEquals(ddl._expand(insane_schema, bind),
+ eq_(ddl._expand(sane_alone, bind), '-t-t')
+ eq_(ddl._expand(sane_schema, bind), 's-t-s.t')
+ eq_(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
+ eq_(ddl._expand(insane_schema, bind),
'"s s"-"t t"-"s s"."t t"')
# overrides are used piece-meal and verbatim.
ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s',
context={'schema':'S S', 'table': 'T T', 'bonus': 'b'})
- self.assertEquals(ddl._expand(sane_alone, bind), 'S S-T T-t-b')
- self.assertEquals(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b')
- self.assertEquals(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b')
- self.assertEquals(ddl._expand(insane_schema, bind),
+ eq_(ddl._expand(sane_alone, bind), 'S S-T T-t-b')
+ eq_(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b')
+ eq_(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b')
+ eq_(ddl._expand(insane_schema, bind),
'S S-T T-"s s"."t t"-b')
def test_filter(self):
cx = self.mock_engine()
assert repr(DDL('s', on='engine', context={'a':1}))
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
import re
from sqlalchemy.interfaces import ConnectionProxy
-from testlib.sa import MetaData, Table, Column, Integer, String, INT, \
- VARCHAR, func, bindparam
-import testlib.sa as tsa
-from testlib import TestBase, testing, engines
+from sqlalchemy import MetaData, Integer, String, INT, VARCHAR, func, bindparam
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase, testing, engines
users, metadata = None, None
class ExecuteTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global users, metadata
metadata = MetaData(testing.db)
users = Table('users', metadata,
)
metadata.create_all()
- def tearDown(self):
+ def teardown(self):
testing.db.connect().execute(users.delete())
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
@testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite')
def test_empty_insert(self):
"""test that execute() interprets [] as a list with no params"""
result = testing.db.execute(users.insert().values(user_name=bindparam('name')), [])
- self.assertEquals(result.rowcount, 1)
+ eq_(result.rowcount, 1)
class ProxyConnectionTest(TestBase):
@testing.fails_on('firebird', 'Data type unknown')
assert_stmts(cursor, cursor_stmts)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
import pickle
from sqlalchemy import MetaData
-from testlib.sa import Table, Column, Integer, String, UniqueConstraint, \
- CheckConstraint, ForeignKey
-import testlib.sa as tsa
-from testlib import TestBase, ComparesTables, testing, engines
-from testlib.testing import eq_
+from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase, ComparesTables, testing, engines
+from sqlalchemy.test.testing import eq_
class MetaDataTest(TestBase, ComparesTables):
def test_metadata_connect(self):
def test_nonexistent(self):
- self.assertRaises(tsa.exc.NoSuchTableError, Table,
+ assert_raises(tsa.exc.NoSuchTableError, Table,
'fake_table',
MetaData(testing.db), autoload=True)
class TableOptionsTest(TestBase):
- def setUp(self):
+ def setup(self):
self.engine = engines.mock_engine()
self.metadata = MetaData(self.engine)
table2.create()
assert [str(x) for x in self.engine.mock if 'CREATE VIRTUAL TABLE' in str(x)]
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
import ConfigParser, StringIO
import sqlalchemy.engine.url as url
from sqlalchemy import create_engine, engine_from_config
-import testlib.sa as tsa
-from testlib import TestBase
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase
class ParseConnectTest(TestBase):
pass
mock_dbapi = MockDBAPI()
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
import threading, time, gc
from sqlalchemy import pool, interfaces
-import testlib.sa as tsa
-from testlib import TestBase
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase
mcid = 1
class PoolTestBase(TestBase):
- def setUp(self):
+ def setup(self):
pool.clear_managers()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
pool.clear_managers()
class PoolTest(PoolTestBase):
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
import weakref
-from testlib.sa import select, MetaData, Table, Column, Integer, String, pool
-import testlib.sa as tsa
-from testlib import TestBase, testing, engines
+from sqlalchemy import select, MetaData, Integer, String, pool
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase, testing, engines
import time
import gc
db, dbapi = None, None
class MockReconnectTest(TestBase):
- def setUp(self):
+ def setup(self):
global db, dbapi
dbapi = MockDBAPI()
engine = None
class RealReconnectTest(TestBase):
- def setUp(self):
+ def setup(self):
global engine
engine = engines.reconnecting_engine()
- def tearDown(self):
+ def teardown(self):
engine.dispose()
def test_reconnect(self):
conn = engine.connect()
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
engine.test_shutdown()
assert conn.invalidated
assert conn.invalidated
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.invalidated
# one more time
if not e.connection_invalidated:
raise
assert conn.invalidated
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.invalidated
conn.close()
def test_null_pool(self):
engine = engines.reconnecting_engine(options=dict(poolclass=pool.NullPool))
conn = engine.connect()
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
engine.test_shutdown()
try:
raise
assert not conn.closed
assert conn.invalidated
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.invalidated
def test_close(self):
conn = engine.connect()
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
engine.test_shutdown()
conn.close()
conn = engine.connect()
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
def test_with_transaction(self):
conn = engine.connect()
trans = conn.begin()
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.closed
engine.test_shutdown()
assert not trans.is_active
assert conn.invalidated
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
assert not conn.invalidated
class RecycleTest(TestBase):
engine = engines.reconnecting_engine(options={'pool_recycle':1, 'pool_threadlocal':threadlocal})
conn = engine.contextual_connect()
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
conn.close()
engine.test_shutdown()
time.sleep(2)
conn = engine.contextual_connect()
- self.assertEquals(conn.execute(select([1])).scalar(), 1)
+ eq_(conn.execute(select([1])).scalar(), 1)
conn.close()
meta, table, engine = None, None, None
class InvalidateDuringResultTest(TestBase):
- def setUp(self):
+ def setup(self):
global meta, table, engine
engine = engines.reconnecting_engine()
meta = MetaData(engine)
[{'id':i, 'name':'row %d' % i} for i in range(1, 100)]
)
- def tearDown(self):
+ def teardown(self):
meta.drop_all()
engine.dispose()
assert conn.invalidated
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import StringIO, unicodedata
import sqlalchemy as sa
-from testlib.sa import MetaData, Table, Column
-from testlib import TestBase, ComparesTables, testing, engines, sa as tsa
+from sqlalchemy import MetaData
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase, ComparesTables, testing, engines
metadata, users = None, None
foo = Table('foo', meta2, autoload=True,
include_columns=['b', 'f', 'e'])
# test that cols come back in original order
- self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f'])
+ eq_([c.name for c in foo.c], ['b', 'e', 'f'])
for c in ('b', 'f', 'e'):
assert c in foo.c
for c in ('a', 'c', 'd'):
foo = Table('foo', meta3, autoload=True)
foo = Table('foo', meta3, include_columns=['b', 'f', 'e'],
useexisting=True)
- self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f'])
+ eq_([c.name for c in foo.c], ['b', 'e', 'f'])
for c in ('b', 'f', 'e'):
assert c in foo.c
for c in ('a', 'c', 'd'):
dialect_module.ischema_names = {}
try:
m2 = MetaData(testing.db)
- self.assertRaises(tsa.exc.SAWarning, Table, "test", m2, autoload=True)
+ assert_raises(tsa.exc.SAWarning, Table, "test", m2, autoload=True)
@testing.emits_warning('Did not recognize type')
def warns():
a2 = Table('a', m2, include_columns=['z'], autoload=True)
b2 = Table('b', m2, autoload=True)
- self.assertRaises(tsa.exc.NoReferencedColumnError, a2.join, b2)
+ assert_raises(tsa.exc.NoReferencedColumnError, a2.join, b2)
finally:
meta.drop_all()
Column('slot', sa.String(128)),
)
- self.assertRaisesMessage(tsa.exc.InvalidRequestError, "Could not find table 'pkgs' with which to generate a foreign key", metadata.create_all)
+ assert_raises_message(tsa.exc.InvalidRequestError, "Could not find table 'pkgs' with which to generate a foreign key", metadata.create_all)
def test_composite_pks(self):
"""test reflection of a composite primary key"""
m1.drop_all()
class CreateDropTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, users
metadata = MetaData()
users = Table('users', metadata,
def test_createdrop(self):
metadata.create_all(bind=testing.db)
- self.assertEqual( testing.db.has_table('items'), True )
- self.assertEqual( testing.db.has_table('email_addresses'), True )
+ eq_( testing.db.has_table('items'), True )
+ eq_( testing.db.has_table('email_addresses'), True )
metadata.create_all(bind=testing.db)
- self.assertEqual( testing.db.has_table('items'), True )
+ eq_( testing.db.has_table('items'), True )
metadata.drop_all(bind=testing.db)
- self.assertEqual( testing.db.has_table('items'), False )
- self.assertEqual( testing.db.has_table('email_addresses'), False )
+ eq_( testing.db.has_table('items'), False )
+ eq_( testing.db.has_table('email_addresses'), False )
metadata.drop_all(bind=testing.db)
- self.assertEqual( testing.db.has_table('items'), False )
+ eq_( testing.db.has_table('items'), False )
def test_tablenames(self):
metadata.create_all(bind=testing.db)
class HasSequenceTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, users
metadata = MetaData()
users = Table('users', metadata,
@testing.requires.sequences
def test_hassequence(self):
metadata.create_all(bind=testing.db)
- self.assertEqual(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), True)
+ eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), True)
metadata.drop_all(bind=testing.db)
- self.assertEqual(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False)
+ eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import sys, time, threading
-from testlib.sa import create_engine, MetaData, Table, Column, INT, VARCHAR, \
- Sequence, select, Integer, String, func, text
-from testlib import TestBase, testing
+from sqlalchemy import create_engine, MetaData, INT, VARCHAR, Sequence, select, Integer, String, func, text
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.test import TestBase, testing
users, metadata = None, None
class TransactionTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global users, metadata
metadata = MetaData()
users = Table('query_users', metadata,
)
users.create(testing.db)
- def tearDown(self):
+ def teardown(self):
testing.db.connect().execute(users.delete())
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
users.drop(testing.db)
def test_commits(self):
connection.execute(users.insert(), user_id=3, user_name='user3')
transaction.commit()
- self.assertEquals(
+ eq_(
connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
[(1,),(3,)]
)
connection.execute(users.insert(), user_id=3, user_name='user3')
transaction.commit()
- self.assertEquals(
+ eq_(
connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
[(1,),(2,),(3,)]
)
connection.execute(users.insert(), user_id=4, user_name='user4')
transaction.commit()
- self.assertEquals(
+ eq_(
connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
[(1,),(4,)]
)
transaction.prepare()
transaction.rollback()
- self.assertEquals(
+ eq_(
connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
[(1,),(2,)]
)
transaction.commit()
- self.assertEquals(
+ eq_(
connection.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
[(1,),(2,),(5,)]
)
connection.close()
connection2 = testing.db.connect()
- self.assertEquals(
+ eq_(
connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
[]
)
recoverables = connection2.recover_twophase()
- self.assertTrue(
- transaction.xid in recoverables
- )
+ assert transaction.xid in recoverables
connection2.commit_prepared(transaction.xid, recover=True)
- self.assertEquals(
+ eq_(
connection2.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
[(1,)]
)
xa.commit()
result = conn.execute(select([users.c.user_name]).order_by(users.c.user_id))
- self.assertEqual(result.fetchall(), [('user1',),('user4',)])
+ eq_(result.fetchall(), [('user1',),('user4',)])
conn.close()
class AutoRollbackTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata
metadata = MetaData()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all(testing.db)
def test_rollback_deadlock(self):
__only_on__ = 'postgres'
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, foo
metadata = MetaData(testing.db)
foo = Table('foo', metadata, Column('id', Integer, primary_key=True), Column('data', String(100)))
metadata.create_all()
testing.db.execute("create function insert_foo(varchar) returns integer as 'insert into foo(data) values ($1);select 1;' language sql")
- def tearDown(self):
+ def teardown(self):
foo.delete().execute()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
testing.db.execute("drop function insert_foo(varchar)")
metadata.drop_all()
tlengine = None
class TLTransactionTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global users, metadata, tlengine
tlengine = create_engine(testing.db.url, strategy='threadlocal')
metadata = MetaData()
test_needs_acid=True,
)
users.create(tlengine)
- def tearDown(self):
+ def teardown(self):
tlengine.execute(users.delete())
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
users.drop(tlengine)
tlengine.dispose()
def test_nested_unsupported(self):
- self.assertRaises(NotImplementedError, tlengine.contextual_connect().begin_nested)
- self.assertRaises(NotImplementedError, tlengine.begin_nested)
+ assert_raises(NotImplementedError, tlengine.contextual_connect().begin_nested)
+ assert_raises(NotImplementedError, tlengine.begin_nested)
def test_connection_close(self):
"""test that when connections are closed for real, transactions are rolled back and disposed."""
tlengine.prepare()
tlengine.rollback()
- self.assertEquals(
+ eq_(
tlengine.execute(select([users.c.user_id]).order_by(users.c.user_id)).fetchall(),
[(1,),(2,)]
)
counters = None
class ForUpdateTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global counters, metadata
metadata = MetaData()
counters = Table('forupdate_counters', metadata,
test_needs_acid=True,
)
counters.create(testing.db)
- def tearDown(self):
+ def teardown(self):
testing.db.connect().execute(counters.delete())
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
counters.drop(testing.db)
def increment(self, count, errors, update_style=True, delay=0.005):
self.assert_(len(errors) != 0)
-if __name__ == "__main__":
- testenv.main()
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-import doctest, sys
-
-from testlib import sa_unittest as unittest
-
-
-def suite():
- unittest_modules = (
- 'ext.declarative',
- 'ext.orderinglist',
- 'ext.associationproxy',
- 'ext.serializer',
- 'ext.compiler',
- )
-
- if sys.version_info < (2, 4):
- doctest_modules = ()
- else:
- doctest_modules = (
- ('sqlalchemy.ext.orderinglist', {'optionflags': doctest.ELLIPSIS}),
- ('sqlalchemy.ext.sqlsoup', {})
- )
-
- alltests = unittest.TestSuite()
- for name in unittest_modules:
- mod = __import__(name)
- for token in name.split('.')[1:]:
- mod = getattr(mod, token)
- alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
- for name, opts in doctest_modules:
- alltests.addTest(doctest.DocTestSuite(name, **opts))
- return alltests
-
-
-if __name__ == '__main__':
- testenv.main(suite())
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import gc
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.orm.collections import collection
from sqlalchemy.ext.associationproxy import *
-from testlib import *
+from sqlalchemy.test import *
class DictCollection(dict):
return iter(self.values)
class _CollectionOperations(TestBase):
- def setUp(self):
+ def setup(self):
collection_class = self.collection_class
metadata = MetaData(testing.db)
self.session = create_session()
self.Parent, self.Child = Parent, Child
- def tearDown(self):
+ def teardown(self):
self.metadata.drop_all()
def roundtrip(self, obj):
self.assert_(p1.children == after)
self.assert_([c.name for c in p1._children] == after)
- self.assertRaises(TypeError, set, [p1.children])
+ assert_raises(TypeError, set, [p1.children])
p1.children *= 0
after = []
except TypeError:
self.assert_(True)
- self.assertRaises(TypeError, set, [p1.children])
+ assert_raises(TypeError, set, [p1.children])
class SetTest(_CollectionOperations):
except TypeError:
self.assert_(True)
- self.assertRaises(TypeError, set, [p1.children])
+ assert_raises(TypeError, set, [p1.children])
def test_set_comparisons(self):
set(['c','d']), set(['e', 'f', 'g']),
set()):
- self.assertEqual(p1.children.union(other),
+ eq_(p1.children.union(other),
control.union(other))
- self.assertEqual(p1.children.difference(other),
+ eq_(p1.children.difference(other),
control.difference(other))
- self.assertEqual((p1.children - other),
+ eq_((p1.children - other),
(control - other))
- self.assertEqual(p1.children.intersection(other),
+ eq_(p1.children.intersection(other),
control.intersection(other))
- self.assertEqual(p1.children.symmetric_difference(other),
+ eq_(p1.children.symmetric_difference(other),
control.symmetric_difference(other))
- self.assertEqual(p1.children.issubset(other),
+ eq_(p1.children.issubset(other),
control.issubset(other))
- self.assertEqual(p1.children.issuperset(other),
+ eq_(p1.children.issuperset(other),
control.issuperset(other))
self.assert_((p1.children == other) == (control == other))
class LazyLoadTest(TestBase):
- def setUp(self):
+ def setup(self):
metadata = MetaData(testing.db)
parents_table = Table('Parent', metadata,
self.Parent, self.Child = Parent, Child
self.table = parents_table
- def tearDown(self):
+ def teardown(self):
self.metadata.drop_all()
def roundtrip(self, obj):
class ReconstitutionTest(TestBase):
- def setUp(self):
+ def setup(self):
metadata = MetaData(testing.db)
parents = Table('parents', metadata,
Column('id', Integer, primary_key=True,
self.metadata = metadata
self.Parent = Parent
- def tearDown(self):
+ def teardown(self):
self.metadata.drop_all()
def test_weak_identity_map(self):
assert set(p_copy.kids) == set(['c1', 'c2']), p.kids
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.sql.expression import ClauseElement, ColumnClause
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import table, column
-from testlib import *
+from sqlalchemy.test import *
class UserDefinedTest(TestBase, AssertsCompiledSQL):
"DROP THINGY",
)
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
from sqlalchemy.ext import declarative as decl
from sqlalchemy import exc
-from testlib import sa, testing
-from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index
-from testlib.sa.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred
-from testlib.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import MetaData, Integer, String, ForeignKey, ForeignKeyConstraint, asc, Index
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref, clear_mappers, polymorphic_union, deferred
+from sqlalchemy.test.testing import eq_
-from orm._base import ComparableEntity, MappedTest
+from test.orm._base import ComparableEntity, MappedTest
class DeclarativeTestBase(testing.TestBase, testing.AssertsExecutionResults):
- def setUp(self):
+ def setup(self):
global Base
Base = decl.declarative_base(testing.db)
- def tearDown(self):
+ def teardown(self):
clear_mappers()
Base.metadata.drop_all()
def go():
class User(Base):
id = Column('id', Integer, primary_key=True)
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "does not have a __table__", go)
+ assert_raises_message(sa.exc.InvalidRequestError, "does not have a __table__", go)
def test_cant_add_columns(self):
t = Table('t', Base.metadata, Column('id', Integer, primary_key=True), Column('data', String))
__table__ = t
foo = Column(Integer, primary_key=True)
# can't specify new columns not already in the table
- self.assertRaisesMessage(sa.exc.ArgumentError, "Can't add additional column 'foo' when specifying __table__", go)
+ assert_raises_message(sa.exc.ArgumentError, "Can't add additional column 'foo' when specifying __table__", go)
# regular re-mapping works tho
class Bar(Base):
sess.add(u1)
sess.flush()
sess.expunge_all()
- self.assertEquals(sess.query(User).filter(User.name == 'ed').one(),
+ eq_(sess.query(User).filter(User.name == 'ed').one(),
User(name='ed', addresses=[Address(email='xyz'), Address(email='def'), Address(email='abc')])
)
__tablename__ = 'foo'
id = Column(Integer, primary_key=True)
rel = relation("User", primaryjoin="User.addresses==Foo.id")
- self.assertRaisesMessage(exc.InvalidRequestError, "'addresses' is not an instance of ColumnProperty", compile_mappers)
+ assert_raises_message(exc.InvalidRequestError, "'addresses' is not an instance of ColumnProperty", compile_mappers)
def test_string_dependency_resolution_in_backref(self):
class User(Base, ComparableEntity):
sess.add(u1)
sess.flush()
sess.expunge_all()
- self.assertEquals(sess.query(User).filter(User.name == 'ed').one(),
+ eq_(sess.query(User).filter(User.name == 'ed').one(),
User(name='ed', addresses=[Address(email='abc'), Address(email='def'), Address(email='xyz')])
)
# this used to raise an error when accessing User.id but that's no longer the case
# since we got rid of _CompileOnAttr.
- self.assertRaises(sa.exc.ArgumentError, compile_mappers)
+ assert_raises(sa.exc.ArgumentError, compile_mappers)
def test_nice_dependency_error_works_with_hasattr(self):
class User(Base):
# hasattr() on a compile-loaded attribute
hasattr(User.addresses, 'property')
# the exeption is preserved
- self.assertRaisesMessage(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers)
+ assert_raises_message(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers)
def test_custom_base(self):
class MyBase(object):
i = Index('my_index', User.name)
# compile fails due to the nonexistent Addresses relation
- self.assertRaises(sa.exc.InvalidRequestError, compile_mappers)
+ assert_raises(sa.exc.InvalidRequestError, compile_mappers)
# index configured
assert i in User.__table__.indexes
id = Column('id', Integer, primary_key=True),
name = Column('name', String(50))
assert False
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Mapper Mapper|User|users could not assemble any primary key",
define)
__mapper_args__ = {'polymorphic_identity':'engineer'}
primary_language = Column('primary_language', String(50))
foo_bar = Column(Integer, primary_key=True)
- self.assertRaisesMessage(sa.exc.ArgumentError, "place primary key", go)
+ assert_raises_message(sa.exc.ArgumentError, "place primary key", go)
def test_single_no_table_args(self):
class Person(Base, ComparableEntity):
__mapper_args__ = {'polymorphic_identity':'engineer'}
primary_language = Column('primary_language', String(50))
__table_args__ = ()
- self.assertRaisesMessage(sa.exc.ArgumentError, "place __table_args__", go)
+ assert_raises_message(sa.exc.ArgumentError, "place __table_args__", go)
def test_concrete(self):
engineers = Table('engineers', Base.metadata,
)
-def produce_test(inline, stringbased):
+def _produce_test(inline, stringbased):
class ExplicitJoinTest(MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global User, Address
Base = decl.declarative_base(metadata=metadata)
else:
Address.user = relation(User, primaryjoin=User.id==Address.user_id, backref="addresses")
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
params = [dict(zip(('id', 'name'), column_values)) for column_values in
[(7, 'jack'),
(8, 'ed'),
for inline in (True, False):
for stringbased in (True, False):
- testclass = produce_test(inline, stringbased)
+ testclass = _produce_test(inline, stringbased)
exec("%s = testclass" % testclass.__name__)
del testclass
class DeclarativeReflectionTest(testing.TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global reflection_metadata
reflection_metadata = MetaData(testing.db)
reflection_metadata.create_all()
- def setUp(self):
+ def setup(self):
global Base
Base = decl.declarative_base(testing.db)
- def tearDown(self):
+ def teardown(self):
for t in reversed(reflection_metadata.sorted_tables):
t.delete().execute()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
reflection_metadata.drop_all()
def test_basic(self):
eq_(a1, Address(email='two'))
eq_(a1.user, User(nom='u1'))
- self.assertRaises(TypeError, User, name='u3')
+ assert_raises(TypeError, User, name='u3')
def test_supplied_fk(self):
meta = MetaData(testing.db)
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.orderinglist import *
-from testlib.testing import eq_
-from testlib import *
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test import *
metadata = None
return s
class OrderingListTest(TestBase):
- def setUp(self):
+ def setup(self):
global metadata, slides_table, bullets_table, Slide, Bullet
slides_table, bullets_table = None, None
Slide, Bullet = None, None
metadata.create_all()
- def tearDown(self):
+ def teardown(self):
metadata.drop_all()
def test_append_no_reorder(self):
self.assert_(alpha[li].position == pos)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy.ext import serializer
from sqlalchemy import exc
-from testlib import sa, testing
-from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, select, desc, func, util
-from testlib.sa.orm import relation, sessionmaker, scoped_session, class_mapper, mapper, eagerload, compile_mappers, aliased
-from testlib.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import MetaData, Integer, String, ForeignKey, select, desc, func, util
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import relation, sessionmaker, scoped_session, class_mapper, mapper, eagerload, compile_mappers, aliased
+from sqlalchemy.test.testing import eq_
-from orm._base import ComparableEntity, MappedTest
+from test.orm._base import ComparableEntity, MappedTest
class User(ComparableEntity):
run_inserts = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global users, addresses
users = Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('user_id', Integer, ForeignKey('users.id')),
)
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
global Session
Session = scoped_session(sessionmaker())
compile_mappers()
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
params = [dict(zip(('id', 'name'), column_values)) for column_values in
[(7, 'jack'),
(8, 'ed'),
import inspect
import sys
import types
-from testlib import config, sa, testing
-from testlib.testing import resolve_artifact_names, adict
-from testlib.compat import _function_named
+import sqlalchemy as sa
+from sqlalchemy.test import config, testing
+from sqlalchemy.test.testing import resolve_artifact_names, adict
+from sqlalchemy.util import function_named
_repr_stack = set()
class ORMTest(testing.TestBase, testing.AssertsExecutionResults):
__requires__ = ('subqueries',)
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
sa.orm.session.Session.close_all()
sa.orm.clear_mappers()
# TODO: ensure mapper registry is empty
classes = None
other_artifacts = None
- def setUpAll(self):
- if self.run_setup_classes == 'each':
- assert self.run_setup_mappers != 'once'
+ @classmethod
+ def setup_class(cls):
+ if cls.run_setup_classes == 'each':
+ assert cls.run_setup_mappers != 'once'
- assert self.run_deletes in (None, 'each')
- if self.run_inserts == 'once':
- assert self.run_deletes is None
+ assert cls.run_deletes in (None, 'each')
+ if cls.run_inserts == 'once':
+ assert cls.run_deletes is None
- assert not hasattr(self, 'keep_mappers')
- assert not hasattr(self, 'keep_data')
+ assert not hasattr(cls, 'keep_mappers')
+ assert not hasattr(cls, 'keep_data')
- cls = self.__class__
if cls.tables is None:
cls.tables = adict()
if cls.classes is None:
if cls.other_artifacts is None:
cls.other_artifacts = adict()
- if self.metadata is None:
- setattr(type(self), 'metadata', sa.MetaData())
+ if cls.metadata is None:
+ setattr(cls, 'metadata', sa.MetaData())
- if self.metadata.bind is None:
- self.metadata.bind = getattr(self, 'engine', config.db)
+ if cls.metadata.bind is None:
+ cls.metadata.bind = getattr(cls, 'engine', config.db)
- if self.run_define_tables:
- self.define_tables(self.metadata)
- self.metadata.create_all()
- self.tables.update(self.metadata.tables)
+ if cls.run_define_tables == 'once':
+ cls.define_tables(cls.metadata)
+ cls.metadata.create_all()
+ cls.tables.update(cls.metadata.tables)
- if self.run_setup_classes:
+ if cls.run_setup_classes == 'once':
baseline = subclasses(BasicEntity)
- self.setup_classes()
- self._register_new_class_artifacts(baseline)
+ cls.setup_classes()
+ cls._register_new_class_artifacts(baseline)
- if self.run_setup_mappers:
+ if cls.run_setup_mappers == 'once':
baseline = subclasses(BasicEntity)
- self.setup_mappers()
- self._register_new_class_artifacts(baseline)
+ cls.setup_mappers()
+ cls._register_new_class_artifacts(baseline)
- if self.run_inserts:
- self._load_fixtures()
- self.insert_data()
-
- def setUp(self):
- if self._sa_first_test:
- return
+ if cls.run_inserts == 'once':
+ cls._load_fixtures()
+ cls.insert_data()
+ def setup(self):
if self.run_define_tables == 'each':
self.tables.clear()
self.metadata.drop_all()
self._load_fixtures()
self.insert_data()
- def tearDown(self):
+ def teardown(self):
sa.orm.session.Session.close_all()
# some tests create mappers in the test bodies
print >> sys.stderr, "Error emptying table %s: %r" % (
table, ex)
- def tearDownAll(self):
- for cls in self.classes.values():
- self.unregister_class(cls)
- ORMTest.tearDownAll(self)
- self.metadata.drop_all()
- self.metadata.bind = None
+ @classmethod
+ def teardown_class(cls):
+ for cl in cls.classes.values():
+ cls.unregister_class(cl)
+ ORMTest.teardown_class()
+ cls.metadata.drop_all()
+ cls.metadata.bind = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
raise NotImplementedError()
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
pass
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
pass
- def fixtures(self):
+ @classmethod
+ def fixtures(cls):
return {}
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
pass
def sql_count_(self, count, fn):
if name[0].isupper:
delattr(cls, name)
del cls.classes[name]
-
- def _load_fixtures(self):
+
+ @classmethod
+ def _load_fixtures(cls):
headers, rows = {}, {}
- for table, data in self.fixtures().iteritems():
+ for table, data in cls.fixtures().iteritems():
if isinstance(table, basestring):
- table = self.tables[table]
+ table = cls.tables[table]
headers[table] = data[0]
rows[table] = data[1:]
- for table in self.metadata.sorted_tables:
+ for table in cls.metadata.sorted_tables:
if table not in headers:
continue
table.bind.execute(
-from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import attributes
-from testlib.testing import fixture
-from orm import _base
+from sqlalchemy import MetaData, Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import attributes
+from sqlalchemy.test.testing import fixture
+from test.orm import _base
__all__ = ()
Address=Address,
Dingaling=Dingaling)
- def setUpAll(self):
- assert not hasattr(self, 'refresh_data')
- assert not hasattr(self, 'only_tables')
- #refresh_data = False
- #only_tables = False
-
- #if type(self) is not FixtureTest:
- # setattr(type(self), 'classes', _base.adict(self.classes))
-
- #if self.run_setup_classes:
- # for cls in self.classes.values():
- # self.register_class(cls)
- super(FixtureTest, self).setUpAll()
-
- #if not self.only_tables and self.keep_data:
- # _registry.load()
-
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
pass
- def setup_classes(self):
- for cls in self.fixture_classes.values():
- self.register_class(cls)
+ @classmethod
+ def setup_classes(cls):
+ for cl in cls.fixture_classes.values():
+ cls.register_class(cl)
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
pass
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
_load_fixtures()
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-import inheritance.alltests as inheritance
-import sharding.alltests as sharding
-
-def suite():
- modules_to_test = (
- 'orm.attributes',
- 'orm.bind',
- 'orm.extendedattr',
- 'orm.instrumentation',
- 'orm.query',
- 'orm.lazy_relations',
- 'orm.eager_relations',
- 'orm.mapper',
- 'orm.expire',
- 'orm.selectable',
- 'orm.collection',
- 'orm.generative',
- 'orm.lazytest1',
- 'orm.assorted_eager',
-
- 'orm.naturalpks',
- 'orm.defaults',
- 'orm.unitofwork',
- 'orm.session',
- 'orm.transaction',
- 'orm.scoping',
- 'orm.cascade',
- 'orm.relationships',
- 'orm.association',
- 'orm.merge',
- 'orm.pickled',
- 'orm.utils',
-
- 'orm.cycles',
-
- 'orm.compile',
- 'orm.manytomany',
- 'orm.onetoone',
- 'orm.dynamic',
-
- 'orm.evaluator',
-
- 'orm.deprecations',
- )
- alltests = unittest.TestSuite()
- for name in modules_to_test:
- mod = __import__(name)
- for token in name.split('.')[1:]:
- mod = getattr(mod, token)
- alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
- alltests.addTest(inheritance.suite())
- alltests.addTest(sharding.suite())
- return alltests
-
-
-if __name__ == '__main__':
- testenv.main(suite())
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-def suite():
- modules_to_test = (
- 'orm.inheritance.basic',
- 'orm.inheritance.query',
- 'orm.inheritance.manytomany',
- 'orm.inheritance.single',
- 'orm.inheritance.concrete',
- 'orm.inheritance.polymorph',
- 'orm.inheritance.polymorph2',
- 'orm.inheritance.poly_linked_list',
- 'orm.inheritance.abc_polymorphic',
- 'orm.inheritance.abc_inheritance',
- 'orm.inheritance.productspec',
- 'orm.inheritance.magazine',
- 'orm.inheritance.selects',
-
- )
- alltests = unittest.TestSuite()
- for name in modules_to_test:
- mod = __import__(name)
- for token in name.split('.')[1:]:
- mod = getattr(mod, token)
- alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
- return alltests
-
-
-if __name__ == '__main__':
- testenv.main(suite())
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
-from testlib import testing
-from orm import _base
+from sqlalchemy.test import testing
+from test.orm import _base
def produce_test(parent, child, direction):
relationship between two of the classes, using either one-to-many or
many-to-one."""
class ABCTest(_base.MappedTest):
- def define_tables(self, meta):
+ @classmethod
+ def define_tables(cls, metadata):
global ta, tb, tc
- ta = ["a", meta]
+ ta = ["a", metadata]
ta.append(Column('id', Integer, primary_key=True)),
ta.append(Column('a_data', String(30)))
if "a"== parent and direction == MANYTOONE:
ta.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo")))
ta = Table(*ta)
- tb = ["b", meta]
+ tb = ["b", metadata]
tb.append(Column('id', Integer, ForeignKey("a.id"), primary_key=True, ))
tb.append(Column('b_data', String(30)))
tb.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo")))
tb = Table(*tb)
- tc = ["c", meta]
+ tc = ["c", metadata]
tc.append(Column('id', Integer, ForeignKey("b.id"), primary_key=True, ))
tc.append(Column('c_data', String(30)))
tc.append(Column('parent_id', Integer, ForeignKey("%s.id" % parent, use_alter=True, name="foo")))
tc = Table(*tc)
- def tearDown(self):
+ def teardown(self):
if direction == MANYTOONE:
parent_table = {"a":ta, "b":tb, "c": tc}[parent]
parent_table.update(values={parent_table.c.child_id:None}).execute()
elif direction == ONETOMANY:
child_table = {"a":ta, "b":tb, "c": tc}[child]
child_table.update(values={child_table.c.parent_id:None}).execute()
- super(ABCTest, self).tearDown()
+ super(ABCTest, self).teardown()
def test_roundtrip(self):
parent_table = {"a":ta, "b":tb, "c": tc}[parent]
exec("%s = testclass" % testclass.__name__)
del testclass
-if __name__ == "__main__":
- testenv.main()
+del produce_test
\ No newline at end of file
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy import util
from sqlalchemy.orm import *
-from testlib import _function_named
-from orm import _base, _fixtures
+from sqlalchemy.util import function_named
+from test.orm import _base, _fixtures
class ABCTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global a, b, c
a = Table('a', metadata,
Column('id', Integer, primary_key=True),
C(cdata='c2', bdata='c2', adata='c2'),
] == sess.query(C).all()
- test_roundtrip = _function_named(
+ test_roundtrip = function_named(
test_roundtrip, 'test_%s' % fetchtype)
return test_roundtrip
test_none = make_test('none')
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
from sqlalchemy import *
from sqlalchemy import exc as sa_exc, util
from sqlalchemy.orm import *
from sqlalchemy.orm import exc as orm_exc
-#from testlib import *
-#from testlib import fixtures
-from testlib import _function_named, testing, engines
-from orm import _base, _fixtures
+from sqlalchemy.test import testing, engines
+from sqlalchemy.util import function_named
+from test.orm import _base, _fixtures
class O2MTest(_base.MappedTest):
"""deals with inheritance and one-to-many relationships"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global foo, bar, blub
foo = Table('foo', metadata,
Column('id', Integer, Sequence('foo_seq', optional=True),
self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
class FalseDiscriminatorTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global t1
t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), Column('type', Integer, nullable=False))
assert isinstance(sess.query(Foo).one(), Bar)
class PolymorphicSynonymTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global t1, t2
t1 = Table('t1', metadata,
Column('id', Integer, primary_key=True),
sess.add(at2)
sess.flush()
sess.expunge_all()
- self.assertEquals(sess.query(T2).filter(T2.info=='at2').one(), at2)
- self.assertEquals(at2.info, "THE INFO IS:at2")
+ eq_(sess.query(T2).filter(T2.info=='at2').one(), at2)
+ eq_(at2.info, "THE INFO IS:at2")
class CascadeTest(_base.MappedTest):
cascading along the path of the instance's mapper, not
the base mapper."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global t1, t2, t3, t4
t1= Table('t1', metadata,
Column('id', Integer, primary_key=True),
sess.flush()
class GetTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global foo, bar, blub
foo = Table('foo', metadata,
Column('id', Integer, Sequence('foo_seq', optional=True),
Column('bar_id', Integer, ForeignKey('bar.id')),
Column('data', String(20)))
- def create_test(polymorphic, name):
+ def _create_test(polymorphic, name):
def test_get(self):
class Foo(object):
pass
self.assert_sql_count(testing.db, go, 3)
- test_get = _function_named(test_get, name)
+ test_get = function_named(test_get, name)
return test_get
- test_get_polymorphic = create_test(True, 'test_get_polymorphic')
- test_get_nonpolymorphic = create_test(False, 'test_get_nonpolymorphic')
+ test_get_polymorphic = _create_test(True, 'test_get_polymorphic')
+ test_get_nonpolymorphic = _create_test(False, 'test_get_nonpolymorphic')
class EagerLazyTest(_base.MappedTest):
"""tests eager load/lazy load of child items off inheritance mappers, tests that
LazyLoader constructs the right query condition."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global foo, bar, bar_foo
foo = Table('foo', metadata,
Column('id', Integer, Sequence('foo_seq', optional=True),
class FlushTest(_base.MappedTest):
"""test dependency sorting among inheriting mappers"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global users, roles, user_roles, admins
users = Table('users', metadata,
Column('id', Integer, primary_key=True),
assert user_roles.count().scalar() == 1
class VersioningTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global base, subtable, stuff
base = Table('base', metadata,
Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ),
run_inserts = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global person_table, employee_table, Person, Employee
person_table = Table("persons", metadata,
class Employee(Person): pass
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
person_insert = person_table.insert()
person_insert.execute(id=1, name='alice')
person_insert.execute(id=2, name='bob')
class SyncCompileTest(_base.MappedTest):
"""test that syncrules compile properly on custom inherit conds"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global _a_table, _b_table, _c_table
_a_table = Table('a', metadata,
class OverrideColKeyTest(_base.MappedTest):
"""test overriding of column attributes."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global base, subtable
base = Table('base', metadata,
# Sub gets a "base_id" property using the "base_id"
# column of both tables.
- self.assertEquals(
+ eq_(
class_mapper(Sub).get_property('base_id').columns,
[base.c.base_id, subtable.c.base_id]
)
'id':[base.c.base_id, subtable.c.base_id]
})
- self.assertEquals(
+ eq_(
class_mapper(Sub).get_property('id').columns,
[base.c.base_id, subtable.c.base_id]
)
})
mapper(Sub, subtable, inherits=Base)
- self.assertEquals(
+ eq_(
class_mapper(Sub).get_property('id').columns,
[base.c.base_id]
)
- self.assertEquals(
+ eq_(
class_mapper(Sub).get_property('base_id').columns,
[subtable.c.base_id]
)
# it has its own "id" property. Sub's "id" property
# gets joined normally with the extra column.
- self.assertEquals(
+ eq_(
class_mapper(Sub).get_property('id').columns,
[base.c.base_id, subtable.c.base_id]
)
a column in the join condition is not available.
"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global base, sub
base = Table('base', metadata,
Column('id', Integer, primary_key=True),
assert s1.sub == 's1sub'
class PKDiscriminatorTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
parents = Table('parents', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(60)))
class DeleteOrphanTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global single, parent
single = Table('single', metadata,
Column('id', Integer, primary_key=True),
sess = create_session()
s1 = SubClass(data='s1')
sess.add(s1)
- self.assertRaisesMessage(orm_exc.FlushError,
+ assert_raises_message(orm_exc.FlushError,
"is not attached to any parent 'Parent' instance via that classes' 'related' attribute", sess.flush)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.orm import exc as orm_exc
-from testlib import *
-from testlib import sa, testing
-from orm import _base
+from sqlalchemy.test import *
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from test.orm import _base
from sqlalchemy.orm import attributes
-from testlib.testing import eq_
+from sqlalchemy.test.testing import eq_
class Employee(object):
def __init__(self, name):
class ConcreteTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global managers_table, engineers_table, hackers_table, companies, employees_table
companies = Table('companies', metadata,
manager = session.query(Manager).one()
session.expire(manager, ['manager_data'])
- self.assertEquals(manager.manager_data, "knows how to manage things")
+ eq_(manager.manager_data, "knows how to manage things")
def test_multi_level_no_base(self):
pjoin = polymorphic_union({
assert 'name' not in attributes.instance_state(hacker).expired_attributes
assert 'nickname' not in attributes.instance_state(hacker).expired_attributes
def go():
- self.assertEquals(jerry.name, "Jerry")
- self.assertEquals(hacker.nickname, "Badass")
+ eq_(jerry.name, "Jerry")
+ eq_(hacker.nickname, "Badass")
self.assert_sql_count(testing.db, go, 0)
session.expunge_all()
session.flush()
def go():
- self.assertEquals(jerry.name, "Jerry")
- self.assertEquals(hacker.nickname, "Badass")
+ eq_(jerry.name, "Jerry")
+ eq_(hacker.nickname, "Badass")
self.assert_sql_count(testing.db, go, 0)
session.expunge_all()
self.assert_sql_count(testing.db, go, 1)
class PropertyInheritanceTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('a_table', metadata,
Column('id', Integer, primary_key=True),
Column('some_c_id', Integer, ForeignKey('c_table.id')),
)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class A(_base.ComparableEntity):
pass
b = B()
c = C()
- self.assertRaises(AttributeError, setattr, b, 'some_c', c)
+ assert_raises(AttributeError, setattr, b, 'some_c', c)
clear_mappers()
mapper(A, a_table, properties={
mapper(B, b_table,inherits=A, concrete=True)
mapper(C, c_table)
b = B()
- self.assertRaises(AttributeError, setattr, b, 'a_id', 3)
+ assert_raises(AttributeError, setattr, b, 'a_id', 3)
clear_mappers()
mapper(A, a_table, properties={
b1 = B(some_c=c1, bname='b1')
b2 = B(some_c=c1, bname='b2')
- self.assertRaises(AttributeError, setattr, b1, 'aname', 'foo')
- self.assertRaises(AttributeError, getattr, A, 'bname')
+ assert_raises(AttributeError, setattr, b1, 'aname', 'foo')
+ assert_raises(AttributeError, getattr, A, 'bname')
assert c2.many_a == [a2]
assert c1.many_a == [a1]
class ColKeysTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global offices_table, refugees_table
refugees_table = Table('refugee', metadata,
Column('refugee_fid', Integer, primary_key=True),
Column('office_fid', Integer, primary_key=True),
Column('office_name', Unicode(30), key='name'))
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
refugees_table.insert().execute(
dict(refugee_fid=1, name=u"refugee1"),
dict(refugee_fid=2, name=u"refugee2")
eq_(sess.query(Office).get(1).name, "office1")
eq_(sess.query(Office).get(2).name, "office2")
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import testing, _function_named
-from orm import _base
+from sqlalchemy.test import testing
+from sqlalchemy.util import function_named
+from test.orm import _base
class BaseObject(object):
def __init__(self, *args, **kwargs):
class MagazineTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global publication_table, issue_table, location_table, location_name_table, magazine_table, \
page_table, magazine_page_table, classified_page_table, page_size_table
print [page, page2, page3]
assert repr(p.issues[0].locations[0].magazine.pages) == repr([page, page2, page3]), repr(p.issues[0].locations[0].magazine.pages)
- test_roundtrip = _function_named(
+ test_roundtrip = function_named(
test_roundtrip, "test_%s" % (not use_union and (use_joins and "joins" or "select") or "unions"))
setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip)
generate_round_trip_test(use_union, use_join)
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import testing
-from orm import _base
+from sqlalchemy.test import testing
+from test.orm import _base
class InheritTest(_base.MappedTest):
"""deals with inheritance and many-to-many relationships"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global principals
global users
global groups
class InheritTest2(_base.MappedTest):
"""deals with inheritance and many-to-many relationships"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global foo, bar, foo_bar
foo = Table('foo', metadata,
Column('id', Integer, Sequence('foo_id_seq', optional=True),
class InheritTest3(_base.MappedTest):
"""deals with inheritance and many-to-many relationships"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global foo, bar, blub, bar_foo, blub_bar, blub_foo
# the 'data' columns are to appease SQLite which cant handle a blank INSERT
l = sess.query(Bar).all()
print repr(l[0]) + repr(l[0].foos)
found = repr(l[0]) + repr(sorted([repr(o) for o in l[0].foos]))
- self.assertEqual(found, compare)
+ eq_(found, compare)
@testing.fails_on('maxdb', 'FIXME: unknown')
def testadvanced(self):
self.assert_(repr(x) == compare)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from orm import _base
-from testlib import testing
+from test.orm import _base
+from sqlalchemy.test import testing
class PolymorphicCircularTest(_base.MappedTest):
run_setup_mappers = 'once'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global Table1, Table1B, Table2, Table3, Data
table1 = Table('table1', metadata,
Column('id', Integer, primary_key=True),
@testing.fails_on('maxdb', 'FIXME: unknown')
def testone(self):
- self.do_testlist([Table1, Table2, Table1, Table2])
+ self._testlist([Table1, Table2, Table1, Table2])
@testing.fails_on('maxdb', 'FIXME: unknown')
def testtwo(self):
- self.do_testlist([Table3])
+ self._testlist([Table3])
@testing.fails_on('maxdb', 'FIXME: unknown')
def testthree(self):
- self.do_testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1])
+ self._testlist([Table2, Table1, Table1B, Table3, Table3, Table1B, Table1B, Table2, Table1])
@testing.fails_on('maxdb', 'FIXME: unknown')
def testfour(self):
- self.do_testlist([
+ self._testlist([
Table2('t2', [Data('data1'), Data('data2')]),
Table1('t1', []),
Table3('t3', [Data('data3')]),
Table1B('t1b', [Data('data4'), Data('data5')])
])
- def do_testlist(self, classes):
+ def _testlist(self, classes):
sess = create_session( )
# create objects in a linked list
# everything should match !
assert original == forwards == backwards
-if __name__ == '__main__':
- testenv.main()
"""tests basic polymorphic mapper loading/saving, minimal relations"""
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.orm import exc as orm_exc
-from testlib import _function_named, Column, testing
-from orm import _fixtures, _base
+from sqlalchemy.test import Column, testing
+from sqlalchemy.util import function_named
+from test.orm import _fixtures, _base
class Person(_fixtures.Base):
pass
pass
class PolymorphTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global companies, people, engineers, managers, boss
companies = Table('companies', metadata,
session.add(c)
session.flush()
session.expunge_all()
- self.assertEquals(session.query(Company).get(c.company_id), c)
+ eq_(session.query(Company).get(c.company_id), c)
class RelationToSubclassTest(PolymorphTest):
def test_basic(self):
sess.flush()
sess.expunge_all()
- self.assertEquals(sess.query(Company).filter_by(company_id=c.company_id).one(), c)
+ eq_(sess.query(Company).filter_by(company_id=c.company_id).one(), c)
assert c.managers[0].company is c
class RoundTripTest(PolymorphTest):
pass
-def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic):
+def _generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic):
"""generates a round trip test.
include_base - whether or not to include the base 'person' type in the union.
session.flush()
session.expunge_all()
- self.assertEquals(session.query(Person).get(dilbert.person_id), dilbert)
+ eq_(session.query(Person).get(dilbert.person_id), dilbert)
session.expunge_all()
- self.assertEquals(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert)
+ eq_(session.query(Person).filter(Person.person_id==dilbert.person_id).one(), dilbert)
session.expunge_all()
def go():
cc = session.query(Company).get(c.company_id)
- self.assertEquals(cc.employees, employees)
+ eq_(cc.employees, employees)
if not lazy_relation:
if with_polymorphic != 'none':
# test selecting from the query, using the base mapped table (people) as the selection criterion.
# in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join"
- self.assertEquals(
+ eq_(
session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first(),
dilbert
)
assert session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first().person_id
- self.assertEquals(
+ eq_(
session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first(),
dilbert
)
# test standalone orphans
daboss = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'})
session.add(daboss)
- self.assertRaises(orm_exc.FlushError, session.flush)
+ assert_raises(orm_exc.FlushError, session.flush)
c = session.query(Company).first()
daboss.company = c
manager_list = [e for e in c.employees if isinstance(e, Manager)]
session.flush()
session.expunge_all()
- self.assertEquals(session.query(Manager).order_by(Manager.person_id).all(), manager_list)
+ eq_(session.query(Manager).order_by(Manager.person_id).all(), manager_list)
c = session.query(Company).first()
session.delete(c)
session.flush()
- self.assertEquals(people.count().scalar(), 0)
+ eq_(people.count().scalar(), 0)
- test_roundtrip = _function_named(
+ test_roundtrip = function_named(
test_roundtrip, "test_%s%s%s_%s" % (
(lazy_relation and "lazy" or "eager"),
(include_base and "_inclbase" or ""),
for with_polymorphic in ['unions', 'joins', 'auto', 'none']:
if with_polymorphic == 'unions':
for include_base in [True, False]:
- generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic)
+ _generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic)
else:
- generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic)
+ _generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic)
-if __name__ == "__main__":
- testenv.main()
inheritance setups for which we maintain compatibility.
"""
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
from sqlalchemy import *
from sqlalchemy import util
from sqlalchemy.orm import *
-from testlib import _function_named, TestBase, AssertsExecutionResults, testing
-from orm import _base, _fixtures
-from testlib.testing import eq_
+from sqlalchemy.test import TestBase, AssertsExecutionResults, testing
+from sqlalchemy.util import function_named
+from test.orm import _base, _fixtures
+from sqlalchemy.test.testing import eq_
class AttrSettable(object):
def __init__(self, **kwargs):
class RelationTest1(_base.MappedTest):
"""test self-referential relationships on polymorphic mappers"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, managers
people = Table('people', metadata,
Column('manager_name', String(50))
)
- def tearDown(self):
+ def teardown(self):
people.update(values={people.c.manager_id:None}).execute()
- super(RelationTest1, self).tearDown()
+ super(RelationTest1, self).teardown()
- def testparentrefsdescendant(self):
+ def test_parent_refs_descendant(self):
class Person(AttrSettable):
pass
class Manager(Person):
mapper(Manager, managers, inherits=Person,
inherit_condition=people.c.person_id==managers.c.person_id)
- self.assertEquals(class_mapper(Person).get_property('manager').synchronize_pairs, [(managers.c.person_id,people.c.manager_id)])
+ eq_(class_mapper(Person).get_property('manager').synchronize_pairs, [(managers.c.person_id,people.c.manager_id)])
session = create_session()
p = Person(name='some person')
print p, m, p.manager
assert p.manager is m
- def testdescendantrefsparent(self):
+ def test_descendant_refs_parent(self):
class Person(AttrSettable):
pass
class Manager(Person):
class RelationTest2(_base.MappedTest):
"""test self-referential relationships on polymorphic mappers"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, managers, data
people = Table('people', metadata,
Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
class RelationTest3(_base.MappedTest):
"""test self-referential relationships on polymorphic mappers"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, managers, data
people = Table('people', metadata,
Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
Column('data', String(30))
)
-def generate_test(jointype="join1", usedata=False):
+def _generate_test(jointype="join1", usedata=False):
def do_test(self):
class Person(AttrSettable):
pass
assert p.data.data == 'ps data'
assert m.data.data == 'ms data'
- do_test = _function_named(
- do_test, 'test_relationonbaseclass_%s_%s' % (
+ do_test = function_named(
+ do_test, 'test_relation_on_base_class_%s_%s' % (
jointype, data and "nodata" or "data"))
return do_test
for jointype in ["join1", "join2", "join3", "join4"]:
for data in (True, False):
- func = generate_test(jointype, data)
+ func = _generate_test(jointype, data)
setattr(RelationTest3, func.__name__, func)
-
+del func
class RelationTest4(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, engineers, managers, cars
people = Table('people', metadata,
Column('person_id', Integer, primary_key=True),
assert c.car_id==car1.car_id
class RelationTest5(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, engineers, managers, cars
people = Table('people', metadata,
Column('person_id', Integer, primary_key=True),
class RelationTest6(_base.MappedTest):
"""test self-referential relationships on a single joined-table inheritance mapper"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, managers, data
people = Table('people', metadata,
Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
assert m.colleague is m2
class RelationTest7(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, engineers, managers, cars, offroad_cars
cars = Table('cars', metadata,
Column('car_id', Integer, primary_key=True),
assert p.car_id == p.car.car_id
class RelationTest8(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global taggable, users
taggable = Table('taggable', metadata,
Column('id', Integer, primary_key=True),
)
class GenerativeTest(TestBase, AssertsExecutionResults):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
# cars---owned by--- people (abstract) --- has a --- status
# | ^ ^ |
# | | | |
metadata.create_all()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
- def tearDown(self):
+ def teardown(self):
clear_mappers()
for t in reversed(metadata.sorted_tables):
t.delete().execute()
assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
class MultiLevelTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global table_Employee, table_Engineer, table_Manager
table_Employee = Table( 'Employee', metadata,
Column( 'name', type_= String(100), ),
assert session.query( Manager).all() == [c]
class ManyToManyPolyTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global base_item_table, item_table, base_item_collection_table, collection_table
base_item_table = Table(
'base_item', metadata,
class_mapper(BaseItem)
class CustomPKTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global t1, t2
t1 = Table('t1', metadata,
Column('id', Integer, primary_key=True),
sess.flush()
class InheritingEagerTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, employees, tags, peopleTags
people = Table('people', metadata,
assert len(instance.tags) == 2
class MissingPolymorphicOnTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global tablea, tableb, tablec, tabled
tablea = Table('tablea', metadata,
Column('id', Integer, primary_key=True),
sess.add(d)
sess.flush()
sess.expunge_all()
- self.assertEquals(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')])
+ eq_(sess.query(A).all(), [C(cdata='c1', adata='a1'), D(cdata='c2', adata='a2', ddata='d2')])
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
from datetime import datetime
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import testing
-from orm import _base
+from sqlalchemy.test import testing
+from test.orm import _base
class InheritTest(_base.MappedTest):
"""tests some various inheritance round trips involving a particular set of polymorphic inheritance relationships"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global products_table, specification_table, documents_table
global Product, Detail, Assembly, SpecLine, Document, RasterDocument
print new
assert orig == new == '<Assembly a1> specification=[<SpecLine 1.0 <Detail d1>>] documents=[<Document doc1>, <RasterDocument doc2>]'
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy import exc as sa_exc
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.engine import default
-from testlib import AssertsCompiledSQL, testing
-from orm import _base, _fixtures
-from testlib.testing import eq_
+from sqlalchemy.test import AssertsCompiledSQL, testing
+from test.orm import _base, _fixtures
+from sqlalchemy.test.testing import eq_
class Company(_fixtures.Base):
pass
class Paperwork(_fixtures.Base):
pass
-def make_test(select_type):
+def _produce_test(select_type):
class PolymorphicQueryTest(_base.MappedTest, AssertsCompiledSQL):
run_inserts = 'once'
run_setup_mappers = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global companies, people, engineers, managers, boss, paperwork, machines
companies = Table('companies', metadata,
mapper(Paperwork, paperwork)
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
global all_employees, c1_employees, c2_employees, e1, e2, b1, m1, e3, c1, c2
c1 = Company(name="MegaCorp, Inc.")
sess = create_session()
def go():
- self.assertEquals(sess.query(Person).all(), all_employees)
+ eq_(sess.query(Person).all(), all_employees)
self.assert_sql_count(testing.db, go, {'':14, 'Polymorphic':9}.get(select_type, 10))
def test_primary_eager_aliasing(self):
sess = create_session()
def go():
- self.assertEquals(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3])
+ eq_(sess.query(Person).options(eagerload(Engineer.machines))[1:3], all_employees[1:3])
self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4))
sess = create_session()
assert sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().subquery().count().scalar() == 2
def go():
- self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3])
+ eq_(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3], all_employees[1:3])
self.assert_sql_count(testing.db, go, 3)
# for all mappers, ensure the primary key has been calculated as just the "person_id"
# column
- self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
- self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
- self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore"))
+ eq_(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
+ eq_(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
+ eq_(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore"))
def test_multi_join(self):
sess = create_session()
q = sess.query(Company, Person, c, e).join((Person, Company.employees)).join((e, c.employees)).\
filter(Person.name=='dilbert').filter(e.name=='wally')
- self.assertEquals(q.count(), 1)
- self.assertEquals(q.all(), [
+ eq_(q.count(), 1)
+ eq_(q.all(), [
(
Company(company_id=1,name=u'MegaCorp, Inc.'),
Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'),
def test_filter_on_subclass(self):
sess = create_session()
- self.assertEquals(sess.query(Engineer).all()[0], Engineer(name="dilbert"))
+ eq_(sess.query(Engineer).all()[0], Engineer(name="dilbert"))
- self.assertEquals(sess.query(Engineer).first(), Engineer(name="dilbert"))
+ eq_(sess.query(Engineer).first(), Engineer(name="dilbert"))
- self.assertEquals(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert"))
+ eq_(sess.query(Engineer).filter(Engineer.person_id==e1.person_id).first(), Engineer(name="dilbert"))
- self.assertEquals(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert"))
+ eq_(sess.query(Manager).filter(Manager.person_id==m1.person_id).one(), Manager(name="dogbert"))
- self.assertEquals(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss"))
+ eq_(sess.query(Manager).filter(Manager.person_id==b1.person_id).one(), Boss(name="pointy haired boss"))
- self.assertEquals(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss"))
+ eq_(sess.query(Boss).filter(Boss.person_id==b1.person_id).one(), Boss(name="pointy haired boss"))
def test_join_from_polymorphic(self):
sess = create_session()
for aliased in (True, False):
- self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
+ eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
- self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
+ eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
- self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1])
+ eq_(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1])
- self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+ eq_(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
def test_join_from_with_polymorphic(self):
sess = create_session()
for aliased in (True, False):
sess.expunge_all()
- self.assertEquals(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
+ eq_(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
sess.expunge_all()
- self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
+ eq_(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
sess.expunge_all()
- self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+ eq_(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
def test_join_to_polymorphic(self):
sess = create_session()
- self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2)
+ eq_(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2)
- self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
+ eq_(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
def test_polymorphic_any(self):
sess = create_session()
- self.assertEquals(
+ eq_(
sess.query(Company).\
filter(Company.employees.any(Person.name=='vlad')).all(), [c2]
)
# test that the aliasing on "Person" does not bleed into the
# EXISTS clause generated by any()
- self.assertEquals(
+ eq_(
sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\
filter(Company.employees.any(Person.name=='wally')).all(), [c1]
)
- self.assertEquals(
+ eq_(
sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\
filter(Company.employees.any(Person.name=='vlad')).all(), []
)
- self.assertEquals(
+ eq_(
sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(),
c2
)
calias = aliased(Company)
- self.assertEquals(
+ eq_(
sess.query(calias).filter(calias.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(),
c2
)
- self.assertEquals(
+ eq_(
sess.query(Company).filter(Company.employees.of_type(Boss).any(Boss.golf_swing=='fore')).one(),
c1
)
- self.assertEquals(
+ eq_(
sess.query(Company).filter(Company.employees.of_type(Boss).any(Manager.manager_name=='pointy')).one(),
c1
)
if select_type != '':
- self.assertEquals(
+ eq_(
sess.query(Person).filter(Engineer.machines.any(Machine.name=="Commodore 64")).all(), [e2, e3]
)
- self.assertEquals(
+ eq_(
sess.query(Person).filter(Person.paperwork.any(Paperwork.description=="review #2")).all(), [m1]
)
- self.assertEquals(
+ eq_(
sess.query(Company).filter(Company.employees.of_type(Engineer).any(and_(Engineer.primary_language=='cobol'))).one(),
c2
)
def test_join_from_columns_or_subclass(self):
sess = create_session()
- self.assertEquals(
+ eq_(
sess.query(Manager.name).order_by(Manager.name).all(),
[(u'dogbert',), (u'pointy haired boss',)]
)
- self.assertEquals(
+ eq_(
sess.query(Manager.name).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(),
[(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)]
)
- self.assertEquals(
+ eq_(
sess.query(Person.name).join((Paperwork, Person.paperwork)).order_by(Person.name).all(),
[(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)]
)
- self.assertEquals(
+ eq_(
sess.query(Person.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Person.name).all(),
[(u'dilbert',), (u'dilbert',), (u'dogbert',), (u'dogbert',), (u'pointy haired boss',), (u'vlad',), (u'wally',), (u'wally',)]
)
- self.assertEquals(
+ eq_(
sess.query(Manager).join((Paperwork, Manager.paperwork)).order_by(Manager.name).all(),
[m1, b1]
)
- self.assertEquals(
+ eq_(
sess.query(Manager.name).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(),
[(u'dogbert',), (u'dogbert',), (u'pointy haired boss',)]
)
- self.assertEquals(
+ eq_(
sess.query(Manager.person_id).join((paperwork, Manager.person_id==paperwork.c.person_id)).order_by(Manager.name).all(),
[(4,), (4,), (3,)]
)
- self.assertEquals(
+ eq_(
sess.query(Manager.name, Paperwork.description).join((Paperwork, Manager.person_id==Paperwork.person_id)).all(),
[(u'pointy haired boss', u'review #1'), (u'dogbert', u'review #2'), (u'dogbert', u'review #3')]
)
malias = aliased(Manager)
- self.assertEquals(
+ eq_(
sess.query(malias.name).join((paperwork, malias.person_id==paperwork.c.person_id)).all(),
[(u'pointy haired boss',), (u'dogbert',), (u'dogbert',)]
)
sess = create_session()
- self.assertRaises(sa_exc.InvalidRequestError, sess.query(Person).with_polymorphic, Paperwork)
- self.assertRaises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Boss)
- self.assertRaises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Person)
+ assert_raises(sa_exc.InvalidRequestError, sess.query(Person).with_polymorphic, Paperwork)
+ assert_raises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Boss)
+ assert_raises(sa_exc.InvalidRequestError, sess.query(Engineer).with_polymorphic, Person)
# compare to entities without related collections to prevent additional lazy SQL from firing on
# loaded entities
Manager(name="dogbert", manager_name="dogbert", status="regular manager"),
Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer")
]
- self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
+ eq_(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
def go():
- self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1])
+ eq_(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1])
self.assert_sql_count(testing.db, go, 1)
sess.expunge_all()
def go():
- self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
+ eq_(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
self.assert_sql_count(testing.db, go, 1)
sess.expunge_all()
def go():
- self.assertEquals(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations)
+ eq_(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations)
self.assert_sql_count(testing.db, go, 3)
sess.expunge_all()
def go():
- self.assertEquals(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations)
+ eq_(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations)
self.assert_sql_count(testing.db, go, 3)
sess.expunge_all()
def go():
# limit the polymorphic join down to just "Person", overriding select_table
- self.assertEquals(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations)
+ eq_(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations)
self.assert_sql_count(testing.db, go, 6)
def test_relation_to_polymorphic(self):
def go():
# test load Companies with lazy load to 'employees'
- self.assertEquals(sess.query(Company).all(), assert_result)
+ eq_(sess.query(Company).all(), assert_result)
self.assert_sql_count(testing.db, go, {'':9, 'Polymorphic':4}.get(select_type, 5))
sess = create_session()
def go():
# currently, it doesn't matter if we say Company.employees, or Company.employees.of_type(Engineer). eagerloader doesn't
# pick up on the "of_type()" as of yet.
- self.assertEquals(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result)
+ eq_(sess.query(Company).options(eagerload_all([Company.employees.of_type(Engineer), Engineer.machines])).all(), assert_result)
# in the case of select_type='', the eagerload doesn't take in this case;
# it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines"
sess = create_session()
def go():
# test load People with eagerload to engineers + machines
- self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(),
+ eq_(sess.query(Person).with_polymorphic('*').options(eagerload([Engineer.machines])).filter(Person.name=='dilbert').all(),
[Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer", machines=[Machine(name="IBM ThinkPad"), Machine(name="IPhone")])]
)
self.assert_sql_count(testing.db, go, 1)
def test_join_to_subclass(self):
sess = create_session()
- self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
+ eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
if select_type == '':
- self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
- self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
+ eq_(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
+ eq_(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
ealias = aliased(Engineer)
- self.assertEquals(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1])
+ eq_(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1])
- self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3])
- self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
- self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2])
- self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
+ eq_(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3])
+ eq_(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
+ eq_(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2])
+ eq_(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
else:
- self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
- self.assertEquals(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1])
- self.assertEquals(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3])
- self.assertEquals(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
- self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2])
- self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
+ eq_(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
+ eq_(sess.query(Company).join(['employees']).filter(Engineer.primary_language=='java').all(), [c1])
+ eq_(sess.query(Person).join(Engineer.machines).all(), [e1, e2, e3])
+ eq_(sess.query(Person).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
+ eq_(sess.query(Company).join(['employees', Engineer.machines]).all(), [c1, c2])
+ eq_(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
# non-polymorphic
- self.assertEquals(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3])
- self.assertEquals(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
+ eq_(sess.query(Engineer).join(Engineer.machines).all(), [e1, e2, e3])
+ eq_(sess.query(Engineer).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
# here's the new way
- self.assertEquals(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1])
- self.assertEquals(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
+ eq_(sess.query(Company).join(Company.employees.of_type(Engineer)).filter(Engineer.primary_language=='java').all(), [c1])
+ eq_(sess.query(Company).join([Company.employees.of_type(Engineer), 'machines']).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
def test_join_through_polymorphic(self):
sess = create_session()
for aliased in (True, False):
- self.assertEquals(
+ eq_(
sess.query(Company).\
join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#2%')).all(),
[c1]
)
- self.assertEquals(
+ eq_(
sess.query(Company).\
join(['employees', 'paperwork'], aliased=aliased).filter(Paperwork.description.like('%#%')).all(),
[c1, c2]
)
- self.assertEquals(
+ eq_(
sess.query(Company).\
join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#2%')).all(),
[c1]
)
- self.assertEquals(
+ eq_(
sess.query(Company).\
join(['employees', 'paperwork'], aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).filter(Paperwork.description.like('%#%')).all(),
[c1, c2]
)
- self.assertEquals(
+ eq_(
sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\
join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#2%')).all(),
[c1]
)
- self.assertEquals(
+ eq_(
sess.query(Company).join('employees', aliased=aliased).filter(Person.name.in_(['dilbert', 'vlad'])).\
join('paperwork', from_joinpoint=True, aliased=aliased).filter(Paperwork.description.like('%#%')).all(),
[c1, c2]
# ORMJoin using regular table foreign key connections. Engineer
# is expressed as "(select * people join engineers) as anon_1"
# so the join is contained.
- self.assertEquals(
+ eq_(
sess.query(Company).join(Engineer).filter(Engineer.engineer_name=='vlad').one(),
c2
)
# same, using explicit join condition. Query.join() must adapt the on clause
# here to match the subquery wrapped around "people join engineers".
- self.assertEquals(
+ eq_(
sess.query(Company).join((Engineer, Company.company_id==Engineer.company_id)).filter(Engineer.engineer_name=='vlad').one(),
c2
)
def test_filter_on_baseclass(self):
sess = create_session()
- self.assertEquals(sess.query(Person).all(), all_employees)
+ eq_(sess.query(Person).all(), all_employees)
- self.assertEquals(sess.query(Person).first(), all_employees[0])
+ eq_(sess.query(Person).first(), all_employees[0])
- self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2)
+ eq_(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2)
def test_from_alias(self):
sess = create_session()
palias = aliased(Person)
- self.assertEquals(
+ eq_(
sess.query(palias).filter(palias.name.in_(['dilbert', 'wally'])).all(),
[e1, e2]
)
c1_employees = [e1, e2, b1, m1]
palias = aliased(Person)
- self.assertEquals(
+ eq_(
sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(),
[
]
)
- self.assertEquals(
+ eq_(
sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
filter(Person.person_id>palias.person_id).from_self().order_by(Person.person_id, palias.person_id).all(),
[
# the subquery and usually results in recursion overflow errors within the adaption.
subq = sess.query(engineers.c.person_id).filter(Engineer.primary_language=='java').statement.as_scalar()
- self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1)
+ eq_(sess.query(Person).filter(Person.person_id==subq).one(), e1)
def test_mixed_entities(self):
sess = create_session()
- self.assertEquals(
+ eq_(
sess.query(Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
[(u'Elbonia, Inc.',
Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'))]
)
- self.assertEquals(
+ eq_(
sess.query(Person, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
[(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
u'Elbonia, Inc.')]
)
- self.assertEquals(
+ eq_(
sess.query(Manager.name).all(),
[('pointy haired boss', ), ('dogbert',)]
)
- self.assertEquals(
+ eq_(
sess.query(Manager.name + " foo").all(),
[('pointy haired boss foo', ), ('dogbert foo',)]
)
assert row.primary_language == 'java'
- self.assertEquals(
+ eq_(
sess.query(Engineer.name, Engineer.primary_language).all(),
[(u'dilbert', u'java'), (u'wally', u'c++'), (u'vlad', u'cobol')]
)
- self.assertEquals(
+ eq_(
sess.query(Boss.name, Boss.golf_swing).all(),
[(u'pointy haired boss', u'fore')]
)
# []
# )
- self.assertEquals(
+ eq_(
sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
[(u'vlad',u'Elbonia, Inc.')]
)
- self.assertEquals(
+ eq_(
sess.query(Engineer.primary_language).filter(Person.type=='engineer').all(),
[(u'java',), (u'c++',), (u'cobol',)]
)
if select_type != '':
- self.assertEquals(
+ eq_(
sess.query(Engineer, Company.name).join(Company.employees).filter(Person.type=='engineer').all(),
[
(Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'MegaCorp, Inc.'),
]
)
- self.assertEquals(
+ eq_(
sess.query(Engineer.primary_language, Company.name).join(Company.employees).filter(Person.type=='engineer').order_by(desc(Engineer.primary_language)).all(),
[(u'java', u'MegaCorp, Inc.'), (u'cobol', u'Elbonia, Inc.'), (u'c++', u'MegaCorp, Inc.')]
)
palias = aliased(Person)
- self.assertEquals(
+ eq_(
sess.query(Person, Company.name, palias).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
[(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
u'Elbonia, Inc.',
Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'))]
)
- self.assertEquals(
+ eq_(
sess.query(palias, Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
[(Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'),
u'Elbonia, Inc.',
]
)
- self.assertEquals(
+ eq_(
sess.query(Person.name, Company.name, palias.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
[(u'vlad', u'Elbonia, Inc.', u'dilbert')]
)
palias = aliased(Person)
- self.assertEquals(
+ eq_(
sess.query(Person.type, Person.name, palias.type, palias.name).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(),
[(u'manager', u'dogbert', u'engineer', u'dilbert'),
(u'manager', u'dogbert', u'boss', u'pointy haired boss')]
)
- self.assertEquals(
+ eq_(
sess.query(Person.name, Paperwork.description).filter(Person.person_id==Paperwork.person_id).order_by(Person.name, Paperwork.description).all(),
[(u'dilbert', u'tps report #1'), (u'dilbert', u'tps report #2'), (u'dogbert', u'review #2'),
(u'dogbert', u'review #3'),
)
if select_type != '':
- self.assertEquals(
+ eq_(
sess.query(func.count(Person.person_id)).filter(Engineer.primary_language=='java').all(),
[(1, )]
)
- self.assertEquals(
+ eq_(
sess.query(Company.name, func.count(Person.person_id)).filter(Company.company_id==Person.company_id).group_by(Company.name).order_by(Company.name).all(),
[(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)]
)
- self.assertEquals(
+ eq_(
sess.query(Company.name, func.count(Person.person_id)).join(Company.employees).group_by(Company.name).order_by(Company.name).all(),
[(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)]
)
return PolymorphicQueryTest
for select_type in ('', 'Polymorphic', 'Unions', 'AliasedJoins', 'Joins'):
- testclass = make_test(select_type)
+ testclass = _produce_test(select_type)
exec("%s = testclass" % testclass.__name__)
del testclass
class SelfReferentialTestJoinedToBase(_base.MappedTest):
run_setup_mappers = 'once'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, engineers
people = Table('people', metadata,
Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
Column('reports_to_id', Integer, ForeignKey('people.person_id'))
)
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
mapper(Engineer, engineers, inherits=Person,
inherit_condition=engineers.c.person_id==people.c.person_id,
sess.flush()
sess.expunge_all()
- self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert'))
+ eq_(sess.query(Engineer).filter(Engineer.reports_to.has(Person.name=='dogbert')).first(), Engineer(name='dilbert'))
def test_oftype_aliases_in_exists(self):
e1 = Engineer(name='dilbert', primary_language='java')
sess.add_all([e1, e2])
sess.flush()
- self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.of_type(Engineer).has(Engineer.name=='dilbert')).first(), e2)
+ eq_(sess.query(Engineer).filter(Engineer.reports_to.of_type(Engineer).has(Engineer.name=='dilbert')).first(), e2)
def test_join(self):
p1 = Person(name='dogbert')
sess.flush()
sess.expunge_all()
- self.assertEquals(
+ eq_(
sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(),
Engineer(name='dilbert'))
class SelfReferentialJ2JTest(_base.MappedTest):
run_setup_mappers = 'once'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, engineers, managers
people = Table('people', metadata,
Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
)
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
mapper(Manager, managers, inherits=Person, polymorphic_identity='manager')
sess.flush()
sess.expunge_all()
- self.assertEquals(sess.query(Engineer).filter(Engineer.reports_to.has(Manager.name=='dogbert')).first(), Engineer(name='dilbert'))
+ eq_(sess.query(Engineer).filter(Engineer.reports_to.has(Manager.name=='dogbert')).first(), Engineer(name='dilbert'))
def test_join(self):
m1 = Manager(name='dogbert')
sess.flush()
sess.expunge_all()
- self.assertEquals(
+ eq_(
sess.query(Engineer).join('reports_to', aliased=True).filter(Manager.name=='dogbert').first(),
Engineer(name='dilbert'))
sess.expunge_all()
# filter aliasing applied to Engineer doesn't whack Manager
- self.assertEquals(
+ eq_(
sess.query(Manager).join(Manager.engineers).filter(Manager.name=='dogbert').all(),
[m1]
)
- self.assertEquals(
+ eq_(
sess.query(Manager).join(Manager.engineers).filter(Engineer.name=='dilbert').all(),
[m2]
)
- self.assertEquals(
+ eq_(
sess.query(Manager, Engineer).join(Manager.engineers).order_by(Manager.name.desc()).all(),
[
(m2, e2),
sess.flush()
sess.expunge_all()
- self.assertEquals(
+ eq_(
sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==None).all(),
[]
)
- self.assertEquals(
+ eq_(
sess.query(Manager).join(Manager.engineers).filter(Engineer.reports_to==m1).all(),
[m1]
)
run_inserts = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, engineers, organizations, engineers_to_org
organizations = Table('organizations', metadata,
Column('primary_language', String(50)),
)
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
global Organization
class Organization(_fixtures.Base):
pass
mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
e1 = Engineer(name='e1')
e2 = Engineer(name='e2')
e3 = Engineer(name='e3')
e1 = sess.query(Person).filter(Engineer.name=='e1').one()
# this works
- self.assertEquals(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')])
+ eq_(sess.query(Organization).filter(~Organization.engineers.of_type(Engineer).contains(e1)).all(), [Organization(name='org2')])
# this had a bug
- self.assertEquals(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')])
+ eq_(sess.query(Organization).filter(~Organization.engineers.contains(e1)).all(), [Organization(name='org2')])
def test_any(self):
sess = create_session()
- self.assertEquals(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')])
- self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')])
+ eq_(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')])
+ eq_(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')])
class SelfReferentialM2MTest(_base.MappedTest, AssertsCompiledSQL):
run_setup_mappers = 'once'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global Parent, Child1, Child2
Base = declarative_base(metadata=metadata)
assert q.first() is c1
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import testing
-from orm._fixtures import Base
-from orm._base import MappedTest
+from sqlalchemy.test import testing
+from test.orm._fixtures import Base
+from test.orm._base import MappedTest
class InheritingSelectablesTest(MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global foo, bar, baz
foo = Table('foo', metadata,
Column('a', String(30), primary_key=1),
assert [Baz(), Baz(), Bar(), Bar()] == s.query(Foo).all()
assert [Bar(), Bar()] == s.query(Bar).all()
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import testing
-from orm import _fixtures
-from orm._base import MappedTest, ComparableEntity
+from sqlalchemy.test import testing
+from test.orm import _fixtures
+from test.orm._base import MappedTest, ComparableEntity
class SingleInheritanceTest(MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('employees', metadata,
Column('employee_id', Integer, primary_key=True),
Column('name', String(50)),
Column('name', String(50)),
)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Employee(ComparableEntity):
pass
class Manager(Employee):
class JuniorEngineer(Engineer):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Employee, employees, polymorphic_on=employees.c.type)
mapper(Manager, inherits=Employee, polymorphic_identity='manager')
mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
m1 = session.query(Manager).one()
session.expire(m1, ['manager_data'])
- self.assertEquals(m1.manager_data, "knows how to manage things")
+ eq_(m1.manager_data, "knows how to manage things")
row = session.query(Engineer.name, Engineer.employee_id).filter(Engineer.name=='Kurt').first()
assert row.name == 'Kurt'
session.flush()
ealias = aliased(Engineer)
- self.assertEquals(
+ eq_(
session.query(Manager, ealias).all(),
[(m1, e1), (m1, e2)]
)
- self.assertEquals(
+ eq_(
session.query(Manager.name).all(),
[("Tom",)]
)
- self.assertEquals(
+ eq_(
session.query(Manager.name, ealias.name).all(),
[("Tom", "Kurt"), ("Tom", "Ed")]
)
- self.assertEquals(
+ eq_(
session.query(func.upper(Manager.name), func.upper(ealias.name)).all(),
[("TOM", "KURT"), ("TOM", "ED")]
)
- self.assertEquals(
+ eq_(
session.query(Manager).add_entity(ealias).all(),
[(m1, e1), (m1, e2)]
)
- self.assertEquals(
+ eq_(
session.query(Manager.name).add_column(ealias.name).all(),
[("Tom", "Kurt"), ("Tom", "Ed")]
)
sess.add_all([m1, m2, e1, e2])
sess.flush()
- self.assertEquals(
+ eq_(
sess.query(Manager).select_from(employees.select().limit(10)).all(),
[m1, m2]
)
sess.add_all([m1, m2, e1, e2])
sess.flush()
- self.assertEquals(sess.query(Manager).count(), 2)
- self.assertEquals(sess.query(Engineer).count(), 2)
- self.assertEquals(sess.query(Employee).count(), 4)
+ eq_(sess.query(Manager).count(), 2)
+ eq_(sess.query(Engineer).count(), 2)
+ eq_(sess.query(Employee).count(), 4)
- self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2)
- self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3)
+ eq_(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2)
+ eq_(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3)
@testing.resolve_artifact_names
def test_type_filtering(self):
class RelationToSingleTest(MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('employees', metadata,
Column('employee_id', Integer, primary_key=True),
Column('name', String(50)),
Column('name', String(50)),
)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Company(ComparableEntity):
pass
sess.add_all([c1, c2, m1, m2, e1, e2])
sess.commit()
sess.expunge_all()
- self.assertEquals(
+ eq_(
sess.query(Company).filter(Company.employees.of_type(JuniorEngineer).any()).all(),
[
Company(name='c1'),
]
)
- self.assertEquals(
+ eq_(
sess.query(Company).join(Company.employees.of_type(JuniorEngineer)).all(),
[
Company(name='c1'),
sess.add_all([c1, c2, m1, m2, e1, e2])
sess.commit()
- self.assertEquals(c1.engineers, [e2])
- self.assertEquals(c2.engineers, [e1])
+ eq_(c1.engineers, [e2])
+ eq_(c2.engineers, [e1])
sess.expunge_all()
- self.assertEquals(sess.query(Company).order_by(Company.name).all(),
+ eq_(sess.query(Company).order_by(Company.name).all(),
[
Company(name='c1', engineers=[JuniorEngineer(name='Ed')]),
Company(name='c2', engineers=[Engineer(name='Kurt')])
# eager load join should limit to only "Engineer"
sess.expunge_all()
- self.assertEquals(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(),
+ eq_(sess.query(Company).options(eagerload('engineers')).order_by(Company.name).all(),
[
Company(name='c1', engineers=[JuniorEngineer(name='Ed')]),
Company(name='c2', engineers=[Engineer(name='Kurt')])
# join() to Company.engineers, Employee as the requested entity
sess.expunge_all()
- self.assertEquals(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(),
+ eq_(sess.query(Company, Employee).join(Company.engineers).order_by(Company.name).all(),
[
(Company(name='c1'), JuniorEngineer(name='Ed')),
(Company(name='c2'), Engineer(name='Kurt'))
# join() to Company.engineers, Engineer as the requested entity.
# this actually applies the IN criterion twice which is less than ideal.
sess.expunge_all()
- self.assertEquals(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(),
+ eq_(sess.query(Company, Engineer).join(Company.engineers).order_by(Company.name).all(),
[
(Company(name='c1'), JuniorEngineer(name='Ed')),
(Company(name='c2'), Engineer(name='Kurt'))
# join() to Company.engineers without any Employee/Engineer entity
sess.expunge_all()
- self.assertEquals(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(),
+ eq_(sess.query(Company).join(Company.engineers).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(),
[
Company(name='c2')
]
@testing.fails_on_everything_except()
def go():
sess.expunge_all()
- self.assertEquals(sess.query(Company).\
+ eq_(sess.query(Company).\
filter(Company.company_id==Engineer.company_id).filter(Engineer.name.in_(['Tom', 'Kurt'])).all(),
[
Company(name='c2')
go()
class SingleOnJoinedTest(MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global persons_table, employees_table
persons_table = Table('persons', metadata,
sess.flush()
sess.expunge_all()
- self.assertEquals(sess.query(Person).order_by(Person.person_id).all(), [
+ eq_(sess.query(Person).order_by(Person.person_id).all(), [
Person(name='p1'),
Employee(name='e1', employee_data='ed1'),
Manager(name='m1', employee_data='ed2', manager_data='md1')
])
sess.expunge_all()
- self.assertEquals(sess.query(Employee).order_by(Person.person_id).all(), [
+ eq_(sess.query(Employee).order_by(Person.person_id).all(), [
Employee(name='e1', employee_data='ed1'),
Manager(name='m1', employee_data='ed2', manager_data='md1')
])
sess.expunge_all()
- self.assertEquals(sess.query(Manager).order_by(Person.person_id).all(), [
+ eq_(sess.query(Manager).order_by(Person.person_id).all(), [
Manager(name='m1', employee_data='ed2', manager_data='md1')
])
sess.expunge_all()
def go():
- self.assertEquals(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [
+ eq_(sess.query(Person).with_polymorphic('*').order_by(Person.person_id).all(), [
Person(name='p1'),
Employee(name='e1', employee_data='ed1'),
Manager(name='m1', employee_data='ed2', manager_data='md1')
])
self.assert_sql_count(testing.db, go, 1)
-if __name__ == '__main__':
- testenv.main()
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-def suite():
- modules_to_test = (
- 'orm.sharding.shard',
- )
- alltests = unittest.TestSuite()
- for name in modules_to_test:
- mod = __import__(name)
- for token in name.split('.')[1:]:
- mod = getattr(mod, token)
- alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
- return alltests
-
-
-if __name__ == '__main__':
- testenv.main(suite())
-import testenv; testenv.configure_for_tests()
import datetime, os
from sqlalchemy import *
from sqlalchemy import sql
from sqlalchemy.orm import *
from sqlalchemy.orm.shard import ShardedSession
from sqlalchemy.sql import operators
-from testlib import *
-from testlib.testing import eq_
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
# TODO: ShardTest can be turned into a base for further subclasses
class ShardTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global db1, db2, db3, db4, weather_locations, weather_reports
db1 = create_engine('sqlite:///shard1.db')
db1.execute(ids.insert(), nextid=1)
- self.setup_session()
- self.setup_mappers()
+ cls.setup_session()
+ cls.setup_mappers()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
for db in (db1, db2, db3, db4):
db.connect().invalidate()
for i in range(1,5):
os.remove("shard%d.db" % i)
- def setup_session(self):
+ @classmethod
+ def setup_session(cls):
global create_session
shard_lookup = {
}, shard_chooser=shard_chooser, id_chooser=id_chooser, query_chooser=query_chooser)
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
global WeatherLocation, Report
class WeatherLocation(object):
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
-from testlib import testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from orm import _base
-from testlib.testing import eq_
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from test.orm import _base
+from sqlalchemy.test.testing import eq_
class AssociationTest(_base.MappedTest):
run_setup_classes = 'once'
run_setup_mappers = 'once'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('items', metadata,
Column('item_id', Integer, primary_key=True),
Column('name', String(40)))
Column('keyword_id', Integer, primary_key=True),
Column('name', String(40)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Item(_base.BasicEntity):
def __init__(self, name):
self.name = name
return "KeywordAssociation itemid=%d keyword=%r data=%s" % (
self.item_id, self.keyword, self.data)
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
- items, item_keywords, keywords = self.tables.get_all(
+ def setup_mappers(cls):
+ items, item_keywords, keywords = cls.tables.get_all(
'items', 'item_keywords', 'keywords')
mapper(Keyword, keywords)
item2.keywords.append(KeywordAssociation(Keyword('green'), 'green_assoc'))
sess.add_all((item1, item2))
sess.flush()
- eq_(self.tables.item_keywords.count().scalar(), 3)
+ eq_(item_keywords.count().scalar(), 3)
sess.delete(item1)
sess.delete(item2)
sess.flush()
- eq_(self.tables.item_keywords.count().scalar(), 0)
+ eq_(item_keywords.count().scalar(), 0)
-if __name__ == "__main__":
- testenv.main()
Derived from mailing list-reported problems and trac tickets.
"""
-import testenv; testenv.configure_for_tests()
import datetime
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, backref, create_session
-from testlib.testing import eq_
-from orm import _base
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, backref, create_session
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
class EagerTest(_base.MappedTest):
run_deletes = None
run_inserts = "once"
+ run_setup_mappers = "once"
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
# determine a literal value for "false" based on the dialect
# FIXME: this DefaultClause setup is bogus.
false = text('FALSE')
else:
false = str(False)
- self.other_artifacts['false'] = false
+ cls.other_artifacts['false'] = false
Table('owners', metadata ,
Column('id', Integer, primary_key=True, nullable=False),
Column('someoption', sa.Boolean, server_default=false,
nullable=False))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Owner(_base.BasicEntity):
pass
class Category(_base.BasicEntity):
pass
- class Test(_base.BasicEntity):
+ class Thing(_base.BasicEntity):
pass
class Option(_base.BasicEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Owner, owners)
mapper(Category, categories)
mapper(Option, options, properties=dict(
owner=relation(Owner),
- test=relation(Test)))
+ test=relation(Thing)))
- mapper(Test, tests, properties=dict(
+ mapper(Thing, tests, properties=dict(
owner=relation(Owner, backref='tests'),
category=relation(Category),
owner_option=relation(Option,
foreign_keys=[options.c.test_id, options.c.owner_id],
uselist=False)))
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
session = create_session()
o = Owner()
c = Category(name='Some Category')
session.add_all((
- Test(owner=o, category=c),
- Test(owner=o, category=c, owner_option=Option(someoption=True)),
- Test(owner=o, category=c, owner_option=Option())))
+ Thing(owner=o, category=c),
+ Thing(owner=o, category=c, owner_option=Option(someoption=True)),
+ Thing(owner=o, category=c, owner_option=Option())))
session.flush()
@testing.resolve_artifact_names
def test_withouteagerload(self):
s = create_session()
- l = (s.query(Test).
+ l = (s.query(Thing).
select_from(tests.outerjoin(options,
sa.and_(tests.c.id == options.c.test_id,
tests.c.owner_id ==
"""
s = create_session()
- q=s.query(Test).options(sa.orm.eagerload('category'))
+ q=s.query(Thing).options(sa.orm.eagerload('category'))
l=(q.select_from(tests.outerjoin(options,
sa.and_(tests.c.id ==
def test_dslish(self):
"""test the same as witheagerload except using generative"""
s = create_session()
- q = s.query(Test).options(sa.orm.eagerload('category'))
+ q = s.query(Thing).options(sa.orm.eagerload('category'))
l = q.filter (
sa.and_(tests.c.owner_id == 1,
sa.or_(options.c.someoption == None,
@testing.resolve_artifact_names
def test_without_outerjoin_literal(self):
s = create_session()
- q = s.query(Test).options(sa.orm.eagerload('category'))
+ q = s.query(Thing).options(sa.orm.eagerload('category'))
l = (q.filter(
(tests.c.owner_id==1) &
('options.someoption is null or options.someoption=%s' % false)).
@testing.resolve_artifact_names
def test_withoutouterjoin(self):
s = create_session()
- q = s.query(Test).options(sa.orm.eagerload('category'))
+ q = s.query(Thing).options(sa.orm.eagerload('category'))
l = q.filter(
(tests.c.owner_id==1) &
((options.c.someoption==None) | (options.c.someoption==False))
class EagerTest2(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('left', metadata,
Column('id', Integer, ForeignKey('middle.id'), primary_key=True),
Column('data', String(50), primary_key=True))
Column('id', Integer, ForeignKey('middle.id'), primary_key=True),
Column('data', String(50), primary_key=True))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Left(_base.BasicEntity):
def __init__(self, data):
self.data = data
def __init__(self, data):
self.data = data
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
# set up bi-directional eager loads
mapper(Left, left)
mapper(Right, right)
class EagerTest3(_base.MappedTest):
"""Eager loading combined with nested SELECT statements, functions, and aggregates."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('datas', metadata,
Column('id', Integer, primary_key=True, nullable=False),
Column('a', Integer, nullable=False))
Column('data_id', Integer, ForeignKey('datas.id')),
Column('somedata', Integer, nullable=False ))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Data(_base.BasicEntity):
pass
class EagerTest4(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('departments', metadata,
Column('department_id', Integer, primary_key=True),
Column('name', String(50)))
Column('department_id', Integer,
ForeignKey('departments.department_id')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Department(_base.BasicEntity):
pass
class EagerTest5(_base.MappedTest):
"""Construction of AliasedClauses for the same eager load property but different parent mappers, due to inheritance."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('base', metadata,
Column('uid', String(30), primary_key=True),
Column('x', String(30)))
Column('uid', String(30), ForeignKey('base.uid')),
Column('comment', String(30)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Base(_base.BasicEntity):
def __init__(self, uid, x):
self.uid = uid
class EagerTest6(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('design_types', metadata,
Column('design_type_id', Integer, primary_key=True))
Column('part_id', Integer, ForeignKey('parts.part_id')),
Column('design_id', Integer, ForeignKey('design.design_id')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Part(_base.BasicEntity):
pass
class EagerTest7(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('companies', metadata,
Column('company_id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('code', String(20)),
Column('qty', Integer))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Company(_base.ComparableEntity):
pass
class EagerTest8(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('prj', metadata,
Column('id', Integer, primary_key=True),
Column('created', sa.DateTime ),
Column('name', sa.Unicode(20)),
Column('display_name', sa.Unicode(20)))
+ @classmethod
@testing.resolve_artifact_names
- def fixtures(self):
+ def fixtures(cls):
return dict(
prj=(('id',),
(1,)),
task=(('title', 'task_type_id', 'status_id', 'prj_id'),
(u'task 1', 1, 1, 1)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Task_Type(_base.BasicEntity):
pass
throughout the query setup/mapper instances process.
"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('accounts', metadata,
Column('account_id', Integer, primary_key=True),
Column('name', String(40)))
Column('transaction_id', Integer,
ForeignKey('transactions.transaction_id')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Account(_base.BasicEntity):
pass
class Entry(_base.BasicEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Account, accounts)
mapper(Transaction, transactions)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
import pickle
import sqlalchemy.orm.attributes as attributes
from sqlalchemy.orm.collections import collection
from sqlalchemy.orm.interfaces import AttributeExtension
from sqlalchemy import exc as sa_exc
-from testlib import *
-from testlib.testing import eq_
-from orm import _base
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
import gc
# global for pickling tests
class AttributesTest(_base.ORMTest):
- def setUp(self):
+ def setup(self):
global MyTest, MyTest2
class MyTest(object): pass
class MyTest2(object): pass
- def tearDown(self):
+ def teardown(self):
global MyTest, MyTest2
MyTest, MyTest2 = None, None
self.assert_(p.jack is None)
class PendingBackrefTest(_base.ORMTest):
- def setUp(self):
+ def setup(self):
global Post, Blog, called, lazy_load
class Post(object):
assert f1.barset.pop().data == "some bar appended"
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
-from testlib.sa import MetaData, Table, Column, Integer
-from testlib.sa.orm import mapper, create_session
-from testlib import sa, testing
-from orm import _base
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+from sqlalchemy import MetaData, Integer
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, create_session
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from test.orm import _base
class BindTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('test_table', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('data', Integer))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Foo(_base.BasicEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
meta = MetaData()
test_table.tometadata(meta)
def test_session_unbound(self):
sess = create_session()
sess.add(Foo())
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.UnboundExecutionError,
('Could not locate a bind configured on Mapper|Foo|test_table '
'or this Session'),
sess.flush)
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
-from testlib.sa import Table, Column, Integer, String, ForeignKey, Sequence, exc as sa_exc
-from testlib.sa.orm import mapper, relation, create_session, class_mapper, backref
-from testlib.sa.orm import attributes, exc as orm_exc
-from testlib import testing
-from testlib.testing import eq_
-from orm import _base, _fixtures
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref
+from sqlalchemy.orm import attributes, exc as orm_exc
+from sqlalchemy.test import testing
+from sqlalchemy.test.testing import eq_
+from test.orm import _base, _fixtures
class O2MCascadeTest(_fixtures.FixtureTest):
run_inserts = None
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Address, addresses)
mapper(User, users, properties = dict(
addresses = relation(Address, cascade="all, delete-orphan", backref="user"),
class O2OCascadeTest(_fixtures.FixtureTest):
run_inserts = None
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Address, addresses)
mapper(User, users, properties = {
'address':relation(Address, backref=backref("user", single_parent=True), uselist=False)
a1 = Address(email_address='some address')
u1 = User(name='u1', address=a1)
- self.assertRaises(sa_exc.InvalidRequestError, Address, email_address='asd', user=u1)
+ assert_raises(sa_exc.InvalidRequestError, Address, email_address='asd', user=u1)
a2 = Address(email_address='asd')
u1.address = a2
class O2MBackrefTest(_fixtures.FixtureTest):
run_inserts = None
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(User, users, properties = dict(
orders = relation(
mapper(Order, orders), cascade="all, delete-orphan", backref="user")
class O2MCascadeNoOrphanTest(_fixtures.FixtureTest):
run_inserts = None
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(User, users, properties = dict(
orders = relation(
mapper(Order, orders), cascade="all")
class M2OCascadeTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("extra", metadata,
Column("id", Integer, Sequence("extra_id_seq", optional=True),
primary_key=True),
Column('name', String(40)),
Column('pref_id', Integer, ForeignKey('prefs.id')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class User(_fixtures.Base):
pass
class Pref(_fixtures.Base):
class Extra(_fixtures.Base):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Extra, extra)
mapper(Pref, prefs, properties=dict(
extra = relation(Extra, cascade="all, delete")
pref = relation(Pref, lazy=False, cascade="all, delete-orphan", single_parent=True )
))
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
u1 = User(name='ed', pref=Pref(data="pref 1", extra=[Extra()]))
u2 = User(name='jack', pref=Pref(data="pref 2", extra=[Extra()]))
u3 = User(name="foo", pref=Pref(data="pref 3", extra=[Extra()]))
[Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")])
class M2OCascadeDeleteTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(50)),
Column('id', Integer, primary_key=True),
Column('data', String(50)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class T1(_fixtures.Base):
pass
class T2(_fixtures.Base):
class T3(_fixtures.Base):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(T1, t1, properties={'t2': relation(T2, cascade="all")})
mapper(T2, t2, properties={'t3': relation(T3, cascade="all")})
mapper(T3, t3)
class M2OCascadeDeleteOrphanTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(50)),
Column('id', Integer, primary_key=True),
Column('data', String(50)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class T1(_fixtures.Base):
pass
class T2(_fixtures.Base):
class T3(_fixtures.Base):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(T1, t1, properties=dict(
t2=relation(T2, cascade="all, delete-orphan", single_parent=True)))
mapper(T2, t2, properties=dict(
y = T2(data='T2a')
x = T1(data='T1a', t2=y)
- self.assertRaises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y)
+ assert_raises(sa_exc.InvalidRequestError, T1, data='T1b', t2=y)
@testing.resolve_artifact_names
def test_single_parent_backref(self):
x = T2(data='T2a', t3=y)
# cant attach the T3 to another T2
- self.assertRaises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y)
+ assert_raises(sa_exc.InvalidRequestError, T2, data='T2b', t3=y)
# set via backref tho is OK, unsets from previous parent
# first
assert x.t3 is None
class M2MCascadeTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('a', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(30)),
)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class A(_fixtures.Base):
pass
class B(_fixtures.Base):
b1 =B(data='b1')
a1 = A(data='a1', bs=[b1])
- self.assertRaises(sa_exc.InvalidRequestError,
+ assert_raises(sa_exc.InvalidRequestError,
A, data='a2', bs=[b1]
)
b1 =B(data='b1')
a1 = A(data='a1', bs=[b1])
- self.assertRaises(
+ assert_raises(
sa_exc.InvalidRequestError,
A, data='a2', bs=[b1]
)
class UnsavedOrphansTest(_base.MappedTest):
"""Pending entities that are orphans"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('users', metadata,
Column('user_id', Integer,
Sequence('user_id_seq', optional=True),
Column('user_id', Integer, ForeignKey('users.user_id')),
Column('email_address', String(40)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class User(_fixtures.Base):
pass
class Address(_fixtures.Base):
class UnsavedOrphansTest2(_base.MappedTest):
"""same test as UnsavedOrphans only three levels deep"""
- def define_tables(self, meta):
+ @classmethod
+ def define_tables(cls, meta):
Table('orders', meta,
Column('id', Integer, Sequence('order_id_seq'),
primary_key=True),
class UnsavedOrphansTest3(_base.MappedTest):
"""test not expunging double parents"""
- def define_tables(self, meta):
+ @classmethod
+ def define_tables(cls, meta):
Table('sales_reps', meta,
Column('sales_rep_id', Integer,
Sequence('sales_rep_id_seq'),
class DoubleParentOrphanTest(_base.MappedTest):
"""test orphan detection for an entity with two parent relations"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('addresses', metadata,
Column('address_id', Integer, primary_key=True),
Column('street', String(30)),
assert True
class CollectionAssignmentOrphanTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('table_a', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(30)))
"""test cascade behavior as it relates to object lists passed to flush().
"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("base", metadata,
Column("id", Integer, primary_key=True),
Column("descr", String(50))
assert c1 not in sess.new
assert c2 in sess.new
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
import sys
from operator import and_
import sqlalchemy.orm.collections as collections
from sqlalchemy.orm.collections import collection
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa import util, exc as sa_exc
-from testlib.sa.orm import create_session, mapper, relation, \
- attributes
-from orm import _base
-from testlib.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy import util, exc as sa_exc
+from sqlalchemy.orm import create_session, mapper, relation, attributes
+from test.orm import _base
+from sqlalchemy.test.testing import eq_
class Canary(sa.orm.interfaces.AttributeExtension):
def __init__(self):
def __repr__(self):
return str((id(self), self.a, self.b, self.c))
- def setUpAll(self):
- attributes.register_class(self.Entity)
+ @classmethod
+ def setup_class(cls):
+ attributes.register_class(cls.Entity)
- def tearDownAll(self):
- attributes.unregister_class(self.Entity)
- _base.ORMTest.tearDownAll(self)
+ @classmethod
+ def teardown_class(cls):
+ attributes.unregister_class(cls.Entity)
+ super(CollectionsTest, cls).teardown_class()
_entity_id = 1
pass
self.assert_(obj.attr is not real_dict)
self.assert_('badkey' not in obj.attr)
- self.assertEquals(set(collections.collection_adapter(obj.attr)),
+ eq_(set(collections.collection_adapter(obj.attr)),
set([e2]))
self.assert_(e3 not in canary.added)
else:
obj.attr = real_dict
self.assert_(obj.attr is not real_dict)
self.assert_('keyignored1' not in obj.attr)
- self.assertEquals(set(collections.collection_adapter(obj.attr)),
+ eq_(set(collections.collection_adapter(obj.attr)),
set([e3]))
self.assert_(e2 in canary.removed)
self.assert_(e3 in canary.added)
obj.attr = typecallable()
- self.assertEquals(list(collections.collection_adapter(obj.attr)), [])
+ eq_(list(collections.collection_adapter(obj.attr)), [])
e4 = creator()
try:
class DictHelpersTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('parents', metadata,
Column('id', Integer, primary_key=True),
Column('label', String(128)))
Column('b', String(128)),
Column('c', String(128)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Parent(_base.BasicEntity):
def __init__(self, label=None):
self.label = label
p = session.query(Parent).get(pid)
- self.assertEquals(set(p.children.keys()), set(['foo', 'bar']))
+ eq_(set(p.children.keys()), set(['foo', 'bar']))
cid = p.children['foo'].id
collections.collection_adapter(p.children).append_with_event(
# remove if so
class CustomCollectionsTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('sometable', metadata,
Column('col1',Integer, primary_key=True),
Column('data', String(30)))
instrumented = collections._instrument_class(Touchy)
assert True
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
-from testlib import *
-from orm import _base
+from sqlalchemy.test import *
+from test.orm import _base
class CompileTest(_base.ORMTest):
"""test various mapper compilation scenarios"""
- def tearDown(self):
+ def teardown(self):
clear_mappers()
def testone(self):
assert str(e).index("Error creating backref") > -1
-if __name__ == '__main__':
- testenv.main()
T1/T2.
"""
-import testenv; testenv.configure_for_tests()
-from testlib import testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, backref, create_session
-from testlib.testing import eq_
-from testlib.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf
-from orm import _base
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, backref, create_session
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.assertsql import RegexSQL, ExactSQL, CompiledSQL, AllOf
+from test.orm import _base
class SelfReferentialTest(_base.MappedTest):
"""A self-referential mapper with an additional list of child objects."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1', metadata,
Column('c1', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('c1id', Integer, ForeignKey('t1.c1')),
Column('data', String(20)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class C1(_base.BasicEntity):
def __init__(self, data=None):
self.data = data
class SelfReferentialNoPKTest(_base.MappedTest):
"""A self-referential relationship that joins on a column other than the primary key column"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('item', metadata,
Column('id', Integer, primary_key=True),
Column('uuid', String(32), unique=True, nullable=False),
Column('parent_uuid', String(32), ForeignKey('item.uuid'),
nullable=True))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class TT(_base.BasicEntity):
def __init__(self):
self.uuid = hex(id(self))
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(TT, item, properties={
'children': relation(
TT,
class InheritTestOne(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("parent", metadata,
Column("id", Integer, primary_key=True),
Column("parent_data", String(50)),
nullable=False),
Column("child2_data", String(50)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Parent(_base.BasicEntity):
pass
class Child2(Parent):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Parent, parent)
mapper(Child1, child1, inherits=Parent)
mapper(Child2, child2, inherits=Parent, properties=dict(
"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('a', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(30)),
Column('aid', Integer,
ForeignKey('a.id', use_alter=True, name="foo")))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class A(_base.BasicEntity):
pass
class BiDirectionalManyToOneTest(_base.MappedTest):
run_define_tables = 'each'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(30)),
Column('t1id', Integer, ForeignKey('t1.id'), nullable=False),
Column('t2id', Integer, ForeignKey('t2.id'), nullable=False))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class T1(_base.BasicEntity):
pass
class T2(_base.BasicEntity):
class T3(_base.BasicEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(T1, t1, properties={
't2':relation(T2, primaryjoin=t1.c.t2id == t2.c.id)})
mapper(T2, t2, properties={
run_define_tables = 'each'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1', metadata,
Column('c1', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('c2', Integer,
ForeignKey('t1.c1', use_alter=True, name='t1c1_fk')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class C1(_base.BasicEntity):
pass
run_define_tables = 'each'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1', metadata,
Column('c1', Integer, primary_key=True),
Column('c2', Integer, ForeignKey('t2.c1')),
Column('data', String(20)),
test_needs_autoincrement=True)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class C1(_base.BasicEntity):
pass
class C1Data(_base.BasicEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(C2, t2, properties={
'c1s': relation(C1,
primaryjoin=t2.c.c1 == t1.c.c2,
"""
run_define_tables = 'each'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('ball', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('favorite_ball_id', Integer, ForeignKey('ball.id')),
Column('data', String(30)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Person(_base.BasicEntity):
pass
class SelfReferentialPostUpdateTest(_base.MappedTest):
"""Post_update on a single self-referential mapper"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('node', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('next_sibling_id', Integer,
ForeignKey('node.id'), nullable=True))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Node(_base.BasicEntity):
def __init__(self, path=''):
self.path = path
class SelfReferentialPostUpdateTest2(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("a_table", metadata,
Column("id", Integer(), primary_key=True),
Column("fui", String(128)),
Column("b", Integer(), ForeignKey("a_table.id")))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class A(_base.BasicEntity):
pass
assert f2.foo is f1
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from orm import _base
-from testlib.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from test.orm import _base
+from sqlalchemy.test.testing import eq_
class TriggerDefaultsTest(_base.MappedTest):
__requires__ = ('row_triggers',)
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
dt = Table('dt', metadata,
Column('id', Integer, primary_key=True),
Column('col1', String(20)),
sa.DDL("DROP TRIGGER dt_up").execute_at('before-drop', dt)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Default(_base.BasicEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Default, dt)
@testing.resolve_artifact_names
eq_(d1.col4, 'up')
class ExcludedDefaultsTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
dt = Table('dt', metadata,
Column('id', Integer, primary_key=True),
Column('col1', String(20), default="hello"),
sess.flush()
eq_(dt.select().execute().fetchall(), [(1, "hello")])
-if __name__ == "__main__":
- testenv.main()
be migrated directly to the wiki, docs, etc.
"""
-import testenv; testenv.configure_for_tests()
-from testlib import testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, func
-from testlib.sa.orm import mapper, relation, create_session, sessionmaker
-from orm import _base
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey, func
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, sessionmaker
+from test.orm import _base
class QueryAlternativesTest(_base.MappedTest):
run_inserts = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('users_table', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(64)))
Column('purpose', String(16)),
Column('bounces', Integer, default=0))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class User(_base.BasicEntity):
pass
class Address(_base.BasicEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(User, users_table, properties=dict(
addresses=relation(Address, backref='user'),
))
mapper(Address, addresses_table)
- def fixtures(self):
+ @classmethod
+ def fixtures(cls):
return dict(
users_table=(
('id', 'name'),
assert len(users) == 1 and users[0].name == 'ed'
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
import operator
from sqlalchemy.orm import dynamic_loader, backref
-from testlib import testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, desc, select, func
-from testlib.sa.orm import mapper, relation, create_session, Query, attributes
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey, desc, select, func
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, Query, attributes
from sqlalchemy.orm.dynamic import AppenderMixin
-from testlib.testing import eq_
-from testlib.compat import _function_named
-from orm import _base, _fixtures
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.util import function_named
+from test.orm import _base, _fixtures
class DynamicTest(_fixtures.FixtureTest):
sess.flush()
from sqlalchemy.orm import attributes
- self.assertEquals(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], []))
+ eq_(attributes.get_history(attributes.instance_state(u1), 'addresses'), ([], [Address(email_address='lala@hoho.com')], []))
sess.expunge_all()
sess.close()
-def create_backref_test(autoflush, saveuser):
+def _create_backref_test(autoflush, saveuser):
@testing.resolve_artifact_names
def test_backref(self):
sess.flush()
self.assert_(list(u.addresses) == [])
- test_backref = _function_named(
+ test_backref = function_named(
test_backref, "test%s%s" % ((autoflush and "_autoflush" or ""),
(saveuser and "_saveuser" or "_savead")))
setattr(SessionTest, test_backref.__name__, test_backref)
for autoflush in (False, True):
for saveuser in (False, True):
- create_backref_test(autoflush, saveuser)
+ _create_backref_test(autoflush, saveuser)
class DontDereferenceTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(40)),
Column('email_address', String(100), nullable=False),
Column('user_id', Integer, ForeignKey('users.id')))
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
class User(_base.ComparableEntity):
pass
eq_(query3(), [Address(email_address='joe@joesdomain.example')])
-if __name__ == '__main__':
- testenv.main()
"""basic tests of eager loaded attributes"""
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
+from sqlalchemy.test.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
from sqlalchemy.orm import eagerload, deferred, undefer
-from testlib.sa import Table, Column, Integer, String, Date, ForeignKey, and_, select, func
-from testlib.sa.orm import mapper, relation, create_session, lazyload, aliased
-from testlib.testing import eq_
-from testlib.assertsql import CompiledSQL
-from orm import _base, _fixtures
+from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, func
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, lazyload, aliased
+from sqlalchemy.test.testing import eq_
+from sqlalchemy.test.assertsql import CompiledSQL
+from test.orm import _base, _fixtures
import datetime
class EagerTest(_fixtures.FixtureTest):
q = sess.query(User)
assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all()
- self.assertEquals(self.static.user_address_result, q.order_by(User.id).all())
+ eq_(self.static.user_address_result, q.order_by(User.id).all())
@testing.resolve_artifact_names
def test_late_compile(self):
assert sa.orm.class_mapper(Address).get_property('user').lazy is False
sess = create_session()
- self.assertEquals(self.static.user_address_result, sess.query(User).order_by(User.id).all())
+ eq_(self.static.user_address_result, sess.query(User).order_by(User.id).all())
@testing.resolve_artifact_names
def test_double(self):
def go():
o1 = sess.query(Order).options(lazyload('address')).filter(Order.id==5).one()
- self.assertEquals(o1.address, None)
+ eq_(o1.address, None)
self.assert_sql_count(testing.db, go, 2)
sess.expunge_all()
def go():
o1 = sess.query(Order).filter(Order.id==5).one()
- self.assertEquals(o1.address, None)
+ eq_(o1.address, None)
self.assert_sql_count(testing.db, go, 1)
@testing.resolve_artifact_names
self.assert_sql_count(testing.db, go, 1)
class OrderBySecondaryTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('m2m', metadata,
Column('id', Integer, primary_key=True),
Column('aid', Integer, ForeignKey('a.id')),
Column('id', Integer, primary_key=True),
Column('data', String(50)))
- def fixtures(self):
+ @classmethod
+ def fixtures(cls):
return dict(
a=(('id', 'data'),
(1, 'a1'),
class SelfReferentialEagerTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('nodes', metadata,
Column('id', Integer, sa.Sequence('node_id_seq', optional=True),
primary_key=True),
sess.expunge_all()
def go():
- self.assertEquals(
+ eq_(
Node(data='n1', children=[Node(data='n11'), Node(data='n12')]),
sess.query(Node).order_by(Node.id).first(),
)
self.assert_sql_count(testing.db, go, 3)
class MixedSelfReferentialEagerTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('a_table', metadata,
Column('id', Integer, primary_key=True)
)
Column('parent_b2_id', Integer, ForeignKey('b_table.id')))
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
class A(_base.ComparableEntity):
pass
class B(_base.ComparableEntity):
)
});
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
a_table.insert().execute(dict(id=1), dict(id=2), dict(id=3))
b_table.insert().execute(
dict(id=1, parent_a_id=2, parent_b1_id=None, parent_b2_id=None),
self.assert_sql_count(testing.db, go, 1)
class SelfReferentialM2MEagerTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('widget', metadata,
Column('id', Integer, primary_key=True),
Column('name', sa.Unicode(40), nullable=False, unique=True),
run_inserts = 'once'
run_deletes = None
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(User, users, properties={
'addresses':relation(Address, backref='user'),
'orders':relation(Order, backref='user'), # o2m, m2o
class CyclicalInheritingEagerTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1', metadata,
Column('c1', Integer, primary_key=True),
Column('c2', String(30)),
create_session().query(SubT).all()
class SubqueryTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('users_table', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(16))
"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
users = Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(50))
Column('date', Date),
Column('user_id', Integer, ForeignKey('users.id')))
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
users.insert().execute(
{'id':1, 'name':'user1'},
{'id':2, 'name':'user2'},
self.assert_sql_count(testing.db, go, 1)
-if __name__ == '__main__':
- testenv.main()
"""Evluating SQL expressions on ORM objects"""
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, String, Integer, select
-from testlib.sa.orm import mapper, create_session
-from testlib.testing import eq_
-from orm import _base
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import String, Integer, select
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, create_session
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
from sqlalchemy import and_, or_, not_
from sqlalchemy.orm import evaluator
return testeval
class EvaluateTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(64)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class User(_base.ComparableEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(User, users)
@testing.resolve_artifact_names
(User(id=None, name=None), None),
])
-if __name__ == '__main__':
- testenv.main()
"""Attribute/instance expiration, deferral of attributes, etc."""
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import gc
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, exc as sa_exc
-from testlib.sa.orm import mapper, relation, create_session, attributes, deferred
-from orm import _base, _fixtures
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey, exc as sa_exc
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, attributes, deferred
+from test.orm import _base, _fixtures
class ExpireTest(_fixtures.FixtureTest):
u = s.query(User).get(7)
s.expunge_all()
- self.assertRaisesMessage(sa.exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u)
+ assert_raises_message(sa.exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u)
@testing.resolve_artifact_names
def test_get_refreshes(self):
u = s.query(User).get(10) # get() refreshes
self.assert_sql_count(testing.db, go, 1)
def go():
- self.assertEquals(u.name, 'chuck') # attributes unexpired
+ eq_(u.name, 'chuck') # attributes unexpired
self.assert_sql_count(testing.db, go, 0)
def go():
u = s.query(User).get(10) # expire flag reset, so not expired
# add it back
s.add(u)
# nope, raises ObjectDeletedError
- self.assertRaises(sa.orm.exc.ObjectDeletedError, getattr, u, 'name')
+ assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u, 'name')
# do a get()/remove u from session again
assert s.query(User).get(10) is None
assert u in s
# but now its back, rollback has occured, the _remove_newly_deleted
# is reverted
- self.assertEquals(u.name, 'chuck')
+ eq_(u.name, 'chuck')
@testing.resolve_artifact_names
def test_deferred(self):
s = create_session(autoflush=True, autocommit=False)
u = s.query(User).get(8)
adlist = u.addresses
- self.assertEquals(adlist, [
+ eq_(adlist, [
Address(email_address='ed@bettyboop.com'),
Address(email_address='ed@lala.com'),
Address(email_address='ed@wood.com'),
a1 = u.addresses[2]
a1.email_address = 'aaaaa'
s.expire(u, ['addresses'])
- self.assertEquals(u.addresses, [
+ eq_(u.addresses, [
Address(email_address='aaaaa'),
Address(email_address='ed@bettyboop.com'),
Address(email_address='ed@lala.com'),
mapper(Address, addresses)
s = create_session(autoflush=True, autocommit=False)
u = s.query(User).get(8)
- self.assertRaisesMessage(sa_exc.InvalidRequestError, "properties specified for refresh", s.refresh, u, ['addresses'])
+ assert_raises_message(sa_exc.InvalidRequestError, "properties specified for refresh", s.refresh, u, ['addresses'])
# in contrast to a regular query with no columns
- self.assertRaisesMessage(sa_exc.InvalidRequestError, "no columns with which to SELECT", s.query().all)
+ assert_raises_message(sa_exc.InvalidRequestError, "no columns with which to SELECT", s.query().all)
@testing.resolve_artifact_names
def test_refresh_cancels_expire(self):
def go():
u = s.query(User).get(7)
- self.assertEquals(u.name, 'jack')
+ eq_(u.name, 'jack')
self.assert_sql_count(testing.db, go, 0)
@testing.resolve_artifact_names
sess.expire(u, attribute_names=['name'])
sess.expunge(u)
- self.assertRaises(sa.exc.UnboundExecutionError, getattr, u, 'name')
+ assert_raises(sa.exc.UnboundExecutionError, getattr, u, 'name')
@testing.resolve_artifact_names
def test_pending_raises(self):
sess = create_session()
u = User(id=15)
sess.add(u)
- self.assertRaises(sa.exc.InvalidRequestError, sess.expire, u, ['name'])
+ assert_raises(sa.exc.InvalidRequestError, sess.expire, u, ['name'])
@testing.resolve_artifact_names
def test_no_instance_key(self):
userlist = sess.query(User).order_by(User.id).all()
u = userlist[1]
- self.assertEquals(self.static.user_address_result, userlist)
+ eq_(self.static.user_address_result, userlist)
assert len(list(sess)) == 9
class PolymorphicExpireTest(_base.MappedTest):
run_inserts = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global people, engineers, Person, Engineer
people = Table('people', metadata,
Column('status', String(30)),
)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Person(_base.ComparableEntity):
pass
class Engineer(Person):
pass
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
people.insert().execute(
{'person_id':1, 'name':'person1', 'type':'person'},
{'person_id':2, 'name':'engineer1', 'type':'engineer'},
assert e1.status == 'new engineer'
assert e2.status == 'old engineer'
self.assert_sql_count(testing.db, go, 2)
- self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1']))
+ eq_(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1']))
class ExpiredPendingTest(_fixtures.FixtureTest):
run_define_tables = 'once'
s = create_session()
u = s.query(User).get(7)
s.expunge_all()
- self.assertRaisesMessage(sa.exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u))
+ assert_raises_message(sa.exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u))
@testing.resolve_artifact_names
def test_refresh_expired(self):
s.refresh(u)
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import pickle
from sqlalchemy import util
import sqlalchemy.orm.attributes as attributes
from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented
from sqlalchemy.orm import clear_mappers
from sqlalchemy.orm import InstrumentationManager
-from testlib import *
-from orm import _base
+from sqlalchemy.test import *
+from test.orm import _base
class MyTypesManager(InstrumentationManager):
del self._goofy_dict[key]
class UserDefinedExtensionTest(_base.ORMTest):
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
clear_mappers()
attributes._install_lookup_strategy(util.symbol('native'))
assert Foo in attributes.instrumentation_registry._state_finders
f = Foo()
attributes.instance_state(f).expire_attributes(None)
- self.assertEquals(f.a, "this is a")
- self.assertEquals(f.b, 12)
+ eq_(f.a, "this is a")
+ eq_(f.b, 12)
f.a = "this is some new a"
attributes.instance_state(f).expire_attributes(None)
- self.assertEquals(f.a, "this is a")
- self.assertEquals(f.b, 12)
+ eq_(f.a, "this is a")
+ eq_(f.b, 12)
attributes.instance_state(f).expire_attributes(None)
f.a = "this is another new a"
- self.assertEquals(f.a, "this is another new a")
- self.assertEquals(f.b, 12)
+ eq_(f.a, "this is another new a")
+ eq_(f.b, 12)
attributes.instance_state(f).expire_attributes(None)
- self.assertEquals(f.a, "this is a")
- self.assertEquals(f.b, 12)
+ eq_(f.a, "this is a")
+ eq_(f.b, 12)
del f.a
- self.assertEquals(f.a, None)
- self.assertEquals(f.b, 12)
+ eq_(f.a, None)
+ eq_(f.b, 12)
attributes.instance_state(f).commit_all(attributes.instance_dict(f))
- self.assertEquals(f.a, None)
- self.assertEquals(f.b, 12)
+ eq_(f.a, None)
+ eq_(f.b, 12)
def test_inheritance(self):
"""tests that attributes are polymorphic"""
f1 = Foo()
f1.name = 'f1'
- self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], (), ()))
+ eq_(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], (), ()))
b1 = Bar()
b1.name = 'b1'
f1.bars.append(b1)
- self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
+ eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
attributes.instance_state(f1).commit_all(attributes.instance_dict(f1))
attributes.instance_state(b1).commit_all(attributes.instance_dict(b1))
- self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ()))
- self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ()))
+ eq_(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ()))
+ eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ()))
f1.name = 'f1mod'
b2 = Bar()
b2.name = 'b2'
f1.bars.append(b2)
- self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1']))
- self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], []))
+ eq_(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1']))
+ eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], []))
f1.bars.remove(b1)
- self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1]))
+ eq_(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1]))
def test_null_instrumentation(self):
class Foo(MyBaseClass):
assert attributes.manager_of_class(None) is None
assert attributes.instance_state(k) is not None
- self.assertRaises((AttributeError, KeyError),
+ assert_raises((AttributeError, KeyError),
attributes.instance_state, u)
- self.assertRaises((AttributeError, KeyError),
+ assert_raises((AttributeError, KeyError),
attributes.instance_state, None)
-import testenv; testenv.configure_for_tests()
-from testlib import testing, sa
-from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, func
+from sqlalchemy.test.testing import eq_
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey, MetaData, func
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
from sqlalchemy.orm import mapper, relation, create_session
-from testlib.testing import eq_
-from orm import _base, _fixtures
+from sqlalchemy.test.testing import eq_
+from test.orm import _base, _fixtures
class GenerativeQueryTest(_base.MappedTest):
run_inserts = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('foo', metadata,
Column('id', Integer, sa.Sequence('foo_id_seq'), primary_key=True),
Column('bar', Integer),
Column('range', Integer))
- def fixtures(self):
+ @classmethod
+ def fixtures(cls):
rows = tuple([(i, i % 10) for i in range(100)])
foo_data = (('bar', 'range'),) + rows
return dict(foo=foo_data)
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
class Foo(_base.BasicEntity):
pass
class GenerativeTest2(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('Table1', metadata,
Column('id', Integer, primary_key=True))
Table('Table2', metadata,
primary_key=True),
Column('num', Integer, primary_key=True))
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
class Obj1(_base.BasicEntity):
pass
class Obj2(_base.BasicEntity):
mapper(Obj1, Table1)
mapper(Obj2, Table2)
- def fixtures(self):
+ @classmethod
+ def fixtures(cls):
return dict(
Table1=(('id',),
(1,),
run_inserts = 'once'
run_deletes = None
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(User, users, properties={
'orders':relation(mapper(Order, orders, properties={
'addresses':relation(mapper(Address, addresses))}))})
class CaseSensitiveTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('Table1', metadata,
Column('ID', Integer, primary_key=True))
Table('Table2', metadata,
primary_key=True),
Column('NUM', Integer, primary_key=True))
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
class Obj1(_base.BasicEntity):
pass
class Obj2(_base.BasicEntity):
mapper(Obj1, Table1)
mapper(Obj2, Table2)
- def fixtures(self):
+ @classmethod
+ def fixtures(cls):
return dict(
Table1=(('ID',),
(1,),
res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1))
assert res.count() == 3
res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)).distinct()
- self.assertEqual(res.count(), 1)
+ eq_(res.count(), 1)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
-from testlib import sa
-from testlib.sa import MetaData, Table, Column, Integer, ForeignKey, util
-from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers
-from testlib.testing import eq_, ne_
-from testlib.compat import _function_named
-from orm import _base
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy import MetaData, Integer, ForeignKey, util
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers
+from sqlalchemy.test.testing import eq_, ne_
+from sqlalchemy.util import function_named
+from test.orm import _base
def modifies_instrumentation_finders(fn):
finally:
del attributes.instrumentation_finders[:]
attributes.instrumentation_finders.extend(pristine)
- return _function_named(decorated, fn.func_name)
+ return function_named(decorated, fn.func_name)
def with_lookup_strategy(strategy):
def decorate(fn):
return fn(*args, **kw)
finally:
attributes._install_lookup_strategy(sa.util.symbol('native'))
- return _function_named(wrapped, fn.func_name)
+ return function_named(wrapped, fn.func_name)
return decorate
m = mapper(A, self.fixture())
# B is not mapped in the current implementation
- self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, B)
+ assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, B)
# C is not mapped in the current implementation
- self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, C)
+ assert_raises(sa.orm.exc.UnmappedClassError, class_mapper, C)
class InstrumentationCollisionTest(_base.ORMTest):
def test_none(self):
class B(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
- self.assertRaises(TypeError, attributes.register_class, B)
+ assert_raises(TypeError, attributes.register_class, B)
def test_single_up(self):
class B(A):
__sa_instrumentation_manager__ = staticmethod(mgr_factory)
attributes.register_class(B)
- self.assertRaises(TypeError, attributes.register_class, A)
+ assert_raises(TypeError, attributes.register_class, A)
def test_diamond_b1(self):
mgr_factory = lambda cls: attributes.ClassManager(cls)
__sa_instrumentation_manager__ = mgr_factory
class C(object): pass
- self.assertRaises(TypeError, attributes.register_class, B1)
+ assert_raises(TypeError, attributes.register_class, B1)
def test_diamond_b2(self):
mgr_factory = lambda cls: attributes.ClassManager(cls)
__sa_instrumentation_manager__ = mgr_factory
class C(object): pass
- self.assertRaises(TypeError, attributes.register_class, B2)
+ assert_raises(TypeError, attributes.register_class, B2)
def test_diamond_c_b(self):
mgr_factory = lambda cls: attributes.ClassManager(cls)
class C(object): pass
attributes.register_class(C)
- self.assertRaises(TypeError, attributes.register_class, B1)
+ assert_raises(TypeError, attributes.register_class, B1)
class OnLoadTest(_base.ORMTest):
finally:
del A
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
clear_mappers()
attributes._install_lookup_strategy(util.symbol('native'))
sa = attributes.ClassManager.STATE_ATTR
ma = attributes.ClassManager.MANAGER_ATTR
- fails = lambda method, attr: self.assertRaises(
+ fails = lambda method, attr: assert_raises(
KeyError, getattr(manager, method), attr, property())
fails('install_member', sa)
class T(object): pass
- self.assertRaises(KeyError, mapper, T, t)
+ assert_raises(KeyError, mapper, T, t)
@with_lookup_strategy(sa.util.symbol('native'))
def test_mapped_managerattr(self):
Column(attributes.ClassManager.MANAGER_ATTR, Integer))
class T(object): pass
- self.assertRaises(KeyError, mapper, T, t)
+ assert_raises(KeyError, mapper, T, t)
class MiscTest(_base.ORMTest):
eq_(type(attributes.manager_of_class(A)), attributes.ClassManager)
-if __name__ == "__main__":
- testenv.main()
"""basic tests of lazy loaded attributes"""
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
import datetime
from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import attributes
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from testlib.testing import eq_
-from orm import _base, _fixtures
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from sqlalchemy.test.testing import eq_
+from test.orm import _base, _fixtures
class LazyTest(_fixtures.FixtureTest):
q = sess.query(User)
u = q.filter(users.c.id == 7).first()
sess.expunge(u)
- self.assertRaises(sa_exc.InvalidRequestError, getattr, u, 'addresses')
+ assert_raises(sa_exc.InvalidRequestError, getattr, u, 'addresses')
@testing.resolve_artifact_names
def test_orderby(self):
class CorrelatedTest(_base.MappedTest):
+ @classmethod
def define_tables(self, meta):
Table('user_t', meta,
Column('id', Integer, primary_key=True),
Column('date', sa.Date),
Column('user_id', Integer, ForeignKey('user_t.id')))
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
user_t.insert().execute(
{'id':1, 'name':'user1'},
{'id':2, 'name':'user2'},
])
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from orm import _base
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from test.orm import _base
class LazyTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('infos', metadata,
Column('pk', Integer, primary_key=True),
Column('info', String(128)))
Column('start', Integer),
Column('finish', Integer))
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
infos.insert().execute(
{'pk':1, 'info':'pk_1_info'},
{'pk':2, 'info':'pk_2_info'},
assert len(info.rels[0].datas) == 3
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from orm import _base
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from test.orm import _base
class M2MTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('place', metadata,
Column('place_id', Integer, sa.Sequence('pid_seq', optional=True),
primary_key=True),
Column('pl1_id', Integer, ForeignKey('place.place_id')),
Column('pl2_id', Integer, ForeignKey('place.place_id')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Place(_base.BasicEntity):
def __init__(self, name=None):
self.name = name
mapper(Transition, transition, properties={
'places':relation(Place, secondary=place_input, backref='transitions')
})
- self.assertRaisesMessage(sa.exc.ArgumentError, "Error creating backref",
+ assert_raises_message(sa.exc.ArgumentError, "Error creating backref",
sa.orm.compile_mappers)
@testing.resolve_artifact_names
class M2MTest2(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('student', metadata,
Column('name', String(20), primary_key=True))
Column('course_id', String(20), ForeignKey('course.name'),
primary_key=True))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Student(_base.BasicEntity):
def __init__(self, name=''):
self.name = name
s1.courses.append(c1)
s1.courses.append(c1)
sess.add(s1)
- self.assertRaises(sa.exc.DBAPIError, sess.flush)
+ assert_raises(sa.exc.DBAPIError, sess.flush)
@testing.resolve_artifact_names
def test_delete(self):
assert enroll.count().scalar() == 0
class M2MTest3(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('c', metadata,
Column('c1', Integer, primary_key = True),
Column('c2', String(20)))
# how about some data/inserts/queries/assertions for this one
-if __name__ == "__main__":
- testenv.main()
"""General mapper operations with an emphasis on selecting/loading."""
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, func
-from testlib.sa.engine import default
-from testlib.sa.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased
-from testlib.sa.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property
-from testlib.testing import eq_, AssertsCompiledSQL
-import pickleable
-from orm import _base, _fixtures
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing, pickleable
+from sqlalchemy import MetaData, Integer, String, ForeignKey, func
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.engine import default
+from sqlalchemy.orm import mapper, relation, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased
+from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relation, dynamic_loader, comparable_property
+from sqlalchemy.test.testing import eq_, AssertsCompiledSQL
+from test.orm import _base, _fixtures
class MapperTest(_fixtures.FixtureTest):
properties={
'addresses':relation(Address, backref='email_address')
})
- self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers)
+ assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers)
@testing.resolve_artifact_names
def test_update_attr_keys(self):
@testing.resolve_artifact_names
def test_prop_accessor(self):
mapper(User, users)
- self.assertRaises(NotImplementedError,
+ assert_raises(NotImplementedError,
getattr, sa.orm.class_mapper(User), 'properties')
@testing.resolve_artifact_names
def test_bad_cascade(self):
mapper(Address, addresses)
- self.assertRaises(sa.exc.ArgumentError,
+ assert_raises(sa.exc.ArgumentError,
relation, Address, cascade="fake, all, delete-orphan")
@testing.resolve_artifact_names
})
hasattr(Address.user, 'property')
- self.assertRaisesMessage(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers)
+ assert_raises_message(sa.exc.InvalidRequestError, r"suppressed within a hasattr\(\)", compile_mappers)
@testing.resolve_artifact_names
def test_column_prefix(self):
@testing.resolve_artifact_names
def test_no_pks_1(self):
s = sa.select([users.c.name]).alias('foo')
- self.assertRaises(sa.exc.ArgumentError, mapper, User, s)
+ assert_raises(sa.exc.ArgumentError, mapper, User, s)
@testing.emits_warning(
'mapper Mapper|User|Select object creating an alias for '
@testing.resolve_artifact_names
def test_no_pks_2(self):
s = sa.select([users.c.name])
- self.assertRaises(sa.exc.ArgumentError, mapper, User, s)
+ assert_raises(sa.exc.ArgumentError, mapper, User, s)
@testing.resolve_artifact_names
def test_recompile_on_other_mapper(self):
create_session).extension)
sess = create_session()
- self.assertRaises(TypeError, Foo, 'one', _sa_session=sess)
+ assert_raises(TypeError, Foo, 'one', _sa_session=sess)
eq_(len(list(sess)), 0)
- self.assertRaises(TypeError, Foo, 'one')
+ assert_raises(TypeError, Foo, 'one')
Foo('one', 'two', _sa_session=sess)
eq_(len(list(sess)), 1)
raise Exception("this exception should be stated as a warning")
sess.expunge = bad_expunge
- self.assertRaises(sa.exc.SAWarning, Foo, _sa_session=sess)
+ assert_raises(sa.exc.SAWarning, Foo, _sa_session=sess)
@testing.resolve_artifact_names
def test_constructor_exc_2(self):
mapper(Foo, users)
mapper(Bar, addresses)
- self.assertRaises(TypeError, Foo, x=5)
- self.assertRaises(TypeError, Bar, x=5)
+ assert_raises(TypeError, Foo, x=5)
+ assert_raises(TypeError, Bar, x=5)
@testing.resolve_artifact_names
def test_props(self):
# excluding the discriminator column is currently not allowed
class Foo(Person):
pass
- self.assertRaises(sa.exc.InvalidRequestError, mapper, Foo, inherits=Person, polymorphic_identity='foo', exclude_properties=('type',) )
+ assert_raises(sa.exc.InvalidRequestError, mapper, Foo, inherits=Person, polymorphic_identity='foo', exclude_properties=('type',) )
@testing.resolve_artifact_names
def test_mapping_to_join(self):
properties=dict(
name=relation(mapper(Address, addresses))))
- self.assertRaises(sa.exc.ArgumentError, go)
+ assert_raises(sa.exc.ArgumentError, go)
@testing.resolve_artifact_names
def test_override_2(self):
mapper(User, users, properties={
'not_name':synonym('_name', map_column=True)})
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
("Can't compile synonym '_name': no column on table "
"'users' named 'not_name'"),
eq_(User.uc_name.method1(), "method1")
eq_(User.uc_name.method2('x'), "method2")
- self.assertRaisesMessage(
+ assert_raises_message(
AttributeError,
"Neither 'extendedproperty' object nor 'UCComparator' object has an attribute 'nonexistent'",
getattr, User.uc_name, 'nonexistent')
'name':sa.orm.column_property(users.c.name, comparator_factory=MyComparator)
})
- self.assertRaisesMessage(
+ assert_raises_message(
AttributeError,
"Neither 'InstrumentedAttribute' object nor 'MyComparator' object has an attribute 'nonexistent'",
getattr, User.name, "nonexistent")
'addresses':relation(Address)
})
- self.assertRaises(sa.orm.exc.UnmappedClassError, sa.orm.compile_mappers)
+ assert_raises(sa.orm.exc.UnmappedClassError, sa.orm.compile_mappers)
@testing.resolve_artifact_names
def test_oldstyle_mixin(self):
class DeepOptionsTest(_fixtures.FixtureTest):
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Keyword, keywords)
mapper(Item, items, properties=dict(
def test_deep_options_4(self):
sess = create_session()
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
r"Can't find entity Mapper\|Order\|orders in Query. "
r"Current list: \['Mapper\|User\|users'\]",
sess = create_session()
u1 = User(name='ed')
eq_(u1.name, 'ed modified')
- self.assertRaises(AssertionError, setattr, u1, "name", "fred")
+ assert_raises(AssertionError, setattr, u1, "name", "fred")
eq_(u1.name, 'ed modified')
sess.add(u1)
sess.flush()
mapper(Address, addresses)
sess = create_session()
u1 = User(name='edward')
- self.assertRaises(AssertionError, u1.addresses.append, Address(email_address='noemail'))
+ assert_raises(AssertionError, u1.addresses.append, Address(email_address='noemail'))
u1.addresses.append(Address(id=15, email_address='foo@bar.com'))
sess.add(u1)
sess.flush()
eq_(item.description, 'item 4')
class DeferredPopulationTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("thing", metadata,
Column("id", Integer, primary_key=True),
Column("name", String(20)))
Column("thing_id", Integer, ForeignKey("thing.id")),
Column("name", String(20)))
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
class Human(_base.BasicEntity): pass
class Thing(_base.BasicEntity): pass
mapper(Human, human, properties={"thing": relation(Thing)})
mapper(Thing, thing, properties={"name": deferred(thing.c.name)})
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
thing.insert().execute([
{"id": 1, "name": "Chair"},
])
class CompositeTypesTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('graphs', metadata,
Column('id', Integer, primary_key=True),
Column('version_id', Integer, primary_key=True, nullable=True),
class RequirementsTest(_base.MappedTest):
"""Tests the contract for user classes."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('ht1', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
class OldStyle:
pass
- self.assertRaises(sa.exc.ArgumentError, mapper, OldStyle, ht1)
+ assert_raises(sa.exc.ArgumentError, mapper, OldStyle, ht1)
- self.assertRaises(sa.exc.ArgumentError, mapper, 123)
+ assert_raises(sa.exc.ArgumentError, mapper, 123)
class NoWeakrefSupport(str):
pass
class MagicNamesTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('cartographers', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(50)),
Column('state', String(2)),
Column('data', sa.Text))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Cartographer(_base.BasicEntity):
pass
class T(object):
pass
- self.assertRaisesMessage(
+ assert_raises_message(
KeyError,
('%r: requested attribute name conflicts with '
'instrumentation attribute of the same name.' % reserved),
class M(object):
pass
- self.assertRaisesMessage(
+ assert_raises_message(
KeyError,
('requested attribute name conflicts with '
'instrumentation attribute of the same name'),
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa.util import OrderedSet
-from testlib.sa.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property
-from testlib.testing import eq_, ne_
-from orm import _base, _fixtures
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy.util import OrderedSet
+from sqlalchemy.orm import mapper, relation, create_session, PropComparator, synonym, comparable_property
+from sqlalchemy.test.testing import eq_, ne_
+from test.orm import _base, _fixtures
class MergeTest(_fixtures.FixtureTest):
sess = create_session()
u = User()
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
+ assert_raises_message(sa.exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
@testing.resolve_artifact_names
-if __name__ == "__main__":
- testenv.main()
Primary key changing capabilities and passive/non-passive cascading updates.
"""
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from testlib.testing import eq_
-from orm import _base
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
class NaturalPKTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
users = Table('users', metadata,
Column('username', String(50), primary_key=True),
Column('fullname', String(100)),
Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True),
test_needs_fk=True)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class User(_base.ComparableEntity):
pass
class Address(_base.ComparableEntity):
sess.expunge_all()
u1 = sess.query(User).get('ed')
- self.assertEquals(User(username='ed', fullname='jack'), u1)
+ eq_(User(username='ed', fullname='jack'), u1)
@testing.resolve_artifact_names
def test_load_after_expire(self):
# in this case so theres no way to look it up. criterion-
# based session invalidation could solve this [ticket:911]
sess.expire(u1)
- self.assertRaises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username')
+ assert_raises(sa.orm.exc.ObjectDeletedError, getattr, u1, 'username')
sess.expunge_all()
assert sess.query(User).get('jack') is None
assert u1.addresses[0].username == 'ed'
sess.expunge_all()
- self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
+ eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
u1 = sess.query(User).get('ed')
u1.username = 'jack'
sess.expunge_all()
assert sess.query(Address).get('jack1').username is None
u1 = sess.query(User).get('fred')
- self.assertEquals(User(username='fred', fullname='jack'), u1)
+ eq_(User(username='fred', fullname='jack'), u1)
@testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
assert a1.username == a2.username == 'ed'
sess.expunge_all()
- self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
+ eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
@testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
def test_onetoone_passive(self):
self.assert_sql_count(testing.db, go, 0)
sess.expunge_all()
- self.assertEquals([Address(username='ed')], sess.query(Address).all())
+ eq_([Address(username='ed')], sess.query(Address).all())
@testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
def test_bidirectional_passive(self):
u1.username = 'ed'
(ad1, ad2) = sess.query(Address).all()
- self.assertEquals([Address(username='jack'), Address(username='jack')], [ad1, ad2])
+ eq_([Address(username='jack'), Address(username='jack')], [ad1, ad2])
def go():
sess.flush()
if passive_updates:
self.assert_sql_count(testing.db, go, 1)
else:
self.assert_sql_count(testing.db, go, 3)
- self.assertEquals([Address(username='ed'), Address(username='ed')], [ad1, ad2])
+ eq_([Address(username='ed'), Address(username='ed')], [ad1, ad2])
sess.expunge_all()
- self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
+ eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
u1 = sess.query(User).get('ed')
assert len(u1.addresses) == 2 # load addresses
else:
self.assert_sql_count(testing.db, go, 3)
sess.expunge_all()
- self.assertEquals([Address(username='fred'), Address(username='fred')], sess.query(Address).all())
+ eq_([Address(username='fred'), Address(username='fred')], sess.query(Address).all())
@testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
r = sess.query(Item).all()
# ComparableEntity can't handle a comparison with the backrefs
# involved....
- self.assertEquals(Item(itemname='item1'), r[0])
- self.assertEquals(['jack'], [u.username for u in r[0].users])
- self.assertEquals(Item(itemname='item2'), r[1])
- self.assertEquals(['jack', 'fred'], [u.username for u in r[1].users])
+ eq_(Item(itemname='item1'), r[0])
+ eq_(['jack'], [u.username for u in r[0].users])
+ eq_(Item(itemname='item2'), r[1])
+ eq_(['jack', 'fred'], [u.username for u in r[1].users])
u2.username='ed'
def go():
sess.expunge_all()
r = sess.query(Item).all()
- self.assertEquals(Item(itemname='item1'), r[0])
- self.assertEquals(['jack'], [u.username for u in r[0].users])
- self.assertEquals(Item(itemname='item2'), r[1])
- self.assertEquals(['ed', 'jack'], sorted([u.username for u in r[1].users]))
+ eq_(Item(itemname='item1'), r[0])
+ eq_(['jack'], [u.username for u in r[0].users])
+ eq_(Item(itemname='item2'), r[1])
+ eq_(['ed', 'jack'], sorted([u.username for u in r[1].users]))
sess.expunge_all()
u2 = sess.query(User).get(u2.username)
u2.username='wendy'
sess.flush()
r = sess.query(Item).with_parent(u2).all()
- self.assertEquals(Item(itemname='item2'), r[0])
+ eq_(Item(itemname='item2'), r[0])
class SelfRefTest(_base.MappedTest):
__unsupported_on__ = 'mssql' # mssql doesn't allow ON UPDATE on self-referential keys
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('nodes', metadata,
Column('name', String(50), primary_key=True),
Column('parent', String(50),
ForeignKey('nodes.name', onupdate='cascade')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Node(_base.ComparableEntity):
pass
class NonPKCascadeTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('username', String(50), unique=True),
test_needs_fk=True
)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class User(_base.ComparableEntity):
pass
class Address(_base.ComparableEntity):
sess.flush()
a1 = u1.addresses[0]
- self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)])
+ eq_(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)])
assert sess.query(Address).get(a1.id) is u1.addresses[0]
u1.username = 'ed'
sess.flush()
assert u1.addresses[0].username == 'ed'
- self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)])
+ eq_(sa.select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)])
sess.expunge_all()
- self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
+ eq_([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
u1 = sess.query(User).get(u1.id)
u1.username = 'jack'
sess.flush()
sess.expunge_all()
a1 = sess.query(Address).get(a1.id)
- self.assertEquals(a1.username, None)
+ eq_(a1.username, None)
- self.assertEquals(sa.select([addresses.c.username]).execute().fetchall(), [(None,), (None,)])
+ eq_(sa.select([addresses.c.username]).execute().fetchall(), [(None,), (None,)])
u1 = sess.query(User).get(u1.id)
- self.assertEquals(User(username='fred', fullname='jack'), u1)
+ eq_(User(username='fred', fullname='jack'), u1)
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session
-from orm import _base
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session
+from test.orm import _base
class O2OTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('jack', metadata,
Column('id', Integer, primary_key=True),
Column('number', String(50)),
Column('description', String(100)),
Column('jack_id', Integer, ForeignKey("jack.id")))
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
class Jack(_base.BasicEntity):
pass
class Port(_base.BasicEntity):
session.delete(j)
session.flush()
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
import pickle
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session, attributes
-from orm import _base, _fixtures
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, attributes
+from test.orm import _base, _fixtures
User, EmailUser = None, None
sess.expunge_all()
- self.assertEquals(u1, sess.query(User).get(u2.id))
+ eq_(u1, sess.query(User).get(u2.id))
@testing.resolve_artifact_names
def test_class_deferred_cols(self):
u2 = pickle.loads(pickle.dumps(u1))
sess2 = create_session()
sess2.add(u2)
- self.assertEquals(u2.name, 'ed')
- self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+ eq_(u2.name, 'ed')
+ eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
u2 = pickle.loads(pickle.dumps(u1))
sess2 = create_session()
u2 = sess2.merge(u2, dont_load=True)
- self.assertEquals(u2.name, 'ed')
- self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+ eq_(u2.name, 'ed')
+ eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
@testing.resolve_artifact_names
def test_instance_deferred_cols(self):
u2 = pickle.loads(pickle.dumps(u1))
sess2 = create_session()
sess2.add(u2)
- self.assertEquals(u2.name, 'ed')
+ eq_(u2.name, 'ed')
assert 'addresses' not in u2.__dict__
ad = u2.addresses[0]
assert 'email_address' not in ad.__dict__
- self.assertEquals(ad.email_address, 'ed@bar.com')
- self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+ eq_(ad.email_address, 'ed@bar.com')
+ eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
u2 = pickle.loads(pickle.dumps(u1))
sess2 = create_session()
u2 = sess2.merge(u2, dont_load=True)
- self.assertEquals(u2.name, 'ed')
+ eq_(u2.name, 'ed')
assert 'addresses' not in u2.__dict__
ad = u2.addresses[0]
assert 'email_address' in ad.__dict__ # mapper options dont transmit over merge() right now
- self.assertEquals(ad.email_address, 'ed@bar.com')
- self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+ eq_(ad.email_address, 'ed@bar.com')
+ eq_(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
@testing.resolve_artifact_names
def test_options_with_descriptors(self):
sa.orm.eagerload(["addresses", User.addresses]),
]:
opt2 = pickle.loads(pickle.dumps(opt))
- self.assertEquals(opt.key, opt2.key)
+ eq_(opt.key, opt2.key)
u1 = sess.query(User).options(opt).first()
class PolymorphicDeferredTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(30)),
Column('id', Integer, ForeignKey('users.id'), primary_key=True),
Column('email_address', String(30)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
global User, EmailUser
class User(_base.BasicEntity):
pass
class EmailUser(User):
pass
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
global User, EmailUser
User, EmailUser = None, None
- _base.MappedTest.tearDownAll(self)
+ super(PolymorphicDeferredTest, cls).teardown_class()
@testing.resolve_artifact_names
def test_polymorphic_deferred(self):
sess2 = create_session()
sess2.add(eu2)
assert 'email_address' not in eu2.__dict__
- self.assertEquals(eu2.email_address, 'foo@bar.com')
+ eq_(eu2.email_address, 'foo@bar.com')
-class CustomSetupTeardowntest(_fixtures.FixtureTest):
+class CustomSetupTeardownTest(_fixtures.FixtureTest):
@testing.resolve_artifact_names
def test_rebuild_state(self):
"""not much of a 'test', but illustrate how to
attributes.manager_of_class(User).setup_instance(u2)
assert attributes.instance_state(u2)
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import operator
from sqlalchemy import *
from sqlalchemy import exc as sa_exc, util
from sqlalchemy.orm import *
from sqlalchemy.orm import attributes
-from testlib.testing import eq_
+from sqlalchemy.test.testing import eq_
-from testlib import sa, testing, AssertsCompiledSQL, Column, engines
+import sqlalchemy as sa
+from sqlalchemy.test import testing, AssertsCompiledSQL, Column, engines
-from orm import _fixtures
-from orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \
+from test.orm import _fixtures
+from test.orm._fixtures import keywords, addresses, Base, Keyword, FixtureTest, \
Dingaling, item_keywords, dingalings, User, items,\
orders, Address, users, nodes, \
order_items, Item, Order, Node
-from orm import _base
+from test.orm import _base
from sqlalchemy.orm.util import join, outerjoin, with_parent
run_deletes = None
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
mapper(User, users, properties={
'addresses':relation(Address, backref='user', order_by=addresses.c.id),
'orders':relation(Order, backref='user', order_by=orders.c.id), # o2m, m2o
s = create_session()
q = s.query(User).join('addresses').filter(Address.user_id==8)
- self.assertRaises(sa_exc.InvalidRequestError, q.get, 7)
- self.assertRaises(sa_exc.InvalidRequestError, s.query(User).filter(User.id==7).get, 19)
+ assert_raises(sa_exc.InvalidRequestError, q.get, 7)
+ assert_raises(sa_exc.InvalidRequestError, s.query(User).filter(User.id==7).get, 19)
# order_by()/get() doesn't raise
s.query(User).order_by(User.id).get(8)
class LocalFoo(Base):
pass
mapper(LocalFoo, table)
- self.assertEquals(create_session().query(LocalFoo).get(ustring),
+ eq_(create_session().query(LocalFoo).get(ustring),
LocalFoo(id=ustring, data=ustring))
finally:
metadata.drop_all()
def test_query_str(self):
s = create_session()
q = s.query(User).filter(User.id==1)
- self.assertEquals(
+ eq_(
str(q).replace('\n',''),
'SELECT users.id AS users_id, users.name AS users_name FROM users WHERE users.id = ?'
)
s.query(User).offset(2),
s.query(User).limit(2).offset(2)
):
- self.assertRaises(sa_exc.InvalidRequestError, q.join, "addresses")
+ assert_raises(sa_exc.InvalidRequestError, q.join, "addresses")
- self.assertRaises(sa_exc.InvalidRequestError, q.filter, User.name=='ed')
+ assert_raises(sa_exc.InvalidRequestError, q.filter, User.name=='ed')
- self.assertRaises(sa_exc.InvalidRequestError, q.filter_by, name='ed')
+ assert_raises(sa_exc.InvalidRequestError, q.filter_by, name='ed')
- self.assertRaises(sa_exc.InvalidRequestError, q.order_by, 'foo')
+ assert_raises(sa_exc.InvalidRequestError, q.order_by, 'foo')
- self.assertRaises(sa_exc.InvalidRequestError, q.group_by, 'foo')
+ assert_raises(sa_exc.InvalidRequestError, q.group_by, 'foo')
- self.assertRaises(sa_exc.InvalidRequestError, q.having, 'foo')
+ assert_raises(sa_exc.InvalidRequestError, q.having, 'foo')
def test_no_from(self):
s = create_session()
q = s.query(User).select_from(users)
- self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+ assert_raises(sa_exc.InvalidRequestError, q.select_from, users)
q = s.query(User).join('addresses')
- self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+ assert_raises(sa_exc.InvalidRequestError, q.select_from, users)
q = s.query(User).order_by(User.id)
- self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+ assert_raises(sa_exc.InvalidRequestError, q.select_from, users)
# this is fine, however
q.from_self()
def test_invalid_select_from(self):
s = create_session()
q = s.query(User)
- self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id==5)
- self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id)
+ assert_raises(sa_exc.ArgumentError, q.select_from, User.id==5)
+ assert_raises(sa_exc.ArgumentError, q.select_from, User.id)
def test_invalid_from_statement(self):
s = create_session()
q = s.query(User)
- self.assertRaises(sa_exc.ArgumentError, q.from_statement, User.id==5)
- self.assertRaises(sa_exc.ArgumentError, q.from_statement, users.join(addresses))
+ assert_raises(sa_exc.ArgumentError, q.from_statement, User.id==5)
+ assert_raises(sa_exc.ArgumentError, q.from_statement, users.join(addresses))
def test_invalid_column(self):
s = create_session()
q = s.query(User)
- self.assertRaises(sa_exc.InvalidRequestError, q.add_column, object())
+ assert_raises(sa_exc.InvalidRequestError, q.add_column, object())
def test_mapper_zero(self):
s = create_session()
q = s.query(User, Address)
- self.assertRaises(sa_exc.InvalidRequestError, q.get, 5)
+ assert_raises(sa_exc.InvalidRequestError, q.get, 5)
def test_from_statement(self):
s = create_session()
q = s.query(User).filter(User.id==5)
- self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
+ assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x")
q = s.query(User).filter_by(id=5)
- self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
+ assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x")
q = s.query(User).limit(5)
- self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
+ assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x")
q = s.query(User).group_by(User.name)
- self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
+ assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x")
q = s.query(User).order_by(User.name)
- self.assertRaises(sa_exc.InvalidRequestError, q.from_statement, "x")
+ assert_raises(sa_exc.InvalidRequestError, q.from_statement, "x")
class OperatorTest(QueryTest, AssertsCompiledSQL):
"""test sql.Comparator implementation for MapperProperties"""
"users.id IN (:id_1, :id_2)")
def test_in_on_relation_not_supported(self):
- self.assertRaises(NotImplementedError, Address.user.in_, [User(id=5)])
+ assert_raises(NotImplementedError, Address.user.in_, [User(id=5)])
def test_between(self):
self._test(User.id.between('a', 'b'),
assert [Address(id=5)] == sess.query(Address).filter(Address.dingaling==dingaling).all()
# m2m
- self.assertEquals(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)])
- self.assertEquals(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)])
+ eq_(sess.query(Item).filter(Item.keywords==None).all(), [Item(id=4), Item(id=5)])
+ eq_(sess.query(Item).filter(Item.keywords!=None).all(), [Item(id=1),Item(id=2), Item(id=3)])
def test_filter_by(self):
sess = create_session()
sess = create_session()
# o2o
- self.assertEquals([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all())
- self.assertEquals([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all())
+ eq_([Address(id=1), Address(id=3), Address(id=4)], sess.query(Address).filter(Address.dingaling==None).all())
+ eq_([Address(id=2), Address(id=5)], sess.query(Address).filter(Address.dingaling != None).all())
# m2o
- self.assertEquals([Order(id=5)], sess.query(Order).filter(Order.address==None).all())
- self.assertEquals([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).order_by(Order.id).filter(Order.address!=None).all())
+ eq_([Order(id=5)], sess.query(Order).filter(Order.address==None).all())
+ eq_([Order(id=1), Order(id=2), Order(id=3), Order(id=4)], sess.query(Order).order_by(Order.id).filter(Order.address!=None).all())
# o2m
- self.assertEquals([User(id=10)], sess.query(User).filter(User.addresses==None).all())
- self.assertEquals([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).order_by(User.id).all())
+ eq_([User(id=10)], sess.query(User).filter(User.addresses==None).all())
+ eq_([User(id=7),User(id=8),User(id=9)], sess.query(User).filter(User.addresses!=None).order_by(User.id).all())
class FromSelfTest(QueryTest, AssertsCompiledSQL):
def test_multiple_entities(self):
sess = create_session()
- self.assertEquals(
+ eq_(
sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().all(),
[
(User(id=8), Address(id=2)),
]
)
- self.assertEquals(
+ eq_(
sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().options(eagerload('addresses')).first(),
# order_by(User.id, Address.id).first(),
ed = s.query(User).filter(User.name=='ed')
jack = s.query(User).filter(User.name=='jack')
- self.assertEquals(fred.union(ed).order_by(User.name).all(),
+ eq_(fred.union(ed).order_by(User.name).all(),
[User(name='ed'), User(name='fred')]
)
- self.assertEquals(fred.union(ed, jack).order_by(User.name).all(),
+ eq_(fred.union(ed, jack).order_by(User.name).all(),
[User(name='ed'), User(name='fred'), User(name='jack')]
)
fred = s.query(User).filter(User.name=='fred')
ed = s.query(User).filter(User.name=='ed')
jack = s.query(User).filter(User.name=='jack')
- self.assertEquals(fred.intersect(ed, jack).all(),
+ eq_(fred.intersect(ed, jack).all(),
[]
)
- self.assertEquals(fred.union(ed).intersect(ed.union(jack)).all(),
+ eq_(fred.union(ed).intersect(ed.union(jack)).all(),
[User(name='ed')]
)
jack = s.query(User).filter(User.name=='jack')
def go():
- self.assertEquals(
+ eq_(
fred.union(ed).order_by(User.name).options(eagerload(User.addresses)).all(),
[
User(name='ed', addresses=[Address(), Address(), Address()]),
def test_sum(self):
sess = create_session()
orders = sess.query(Order).filter(Order.id.in_([2, 3, 4]))
- self.assertEquals(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,))
- self.assertEquals(orders.value(func.sum(Order.user_id * Order.address_id)), 79)
+ eq_(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,))
+ eq_(orders.value(func.sum(Order.user_id * Order.address_id)), 79)
def test_apply(self):
sess = create_session()
q = iter(sess.query(User).yield_per(1).from_statement("select * from users"))
ret = []
- self.assertEquals(len(sess.identity_map), 0)
+ eq_(len(sess.identity_map), 0)
ret.append(q.next())
ret.append(q.next())
- self.assertEquals(len(sess.identity_map), 2)
+ eq_(len(sess.identity_map), 2)
ret.append(q.next())
ret.append(q.next())
- self.assertEquals(len(sess.identity_map), 4)
+ eq_(len(sess.identity_map), 4)
try:
q.next()
assert False
def test_as_column(self):
s = create_session()
- self.assertRaises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name"))
+ assert_raises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name"))
eq_(s.query(User.id, "name").order_by(User.id).all(), [(7, u'jack'), (8, u'ed'), (9, u'fred'), (10, u'chuck')])
sess = create_session()
for oalias,ialias in [(True, True), (False, False), (True, False), (False, True)]:
- self.assertEquals(
+ eq_(
sess.query(User).join('orders', aliased=oalias).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description == 'item 4').all(),
[User(name='jack')]
)
# use middle criterion
- self.assertEquals(
+ eq_(
sess.query(User).join('orders', aliased=oalias).filter(Order.user_id==9).join('items', from_joinpoint=True, aliased=ialias).filter(Item.description=='item 4').all(),
[]
)
orderalias = aliased(Order)
itemalias = aliased(Item)
- self.assertEquals(
+ eq_(
sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(itemalias.description == 'item 4').all(),
[User(name='jack')]
)
- self.assertEquals(
+ eq_(
sess.query(User).join([('orders', orderalias), ('items', itemalias)]).filter(orderalias.user_id==9).filter(itemalias.description=='item 4').all(),
[]
)
sess = create_session()
- self.assertEquals(
+ eq_(
sess.query(User).join(Address.user).filter(Address.email_address=='ed@wood.com').all(),
[User(id=8,name=u'ed')]
)
# its actually not so controversial if you view it in terms
# of multiple entities.
- self.assertEquals(
+ eq_(
sess.query(User, Address).join(Address.user).filter(Address.email_address=='ed@wood.com').all(),
[(User(id=8,name=u'ed'), Address(email_address='ed@wood.com'))]
)
# this was the controversial part. now, raise an error if the feature is abused.
# before the error raise was added, this would silently work.....
- self.assertRaises(
+ assert_raises(
sa_exc.InvalidRequestError,
sess.query(User).join, (Address, Address.user),
)
# but this one would silently fail
adalias = aliased(Address)
- self.assertRaises(
+ assert_raises(
sa_exc.InvalidRequestError,
sess.query(User).join, (adalias, Address.user),
)
oalias2 = aliased(Order)
result = sess.query(ualias).join((oalias1, ualias.orders), (oalias2, ualias.orders)).\
filter(or_(oalias1.user_id==9, oalias2.user_id==7)).all()
- self.assertEquals(result, [User(id=7,name=u'jack'), User(id=9,name=u'fred')])
+ eq_(result, [User(id=7,name=u'jack'), User(id=9,name=u'fred')])
def test_orderby_arg_bug(self):
sess = create_session()
def test_no_onclause(self):
sess = create_session()
- self.assertEquals(
+ eq_(
sess.query(User).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(),
[User(name='jack')]
)
- self.assertEquals(
+ eq_(
sess.query(User.name).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(),
[('jack',)]
)
- self.assertEquals(
+ eq_(
sess.query(User).join(Order, (Item, Order.items)).filter(Item.description == 'item 4').all(),
[User(name='jack')]
)
def test_clause_onclause(self):
sess = create_session()
- self.assertEquals(
+ eq_(
sess.query(User).join(
(Order, User.id==Order.user_id),
(order_items, Order.id==order_items.c.order_id),
[User(name='jack')]
)
- self.assertEquals(
+ eq_(
sess.query(User.name).join(
(Order, User.id==Order.user_id),
(order_items, Order.id==order_items.c.order_id),
)
ualias = aliased(User)
- self.assertEquals(
+ eq_(
sess.query(ualias.name).join(
(Order, ualias.id==Order.user_id),
(order_items, Order.id==order_items.c.order_id),
# explicit onclause with from_self(), means
# the onclause must be aliased against the query's custom
# FROM object
- self.assertEquals(
+ eq_(
sess.query(User).order_by(User.id).offset(2).from_self().join(
(Order, User.id==Order.user_id)
).all(),
)
# same with an explicit select_from()
- self.assertEquals(
+ eq_(
sess.query(User).select_from(select([users]).order_by(User.id).offset(2).alias()).join(
(Order, User.id==Order.user_id)
).all(),
AdAlias = aliased(Address)
q = q.add_entity(AdAlias).select_from(outerjoin(User, AdAlias))
l = q.order_by(User.id, AdAlias.id).all()
- self.assertEquals(l, expected)
+ eq_(l, expected)
sess.expunge_all()
q = sess.query(User).add_entity(AdAlias)
l = q.select_from(outerjoin(User, AdAlias)).filter(AdAlias.email_address=='ed@bettyboop.com').all()
- self.assertEquals(l, [(user8, address3)])
+ eq_(l, [(user8, address3)])
l = q.select_from(outerjoin(User, AdAlias, 'addresses')).filter(AdAlias.email_address=='ed@bettyboop.com').all()
- self.assertEquals(l, [(user8, address3)])
+ eq_(l, [(user8, address3)])
l = q.select_from(outerjoin(User, AdAlias, User.id==AdAlias.user_id)).filter(AdAlias.email_address=='ed@bettyboop.com').all()
- self.assertEquals(l, [(user8, address3)])
+ eq_(l, [(user8, address3)])
# this is the first test where we are joining "backwards" - from AdAlias to User even though
# the query is against User
q = sess.query(User, AdAlias)
l = q.join(AdAlias.user).filter(User.name=='ed')
- self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
+ eq_(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
q = sess.query(User, AdAlias).select_from(join(AdAlias, User, AdAlias.user)).filter(User.name=='ed')
- self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
+ eq_(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
def test_implicit_joins_from_aliases(self):
sess = create_session()
OrderAlias = aliased(Order)
- self.assertEquals(
+ eq_(
sess.query(OrderAlias).join('items').filter_by(description='item 3').\
order_by(OrderAlias.id).all(),
[
]
)
- self.assertEquals(
+ eq_(
sess.query(User, OrderAlias, Item.description).join(('orders', OrderAlias), 'items').filter_by(description='item 3').\
order_by(User.id, OrderAlias.id).all(),
[
q = sess.query(Order)
q = q.add_entity(Item).select_from(join(Order, Item, 'items')).order_by(Order.id, Item.id)
l = q.all()
- self.assertEquals(l, expected)
+ eq_(l, expected)
IAlias = aliased(Item)
q = sess.query(Order, IAlias).select_from(join(Order, IAlias, 'items')).filter(IAlias.description=='item 3')
l = q.all()
- self.assertEquals(l,
+ eq_(l,
[
(order1, item3),
(order2, item3),
sess = create_session()
ualias = aliased(User)
- self.assertEquals(
+ eq_(
sess.query(User, ualias).filter(User.id > ualias.id).order_by(desc(ualias.id), User.name).all(),
[
(User(id=10,name=u'chuck'), User(id=9,name=u'fred')),
sess = create_session()
- self.assertEquals(
+ eq_(
sess.query(User.name).join((addresses, User.id==addresses.c.user_id)).order_by(User.id).all(),
[(u'jack',), (u'ed',), (u'ed',), (u'ed',), (u'fred',)]
)
class MultiplePathTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global t1, t2, t1t2_1, t1t2_2
t1 = Table('t1', metadata,
Column('id', Integer, primary_key=True),
mapper(T2, t2)
q = create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint()
- self.assertRaisesMessage(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.",
+ assert_raises_message(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.",
q.join, 't2s_2'
)
class SynonymTest(QueryTest):
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
mapper(User, users, properties={
'name_syn':synonym('name'),
'addresses':relation(Address),
adalias = addresses.alias()
q = sess.query(User).select_from(users.outerjoin(adalias)).options(contains_eager(User.addresses, alias=adalias))
def go():
- self.assertEquals(self.static.user_address_result, q.order_by(User.id).all())
+ eq_(self.static.user_address_result, q.order_by(User.id).all())
self.assert_sql_count(testing.db, go, 1)
sess.expunge_all()
sel = users.select(User.id.in_([7, 8])).alias()
q = sess.query(User)
q2 = q.select_from(sel).values(User.name)
- self.assertEquals(list(q2), [(u'jack',), (u'ed',)])
+ eq_(list(q2), [(u'jack',), (u'ed',)])
q = sess.query(User)
q2 = q.order_by(User.id).values(User.name, User.name + " " + cast(User.id, String))
- self.assertEquals(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')])
+ eq_(list(q2), [(u'jack', u'jack 7'), (u'ed', u'ed 8'), (u'fred', u'fred 9'), (u'chuck', u'chuck 10')])
q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(User.id, Address.id).values(User.name, Address.email_address)
- self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
+ eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address)).slice(1, 3).values(User.name, Address.email_address)
- self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')])
+ eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')])
adalias = aliased(Address)
q2 = q.join(('addresses', adalias)).filter(User.name.like('%e%')).values(User.name, adalias.email_address)
- self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
+ eq_(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
q2 = q.values(func.count(User.name))
assert q2.next() == (4,)
q2 = q.select_from(sel).filter(User.id==8).values(User.name, sel.c.name, User.name)
- self.assertEquals(list(q2), [(u'ed', u'ed', u'ed')])
+ eq_(list(q2), [(u'ed', u'ed', u'ed')])
# using User.xxx is alised against "sel", so this query returns nothing
q2 = q.select_from(sel).filter(User.id==8).filter(User.id>sel.c.id).values(User.name, sel.c.name, User.name)
- self.assertEquals(list(q2), [])
+ eq_(list(q2), [])
# whereas this uses users.c.xxx, is not aliased and creates a new join
q2 = q.select_from(sel).filter(users.c.id==8).filter(users.c.id>sel.c.id).values(users.c.name, sel.c.name, User.name)
- self.assertEquals(list(q2), [(u'ed', u'jack', u'jack')])
+ eq_(list(q2), [(u'ed', u'jack', u'jack')])
@testing.fails_on('mssql', 'FIXME: unknown')
def test_values_specific_order_by(self):
q = sess.query(User)
u2 = aliased(User)
q2 = q.select_from(sel).filter(u2.id>1).order_by([User.id, sel.c.id, u2.id]).values(User.name, sel.c.name, u2.name)
- self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')])
+ eq_(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')])
@testing.fails_on('mssql', 'FIXME: unknown')
def test_values_with_boolean_selects(self):
q = sess.query(User)
q2 = q.group_by([User.name.like('%j%')]).order_by(desc(User.name.like('%j%'))).values(User.name.like('%j%'), func.count(User.name.like('%j%')))
- self.assertEquals(list(q2), [(True, 1), (False, 3)])
+ eq_(list(q2), [(True, 1), (False, 3)])
def test_correlated_subquery(self):
"""test that a subquery constructed from ORM attributes doesn't leak out
label('count')
# we don't want Address to be outside of the subquery here
- self.assertEquals(
+ eq_(
list(sess.query(User, subq)[0:3]),
[(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)]
)
label('count')
# we don't want Address to be outside of the subquery here
- self.assertEquals(
+ eq_(
list(sess.query(User, subq)[0:3]),
[(User(id=7,name=u'jack'), 1), (User(id=8,name=u'ed'), 3), (User(id=9,name=u'fred'), 1)]
)
def test_tuple_labeling(self):
sess = create_session()
for row in sess.query(User, Address).join(User.addresses).all():
- self.assertEquals(set(row.keys()), set(['User', 'Address']))
- self.assertEquals(row.User, row[0])
- self.assertEquals(row.Address, row[1])
+ eq_(set(row.keys()), set(['User', 'Address']))
+ eq_(row.User, row[0])
+ eq_(row.Address, row[1])
for row in sess.query(User.name, User.id.label('foobar')):
- self.assertEquals(set(row.keys()), set(['name', 'foobar']))
- self.assertEquals(row.name, row[0])
- self.assertEquals(row.foobar, row[1])
+ eq_(set(row.keys()), set(['name', 'foobar']))
+ eq_(row.name, row[0])
+ eq_(row.foobar, row[1])
for row in sess.query(User).values(User.name, User.id.label('foobar')):
- self.assertEquals(set(row.keys()), set(['name', 'foobar']))
- self.assertEquals(row.name, row[0])
- self.assertEquals(row.foobar, row[1])
+ eq_(set(row.keys()), set(['name', 'foobar']))
+ eq_(row.name, row[0])
+ eq_(row.foobar, row[1])
oalias = aliased(Order)
for row in sess.query(User, oalias).join(User.orders).all():
- self.assertEquals(set(row.keys()), set(['User']))
- self.assertEquals(row.User, row[0])
+ eq_(set(row.keys()), set(['User']))
+ eq_(row.User, row[0])
oalias = aliased(Order, name='orders')
for row in sess.query(User, oalias).join(User.orders).all():
- self.assertEquals(set(row.keys()), set(['User', 'orders']))
- self.assertEquals(row.User, row[0])
- self.assertEquals(row.orders, row[1])
+ eq_(set(row.keys()), set(['User', 'orders']))
+ eq_(row.User, row[0])
+ eq_(row.orders, row[1])
def test_column_queries(self):
sess = create_session()
- self.assertEquals(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)])
+ eq_(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)])
sel = users.select(User.id.in_([7, 8])).alias()
q = sess.query(User.name)
q2 = q.select_from(sel).all()
- self.assertEquals(list(q2), [(u'jack',), (u'ed',)])
+ eq_(list(q2), [(u'jack',), (u'ed',)])
- self.assertEquals(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [
+ eq_(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [
(u'jack', u'jack@bean.com'), (u'ed', u'ed@wood.com'),
(u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'),
(u'fred', u'fred@fred.com')
])
- self.assertEquals(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(),
+ eq_(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(),
[(u'jack', 1), (u'ed', 3), (u'fred', 1), (u'chuck', 0)]
)
- self.assertEquals(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(),
+ eq_(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(),
[(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)]
)
- self.assertEquals(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(),
+ eq_(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(),
[(1, User(name='jack',id=7)), (3, User(name='ed',id=8)), (1, User(name='fred',id=9)), (0, User(name='chuck',id=10))]
)
adalias = aliased(Address)
- self.assertEquals(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(),
+ eq_(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(),
[(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)]
)
- self.assertEquals(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(),
+ eq_(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(),
[(1, User(name=u'jack',id=7)), (3, User(name=u'ed',id=8)), (1, User(name=u'fred',id=9)), (0, User(name=u'chuck',id=10))]
)
# select from aliasing + explicit aliasing
- self.assertEquals(
+ eq_(
sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).order_by(User.id, adalias.id).all(),
[
(User(name=u'jack',id=7), u'jack@bean.com'),
)
# anon + select from aliasing
- self.assertEquals(
+ eq_(
sess.query(User).join(User.addresses, aliased=True).filter(Address.email_address.like('%ed%')).from_self().all(),
[
User(name=u'ed',id=8),
sess.query(User, adalias.email_address).outerjoin((User.addresses, adalias)).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10),
sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10),
]:
- self.assertEquals(
+ eq_(
q.all(),
[(User(addresses=[Address(user_id=7,email_address=u'jack@bean.com',id=1)],name=u'jack',id=7), u'jack@bean.com'),
def go():
results = sess.query(User).limit(1).options(eagerload('addresses')).add_column(User.name).all()
- self.assertEquals(results, [(User(name='jack'), 'jack')])
+ eq_(results, [(User(name='jack'), 'jack')])
self.assert_sql_count(testing.db, go, 1)
def test_self_referential(self):
]:
- self.assertEquals(
+ eq_(
q.all(),
[
(Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)),
sess = create_session()
selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id])
- self.assertEquals(list(sess.query(User, Address).instances(selectquery.execute())), expected)
+ eq_(list(sess.query(User, Address).instances(selectquery.execute())), expected)
sess.expunge_all()
for address_entity in (Address, aliased(Address)):
q = sess.query(User).add_entity(address_entity).outerjoin(('addresses', address_entity)).order_by(User.id, address_entity.id)
- self.assertEquals(q.all(), expected)
+ eq_(q.all(), expected)
sess.expunge_all()
q = sess.query(User).add_entity(address_entity)
q = q.join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com')
- self.assertEquals(q.all(), [(user8, address3)])
+ eq_(q.all(), [(user8, address3)])
sess.expunge_all()
q = sess.query(User, address_entity).join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com')
- self.assertEquals(q.all(), [(user8, address3)])
+ eq_(q.all(), [(user8, address3)])
sess.expunge_all()
q = sess.query(User, address_entity).join(('addresses', address_entity)).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
- self.assertEquals(list(util.OrderedSet(q.all())), [(user8, address3)])
+ eq_(list(util.OrderedSet(q.all())), [(user8, address3)])
sess.expunge_all()
def test_aliased_multi_mappers(self):
assert sess.query(User).add_column(add_col).all() == expected
sess.expunge_all()
- self.assertRaises(sa_exc.InvalidRequestError, sess.query(User).add_column, object())
+ assert_raises(sa_exc.InvalidRequestError, sess.query(User).add_column, object())
def test_add_multi_columns(self):
"""test that add_column accepts a FROM clause."""
q = sess.query(User)
q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(Address.id).label('count'))
- self.assertEquals(q.all(), expected)
+ eq_(q.all(), expected)
sess.expunge_all()
adalias = aliased(Address)
q = sess.query(User)
q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin(('addresses', adalias)).add_column(func.count(adalias.id).label('count'))
- self.assertEquals(q.all(), expected)
+ eq_(q.all(), expected)
sess.expunge_all()
s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id)
run_inserts = 'once'
run_deletes = None
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Address, addresses)
mapper(User, users, properties=dict(
def test_one(self):
sess = create_session()
- self.assertRaises(sa.orm.exc.NoResultFound,
+ assert_raises(sa.orm.exc.NoResultFound,
sess.query(User).filter(User.id == 99).one)
eq_(sess.query(User).filter(User.id == 7).one().id, 7)
- self.assertRaises(sa.orm.exc.MultipleResultsFound,
+ assert_raises(sa.orm.exc.MultipleResultsFound,
sess.query(User).one)
- self.assertRaises(
+ assert_raises(
sa.orm.exc.NoResultFound,
sess.query(User.id, User.name).filter(User.id == 99).one)
eq_(sess.query(User.id, User.name).filter(User.id == 7).one(),
(7, 'jack'))
- self.assertRaises(sa.orm.exc.MultipleResultsFound,
+ assert_raises(sa.orm.exc.MultipleResultsFound,
sess.query(User.id, User.name).one)
- self.assertRaises(sa.orm.exc.NoResultFound,
+ assert_raises(sa.orm.exc.NoResultFound,
(sess.query(User, Address).
join(User.addresses).
filter(Address.id == 99)).one)
filter(Address.id == 4)).one(),
(User(id=8), Address(id=4)))
- self.assertRaises(sa.orm.exc.MultipleResultsFound,
+ assert_raises(sa.orm.exc.MultipleResultsFound,
sess.query(User, Address).join(User.addresses).one)
@testing.future
eq_(sess.query(User.id, User.name).filter_by(id=7).value(User.id), 7)
eq_(sess.query(User).filter_by(id=0).value(User.id), None)
- sess.bind = sa.testing.db
+ sess.bind = testing.db
eq_(sess.query().value(sa.literal_column('1').label('x')), 1)
sel = users.select(users.c.id.in_([7, 8])).alias()
sess = create_session()
- self.assertEquals(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)])
+ eq_(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)])
- self.assertEquals(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)])
+ eq_(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)])
- self.assertEquals(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [
+ eq_(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [
User(name='jack',id=7), User(name='ed',id=8)
])
- self.assertEquals(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [
+ eq_(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [
User(name='ed',id=8), User(name='jack',id=7)
])
- self.assertEquals(sess.query(User).select_from(sel).options(eagerload('addresses')).first(),
+ eq_(sess.query(User).select_from(sel).options(eagerload('addresses')).first(),
User(name='jack', addresses=[Address(id=1)])
)
sel = users.select(users.c.id.in_([7, 8]))
sess = create_session()
- self.assertEquals(sess.query(User).select_from(sel).all(),
+ eq_(sess.query(User).select_from(sel).all(),
[
User(name='jack',id=7), User(name='ed',id=8)
]
sel = users.select(users.c.id.in_([7, 8]))
sess = create_session()
- self.assertEquals(sess.query(User).select_from(sel).all(),
+ eq_(sess.query(User).select_from(sel).all(),
[
User(name='jack',id=7), User(name='ed',id=8)
]
sel = users.select(users.c.id.in_([7, 8]))
sess = create_session()
- self.assertEquals(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(),
+ eq_(sess.query(User).select_from(sel).join('addresses').add_entity(Address).order_by(User.id).order_by(Address.id).all(),
[
(User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)),
(User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)),
)
adalias = aliased(Address)
- self.assertEquals(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(),
+ eq_(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(),
[
(User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)),
(User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)),
# TODO: remove
sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all()
- self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+ eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
User(name=u'jack',id=7)
])
- self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+ eq_(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
User(name=u'jack',id=7)
])
def go():
- self.assertEquals(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+ eq_(sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
User(name=u'jack',orders=[
Order(description=u'order 1',items=[
Item(description=u'item 1',keywords=[Keyword(name=u'red'), Keyword(name=u'big'), Keyword(name=u'round')]),
sess.expunge_all()
sel2 = orders.select(orders.c.id.in_([1,2,3]))
- self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').order_by(Order.id).all(), [
+ eq_(sess.query(Order).select_from(sel2).join(['items', 'keywords']).filter(Keyword.name == 'red').order_by(Order.id).all(), [
Order(description=u'order 1',id=1),
Order(description=u'order 2',id=2),
])
- self.assertEquals(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [
+ eq_(sess.query(Order).select_from(sel2).join(['items', 'keywords'], aliased=True).filter(Keyword.name == 'red').order_by(Order.id).all(), [
Order(description=u'order 1',id=1),
Order(description=u'order 2',id=2),
])
sess = create_session()
def go():
- self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id).all(),
+ eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id).all(),
[
User(id=7, addresses=[Address(id=1)]),
User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])
sess.expunge_all()
def go():
- self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).order_by(User.id).all(),
+ eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).order_by(User.id).all(),
[User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])]
)
self.assert_sql_count(testing.db, go, 1)
sess.expunge_all()
def go():
- self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]))
+ eq_(sess.query(User).options(eagerload('addresses')).select_from(sel).order_by(User.id)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]))
self.assert_sql_count(testing.db, go, 1)
class CustomJoinTest(QueryTest):
run_inserts = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global nodes
nodes = Table('nodes', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('parent_id', Integer, ForeignKey('nodes.id')),
Column('data', String(30)))
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
global Node
class Node(Base):
filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).first()
assert node.data == 'n122'
- self.assertEquals(
+ eq_(
list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\
filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)),
[('n122', 'n12', 'n1')])
n1 = aliased(Node)
# using 'n1.parent' implicitly joins to unaliased Node
- self.assertEquals(
+ eq_(
sess.query(n1).join(n1.parent).filter(Node.data=='n1').all(),
[Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)]
)
# explicit (new syntax)
- self.assertEquals(
+ eq_(
sess.query(n1).join((Node, n1.parent)).filter(Node.data=='n1').all(),
[Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)]
)
parent = aliased(Node)
grandparent = aliased(Node)
- self.assertEquals(
+ eq_(
sess.query(Node, parent, grandparent).\
join((Node.parent, parent), (parent.parent, grandparent)).\
filter(Node.data=='n122').filter(parent.data=='n12').\
(Node(data='n122'), Node(data='n12'), Node(data='n1'))
)
- self.assertEquals(
+ eq_(
sess.query(Node, parent, grandparent).\
join((Node.parent, parent), (parent.parent, grandparent)).\
filter(Node.data=='n122').filter(parent.data=='n12').\
)
# same, change order around
- self.assertEquals(
+ eq_(
sess.query(parent, grandparent, Node).\
join((Node.parent, parent), (parent.parent, grandparent)).\
filter(Node.data=='n122').filter(parent.data=='n12').\
(Node(data='n12'), Node(data='n1'), Node(data='n122'))
)
- self.assertEquals(
+ eq_(
sess.query(Node, parent, grandparent).\
join((Node.parent, parent), (parent.parent, grandparent)).\
filter(Node.data=='n122').filter(parent.data=='n12').\
(Node(data='n122'), Node(data='n12'), Node(data='n1'))
)
- self.assertEquals(
+ eq_(
sess.query(Node, parent, grandparent).\
join((Node.parent, parent), (parent.parent, grandparent)).\
filter(Node.data=='n122').filter(parent.data=='n12').\
def test_any(self):
sess = create_session()
- self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
- self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
- self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
+ eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
+ eq_(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
+ eq_(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
def test_has(self):
sess = create_session()
- self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
- self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
- self.assertEquals(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])
+ eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+ eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
+ eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])
def test_contains(self):
sess = create_session()
n122 = sess.query(Node).filter(Node.data=='n122').one()
- self.assertEquals(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')])
+ eq_(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')])
n13 = sess.query(Node).filter(Node.data=='n13').one()
- self.assertEquals(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')])
+ eq_(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')])
def test_eq_ne(self):
sess = create_session()
n12 = sess.query(Node).filter(Node.data=='n12').one()
- self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
+ eq_(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
- self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')])
+ eq_(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')])
class SelfReferentialM2MTest(_base.MappedTest):
run_setup_mappers = 'once'
run_inserts = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global nodes, node_to_nodes
nodes = Table('nodes', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('right_node_id', Integer, ForeignKey('nodes.id'),primary_key=True),
)
- def insert_data(self):
+ @classmethod
+ def insert_data(cls):
global Node
class Node(Base):
def test_any(self):
sess = create_session()
- self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')])
+ eq_(sess.query(Node).filter(Node.children.any(Node.data=='n3')).all(), [Node(data='n1'), Node(data='n2')])
def test_contains(self):
sess = create_session()
n4 = sess.query(Node).filter_by(data='n4').one()
- self.assertEquals(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')])
- self.assertEquals(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')])
+ eq_(sess.query(Node).filter(Node.children.contains(n4)).order_by(Node.data).all(), [Node(data='n1'), Node(data='n3')])
+ eq_(sess.query(Node).filter(not_(Node.children.contains(n4))).order_by(Node.data).all(), [Node(data='n2'), Node(data='n4'), Node(data='n5'), Node(data='n6'), Node(data='n7')])
def test_explicit_join(self):
sess = create_session()
n1 = aliased(Node)
- self.assertEquals(
+ eq_(
sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data.in_(['n3', 'n7'])).order_by(Node.id).all(),
[Node(data='n1'), Node(data='n2')]
)
def test_external_columns_bad(self):
- self.assertRaisesMessage(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={
+ assert_raises_message(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={
'concat': (users.c.id * 2),
})
clear_mappers()
sess.query(Address).options(eagerload('user')).all()
- self.assertEquals(sess.query(User).all(),
+ eq_(sess.query(User).all(),
[
User(id=7, concat=14, count=1),
User(id=8, concat=16, count=3),
Address(id=4, user=User(id=8, concat=16, count=3)),
Address(id=5, user=User(id=9, concat=18, count=1))
]
- self.assertEquals(sess.query(Address).all(), address_result)
+ eq_(sess.query(Address).all(), address_result)
# run the eager version twice to test caching of aliased clauses
for x in range(2):
sess.expunge_all()
def go():
- self.assertEquals(sess.query(Address).options(eagerload('user')).all(), address_result)
+ eq_(sess.query(Address).options(eagerload('user')).all(), address_result)
self.assert_sql_count(testing.db, go, 1)
ualias = aliased(User)
- self.assertEquals(
+ eq_(
sess.query(Address, ualias).join(('user', ualias)).all(),
[(address, address.user) for address in address_result]
)
- self.assertEquals(
+ eq_(
sess.query(Address, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(),
[
(Address(id=1), 1),
]
)
- self.assertEquals(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(),
+ eq_(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).order_by(Address.id).all(),
[
(Address(id=1), 14, 1),
(Address(id=2), 16, 3),
)
ua = aliased(User)
- self.assertEquals(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(),
+ eq_(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(),
[
(Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1),
(Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3),
]
)
- self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)),
+ eq_(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)),
[(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
)
- self.assertEquals(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)),
+ eq_(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)),
[(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
)
sess = create_session()
def go():
o1 = sess.query(Order).options(eagerload_all('address.user')).get(1)
- self.assertEquals(o1.address.user.count, 1)
+ eq_(o1.address.user.count, 1)
self.assert_sql_count(testing.db, go, 1)
sess = create_session()
def go():
o1 = sess.query(Order).options(eagerload_all('address.user')).first()
- self.assertEquals(o1.address.user.count, 1)
+ eq_(o1.address.user.count, 1)
self.assert_sql_count(testing.db, go, 1)
class TestOverlyEagerEquivalentCols(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global base, sub1, sub2
base = Table('base', metadata,
Column('id', Integer, primary_key=True),
q = sess.query(Base).outerjoin('sub2', aliased=True)
assert sub1.c.id not in q._filter_aliases.equivalents
- self.assertEquals(
+ eq_(
sess.query(Base).join('sub1').outerjoin('sub2', aliased=True).\
filter(Sub1.id==1).one(),
b1
)
class UpdateDeleteTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(32)),
Column('user_id', None, ForeignKey('users.id')),
Column('title', String(32)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class User(_base.ComparableEntity):
pass
class Document(_base.ComparableEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
users.insert().execute([
dict(id=1, name='john', age=25),
dict(id=2, name='jack', age=47),
dict(id=3, user_id=2, title='baz'),
])
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(User, users)
mapper(Document, documents, properties={
'user': relation(User, lazy=False, backref=backref('documents', lazy=True))
sess = create_session(bind=testing.db, autocommit=False)
rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age + 0})
- self.assertEquals(rowcount, 2)
+ eq_(rowcount, 2)
rowcount = sess.query(User).filter(User.age > 29).update({'age': User.age - 10})
- self.assertEquals(rowcount, 2)
+ eq_(rowcount, 2)
@testing.resolve_artifact_names
def test_delete_returns_rowcount(self):
sess = create_session(bind=testing.db, autocommit=False)
rowcount = sess.query(User).filter(User.age > 26).delete(synchronize_session=False)
- self.assertEquals(rowcount, 3)
+ eq_(rowcount, 3)
@testing.resolve_artifact_names
def test_update_with_eager_relations(self):
eq_(sess.query(Document.title).all(), zip(['baz']))
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
import datetime
-from testlib import sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, and_
-from testlib.sa.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers
-from testlib.testing import eq_, startswith_
-from orm import _base, _fixtures
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import Integer, String, ForeignKey, MetaData, and_
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, backref, create_session, compile_mappers, clear_mappers
+from sqlalchemy.test.testing import eq_, startswith_
+from test.orm import _base, _fixtures
class RelationTest(_base.MappedTest):
run_inserts = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("tbl_a", metadata,
Column("id", Integer, primary_key=True),
Column("name", String(128)))
Column("tbl_b_id", Integer, ForeignKey("tbl_b.id")),
Column("name", String(128)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class A(_base.Entity):
pass
class B(_base.Entity):
class D(_base.Entity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(A, tbl_a, properties=dict(
c_rows=relation(C, cascade="all, delete-orphan", backref="a_row")))
mapper(B, tbl_b)
mapper(D, tbl_d, properties=dict(
b_row=relation(B)))
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
session = create_session()
a = A(name='a1')
b = B(name='b1')
key where one column in the foreign key is 'joined to itself'.
"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('company_t', metadata,
Column('company_id', Integer, primary_key=True),
Column('name', sa.Unicode(30)))
assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5'
class RelationTest3(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("jobs", metadata,
Column("jobno", sa.Unicode(15), primary_key=True),
Column("created", sa.DateTime, nullable=False,
["jobno", "pagename"],
["pages.jobno", "pages.pagename"]))
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
class Job(_base.Entity):
def create_page(self, pagename):
return Page(job=self, pagename=pagename)
class RelationTest4(_base.MappedTest):
"""Syncrules on foreign keys that are also primary"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("tableA", metadata,
Column("id",Integer,primary_key=True),
Column("foo",Integer,),
Column("id",Integer,ForeignKey("tableA.id"),primary_key=True),
test_needs_fk=True)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class A(_base.Entity):
pass
class RelationTest5(_base.MappedTest):
"""Test a map to a select that relates to a map to the table."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('items', metadata,
Column('item_policy_num', String(10), primary_key=True,
key='policyNum'),
"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('tags', metadata, Column("id", Integer, primary_key=True),
Column("data", String(50)),
)
"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
subscriber_table = Table('subscriber', metadata,
Column('id', Integer, primary_key=True),
Column('dummy', String(10)) # to appease older sqlite version
Column('type', String(1), primary_key=True),
)
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
subscriber_and_address = subscriber.join(address,
and_(address.c.subscriber_id==subscriber.c.id, address.c.type.in_(['A', 'B', 'C'])))
'user':relation(User, back_populates='addresses')
})
- self.assertRaises(sa.exc.InvalidRequestError, compile_mappers)
+ assert_raises(sa.exc.InvalidRequestError, compile_mappers)
@testing.resolve_artifact_names
def test_invalid_target(self):
'dingaling':relation(Dingaling)
})
- self.assertRaisesMessage(sa.exc.ArgumentError,
+ assert_raises_message(sa.exc.ArgumentError,
r"reverse_property 'dingaling' on relation User.addresses references "
"relation Address.dingaling, which does not reference mapper Mapper\|User\|users",
compile_mappers)
c1id = Column('c1id', Integer, ForeignKey('c1.id'))
c2 = relation(C1, primaryjoin=C1.id)
- self.assertRaises(sa.exc.ArgumentError, compile_mappers)
+ assert_raises(sa.exc.ArgumentError, compile_mappers)
def test_clauseelement_pj_false(self):
from sqlalchemy.ext.declarative import declarative_base
c1id = Column('c1id', Integer, ForeignKey('c1.id'))
c2 = relation(C1, primaryjoin="x"=="y")
- self.assertRaises(sa.exc.ArgumentError, compile_mappers)
+ assert_raises(sa.exc.ArgumentError, compile_mappers)
def test_fk_error_raised(self):
mapper(C1, t1, properties={'c2':relation(C2)})
mapper(C2, t3)
- self.assertRaises(sa.exc.NoReferencedColumnError, compile_mappers)
+ assert_raises(sa.exc.NoReferencedColumnError, compile_mappers)
def test_join_error_raised(self):
m = MetaData()
mapper(C1, t1, properties={'c2':relation(C2)})
mapper(C2, t3)
- self.assertRaises(sa.exc.ArgumentError, compile_mappers)
+ assert_raises(sa.exc.ArgumentError, compile_mappers)
- def tearDown(self):
+ def teardown(self):
clear_mappers()
class TypeMatchTest(_base.MappedTest):
"""test errors raised when trying to add items whose type is not handled by a relation"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("a", metadata,
Column('aid', Integer, primary_key=True),
Column('data', String(30)))
sess.add(a1)
sess.add(b1)
sess.add(c1)
- self.assertRaisesMessage(sa.orm.exc.FlushError,
+ assert_raises_message(sa.orm.exc.FlushError,
"Attempting to flush an item", sess.flush)
@testing.resolve_artifact_names
sess.add(a1)
sess.add(b1)
sess.add(c1)
- self.assertRaisesMessage(sa.orm.exc.FlushError,
+ assert_raises_message(sa.orm.exc.FlushError,
"Attempting to flush an item", sess.flush)
@testing.resolve_artifact_names
sess = create_session()
sess.add(b1)
sess.add(d1)
- self.assertRaisesMessage(sa.orm.exc.FlushError,
+ assert_raises_message(sa.orm.exc.FlushError,
"Attempting to flush an item", sess.flush)
@testing.resolve_artifact_names
d1 = D()
d1.a = b1
sess = create_session()
- self.assertRaisesMessage(AssertionError,
+ assert_raises_message(AssertionError,
"doesn't handle objects of type", sess.add, d1)
class TypedAssociationTable(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
class MySpecialType(sa.types.TypeDecorator):
impl = String
def process_bind_param(self, value, dialect):
class ViewOnlyOverlappingNames(_base.MappedTest):
"""'viewonly' mappings with overlapping PK column names."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("t1", metadata,
Column('id', Integer, primary_key=True),
Column('data', String(40)))
class ViewOnlyUniqueNames(_base.MappedTest):
"""'viewonly' mappings with unique PK column names."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("t1", metadata,
Column('t1id', Integer, primary_key=True),
Column('data', String(40)))
class ViewOnlyNonEquijoin(_base.MappedTest):
"""'viewonly' mappings based on non-equijoins."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('foos', metadata,
Column('id', Integer, primary_key=True))
Table('bars', metadata,
class ViewOnlyRepeatedRemoteColumn(_base.MappedTest):
"""'viewonly' mappings that contain the same 'remote' column twice"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('foos', metadata,
Column('id', Integer, primary_key=True),
Column('bid1', Integer,ForeignKey('bars.id')),
class ViewOnlyRepeatedLocalColumn(_base.MappedTest):
"""'viewonly' mappings that contain the same 'local' column twice"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('foos', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(50)))
class ViewOnlyComplexJoin(_base.MappedTest):
"""'viewonly' mappings with a complex join condition."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(50)))
Column('t2id', Integer, ForeignKey('t2.id')),
Column('t3id', Integer, ForeignKey('t3.id')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class T1(_base.ComparableEntity):
pass
class T2(_base.ComparableEntity):
't1':relation(T1),
't3s':relation(T3, secondary=t2tot3)})
mapper(T3, t3)
- self.assertRaisesMessage(sa.exc.ArgumentError,
+ assert_raises_message(sa.exc.ArgumentError,
"Specify remote_side argument",
sa.orm.compile_mappers)
class ExplicitLocalRemoteTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1', metadata,
Column('id', String(50), primary_key=True),
Column('data', String(50)))
Column('data', String(50)),
Column('t1id', String(50)))
+ @classmethod
@testing.resolve_artifact_names
- def setup_classes(self):
+ def setup_classes(cls):
class T1(_base.ComparableEntity):
pass
class T2(_base.ComparableEntity):
foreign_keys=[t2.c.t1id],
remote_side=[t2.c.t1id])})
mapper(T2, t2)
- self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers)
+ assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers)
@testing.resolve_artifact_names
def test_escalation_2(self):
primaryjoin=t1.c.id==sa.func.lower(t2.c.t1id),
_local_remote_pairs=[(t1.c.id, t2.c.t1id)])})
mapper(T2, t2)
- self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers)
+ assert_raises(sa.exc.ArgumentError, sa.orm.compile_mappers)
class InvalidRemoteSideTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(50)),
Column('t_id', Integer, ForeignKey('t1.id'))
)
+ @classmethod
@testing.resolve_artifact_names
- def setup_classes(self):
+ def setup_classes(cls):
class T1(_base.ComparableEntity):
pass
't1s':relation(T1, backref='parent')
})
- self.assertRaisesMessage(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are "
+ assert_raises_message(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are "
"both of the same direction <symbol 'ONETOMANY>. Did you "
"mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers)
't1s':relation(T1, backref=backref('parent', remote_side=t1.c.id), remote_side=t1.c.id)
})
- self.assertRaisesMessage(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are "
+ assert_raises_message(sa.exc.ArgumentError, "T1.t1s and back-reference T1.parent are "
"both of the same direction <symbol 'MANYTOONE>. Did you "
"mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers)
})
# can't be sure of ordering here
- self.assertRaisesMessage(sa.exc.ArgumentError,
+ assert_raises_message(sa.exc.ArgumentError,
"both of the same direction <symbol 'ONETOMANY>. Did you "
"mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers)
})
# can't be sure of ordering here
- self.assertRaisesMessage(sa.exc.ArgumentError,
+ assert_raises_message(sa.exc.ArgumentError,
"both of the same direction <symbol 'MANYTOONE>. Did you "
"mean to set remote_side on the many-to-one side ?", sa.orm.compile_mappers)
class InvalidRelationEscalationTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('foos', metadata,
Column('id', Integer, primary_key=True),
Column('fid', Integer))
Column('id', Integer, primary_key=True),
Column('fid', Integer))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Foo(_base.Entity):
pass
class Bar(_base.Entity):
'bars':relation(Bar)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine join condition between parent/child "
"tables on relation", sa.orm.compile_mappers)
'foos':relation(Foo)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine join condition between parent/child "
"tables on relation", sa.orm.compile_mappers)
primaryjoin=foos.c.id>bars.c.fid)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine relation direction for primaryjoin condition",
sa.orm.compile_mappers)
foreign_keys=bars.c.fid)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not locate any equated, locally mapped column pairs "
"for primaryjoin condition", sa.orm.compile_mappers)
foreign_keys=[foos.c.id, bars.c.fid])})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Do the columns in 'foreign_keys' represent only the "
"'foreign' columns in this join condition ?",
)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"could not determine any local/remote column pairs",
sa.orm.compile_mappers)
)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"could not determine any local/remote column pairs",
sa.orm.compile_mappers)
primaryjoin=foos.c.id>foos.c.fid)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine relation direction for primaryjoin condition",
sa.orm.compile_mappers)
foreign_keys=[foos.c.fid])})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not locate any equated, locally mapped column pairs "
"for primaryjoin condition", sa.orm.compile_mappers)
viewonly=True)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine relation direction for primaryjoin condition",
sa.orm.compile_mappers)
viewonly=True)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Specify the 'foreign_keys' argument to indicate which columns "
"on the relation are foreign.", sa.orm.compile_mappers)
primaryjoin=foos.c.id==bars.c.fid)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine relation direction for primaryjoin condition",
sa.orm.compile_mappers)
'foos':relation(Foo,
primaryjoin=foos.c.id==foos.c.fid)})
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine relation direction for primaryjoin condition",
sa.orm.compile_mappers)
primaryjoin=foos.c.id==foos.c.fid,
foreign_keys=[bars.c.id])})
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine relation direction for primaryjoin condition",
sa.orm.compile_mappers)
class InvalidRelationEscalationTestM2M(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('foos', metadata,
Column('id', Integer, primary_key=True))
Table('foobars', metadata,
Table('bars', metadata,
Column('id', Integer, primary_key=True))
+ @classmethod
@testing.resolve_artifact_names
- def setup_classes(self):
+ def setup_classes(cls):
class Foo(_base.Entity):
pass
class Bar(_base.Entity):
'bars': relation(Bar, secondary=foobars)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine join condition between parent/child tables "
"on relation", sa.orm.compile_mappers)
primaryjoin=foos.c.id > foobars.c.fid)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine join condition between parent/child tables "
"on relation",
secondaryjoin=foobars.c.bid<=bars.c.id)})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine relation direction for primaryjoin condition",
sa.orm.compile_mappers)
foreign_keys=[foobars.c.fid])})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not determine relation direction for secondaryjoin "
"condition", sa.orm.compile_mappers)
foreign_keys=[foobars.c.fid, foobars.c.bid])})
mapper(Bar, bars)
- self.assertRaisesMessage(
+ assert_raises_message(
sa.exc.ArgumentError,
"Could not locate any equated, locally mapped column pairs for "
"secondaryjoin condition", sa.orm.compile_mappers)
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing
from sqlalchemy.orm import scoped_session
-from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, query
-from testlib.testing import eq_
-from orm import _base
+from sqlalchemy import Integer, String, ForeignKey
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, query
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
class _ScopedTest(_base.MappedTest):
_artifact_registries = (
_base.MappedTest._artifact_registries + ('scoping',))
- def setUpAll(self):
- type(self).scoping = _base.adict()
- _base.MappedTest.setUpAll(self)
+ @classmethod
+ def setup_class(cls):
+ cls.scoping = _base.adict()
+ super(_ScopedTest, cls).setup_class()
- def tearDownAll(self):
- self.scoping.clear()
- _base.MappedTest.tearDownAll(self)
+ @classmethod
+ def teardown_class(cls):
+ cls.scoping.clear()
+ super(_ScopedTest, cls).teardown_class()
class ScopedSessionTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('table1', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(30)))
class ScopedMapperTest(_ScopedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('table1', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(30)))
Column('id', Integer, primary_key=True),
Column('someid', None, ForeignKey('table1.id')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class SomeObject(_base.ComparableEntity):
pass
class SomeOtherObject(_base.ComparableEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
Session = scoped_session(sa.orm.create_session)
Session.mapper(SomeObject, table1, properties={
'options':relation(SomeOtherObject)
})
Session.mapper(SomeOtherObject, table2)
- self.scoping['Session'] = Session
+ cls.scoping['Session'] = Session
+ @classmethod
@testing.resolve_artifact_names
- def insert_data(self):
+ def insert_data(cls):
s = SomeObject()
s.id = 1
s.data = 'hello'
scope.mapper(B, table2)
A(foo='bar')
- self.assertRaises(TypeError, B, foo='bar')
+ assert_raises(TypeError, B, foo='bar')
scope = scoped_session(sa.orm.sessionmaker())
scope.mapper(C, table1)
scope.mapper(D, table2)
- self.assertRaises(TypeError, C, foo='bar')
+ assert_raises(TypeError, C, foo='bar')
D(foo='bar')
@testing.resolve_artifact_names
Session.mapper(ValidatedOtherObject, table2, validate=True)
v1 = ValidatedOtherObject(someid=12)
- self.assertRaises(sa.exc.ArgumentError, ValidatedOtherObject,
+ assert_raises(sa.exc.ArgumentError, ValidatedOtherObject,
someid=12, bogus=345)
@testing.resolve_artifact_names
class ScopedMapperTest2(_ScopedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('table1', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(30)),
Column('someid', None, ForeignKey('table1.id')),
Column('somedata', String(30)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class BaseClass(_base.ComparableEntity):
pass
class SubClass(BaseClass):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
Session = scoped_session(sa.orm.sessionmaker())
Session.mapper(BaseClass, table1,
polymorphic_identity='sub',
inherits=BaseClass)
- self.scoping['Session'] = Session
+ cls.scoping['Session'] = Session
@testing.resolve_artifact_names
def test_inheritance(self):
SubClass.query.all())
-if __name__ == "__main__":
- testenv.main()
"""Generic mapping to Select statements"""
-import testenv; testenv.configure_for_tests()
-from testlib import sa, testing
-from testlib.sa import Table, Column, String, Integer, select
-from testlib.sa.orm import mapper, create_session
-from testlib.testing import eq_
-from orm import _base
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import String, Integer, select
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, create_session
+from sqlalchemy.test.testing import eq_
+from test.orm import _base
# TODO: more tests mapping to selects
class SelectableNoFromsTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('common', metadata,
Column('id', Integer, primary_key=True),
Column('data', Integer),
Column('extra', String(45)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Subset(_base.ComparableEntity):
pass
def test_no_tables(self):
selectable = select(["x", "y", "z"])
- self.assertRaisesMessage(sa.exc.InvalidRequestError,
+ assert_raises_message(sa.exc.InvalidRequestError,
"Could not find any Table objects",
mapper, Subset, selectable)
Subset(data=1))
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import gc
import inspect
import pickle
from sqlalchemy.orm import create_session, sessionmaker, attributes
-from testlib import engines, sa, testing, config
-from testlib.sa import Table, Column, Integer, String, Sequence
-from testlib.sa.orm import mapper, relation, backref, eagerload
-from testlib.testing import eq_
-from engine import _base as engine_base
-from orm import _base, _fixtures
+import sqlalchemy as sa
+from sqlalchemy.test import engines, testing, config
+from sqlalchemy import Integer, String, Sequence
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, backref, eagerload
+from sqlalchemy.test.testing import eq_
+from test.engine import _base as engine_base
+from test.orm import _base, _fixtures
class SessionTest(_fixtures.FixtureTest):
sess.commit()
- self.assertEquals(set(sess.query(User).all()), set([u2]))
+ eq_(set(sess.query(User).all()), set([u2]))
sess.begin()
sess.begin_nested()
sess.commit() # commit the nested transaction
sess.rollback()
- self.assertEquals(set(sess.query(User).all()), set([u2]))
+ eq_(set(sess.query(User).all()), set([u2]))
sess.close()
sess.close()
- self.assertEquals(len(sess.query(User).all()), 1)
+ eq_(len(sess.query(User).all()), 1)
t1 = sess.begin()
t2 = sess.begin_nested()
sess.close()
- self.assertEquals(len(sess.query(User).all()), 1)
+ eq_(len(sess.query(User).all()), 1)
@testing.resolve_artifact_names
def test_error_on_using_inactive_session(self):
sess.flush()
sess.rollback()
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True)
+ assert_raises_message(sa.exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True)
sess.close()
@testing.resolve_artifact_names
sess.flush()
assert transaction._connection_for_bind(testing.db) is transaction._connection_for_bind(c) is c
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect())
+ assert_raises_message(sa.exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect())
transaction.rollback()
assert len(sess.query(User).all()) == 0
user = User(name='u1')
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "is not persisted", s.update, user)
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "is not persisted", s.delete, user)
+ assert_raises_message(sa.exc.InvalidRequestError, "is not persisted", s.update, user)
+ assert_raises_message(sa.exc.InvalidRequestError, "is not persisted", s.delete, user)
s.add(user)
s.flush()
assert user in s
assert user not in s.dirty
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "is already persistent", s.save, user)
+ assert_raises_message(sa.exc.InvalidRequestError, "is already persistent", s.save, user)
s2 = create_session()
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "is already attached to session", s2.delete, user)
+ assert_raises_message(sa.exc.InvalidRequestError, "is already attached to session", s2.delete, user)
u2 = s2.query(User).get(user.id)
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "another instance with key", s.delete, u2)
+ assert_raises_message(sa.exc.InvalidRequestError, "another instance with key", s.delete, u2)
s.expire(user)
s.expunge(user)
u = User(name='u1')
sess.add(u)
sess.flush()
- self.assertEquals(sess.query(User).order_by(User.name).all(),
+ eq_(sess.query(User).order_by(User.name).all(),
[
User(name='another u1'),
User(name='u1')
)
sess.flush()
- self.assertEquals(sess.query(User).order_by(User.name).all(),
+ eq_(sess.query(User).order_by(User.name).all(),
[
User(name='another u1'),
User(name='u1')
u.name='u2'
sess.flush()
- self.assertEquals(sess.query(User).order_by(User.name).all(),
+ eq_(sess.query(User).order_by(User.name).all(),
[
User(name='another u1'),
User(name='another u2'),
sess.delete(u)
sess.flush()
- self.assertEquals(sess.query(User).order_by(User.name).all(),
+ eq_(sess.query(User).order_by(User.name).all(),
[
User(name='another u1'),
]
u = User(name='u1')
sess.add(u)
sess.flush()
- self.assertEquals(sess.query(User).order_by(User.name).all(),
+ eq_(sess.query(User).order_by(User.name).all(),
[
User(name='u1')
]
sess.add(User(name='u2'))
sess.flush()
sess.expunge_all()
- self.assertEquals(sess.query(User).order_by(User.name).all(),
+ eq_(sess.query(User).order_by(User.name).all(),
[
User(name='u1 modified'),
User(name='u2')
sess = create_session(extension=MyExt())
sess.add(User(name='foo'))
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "already flushing", sess.flush)
+ assert_raises_message(sa.exc.InvalidRequestError, "already flushing", sess.flush)
@testing.resolve_artifact_names
def test_pickled_update(self):
u1 = User(name='u1')
sess1.add(u1)
- self.assertRaisesMessage(sa.exc.InvalidRequestError, "already attached to session", sess2.add, u1)
+ assert_raises_message(sa.exc.InvalidRequestError, "already attached to session", sess2.add, u1)
u2 = pickle.loads(pickle.dumps(u1))
assert u2 is not None and u2 is not u1
assert u2 in sess
- self.assertRaises(Exception, lambda: sess.add(u1))
+ assert_raises(Exception, lambda: sess.add(u1))
sess.expunge(u2)
assert u2 not in sess
run_inserts = 'once'
run_deletes = None
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
global t1
t1 = Table('t1', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(50))
)
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
global T
class T(object):
def __init__(self, data):
self.data = data
mapper(T, t1)
- def tearDown(self):
+ def teardown(self):
from sqlalchemy.orm.session import _sessions
_sessions.clear()
- super(DisposedStates, self).tearDown()
+ super(DisposedStates, self).teardown()
def _set_imap_in_disposal(self, sess, *objs):
"""remove selected objects from the given session, as though they
def x_raises_(obj, method, *args, **kw):
watchdog.add(method)
callable_ = getattr(obj, method)
- self.assertRaises(sa.orm.exc.UnmappedInstanceError,
+ assert_raises(sa.orm.exc.UnmappedInstanceError,
callable_, *args, **kw)
def raises_(method, *args, **kw):
def raises_(method, *args, **kw):
watchdog.add(method)
callable_ = getattr(create_session(), method)
- self.assertRaises(sa.orm.exc.UnmappedClassError,
+ assert_raises(sa.orm.exc.UnmappedClassError,
callable_, *args, **kw)
raises_('connection', mapper=user_arg)
class TLTransactionTest(engine_base.AltEngineTest, _base.MappedTest):
- def create_engine(self):
+ @classmethod
+ def create_engine(cls):
return engines.testing_engine(options=dict(strategy='threadlocal'))
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(20)),
test_needs_acid=True)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class User(_base.BasicEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(User, users)
- def setUpAll(self):
- engine_base.AltEngineTest.setUpAll(self)
- _base.MappedTest.setUpAll(self)
-
-
- def tearDownAll(self):
- _base.MappedTest.tearDownAll(self)
- engine_base.AltEngineTest.tearDownAll(self)
-
@testing.exclude('mysql', '<', (5, 0, 3), 'FIXME: unknown')
@testing.resolve_artifact_names
def test_session_nesting(self):
self.engine.commit()
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
from sqlalchemy import *
from sqlalchemy.orm import attributes
from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import *
-from testlib import testing
-from orm import _base
-from orm._fixtures import FixtureTest, User, Address, users, addresses
+from sqlalchemy.test import testing
+from test.orm import _base
+from test.orm._fixtures import FixtureTest, User, Address, users, addresses
import gc
run_inserts = None
session = sessionmaker()
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
mapper(User, users, properties={
'addresses':relation(Address, backref='user',
cascade="all, delete-orphan"),
u1 = sess.query(User).get(7)
u1.name = 'ed'
sess.rollback()
- self.assertEquals(u1.name, 'jack')
+ eq_(u1.name, 'jack')
def test_commit_persistent(self):
sess = self.session()
u1.name = 'ed'
sess.flush()
sess.commit()
- self.assertEquals(u1.name, 'ed')
+ eq_(u1.name, 'ed')
def test_concurrent_commit_persistent(self):
s1 = self.session()
u1.addresses.remove(a1)
s.flush()
- self.assertEquals(s.query(Address).filter(Address.email_address=='foo').all(), [])
+ eq_(s.query(Address).filter(Address.email_address=='foo').all(), [])
s.rollback()
assert a1 not in s.deleted
assert u1.addresses == [a1]
sess.add(u1)
sess.flush()
sess.commit()
- self.assertEquals(u1.name, 'newuser')
+ eq_(u1.name, 'newuser')
def test_concurrent_commit_pending(self):
u1.name = 'edward'
a1.email_address = 'foober'
s.add(u2)
- self.assertRaises(sa_exc.FlushError, s.commit)
- self.assertRaises(sa_exc.InvalidRequestError, s.commit)
+ assert_raises(sa_exc.FlushError, s.commit)
+ assert_raises(sa_exc.InvalidRequestError, s.commit)
s.rollback()
assert u2 not in s
assert a2 not in s
u1.name = 'edward'
a1.email_address = 'foober'
s.commit()
- self.assertEquals(
+ eq_(
s.query(User).all(),
[User(id=1, name='edward', addresses=[Address(email_address='foober')])]
)
a1.email_address = 'foober'
s.begin_nested()
s.add(u2)
- self.assertRaises(sa_exc.FlushError, s.commit)
- self.assertRaises(sa_exc.InvalidRequestError, s.commit)
+ assert_raises(sa_exc.FlushError, s.commit)
+ assert_raises(sa_exc.InvalidRequestError, s.commit)
s.rollback()
assert u2 not in s
assert a2 not in s
u1.name = 'edward'
u2.name = 'jackward'
s.add_all([u3, u4])
- self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+ eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
s.rollback()
assert u1.name == 'ed'
assert u2.name == 'jack'
- self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
+ eq_(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
s.commit()
assert u1.name == 'ed'
assert u2.name == 'jack'
- self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
+ eq_(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
@testing.requires.savepoints
def test_savepoint_delete(self):
u1 = User(name='ed')
s.add(u1)
s.commit()
- self.assertEquals(s.query(User).filter_by(name='ed').count(), 1)
+ eq_(s.query(User).filter_by(name='ed').count(), 1)
s.begin_nested()
s.delete(u1)
s.commit()
- self.assertEquals(s.query(User).filter_by(name='ed').count(), 0)
+ eq_(s.query(User).filter_by(name='ed').count(), 0)
s.commit()
@testing.requires.savepoints
u1.name = 'edward'
u2.name = 'jackward'
s.add_all([u3, u4])
- self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+ eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
s.commit()
def go():
assert u1.name == 'edward'
assert u2.name == 'jackward'
- self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+ eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
self.assert_sql_count(testing.db, go, 1)
s.commit()
- self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+ eq_(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
@testing.requires.savepoints
def test_savepoint_rollback_collections(self):
s.begin_nested()
u2 = User(name='jack', addresses=[Address(email_address='bat')])
s.add(u2)
- self.assertEquals(s.query(User).order_by(User.id).all(),
+ eq_(s.query(User).order_by(User.id).all(),
[
User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
User(name='jack', addresses=[Address(email_address='bat')])
]
)
s.rollback()
- self.assertEquals(s.query(User).order_by(User.id).all(),
+ eq_(s.query(User).order_by(User.id).all(),
[
User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
]
)
s.commit()
- self.assertEquals(s.query(User).order_by(User.id).all(),
+ eq_(s.query(User).order_by(User.id).all(),
[
User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
]
s.begin_nested()
u2 = User(name='jack', addresses=[Address(email_address='bat')])
s.add(u2)
- self.assertEquals(s.query(User).order_by(User.id).all(),
+ eq_(s.query(User).order_by(User.id).all(),
[
User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
User(name='jack', addresses=[Address(email_address='bat')])
]
)
s.commit()
- self.assertEquals(s.query(User).order_by(User.id).all(),
+ eq_(s.query(User).order_by(User.id).all(),
[
User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
User(name='jack', addresses=[Address(email_address='bat')])
]
)
s.commit()
- self.assertEquals(s.query(User).order_by(User.id).all(),
+ eq_(s.query(User).order_by(User.id).all(),
[
User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
User(name='jack', addresses=[Address(email_address='bat')])
class AutoCommitTest(TransactionTest):
def test_begin_nested_requires_trans(self):
sess = create_session(autocommit=True)
- self.assertRaises(sa_exc.InvalidRequestError, sess.begin_nested)
+ assert_raises(sa_exc.InvalidRequestError, sess.begin_nested)
def test_begin_preflush(self):
sess = create_session(autocommit=True)
-if __name__ == '__main__':
- testenv.main()
# coding: utf-8
"""Tests unitofwork operations."""
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import datetime
import operator
from sqlalchemy.orm import mapper as orm_mapper
-from testlib import engines, sa, testing
-from testlib.sa import Table, Column, Integer, String, ForeignKey, literal_column
-from testlib.sa.orm import mapper, relation, create_session, column_property
-from testlib.testing import eq_, ne_
-from orm import _base, _fixtures
-from engine import _base as engine_base
-import pickleable
-from testlib.assertsql import AllOf, CompiledSQL
+import sqlalchemy as sa
+from sqlalchemy.test import engines, testing, pickleable
+from sqlalchemy import Integer, String, ForeignKey, literal_column
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, column_property
+from sqlalchemy.test.testing import eq_, ne_
+from test.orm import _base, _fixtures
+from test.engine import _base as engine_base
+from sqlalchemy.test.assertsql import AllOf, CompiledSQL
import gc
class UnitOfWorkTest(object):
class HistoryTest(_fixtures.FixtureTest):
run_inserts = None
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class User(_base.ComparableEntity):
pass
class Address(_base.ComparableEntity):
class VersioningTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('version_table', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('version_id', Integer, nullable=False),
Column('value', String(40), nullable=False))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Foo(_base.ComparableEntity):
pass
# Only dialects with a sane rowcount can detect the
# ConcurrentModificationError
if testing.db.dialect.supports_sane_rowcount:
- self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.commit)
+ assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit)
s1.rollback()
else:
s1.commit()
s1.delete(f2)
if testing.db.dialect.supports_sane_multi_rowcount:
- self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.commit)
+ assert_raises(sa.orm.exc.ConcurrentModificationError, s1.commit)
else:
s1.commit()
s2.commit()
# load, version is wrong
- self.assertRaises(sa.orm.exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id)
+ assert_raises(sa.orm.exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id)
# reload it
s1.query(Foo).populate_existing().get(f1s1.id)
class UnicodeTest(_base.MappedTest):
__requires__ = ('unicode_connections',)
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('uni_t1', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
test_needs_autoincrement=True),
Column('txt', sa.Unicode(50), ForeignKey('uni_t1')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Test(_base.BasicEntity):
pass
class Test2(_base.BasicEntity):
class UnicodeSchemaTest(engine_base.AltEngineTest, _base.MappedTest):
__requires__ = ('unicode_connections', 'unicode_ddl',)
- def create_engine(self):
+ @classmethod
+ def create_engine(cls):
return engines.utf8_engine()
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
t1 = Table('unitable1', metadata,
Column(u'méil', Integer, primary_key=True, key='a', test_needs_autoincrement=True),
Column(u'\u6e2c\u8a66', Integer, key='b'),
test_needs_fk=True,
test_needs_autoincrement=True)
- self.tables['t1'] = t1
- self.tables['t2'] = t2
+ cls.tables['t1'] = t1
+ cls.tables['t2'] = t2
- def setUpAll(self):
- engine_base.AltEngineTest.setUpAll(self)
- _base.MappedTest.setUpAll(self)
+ @classmethod
+ def setup_class(cls):
+ super(UnicodeSchemaTest, cls).setup_class()
- def tearDownAll(self):
- _base.MappedTest.tearDownAll(self)
- engine_base.AltEngineTest.tearDownAll(self)
+ @classmethod
+ def teardown_class(cls):
+ super(UnicodeSchemaTest, cls).teardown_class()
@testing.fails_on('mssql', 'pyodbc returns a non unicode encoding of the results description.')
@testing.resolve_artifact_names
class MutableTypesTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('mutable_t', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('data', sa.PickleType),
Column('val', sa.Unicode(30)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Foo(_base.BasicEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Foo, mutable_t)
@testing.resolve_artifact_names
self.sql_count_(0, session.commit)
-class PickledDicts(_base.MappedTest):
+class PickledDictsTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('mutable_t', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('data', sa.PickleType(comparator=operator.eq)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Foo(_base.BasicEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(Foo, mutable_t)
@testing.resolve_artifact_names
class PKTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('multipk1', metadata,
Column('multi_id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('date_assigned', sa.Date, key='assigned', primary_key=True),
Column('data', String(30)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Entry(_base.BasicEntity):
pass
class ForeignPKTest(_base.MappedTest):
"""Detection of the relationship direction on PK joins."""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("people", metadata,
Column('person', String(10), primary_key=True),
Column('firstname', String(10)),
primary_key=True),
Column('site', String(10)))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Person(_base.BasicEntity):
pass
class PersonSite(_base.BasicEntity):
class ClauseAttributesTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('users_t', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('name', String(30)),
Column('counter', Integer, default=1))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class User(_base.ComparableEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
mapper(User, users_t)
@testing.resolve_artifact_names
class PassiveDeletesTest(_base.MappedTest):
__requires__ = ('foreign_keys',)
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('mytable', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
ondelete="CASCADE"),
test_needs_fk=True)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class MyClass(_base.BasicEntity):
pass
class MyOtherClass(_base.BasicEntity):
class ExtraPassiveDeletesTest(_base.MappedTest):
__requires__ = ('foreign_keys',)
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('mytable', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('data', String(30)),
['mytable.id']),
test_needs_fk=True)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class MyClass(_base.BasicEntity):
pass
class MyOtherClass(_base.BasicEntity):
assert myothertable.count().scalar() == 4
mc = session.query(MyClass).get(mc.id)
session.delete(mc)
- self.assertRaises(sa.exc.DBAPIError, session.flush)
+ assert_raises(sa.exc.DBAPIError, session.flush)
@testing.resolve_artifact_names
def test_extra_passive_2(self):
mc = session.query(MyClass).get(mc.id)
session.delete(mc)
mc.children[0].data = 'some new data'
- self.assertRaises(sa.exc.DBAPIError, session.flush)
+ assert_raises(sa.exc.DBAPIError, session.flush)
class DefaultTest(_base.MappedTest):
"""
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
use_string_defaults = testing.against('postgres', 'oracle', 'sqlite', 'mssql')
if use_string_defaults:
hohoval = 9
althohoval = 15
- self.other_artifacts['hohoval'] = hohoval
- self.other_artifacts['althohoval'] = althohoval
+ cls.other_artifacts['hohoval'] = hohoval
+ cls.other_artifacts['althohoval'] = althohoval
dt = Table('default_t', metadata,
Column('id', Integer, primary_key=True,
st.append_column(
Column('hoho', hohotype, ForeignKey('default_t.hoho')))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Hoho(_base.ComparableEntity):
pass
class Secondary(_base.ComparableEntity):
Secondary(data='s2')]))
class ColumnPropertyTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('data', metadata,
Column('id', Integer, primary_key=True),
Column('a', String(50)),
Column('c', String(50)),
)
- def setup_mappers(self):
+ @classmethod
+ def setup_mappers(cls):
class Data(_base.BasicEntity):
pass
sd1 = SubData(a="hello", b="there", c="hi")
sess.add(sd1)
sess.flush()
- self.assertEquals(sd1.aplusb, "hello there")
+ eq_(sd1.aplusb, "hello there")
@testing.resolve_artifact_names
def _test(self):
sess.add(d1)
sess.flush()
- self.assertEquals(d1.aplusb, "hello there")
+ eq_(d1.aplusb, "hello there")
d1.b = "bye"
sess.flush()
- self.assertEquals(d1.aplusb, "hello bye")
+ eq_(d1.aplusb, "hello bye")
d1.b = 'foobar'
d1.aplusb = 'im setting this explicitly'
sess.flush()
- self.assertEquals(d1.aplusb, "im setting this explicitly")
+ eq_(d1.aplusb, "im setting this explicitly")
class OneToManyTest(_fixtures.FixtureTest):
run_inserts = None
u1 = User(name='user1')
u2 = User(name='user2')
session.add_all((u1, u2))
- self.assertRaises(AssertionError, session.flush)
+ assert_raises(AssertionError, session.flush)
class ManyToOneTest(_fixtures.FixtureTest):
)
class SaveTest3(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('items', metadata,
Column('item_id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('keyword_id', Integer, ForeignKey("keywords")),
Column('foo', sa.Boolean, default=True))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class Keyword(_base.BasicEntity):
pass
class Item(_base.BasicEntity):
assert assoc.count().scalar() == 0
class BooleanColTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('t1_t', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('name', String(30)),
class RowSwitchTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
# parent
Table('t5', metadata,
Column('id', Integer, primary_key=True),
Column('t5id', Integer, ForeignKey('t5.id'),nullable=False),
Column('t7id', Integer, ForeignKey('t7.id'),nullable=False))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class T5(_base.ComparableEntity):
pass
assert list(sess.execute(t6.select(), mapper=T5)) == [(1, 'some other t6', 2)]
class InheritingRowSwitchTest(_base.MappedTest):
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table('parent', metadata,
Column('id', Integer, primary_key=True),
Column('pdata', String(30))
Column('cdata', String(30))
)
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class P(_base.ComparableEntity):
pass
# be specified. it'll raise immediately post-INSERT, instead of at
# COMMIT. either way, this test should pass.
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
t1 = Table('t1', metadata,
Column('id', Integer, primary_key=True))
Column('t1_id', Integer,
ForeignKey('t1.id', deferrable=True, initially='deferred')
))
- def setup_classes(self):
+ @classmethod
+ def setup_classes(cls):
class T1(_base.ComparableEntity):
pass
class T2(_base.ComparableEntity):
pass
+ @classmethod
@testing.resolve_artifact_names
- def setup_mappers(self):
+ def setup_mappers(cls):
orm_mapper(T1, t1)
orm_mapper(T2, t2)
if testing.against('postgres'):
t1.bind.engine.dispose()
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
from sqlalchemy.orm import interfaces, util
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy.orm import mapper, create_session
-from testlib import TestBase, testing
+from sqlalchemy.test import TestBase, testing
-from orm import _fixtures
-from testlib.testing import eq_
+from test.orm import _fixtures
+from sqlalchemy.test.testing import eq_
class ExtensionCarrierTest(TestBase):
assert carrier.translate_row() is interfaces.EXT_CONTINUE
assert 'translate_row' not in carrier
- self.assertRaises(AttributeError, lambda: carrier.snickysnack)
+ assert_raises(AttributeError, lambda: carrier.snickysnack)
class Partial(object):
def __init__(self, marker):
table = self.point_map(Point)
alias = aliased(Point)
- self.assertRaises(TypeError, alias)
+ assert_raises(TypeError, alias)
def test_instancemethods(self):
class Point(object):
key = util.identity_key(User, row=row)
eq_(key, (User, (1,)))
-if __name__ == '__main__':
- testenv.main()
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-
-def suite():
- modules_to_test = (
- 'profiling.memusage',
- 'profiling.compiler',
- 'profiling.pool',
- 'profiling.zoomark',
- 'profiling.zoomark_orm',
- )
- alltests = unittest.TestSuite()
- if testenv.testlib.config.coverage_enabled:
- return alltests
-
- for name in modules_to_test:
- mod = __import__(name)
- for token in name.split('.')[1:]:
- mod = getattr(mod, token)
- alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
- return alltests
-
-
-if __name__ == '__main__':
- testenv.main(suite())
-from engine import _base as engine_base
+from test.engine import _base as engine_base
TablesTest = engine_base.TablesTest
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-
-def suite():
- modules_to_test = (
- 'sql.testtypes',
- 'sql.columns',
- 'sql.constraints',
-
- 'sql.generative',
-
- # SQL syntax
- 'sql.select',
- 'sql.selectable',
- 'sql.case_statement',
- 'sql.labels',
- 'sql.unicode',
-
- # assorted round-trip tests
- 'sql.functions',
- 'sql.query',
- 'sql.quote',
- 'sql.rowcount',
-
- # defaults, sequences (postgres/oracle)
- 'sql.defaults',
- )
- alltests = unittest.TestSuite()
- for name in modules_to_test:
- mod = __import__(name)
- for token in name.split('.')[1:]:
- mod = getattr(mod, token)
- alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
- return alltests
-
-if __name__ == '__main__':
- testenv.main(suite())
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
import sys
from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
from sqlalchemy import util, exc
from sqlalchemy.sql import table, column
class CaseTest(TestBase, AssertsCompiledSQL):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
metadata = MetaData(testing.db)
global info_table
info_table = Table('infos', metadata,
{'pk':4, 'info':'pk_4_data'},
{'pk':5, 'info':'pk_5_data'},
{'pk':6, 'info':'pk_6_data'})
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
info_table.drop()
@testing.fails_on('firebird', 'FIXME: unknown')
def test_literal_interpretation(self):
t = table('test', column('col1'))
- self.assertRaises(exc.ArgumentError, case, [("x", "y")])
+ assert_raises(exc.ArgumentError, case, [("x", "y")])
self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END")
self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END")
('other', 3),
]
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
from sqlalchemy import *
from sqlalchemy import exc, sql
-from testlib import *
+from sqlalchemy.test import *
from sqlalchemy import Table, Column # don't use testlib's wrappers
def test_incomplete(self):
c = self.columns()
- self.assertRaises(exc.ArgumentError, Table, 't', MetaData(), *c)
+ assert_raises(exc.ArgumentError, Table, 't', MetaData(), *c)
def test_incomplete_key(self):
c = Column(Integer)
def test_bogus(self):
- self.assertRaises(exc.ArgumentError, Column, 'foo', name='bar')
- self.assertRaises(exc.ArgumentError, Column, 'foo', Integer,
+ assert_raises(exc.ArgumentError, Column, 'foo', name='bar')
+ assert_raises(exc.ArgumentError, Column, 'foo', Integer,
type_=Integer())
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
from sqlalchemy import *
from sqlalchemy import exc
-from testlib import *
-from testlib import config, engines
+from sqlalchemy.test import *
+from sqlalchemy.test import config, engines
class ConstraintTest(TestBase, AssertsExecutionResults):
- def setUp(self):
+ def setup(self):
global metadata
metadata = MetaData(testing.db)
- def tearDown(self):
+ def teardown(self):
metadata.drop_all()
def test_constraint(self):
def test_double_fk_usage_raises(self):
f = ForeignKey('b.id')
- self.assertRaises(exc.InvalidRequestError, Table, "a", metadata,
+ assert_raises(exc.InvalidRequestError, Table, "a", metadata,
Column('x', Integer, f),
Column('y', Integer, f)
)
t1 = Table("sometable", MetaData(), Column("foo", Integer))
schemagen.visit_index(Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
- self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)")
+ eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_name_is_t_1 ON sometable (foo)")
schemagen.buffer.truncate(0)
schemagen.visit_index(Index("this_other_name_is_too_long_for_what_were_doing", t1.c.foo))
- self.assertEquals(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)")
+ eq_(schemagen.buffer.getvalue(), "CREATE INDEX this_other_nam_2 ON sometable (foo)")
schemadrop = dialect.schemadropper(dialect, None)
schemadrop.execute = lambda: None
- self.assertRaises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
+ assert_raises(exc.IdentifierError, schemadrop.visit_index, Index("this_name_is_too_long_for_what_were_doing", t1.c.foo))
class ConstraintCompilationTest(TestBase, AssertsExecutionResults):
def clear(self):
del self.statements[:]
- def setUp(self):
+ def setup(self):
self.sql = self.accum()
opts = config.db_opts.copy()
opts['strategy'] = 'mock'
assert 'INITIALLY DEFERRED' in self.sql, self.sql
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import datetime
from sqlalchemy import Sequence, Column, func
from sqlalchemy.sql import select, text
-from testlib import sa, testing
-from testlib.sa import MetaData, Table, Integer, String, ForeignKey, Boolean
-from testlib.testing import eq_
-from sql import _base
+import sqlalchemy as sa
+from sqlalchemy.test import testing
+from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.testing import eq_
+from test.sql import _base
class DefaultTest(testing.TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global t, f, f2, ts, currenttime, metadata, default_generator
db = testing.db
server_default='ddl'))
t.create()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
t.drop()
- def tearDown(self):
+ def teardown(self):
default_generator['x'] = 50
t.delete().execute()
fn4 = FN4()
for fn in fn1, fn2, fn3, fn4:
- self.assertRaisesMessage(sa.exc.ArgumentError,
+ assert_raises_message(sa.exc.ArgumentError,
ex_msg,
sa.ColumnDefault, fn)
class PKDefaultTest(_base.TablesTest):
__requires__ = ('subqueries',)
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
t2 = Table('t2', metadata,
Column('nextid', Integer))
class PKIncrementTest(_base.TablesTest):
run_define_tables = 'each'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
Table("aitable", metadata,
Column('id', Integer, Sequence('ai_id_seq', optional=True),
primary_key=True),
try:
result = t1.insert().execute()
- self.assertEquals(1, select([func.count(text('*'))], from_obj=t1).scalar())
- self.assertEquals(True, t1.select().scalar())
+ eq_(1, select([func.count(text('*'))], from_obj=t1).scalar())
+ eq_(True, t1.select().scalar())
finally:
metadata.drop_all()
__requires__ = ('identity',)
run_define_tables = 'each'
- def define_tables(self, metadata):
+ @classmethod
+ def define_tables(cls, metadata):
"""Each test manipulates self.metadata individually."""
@testing.exclude('sqlite', '<', (3, 4), 'no database support')
class SequenceTest(testing.TestBase):
__requires__ = ('sequences',)
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global cartitems, sometable, metadata
metadata = MetaData(testing.db)
cartitems = Table("cartitems", metadata,
x = cartitems.c.cart_id.sequence.execute()
self.assert_(1 <= x <= 4)
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_
import datetime
from sqlalchemy import *
from sqlalchemy.sql import table, column
from sqlalchemy import databases, sql, util
from sqlalchemy.sql.compiler import BIND_TEMPLATES
from sqlalchemy.engine import default
-from testlib.engines import all_dialects
+from sqlalchemy.test.engines import all_dialects
from sqlalchemy import types as sqltypes
-from testlib import *
+from sqlalchemy.test import *
from sqlalchemy.sql.functions import GenericFunction
-from testlib.testing import eq_
+from sqlalchemy.test.testing import eq_
from decimal import Decimal as _python_Decimal
from sqlalchemy.databases import *
t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi")
res = exec_sorted(select([t2.c.value, t2.c.stuff]))
- self.assertEquals(res, [(-14, 'hi'), (3, None), (7, None)])
+ eq_(res, [(-14, 'hi'), (3, None), (7, None)])
t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff")
assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")]
return sorted([tuple(row)
for row in statement.execute(*args, **kw).fetchall()])
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.sql import table, column, ClauseElement
from sqlalchemy.sql.expression import _clone, _from_objects
-from testlib import *
+from sqlalchemy.test import *
from sqlalchemy.sql.visitors import *
from sqlalchemy import util
from sqlalchemy.sql import util as sql_util
"""test ClauseVisitor's traversal, particularly its ability to copy and modify
a ClauseElement in place."""
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global A, B
# establish two ficticious ClauseElements.
class ClauseTest(TestBase, AssertsCompiledSQL):
"""test copy-in-place behavior of various ClauseElements."""
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global t1, t2
t1 = table("table1",
column("col1"),
self.assert_compile(s2, "SELECT 1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1 WHERE table1.col1 = anon_1.col1")
class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global t1, t2
t1 = table("table1",
column("col1"),
)
class SpliceJoinsTest(TestBase, AssertsCompiledSQL):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global table1, table2, table3, table4
def _table(name):
return table(name, column("col1"), column("col2"),column("col3"))
class SelectTest(TestBase, AssertsCompiledSQL):
"""tests the generative capability of Select"""
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global t1, t2
t1 = table("table1",
column("col1"),
# fixme: consolidate converage from elsewhere here and expand
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global t1, t2
t1 = table("table1",
column("col1"),
"table1 (col1, col2, col3) "
"VALUES (:col1, :col2, :col3)")
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
from sqlalchemy import *
from sqlalchemy import exc as exceptions
-from testlib import *
+from sqlalchemy.test import *
from sqlalchemy.engine import default
IDENT_LENGTH = 29
assert isinstance(select([t.c.col2]).as_scalar().label('lala').type, Float)
class LongLabelsTest(TestBase, AssertsCompiledSQL):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, table1, table2, maxlen
metadata = MetaData(testing.db)
table1 = Table("some_large_named_table", metadata,
maxlen = testing.db.dialect.max_identifier_length
testing.db.dialect.max_identifier_length = IDENT_LENGTH
- def tearDown(self):
+ def teardown(self):
table1.delete().execute()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
testing.db.dialect.max_identifier_length = maxlen
def test_too_long_name_disallowed(self):
m = MetaData(testing.db)
t1 = Table("this_name_is_too_long_for_what_were_doing_in_this_test", m, Column('foo', Integer))
- self.assertRaises(exceptions.IdentifierError, m.create_all)
- self.assertRaises(exceptions.IdentifierError, m.drop_all)
- self.assertRaises(exceptions.IdentifierError, t1.create)
- self.assertRaises(exceptions.IdentifierError, t1.drop)
+ assert_raises(exceptions.IdentifierError, m.create_all)
+ assert_raises(exceptions.IdentifierError, m.drop_all)
+ assert_raises(exceptions.IdentifierError, t1.create)
+ assert_raises(exceptions.IdentifierError, t1.drop)
def test_result(self):
table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"})
"FROM some_large_named_table WHERE some_large_named_table.this_is_the_primarykey_column = :_1) AS _1", dialect=compile_dialect)
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
import datetime
from sqlalchemy import *
from sqlalchemy import exc, sql
from sqlalchemy.engine import default
-from testlib import *
-from testlib.testing import eq_
+from sqlalchemy.test import *
+from sqlalchemy.test.testing import eq_
class QueryTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global users, users2, addresses, metadata
metadata = MetaData(testing.db)
users = Table('query_users', metadata,
users.delete().execute()
users2.delete().execute()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def test_insert(self):
)
concat = ("test: " + users.c.user_name).label('thedata')
- self.assertEquals(
+ eq_(
select([concat]).order_by(concat).execute().fetchall(),
[("test: ed",), ("test: fred",), ("test: jack",)]
)
concat = ("test: " + users.c.user_name).label('thedata')
- self.assertEquals(
+ eq_(
select([concat]).order_by(desc(concat)).execute().fetchall(),
[("test: jack",), ("test: fred",), ("test: ed",)]
)
concat = ("test: " + users.c.user_name).label('thedata')
- self.assertEquals(
+ eq_(
select([concat]).order_by(concat + "x").execute().fetchall(),
[("test: ed",), ("test: fred",), ("test: jack",)]
)
def test_or_and_as_columns(self):
true, false = literal(True), literal(False)
- self.assertEquals(testing.db.execute(select([and_(true, false)])).scalar(), False)
- self.assertEquals(testing.db.execute(select([and_(true, true)])).scalar(), True)
- self.assertEquals(testing.db.execute(select([or_(true, false)])).scalar(), True)
- self.assertEquals(testing.db.execute(select([or_(false, false)])).scalar(), False)
- self.assertEquals(testing.db.execute(select([not_(or_(false, false))])).scalar(), True)
+ eq_(testing.db.execute(select([and_(true, false)])).scalar(), False)
+ eq_(testing.db.execute(select([and_(true, true)])).scalar(), True)
+ eq_(testing.db.execute(select([or_(true, false)])).scalar(), True)
+ eq_(testing.db.execute(select([or_(false, false)])).scalar(), False)
+ eq_(testing.db.execute(select([not_(or_(false, false))])).scalar(), True)
row = testing.db.execute(select([or_(false, false).label("x"), and_(true, false).label("y")])).fetchone()
assert row.x == False
{'user_id':4, 'user_name':'OnE'},
)
- self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('one')).execute().fetchall(), [(1, ), (3, ), (4, )])
+ eq_(select([users.c.user_id]).where(users.c.user_name.ilike('one')).execute().fetchall(), [(1, ), (3, ), (4, )])
- self.assertEquals(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )])
+ eq_(select([users.c.user_id]).where(users.c.user_name.ilike('TWO')).execute().fetchall(), [(2, )])
if testing.against('postgres'):
- self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )])
- self.assertEquals(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), [])
+ eq_(select([users.c.user_id]).where(users.c.user_name.like('one')).execute().fetchall(), [(1, )])
+ eq_(select([users.c.user_id]).where(users.c.user_name.like('TWO')).execute().fetchall(), [])
def test_compiled_execute(self):
def a_eq(executable, wanted):
got = list(executable.execute())
- self.assertEquals(got, wanted)
+ eq_(got, wanted)
for labels in False, True:
a_eq(users.select(order_by=[users.c.user_id],
def test_keys(self):
users.insert().execute(user_id=1, user_name='foo')
r = users.select().execute().fetchone()
- self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
+ eq_([x.lower() for x in r.keys()], ['user_id', 'user_name'])
def test_items(self):
users.insert().execute(user_id=1, user_name='foo')
r = users.select().execute().fetchone()
- self.assertEqual([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
+ eq_([(x[0].lower(), x[1]) for x in r.items()], [('user_id', 1), ('user_name', 'foo')])
def test_len(self):
users.insert().execute(user_id=1, user_name='foo')
r = users.select().execute().fetchone()
- self.assertEqual(len(r), 2)
+ eq_(len(r), 2)
r.close()
r = testing.db.execute('select user_name, user_id from query_users').fetchone()
- self.assertEqual(len(r), 2)
+ eq_(len(r), 2)
r.close()
r = testing.db.execute('select user_name from query_users').fetchone()
- self.assertEqual(len(r), 1)
+ eq_(len(r), 1)
r.close()
def test_cant_execute_join(self):
# should return values in column definition order
users.insert().execute(user_id=1, user_name='foo')
r = users.select(users.c.user_id==1).execute().fetchone()
- self.assertEqual(r[0], 1)
- self.assertEqual(r[1], 'foo')
- self.assertEqual([x.lower() for x in r.keys()], ['user_id', 'user_name'])
- self.assertEqual(r.values(), [1, 'foo'])
+ eq_(r[0], 1)
+ eq_(r[1], 'foo')
+ eq_([x.lower() for x in r.keys()], ['user_id', 'user_name'])
+ eq_(r.values(), [1, 'foo'])
def test_column_order_with_text_query(self):
# should return values in query order
users.insert().execute(user_id=1, user_name='foo')
r = testing.db.execute('select user_name, user_id from query_users').fetchone()
- self.assertEqual(r[0], 'foo')
- self.assertEqual(r[1], 1)
- self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id'])
- self.assertEqual(r.values(), ['foo', 1])
+ eq_(r[0], 'foo')
+ eq_(r[1], 1)
+ eq_([x.lower() for x in r.keys()], ['user_name', 'user_id'])
+ eq_(r.values(), ['foo', 1])
@testing.crashes('oracle', 'FIXME: unknown, varify not fails_on()')
@testing.crashes('firebird', 'An identifier must begin with a letter')
operation the same way we do for text() and column labels.
"""
+ @classmethod
@testing.crashes('mysql', 'mysqldb calls name % (params)')
@testing.crashes('postgres', 'postgres calls name % (params)')
- def setUpAll(self):
+ def setup_class(cls):
global percent_table, metadata
metadata = MetaData(testing.db)
percent_table = Table('percent%table', metadata,
)
metadata.create_all()
+ @classmethod
@testing.crashes('mysql', 'mysqldb calls name % (params)')
@testing.crashes('postgres', 'postgres calls name % (params)')
- def tearDownAll(self):
+ def teardown_class(cls):
metadata.drop_all()
@testing.crashes('mysql', 'mysqldb calls name % (params)')
class LimitTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global users, addresses, metadata
metadata = MetaData(testing.db)
users = Table('query_users', metadata,
Column('user_id', Integer, ForeignKey('query_users.user_id')),
Column('address', String(30)))
metadata.create_all()
- self._data()
-
- def _data(self):
+
users.insert().execute(user_id=1, user_name='john')
addresses.insert().execute(address_id=1, user_id=1, address='addr1')
users.insert().execute(user_id=2, user_name='jack')
users.insert().execute(user_id=7, user_name='fido')
addresses.insert().execute(address_id=7, user_id=7, address='addr5')
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def test_select_limit(self):
class CompoundTest(TestBase):
"""test compound statements like UNION, INTERSECT, particularly their ability to nest on
different databases."""
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, t1, t2, t3
metadata = MetaData(testing.db)
t1 = Table('t1', metadata,
dict(col2="t3col2r3", col3="ccc", col4="bbb"),
])
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def _fetchall_sorted(self, executed):
wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
('ccc', 'aaa')]
found1 = self._fetchall_sorted(u.execute())
- self.assertEquals(found1, wanted)
+ eq_(found1, wanted)
found2 = self._fetchall_sorted(u.alias('bar').select().execute())
- self.assertEquals(found2, wanted)
+ eq_(found2, wanted)
def test_union_ordered(self):
(s1, s2) = (
wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
('ccc', 'aaa')]
- self.assertEquals(u.execute().fetchall(), wanted)
+ eq_(u.execute().fetchall(), wanted)
@testing.fails_on('maxdb', 'FIXME: unknown')
@testing.requires.subqueries
wanted = [('aaa', 'aaa'), ('bbb', 'bbb'), ('bbb', 'ccc'),
('ccc', 'aaa')]
- self.assertEquals(u.alias('bar').select().execute().fetchall(), wanted)
+ eq_(u.alias('bar').select().execute().fetchall(), wanted)
@testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
@testing.fails_on('mysql', 'FIXME: unknown')
wanted = [('aaa',),('aaa',),('bbb',), ('bbb',), ('ccc',),('ccc',)]
found1 = self._fetchall_sorted(e.execute())
- self.assertEquals(found1, wanted)
+ eq_(found1, wanted)
found2 = self._fetchall_sorted(e.alias('foo').select().execute())
- self.assertEquals(found2, wanted)
+ eq_(found2, wanted)
@testing.crashes('firebird', 'Does not support intersect')
@testing.crashes('sybase', 'FIXME: unknown, verify not fails_on')
wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
found1 = self._fetchall_sorted(i.execute())
- self.assertEquals(found1, wanted)
+ eq_(found1, wanted)
found2 = self._fetchall_sorted(i.alias('bar').select().execute())
- self.assertEquals(found2, wanted)
+ eq_(found2, wanted)
@testing.crashes('firebird', 'Does not support except')
@testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
found = self._fetchall_sorted(e.alias('bar').select().execute())
- self.assertEquals(found, wanted)
+ eq_(found, wanted)
@testing.crashes('firebird', 'Does not support except')
@testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
found1 = self._fetchall_sorted(e.execute())
- self.assertEquals(found1, wanted)
+ eq_(found1, wanted)
found2 = self._fetchall_sorted(e.alias('bar').select().execute())
- self.assertEquals(found2, wanted)
+ eq_(found2, wanted)
@testing.crashes('firebird', 'Does not support except')
@testing.crashes('oracle', 'FIXME: unknown, verify not fails_on')
select([t3.c.col3], t3.c.col3 == 'ccc'), #ccc
)
)
- self.assertEquals(e.execute().fetchall(), [('ccc',)])
- self.assertEquals(e.alias('foo').select().execute().fetchall(),
+ eq_(e.execute().fetchall(), [('ccc',)])
+ eq_(e.alias('foo').select().execute().fetchall(),
[('ccc',)])
@testing.crashes('firebird', 'Does not support intersect')
wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
found = self._fetchall_sorted(u.execute())
- self.assertEquals(found, wanted)
+ eq_(found, wanted)
@testing.crashes('firebird', 'Does not support intersect')
@testing.fails_on('mysql', 'FIXME: unknown')
wanted = [('aaa', 'bbb'), ('bbb', 'ccc'), ('ccc', 'aaa')]
found = self._fetchall_sorted(ua.select().execute())
- self.assertEquals(found, wanted)
+ eq_(found, wanted)
class JoinTest(TestBase):
database seems to be sensitive to this.
"""
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata
global t1, t2, t3
{'t2_id': 21, 't1_id': 11, 'name': 't2 #21'})
t3.insert().execute({'t3_id': 30, 't2_id': 20, 'name': 't3 #30'})
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def assertRows(self, statement, expected):
found = sorted([tuple(row)
for row in statement.execute().fetchall()])
- self.assertEquals(found, sorted(expected))
+ eq_(found, sorted(expected))
def test_join_x1(self):
"""Joins t1->t2."""
class OperatorTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global metadata, flds
metadata = MetaData(testing.db)
flds = Table('flds', metadata,
dict(intcol=13, strcol='bar')
])
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
@testing.fails_on('maxdb', 'FIXME: unknown')
def test_modulo(self):
- self.assertEquals(
+ eq_(
select([flds.c.intcol % 3],
order_by=flds.c.idcol).execute().fetchall(),
[(2,),(1,)]
)
-
-
-
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy import sql
from sqlalchemy.sql import compiler
-from testlib import *
+from sqlalchemy.test import *
class QuoteTest(TestBase, AssertsCompiledSQL):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
# TODO: figure out which databases/which identifiers allow special
# characters to be used, such as: spaces, quote characters,
# punctuation characters, set up tests for those as well.
table1.create()
table2.create()
- def tearDown(self):
+ def teardown(self):
table1.delete().execute()
table2.delete().execute()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
table1.drop()
table2.drop()
a_eq(unformat('`foo`.bar'), ['foo', 'bar'])
a_eq(unformat('`foo`.`b``a``r`.`baz`'), ['foo', 'b`a`r', 'baz'])
-if __name__ == "__main__":
- testenv.main()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
class FoundRowsTest(TestBase, AssertsExecutionResults):
"""tests rowcount functionality"""
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
metadata = MetaData(testing.db)
global employees_table
)
employees_table.create()
- def setUp(self):
+ def setup(self):
global data
data = [ ('Angela', 'A'),
('Andrew', 'A'),
i = employees_table.insert()
i.execute(*[{'name':n, 'department':d} for n, d in data])
- def tearDown(self):
+ def teardown(self):
employees_table.delete().execute()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
employees_table.drop()
def testbasic(self):
if testing.db.dialect.supports_sane_rowcount:
assert r.rowcount == 3
-if __name__ == '__main__':
- testenv.main()
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import datetime, re, operator
from sqlalchemy import *
from sqlalchemy import exc, sql, util
from sqlalchemy.sql.expression import ClauseList
from sqlalchemy.engine import default
from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
-from testlib import *
+from sqlalchemy.test import *
table1 = table('mytable',
column('myid', Integer),
t2 = table('t2', column('c'), column('d'))
s = select([t.c.a]).where(t.c.a==t2.c.d).as_scalar()
s2 =select([t, t2, s])
- self.assertRaises(exc.InvalidRequestError, str, s2)
+ assert_raises(exc.InvalidRequestError, str, s2)
# intentional again
s = s.correlate(t, t2)
# check that conflicts with "unique" params are caught
s = select([table1], or_(table1.c.myid==7, table1.c.myid==bindparam('myid_1')))
- self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
+ assert_raises_message(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('myid_1')))
- self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
+ assert_raises_message(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
def test_binds_no_hash_collision(self):
"""test that construct_params doesn't corrupt dict due to hash collisions"""
)
def check_results(dialect, expected_results, literal):
- self.assertEqual(len(expected_results), 5, 'Incorrect number of expected results')
- self.assertEqual(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0])
- self.assertEqual(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1])
- self.assertEqual(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2])
- self.assertEqual(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
- self.assertEqual(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4]))
+ eq_(len(expected_results), 5, 'Incorrect number of expected results')
+ eq_(str(cast(tbl.c.v1, Numeric).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[0])
+ eq_(str(cast(tbl.c.v1, Numeric(12, 9)).compile(dialect=dialect)), 'CAST(casttest.v1 AS %s)' %expected_results[1])
+ eq_(str(cast(tbl.c.ts, Date).compile(dialect=dialect)), 'CAST(casttest.ts AS %s)' %expected_results[2])
+ eq_(str(cast(1234, TEXT).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[3]))
+ eq_(str(cast('test', String(20)).compile(dialect=dialect)), 'CAST(%s AS %s)' %(literal, expected_results[4]))
# fixme: shoving all of this dialect-specific stuff in one test
# is now officialy completely ridiculous AND non-obviously omits
# coverage on other dialects.
sel = select([tbl, cast(tbl.c.v1, Numeric)]).compile(dialect=dialect)
if isinstance(dialect, type(mysql.dialect())):
- self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest")
+ eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS DECIMAL(10, 2)) AS anon_1 \nFROM casttest")
else:
- self.assertEqual(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest")
+ eq_(str(sel), "SELECT casttest.id, casttest.v1, casttest.v2, casttest.ts, CAST(casttest.v1 AS NUMERIC(10, 2)) AS anon_1 \nFROM casttest")
# first test with Postgres engine
check_results(postgres.dialect(), ['NUMERIC(10, 2)', 'NUMERIC(12, 9)', 'DATE', 'TEXT', 'VARCHAR(20)'], '%(param_1)s')
self.assert_compile(table4.insert(values=(2, 5, 'test')), "INSERT INTO remote_owner.remotetable (rem_id, datatype_id, value) VALUES "\
"(:rem_id, :datatype_id, :value)")
-if __name__ == "__main__":
- testenv.main()
"""Test various algorithmic properties of selectables."""
-import testenv; testenv.configure_for_tests()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
from sqlalchemy import *
-from testlib import *
+from sqlalchemy.test import *
from sqlalchemy.sql import util as sql_util, visitors
from sqlalchemy import exc
from sqlalchemy.sql import table, column
s = select([t2, t3], use_labels=True)
- self.assertRaises(exc.NoReferencedTableError, s.join, t1)
+ assert_raises(exc.NoReferencedTableError, s.join, t1)
class PrimaryKeyTest(TestBase, AssertsExecutionResults):
def test_join_pk_collapse_implicit(self):
Column('id', Integer, ForeignKey( 'Employee.id', ), primary_key=True),
)
- self.assertEquals(
+ eq_(
util.column_set(employee.join(engineer, employee.c.id==engineer.c.id).primary_key),
util.column_set([employee.c.id])
)
- self.assertEquals(
+ eq_(
util.column_set(employee.join(engineer, engineer.c.id==employee.c.id).primary_key),
util.column_set([employee.c.id])
)
Column('t3data', String(30)))
- self.assertEquals(
+ eq_(
util.column_set(sql_util.reduce_columns([t1.c.t1id, t1.c.t1data, t2.c.t2id, t2.c.t2data, t3.c.t3id, t3.c.t3data])),
util.column_set([t1.c.t1id, t1.c.t1data, t2.c.t2data, t3.c.t3data])
)
s = select([engineers, managers]).where(engineers.c.engineer_name==managers.c.manager_name)
- self.assertEquals(util.column_set(sql_util.reduce_columns(list(s.c), s)),
+ eq_(util.column_set(sql_util.reduce_columns(list(s.c), s)),
util.column_set([s.c.engineer_id, s.c.engineer_name, s.c.manager_id])
)
)
pjoin = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin')
- self.assertEquals(
+ eq_(
util.column_set(sql_util.reduce_columns([pjoin.c.people_person_id, pjoin.c.engineers_person_id, pjoin.c.managers_person_id])),
util.column_set([pjoin.c.people_person_id])
)
'Item':base_item_table.join(item_table),
}, None, 'item_join')
- self.assertEquals(
+ eq_(
util.column_set(sql_util.reduce_columns([item_join.c.id, item_join.c.dummy, item_join.c.child_name])),
util.column_set([item_join.c.id, item_join.c.dummy, item_join.c.child_name])
)
'c': page_table.join(magazine_page_table).join(classified_page_table),
}, None, 'page_join')
- self.assertEquals(
+ eq_(
util.column_set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])),
util.column_set([pjoin.c.id])
)
assert b4.left is bin.left # since column is immutable
assert b4.right is not bin.right is not b2.right is not b3.right
-if __name__ == "__main__":
- testenv.main()
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import decimal
-import testenv; testenv.configure_for_tests()
-import datetime, os, pickleable, re
+import datetime, os, re
from sqlalchemy import *
from sqlalchemy import exc, types, util
from sqlalchemy.sql import operators
-from testlib.testing import eq_
+from sqlalchemy.test.testing import eq_
import sqlalchemy.engine.url as url
from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
-from testlib import *
+from sqlalchemy.test import *
class AdaptTest(TestBase):
l
):
for col in row[1:5]:
- self.assertEquals(col, assertstr)
- self.assertEquals(row[5], assertint)
- self.assertEquals(row[6], assertint2)
+ eq_(col, assertstr)
+ eq_(row[5], assertint)
+ eq_(row[6], assertint2)
for col in row[3], row[4]:
assert isinstance(col, unicode)
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global users, metadata
class MyType(types.TypeEngine):
metadata.create_all()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
class ColumnsTest(TestBase, AssertsExecutionResults):
)
for aCol in testTable.c:
- self.assertEquals(
+ eq_(
expectedResults[aCol.name],
db.dialect.schemagenerator(db.dialect, db, None, None).\
get_column_specification(aCol))
class UnicodeTest(TestBase, AssertsExecutionResults):
"""tests the Unicode type. also tests the TypeDecorator with instances in the types package."""
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global unicode_table
metadata = MetaData(testing.db)
unicode_table = Table('unicode_table', metadata,
Column('plain_varchar', String(250))
)
unicode_table.create()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
unicode_table.drop()
- def tearDown(self):
+ def teardown(self):
unicode_table.delete().execute()
def test_round_trip(self):
('mysql', '<', (4, 1, 1)), # screwy varbinary types
)
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global binary_table, MyPickleType
class MyPickleType(types.TypeDecorator):
)
binary_table.create()
- def tearDown(self):
+ def teardown(self):
binary_table.delete().execute()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
binary_table.drop()
@testing.fails_on('mssql', 'MSSQl BINARY type right pads the fixed length with \x00')
text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType, 'mypickle':MyPickleType}, bind=testing.db)
):
l = stmt.execute().fetchall()
- self.assertEquals(list(stream1), list(l[0]['data']))
- self.assertEquals(list(stream1[0:100]), list(l[0]['data_slice']))
- self.assertEquals(list(stream2), list(l[1]['data']))
- self.assertEquals(testobj1, l[0]['pickled'])
- self.assertEquals(testobj2, l[1]['pickled'])
- self.assertEquals(testobj3.moredata, l[0]['mypickle'].moredata)
- self.assertEquals(l[0]['mypickle'].stuff, 'this is the right stuff')
+ eq_(list(stream1), list(l[0]['data']))
+ eq_(list(stream1[0:100]), list(l[0]['data_slice']))
+ eq_(list(stream2), list(l[1]['data']))
+ eq_(testobj1, l[0]['pickled'])
+ eq_(testobj2, l[1]['pickled'])
+ eq_(testobj3.moredata, l[0]['mypickle'].moredata)
+ eq_(l[0]['mypickle'].stuff, 'this is the right stuff')
def load_stream(self, name, len=12579):
- f = os.path.join(os.path.dirname(testenv.__file__), name)
+ f = os.path.join(os.path.dirname(__file__), "..", name)
# put a number less than the typical MySQL default BLOB size
return file(f).read(len)
class ExpressionTest(TestBase, AssertsExecutionResults):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global test_table, meta
class MyCustomType(types.TypeEngine):
test_table.insert().execute({'id':1, 'data':'somedata', 'atimestamp':datetime.date(2007, 10, 15), 'avalue':25})
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
meta.drop_all()
def test_control(self):
assert testing.db.execute(select([expr])).scalar() == -15
class DateTest(TestBase, AssertsExecutionResults):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global users_with_date, insert_data
db = testing.db
for idict in insert_dicts:
users_with_date.insert().execute(**idict)
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
users_with_date.drop()
def testdate(self):
# test mismatched date/datetime
t.insert().execute(adate=d2, adatetime=d2)
- self.assertEquals(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)])
- self.assertEquals(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)])
+ eq_(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)])
+ eq_(select([t.c.adate, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, d2)])
finally:
t.drop(checkfirst=True)
return True
class NumericTest(TestBase, AssertsExecutionResults):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global numeric_table, metadata
metadata = MetaData(testing.db)
numeric_table = Table('numeric_table', metadata,
)
metadata.create_all()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
- def tearDown(self):
+ def teardown(self):
numeric_table.delete().execute()
@testing.fails_if(_missing_decimal)
assert isinstance(row['fcasdec'], decimal.Decimal)
def test_length_deprecation(self):
- self.assertRaises(exc.SADeprecationWarning, Numeric, length=8)
+ assert_raises(exc.SADeprecationWarning, Numeric, length=8)
@testing.uses_deprecated(".*is deprecated for Numeric")
def go():
class IntervalTest(TestBase, AssertsExecutionResults):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global interval_table, metadata
metadata = MetaData(testing.db)
interval_table = Table("intervaltable", metadata,
)
metadata.create_all()
- def tearDown(self):
+ def teardown(self):
interval_table.delete().execute()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
metadata.drop_all()
def test_roundtrip(self):
assert interval_table.select().execute().fetchone()['interval'] is None
class BooleanTest(TestBase, AssertsExecutionResults):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global bool_table
metadata = MetaData(testing.db)
bool_table = Table('booltest', metadata,
Column('id', Integer, primary_key=True),
Column('value', Boolean))
bool_table.create()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
bool_table.drop()
def testbasic(self):
bool_table.insert().execute(id=1, value=True)
def test_noeq_deprecation(self):
p1 = PickleType()
- self.assertRaises(DeprecationWarning,
+ assert_raises(DeprecationWarning,
p1.compare_values, pickleable.BarWithoutCompare(1, 2), pickleable.BarWithoutCompare(1, 2)
)
- self.assertRaises(DeprecationWarning,
+ assert_raises(DeprecationWarning,
p1.compare_values, pickleable.OldSchoolWithoutCompare(1, 2), pickleable.OldSchoolWithoutCompare(1, 2)
)
):
assert p1.compare_values(p1.copy_value(obj), obj)
- self.assertRaises(NotImplementedError, p1.compare_values, pickleable.BrokenComparable('foo'),pickleable.BrokenComparable('foo'))
+ assert_raises(NotImplementedError, p1.compare_values, pickleable.BrokenComparable('foo'),pickleable.BrokenComparable('foo'))
def test_nonmutable_comparison(self):
p1 = PickleType()
assert p1.compare_values(p1.copy_value(obj), obj)
class CallableTest(TestBase):
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global meta
meta = MetaData(testing.db)
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
meta.drop_all()
def test_callable_as_arg(self):
assert isinstance(thang_table.c.name.type, Unicode)
thang_table.create()
-if __name__ == "__main__":
- testenv.main()
# coding: utf-8
"""verrrrry basic unicode column name testing"""
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
-from testlib import *
-from testlib.engines import utf8_engine
+from sqlalchemy.test import *
+from sqlalchemy.test.engines import utf8_engine
from sqlalchemy.sql import column
class UnicodeSchemaTest(TestBase):
__requires__ = ('unicode_ddl',)
- def setUpAll(self):
+ @classmethod
+ def setup_class(cls):
global unicode_bind, metadata, t1, t2, t3
unicode_bind = utf8_engine()
)
metadata.create_all()
- def tearDown(self):
+ def teardown(self):
if metadata.tables:
t3.delete().execute()
t2.delete().execute()
t1.delete().execute()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
global unicode_bind
metadata.drop_all()
del unicode_bind
t1.drop()
-if __name__ == '__main__':
- testenv.main()
+++ /dev/null
-"""First import for all test cases, sets sys.path and loads configuration."""
-
-import sys, os, logging, warnings
-
-if sys.version_info < (2, 4):
- warnings.filterwarnings('ignore', category=FutureWarning)
-
-
-from testlib.testing import main
-import testlib.config
-
-
-_setup = False
-
-def configure_for_tests():
- """import testenv; testenv.configure_for_tests()"""
-
- global _setup
- if not _setup:
- sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
- logging.basicConfig()
-
- testlib.config.configure()
- _setup = True
-
-def simple_setup():
- """import testenv; testenv.simple_setup()"""
-
- global _setup
- if not _setup:
- sys.path.insert(0, os.path.join(os.getcwd(), 'lib'))
- logging.basicConfig()
-
- testlib.config.configure_defaults()
- _setup = True
-
+++ /dev/null
-"""Enhance unittest and instrument SQLAlchemy classes for testing.
-
-Load after sqlalchemy imports to use instrumented stand-ins like Table.
-"""
-
-import sys
-import testlib.config
-from testlib.schema import Table, Column
-import testlib.testing as testing
-from testlib.testing import \
- AssertsCompiledSQL, \
- AssertsExecutionResults, \
- ComparesTables, \
- TestBase, \
- rowset
-from testlib.orm import mapper
-import testlib.profiling as profiling
-import testlib.engines as engines
-import testlib.requires as requires
-from testlib.compat import _function_named
-
-
-__all__ = ('testing',
- 'mapper',
- 'Table', 'Column',
- 'rowset',
- 'TestBase', 'AssertsExecutionResults',
- 'AssertsCompiledSQL', 'ComparesTables',
- 'profiling', 'engines',
- '_function_named')
-
-
-testing.requires = requires
-
-sys.modules['testlib.sa'] = sa = testing.CompositeModule(
- 'testlib.sa', 'sqlalchemy', 'testlib.schema', orm=testing.CompositeModule(
- 'testlib.sa.orm', 'sqlalchemy.orm', 'testlib.orm'))
-sys.modules['testlib.sa.orm'] = sa.orm
+++ /dev/null
-import types
-import __builtin__
-
-__all__ = '_function_named', 'callable'
-
-
-def _function_named(fn, newname):
- try:
- fn.__name__ = newname
- except:
- fn = types.FunctionType(fn.func_code, fn.func_globals, newname,
- fn.func_defaults, fn.func_closure)
- return fn
-
-try:
- callable = __builtin__.callable
-except NameError:
- def callable(fn): return hasattr(fn, '__call__')
-
+++ /dev/null
-import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
-logging, require = None, None
-
-
-__all__ = 'parser', 'configure', 'options',
-
-db = None
-db_label, db_url, db_opts = None, None, {}
-
-options = None
-file_config = None
-coverage_enabled = False
-
-base_config = """
-[db]
-sqlite=sqlite:///:memory:
-sqlite_file=sqlite:///querytest.db
-postgres=postgres://scott:tiger@127.0.0.1:5432/test
-mysql=mysql://scott:tiger@127.0.0.1:3306/test
-oracle=oracle://scott:tiger@127.0.0.1:1521
-oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
-mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
-firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb
-maxdb=maxdb://MONA:RED@/maxdb1
-"""
-
-parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")
-
-def configure():
- global options, config
- global getopts_options, file_config
-
- file_config = ConfigParser.ConfigParser()
- file_config.readfp(StringIO.StringIO(base_config))
- file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
-
- # Opt parsing can fire immediate actions, like logging and coverage
- (options, args) = parser.parse_args()
- sys.argv[1:] = args
-
- # Lazy setup of other options (post coverage)
- for fn in post_configure:
- fn(options, file_config)
-
- return options, file_config
-
-def configure_defaults():
- global options, config
- global getopts_options, file_config
- global db
-
- file_config = ConfigParser.ConfigParser()
- file_config.readfp(StringIO.StringIO(base_config))
- file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
- (options, args) = parser.parse_args([])
-
- # make error messages raised by decorators that depend on a default
- # database clearer.
- class _engine_bomb(object):
- def __getattr__(self, key):
- raise RuntimeError('No default engine available, testlib '
- 'was configured with defaults only.')
-
- db = _engine_bomb()
- import testlib.testing
- testlib.testing.db = db
-
- return options, file_config
-
-def _log(option, opt_str, value, parser):
- global logging
- if not logging:
- import logging
- logging.basicConfig()
-
- if opt_str.endswith('-info'):
- logging.getLogger(value).setLevel(logging.INFO)
- elif opt_str.endswith('-debug'):
- logging.getLogger(value).setLevel(logging.DEBUG)
-
-def _start_cumulative_coverage(option, opt_str, value, parser):
- _start_coverage(option, opt_str, value, parser, erase=False)
-
-def _start_coverage(option, opt_str, value, parser, erase=True):
- import sys, atexit, coverage
- true_out = sys.stdout
-
- global coverage_enabled
- coverage_enabled = True
-
- def _iter_covered_files(mod, recursive=True):
-
- if recursive:
- ff = os.walk
- else:
- ff = os.listdir
-
- for rec in ff(os.path.dirname(mod.__file__)):
- for x in rec[2]:
- if x.endswith('.py'):
- yield os.path.join(rec[0], x)
-
- def _stop():
- coverage.stop()
- true_out.write("\nPreparing coverage report...\n")
-
- from sqlalchemy import sql, orm, engine, \
- ext, databases, log
-
- import sqlalchemy
-
- for modset in [
- _iter_covered_files(sqlalchemy, recursive=False),
- _iter_covered_files(databases),
- _iter_covered_files(engine),
- _iter_covered_files(ext),
- _iter_covered_files(orm),
- ]:
- coverage.report(list(modset),
- show_missing=False, ignore_errors=False,
- file=true_out)
- atexit.register(_stop)
- if erase:
- coverage.erase()
- coverage.start()
-
-def _list_dbs(*args):
- print "Available --db options (use --dburi to override)"
- for macro in sorted(file_config.options('db')):
- print "%20s\t%s" % (macro, file_config.get('db', macro))
- sys.exit(0)
-
-def _server_side_cursors(options, opt_str, value, parser):
- db_opts['server_side_cursors'] = True
-
-def _engine_strategy(options, opt_str, value, parser):
- if value:
- db_opts['strategy'] = value
-
-opt = parser.add_option
-opt("--verbose", action="store_true", dest="verbose",
- help="enable stdout echoing/printing")
-opt("--quiet", action="store_true", dest="quiet", help="suppress output")
-opt("--log-info", action="callback", type="string", callback=_log,
- help="turn on info logging for <LOG> (multiple OK)")
-opt("--log-debug", action="callback", type="string", callback=_log,
- help="turn on debug logging for <LOG> (multiple OK)")
-opt("--require", action="append", dest="require", default=[],
- help="require a particular driver or module version (multiple OK)")
-opt("--db", action="store", dest="db", default="sqlite",
- help="Use prefab database uri")
-opt('--dbs', action='callback', callback=_list_dbs,
- help="List available prefab dbs")
-opt("--dburi", action="store", dest="dburi",
- help="Database uri (overrides --db)")
-opt("--dropfirst", action="store_true", dest="dropfirst",
- help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)")
-opt("--mockpool", action="store_true", dest="mockpool",
- help="Use mock pool (asserts only one connection used)")
-opt("--enginestrategy", action="callback", type="string",
- callback=_engine_strategy,
- help="Engine strategy (plain or threadlocal, defaults to plain)")
-opt("--reversetop", action="store_true", dest="reversetop", default=False,
- help="Reverse the collection ordering for topological sorts (helps "
- "reveal dependency issues)")
-opt("--unhashable", action="store_true", dest="unhashable", default=False,
- help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
-opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
- help="Disallow SQLAlchemy from performing == on mapped test objects.")
-opt("--truthless", action="store_true", dest="truthless", default=False,
- help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
-opt("--serverside", action="callback", callback=_server_side_cursors,
- help="Turn on server side cursors for PG")
-opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
- help="Use the specified MySQL storage engine for all tables, default is "
- "a db-default/InnoDB combo.")
-opt("--table-option", action="append", dest="tableopts", default=[],
- help="Add a dialect-specific table option, key=value")
-opt("--coverage", action="callback", callback=_start_coverage,
- help="Dump a full coverage report after running tests")
-opt("--cumulative-coverage", action="callback", callback=_start_cumulative_coverage,
- help="Like --coverage, but accumlate coverage into the current DB")
-opt("--profile", action="append", dest="profile_targets", default=[],
- help="Enable a named profile target (multiple OK.)")
-opt("--profile-sort", action="store", dest="profile_sort", default=None,
- help="Sort profile stats with this comma-separated sort order")
-opt("--profile-limit", type="int", action="store", dest="profile_limit",
- default=None,
- help="Limit function count in profile stats")
-
-class _ordered_map(object):
- def __init__(self):
- self._keys = list()
- self._data = dict()
-
- def __setitem__(self, key, value):
- if key not in self._keys:
- self._keys.append(key)
- self._data[key] = value
-
- def __iter__(self):
- for key in self._keys:
- yield self._data[key]
-
-# at one point in refactoring, modules were injecting into the config
-# process. this could probably just become a list now.
-post_configure = _ordered_map()
-
-def _engine_uri(options, file_config):
- global db_label, db_url
- db_label = 'sqlite'
- if options.dburi:
- db_url = options.dburi
- db_label = db_url[:db_url.index(':')]
- elif options.db:
- db_label = options.db
- db_url = None
-
- if db_url is None:
- if db_label not in file_config.options('db'):
- raise RuntimeError(
- "Unknown engine. Specify --dbs for known engines.")
- db_url = file_config.get('db', db_label)
-post_configure['engine_uri'] = _engine_uri
-
-def _require(options, file_config):
- if not(options.require or
- (file_config.has_section('require') and
- file_config.items('require'))):
- return
-
- try:
- import pkg_resources
- except ImportError:
- raise RuntimeError("setuptools is required for version requirements")
-
- cmdline = []
- for requirement in options.require:
- pkg_resources.require(requirement)
- cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
-
- if file_config.has_section('require'):
- for label, requirement in file_config.items('require'):
- if not label == db_label or label.startswith('%s.' % db_label):
- continue
- seen = [c for c in cmdline if requirement.startswith(c)]
- if seen:
- continue
- pkg_resources.require(requirement)
-post_configure['require'] = _require
-
-def _engine_pool(options, file_config):
- if options.mockpool:
- from sqlalchemy import pool
- db_opts['poolclass'] = pool.AssertionPool
-post_configure['engine_pool'] = _engine_pool
-
-def _create_testing_engine(options, file_config):
- from testlib import engines, testing
- global db
- db = engines.testing_engine(db_url, db_opts)
- testing.db = db
-post_configure['create_engine'] = _create_testing_engine
-
-def _prep_testing_database(options, file_config):
- from testlib import engines
- from sqlalchemy import schema
-
- try:
- # also create alt schemas etc. here?
- if options.dropfirst:
- e = engines.utf8_engine()
- existing = e.table_names()
- if existing:
- if not options.quiet:
- print "Dropping existing tables in database: " + db_url
- try:
- print "Tables: %s" % ', '.join(existing)
- except:
- pass
- print "Abort within 5 seconds..."
- time.sleep(5)
- md = schema.MetaData(e, reflect=True)
- md.drop_all()
- e.dispose()
- except (KeyboardInterrupt, SystemExit):
- raise
- except Exception, e:
- if not options.quiet:
- warnings.warn(RuntimeWarning(
- "Error checking for existing tables in testing "
- "database: %s" % e))
-post_configure['prep_db'] = _prep_testing_database
-
-def _set_table_options(options, file_config):
- import testlib.schema
-
- table_options = testlib.schema.table_options
- for spec in options.tableopts:
- key, value = spec.split('=')
- table_options[key] = value
-
- if options.mysql_engine:
- table_options['mysql_engine'] = options.mysql_engine
-post_configure['table_options'] = _set_table_options
-
-def _reverse_topological(options, file_config):
- if options.reversetop:
- from sqlalchemy.orm import unitofwork
- from sqlalchemy import topological
- class RevQueueDepSort(topological.QueueDependencySorter):
- def __init__(self, tuples, allitems):
- self.tuples = list(tuples)
- self.allitems = list(allitems)
- self.tuples.reverse()
- self.allitems.reverse()
- topological.QueueDependencySorter = RevQueueDepSort
- unitofwork.DependencySorter = RevQueueDepSort
-post_configure['topological'] = _reverse_topological
-
-def _set_profile_targets(options, file_config):
- from testlib import profiling
-
- profile_config = profiling.profile_config
-
- for target in options.profile_targets:
- profile_config['targets'].add(target)
-
- if options.profile_sort:
- profile_config['sort'] = options.profile_sort.split(',')
-
- if options.profile_limit:
- profile_config['limit'] = options.profile_limit
-
- if options.quiet:
- profile_config['report'] = False
-
- # magic "all" target
- if 'all' in profiling.all_targets:
- targets = profile_config['targets']
- if 'all' in targets and len(targets) != 1:
- targets.clear()
- targets.add('all')
-post_configure['profile_targets'] = _set_profile_targets
+++ /dev/null
-#!/usr/bin/python
-#
-# Perforce Defect Tracking Integration Project
-# <http://www.ravenbrook.com/project/p4dti/>
-#
-# COVERAGE.PY -- COVERAGE TESTING
-#
-# Gareth Rees, Ravenbrook Limited, 2001-12-04
-# Ned Batchelder, 2004-12-12
-# http://nedbatchelder.com/code/modules/coverage.html
-#
-#
-# 1. INTRODUCTION
-#
-# This module provides coverage testing for Python code.
-#
-# The intended readership is all Python developers.
-#
-# This document is not confidential.
-#
-# See [GDR 2001-12-04a] for the command-line interface, programmatic
-# interface and limitations. See [GDR 2001-12-04b] for requirements and
-# design.
-
-r"""\
-Usage:
-
-coverage.py -x [-p] MODULE.py [ARG1 ARG2 ...]
- Execute module, passing the given command-line arguments, collecting
- coverage data. With the -p option, write to a temporary file containing
- the machine name and process ID.
-
-coverage.py -e
- Erase collected coverage data.
-
-coverage.py -c
- Collect data from multiple coverage files (as created by -p option above)
- and store it into a single file representing the union of the coverage.
-
-coverage.py -r [-m] [-o dir1,dir2,...] FILE1 FILE2 ...
- Report on the statement coverage for the given files. With the -m
- option, show line numbers of the statements that weren't executed.
-
-coverage.py -a [-d dir] [-o dir1,dir2,...] FILE1 FILE2 ...
- Make annotated copies of the given files, marking statements that
- are executed with > and statements that are missed with !. With
- the -d option, make the copies in that directory. Without the -d
- option, make each copy in the same directory as the original.
-
--o dir,dir2,...
- Omit reporting or annotating files when their filename path starts with
- a directory listed in the omit list.
- e.g. python coverage.py -i -r -o c:\python23,lib\enthought\traits
-
-Coverage data is saved in the file .coverage by default. Set the
-COVERAGE_FILE environment variable to save it somewhere else."""
-
-__version__ = "2.75.20070722" # see detailed history at the end of this file.
-
-import compiler
-import compiler.visitor
-import glob
-import os
-import re
-import string
-import symbol
-import sys
-import threading
-import token
-import types
-from socket import gethostname
-
-
-# 2. IMPLEMENTATION
-#
-# This uses the "singleton" pattern.
-#
-# The word "morf" means a module object (from which the source file can
-# be deduced by suitable manipulation of the __file__ attribute) or a
-# filename.
-#
-# When we generate a coverage report we have to canonicalize every
-# filename in the coverage dictionary just in case it refers to the
-# module we are reporting on. It seems a shame to throw away this
-# information so the data in the coverage dictionary is transferred to
-# the 'cexecuted' dictionary under the canonical filenames.
-#
-# The coverage dictionary is called "c" and the trace function "t". The
-# reason for these short names is that Python looks up variables by name
-# at runtime and so execution time depends on the length of variables!
-# In the bottleneck of this application it's appropriate to abbreviate
-# names to increase speed.
-
-class StatementFindingAstVisitor(compiler.visitor.ASTVisitor):
- """ A visitor for a parsed Abstract Syntax Tree which finds executable
- statements.
- """
- def __init__(self, statements, excluded, suite_spots):
- compiler.visitor.ASTVisitor.__init__(self)
- self.statements = statements
- self.excluded = excluded
- self.suite_spots = suite_spots
- self.excluding_suite = 0
-
- def doRecursive(self, node):
- for n in node.getChildNodes():
- self.dispatch(n)
-
- visitStmt = visitModule = doRecursive
-
- def doCode(self, node):
- if hasattr(node, 'decorators') and node.decorators:
- self.dispatch(node.decorators)
- self.recordAndDispatch(node.code)
- else:
- self.doSuite(node, node.code)
-
- visitFunction = visitClass = doCode
-
- def getFirstLine(self, node):
- # Find the first line in the tree node.
- lineno = node.lineno
- for n in node.getChildNodes():
- f = self.getFirstLine(n)
- if lineno and f:
- lineno = min(lineno, f)
- else:
- lineno = lineno or f
- return lineno
-
- def getLastLine(self, node):
- # Find the first line in the tree node.
- lineno = node.lineno
- for n in node.getChildNodes():
- lineno = max(lineno, self.getLastLine(n))
- return lineno
-
- def doStatement(self, node):
- self.recordLine(self.getFirstLine(node))
-
- visitAssert = visitAssign = visitAssTuple = visitPrint = \
- visitPrintnl = visitRaise = visitSubscript = visitDecorators = \
- doStatement
-
- def visitPass(self, node):
- # Pass statements have weird interactions with docstrings. If this
- # pass statement is part of one of those pairs, claim that the statement
- # is on the later of the two lines.
- l = node.lineno
- if l:
- lines = self.suite_spots.get(l, [l,l])
- self.statements[lines[1]] = 1
-
- def visitDiscard(self, node):
- # Discard nodes are statements that execute an expression, but then
- # discard the results. This includes function calls, so we can't
- # ignore them all. But if the expression is a constant, the statement
- # won't be "executed", so don't count it now.
- if node.expr.__class__.__name__ != 'Const':
- self.doStatement(node)
-
- def recordNodeLine(self, node):
- # Stmt nodes often have None, but shouldn't claim the first line of
- # their children (because the first child might be an ignorable line
- # like "global a").
- if node.__class__.__name__ != 'Stmt':
- return self.recordLine(self.getFirstLine(node))
- else:
- return 0
-
- def recordLine(self, lineno):
- # Returns a bool, whether the line is included or excluded.
- if lineno:
- # Multi-line tests introducing suites have to get charged to their
- # keyword.
- if lineno in self.suite_spots:
- lineno = self.suite_spots[lineno][0]
- # If we're inside an excluded suite, record that this line was
- # excluded.
- if self.excluding_suite:
- self.excluded[lineno] = 1
- return 0
- # If this line is excluded, or suite_spots maps this line to
- # another line that is exlcuded, then we're excluded.
- elif self.excluded.has_key(lineno) or \
- self.suite_spots.has_key(lineno) and \
- self.excluded.has_key(self.suite_spots[lineno][1]):
- return 0
- # Otherwise, this is an executable line.
- else:
- self.statements[lineno] = 1
- return 1
- return 0
-
- default = recordNodeLine
-
- def recordAndDispatch(self, node):
- self.recordNodeLine(node)
- self.dispatch(node)
-
- def doSuite(self, intro, body, exclude=0):
- exsuite = self.excluding_suite
- if exclude or (intro and not self.recordNodeLine(intro)):
- self.excluding_suite = 1
- self.recordAndDispatch(body)
- self.excluding_suite = exsuite
-
- def doPlainWordSuite(self, prevsuite, suite):
- # Finding the exclude lines for else's is tricky, because they aren't
- # present in the compiler parse tree. Look at the previous suite,
- # and find its last line. If any line between there and the else's
- # first line are excluded, then we exclude the else.
- lastprev = self.getLastLine(prevsuite)
- firstelse = self.getFirstLine(suite)
- for l in range(lastprev+1, firstelse):
- if self.suite_spots.has_key(l):
- self.doSuite(None, suite, exclude=self.excluded.has_key(l))
- break
- else:
- self.doSuite(None, suite)
-
- def doElse(self, prevsuite, node):
- if node.else_:
- self.doPlainWordSuite(prevsuite, node.else_)
-
- def visitFor(self, node):
- self.doSuite(node, node.body)
- self.doElse(node.body, node)
-
- visitWhile = visitFor
-
- def visitIf(self, node):
- # The first test has to be handled separately from the rest.
- # The first test is credited to the line with the "if", but the others
- # are credited to the line with the test for the elif.
- self.doSuite(node, node.tests[0][1])
- for t, n in node.tests[1:]:
- self.doSuite(t, n)
- self.doElse(node.tests[-1][1], node)
-
- def visitTryExcept(self, node):
- self.doSuite(node, node.body)
- for i in range(len(node.handlers)):
- a, b, h = node.handlers[i]
- if not a:
- # It's a plain "except:". Find the previous suite.
- if i > 0:
- prev = node.handlers[i-1][2]
- else:
- prev = node.body
- self.doPlainWordSuite(prev, h)
- else:
- self.doSuite(a, h)
- self.doElse(node.handlers[-1][2], node)
-
- def visitTryFinally(self, node):
- self.doSuite(node, node.body)
- self.doPlainWordSuite(node.body, node.final)
-
- def visitGlobal(self, node):
- # "global" statements don't execute like others (they don't call the
- # trace function), so don't record their line numbers.
- pass
-
-the_coverage = None
-
-class CoverageException(Exception): pass
-
-class coverage:
- # Name of the cache file (unless environment variable is set).
- cache_default = ".coverage"
-
- # Environment variable naming the cache file.
- cache_env = "COVERAGE_FILE"
-
- # A dictionary with an entry for (Python source file name, line number
- # in that file) if that line has been executed.
- c = {}
-
- # A map from canonical Python source file name to a dictionary in
- # which there's an entry for each line number that has been
- # executed.
- cexecuted = {}
-
- # Cache of results of calling the analysis2() method, so that you can
- # specify both -r and -a without doing double work.
- analysis_cache = {}
-
- # Cache of results of calling the canonical_filename() method, to
- # avoid duplicating work.
- canonical_filename_cache = {}
-
- def __init__(self):
- global the_coverage
- if the_coverage:
- raise CoverageException, "Only one coverage object allowed."
- self.usecache = 1
- self.cache = None
- self.parallel_mode = False
- self.exclude_re = ''
- self.nesting = 0
- self.cstack = []
- self.xstack = []
- self.relative_dir = os.path.normcase(os.path.abspath(os.curdir)+os.sep)
- self.exclude('# *pragma[: ]*[nN][oO] *[cC][oO][vV][eE][rR]')
-
- # t(f, x, y). This method is passed to sys.settrace as a trace function.
- # See [van Rossum 2001-07-20b, 9.2] for an explanation of sys.settrace and
- # the arguments and return value of the trace function.
- # See [van Rossum 2001-07-20a, 3.2] for a description of frame and code
- # objects.
-
- def t(self, f, w, unused): #pragma: no cover
- if w == 'line':
- #print "Executing %s @ %d" % (f.f_code.co_filename, f.f_lineno)
- self.c[(f.f_code.co_filename, f.f_lineno)] = 1
- for c in self.cstack:
- c[(f.f_code.co_filename, f.f_lineno)] = 1
- return self.t
-
- def help(self, error=None): #pragma: no cover
- if error:
- print error
- print
- print __doc__
- sys.exit(1)
-
- def command_line(self, argv, help_fn=None):
- import getopt
- help_fn = help_fn or self.help
- settings = {}
- optmap = {
- '-a': 'annotate',
- '-c': 'collect',
- '-d:': 'directory=',
- '-e': 'erase',
- '-h': 'help',
- '-i': 'ignore-errors',
- '-m': 'show-missing',
- '-p': 'parallel-mode',
- '-r': 'report',
- '-x': 'execute',
- '-o:': 'omit=',
- }
- short_opts = string.join(map(lambda o: o[1:], optmap.keys()), '')
- long_opts = optmap.values()
- options, args = getopt.getopt(argv, short_opts, long_opts)
- for o, a in options:
- if optmap.has_key(o):
- settings[optmap[o]] = 1
- elif optmap.has_key(o + ':'):
- settings[optmap[o + ':']] = a
- elif o[2:] in long_opts:
- settings[o[2:]] = 1
- elif o[2:] + '=' in long_opts:
- settings[o[2:]+'='] = a
- else: #pragma: no cover
- pass # Can't get here, because getopt won't return anything unknown.
-
- if settings.get('help'):
- help_fn()
-
- for i in ['erase', 'execute']:
- for j in ['annotate', 'report', 'collect']:
- if settings.get(i) and settings.get(j):
- help_fn("You can't specify the '%s' and '%s' "
- "options at the same time." % (i, j))
-
- args_needed = (settings.get('execute')
- or settings.get('annotate')
- or settings.get('report'))
- action = (settings.get('erase')
- or settings.get('collect')
- or args_needed)
- if not action:
- help_fn("You must specify at least one of -e, -x, -c, -r, or -a.")
- if not args_needed and args:
- help_fn("Unexpected arguments: %s" % " ".join(args))
-
- self.parallel_mode = settings.get('parallel-mode')
- self.get_ready()
-
- if settings.get('erase'):
- self.erase()
- if settings.get('execute'):
- if not args:
- help_fn("Nothing to do.")
- sys.argv = args
- self.start()
- import __main__
- sys.path[0] = os.path.dirname(sys.argv[0])
- execfile(sys.argv[0], __main__.__dict__)
- if settings.get('collect'):
- self.collect()
- if not args:
- args = self.cexecuted.keys()
-
- ignore_errors = settings.get('ignore-errors')
- show_missing = settings.get('show-missing')
- directory = settings.get('directory=')
-
- omit = settings.get('omit=')
- if omit is not None:
- omit = omit.split(',')
- else:
- omit = []
-
- if settings.get('report'):
- self.report(args, show_missing, ignore_errors, omit_prefixes=omit)
- if settings.get('annotate'):
- self.annotate(args, directory, ignore_errors, omit_prefixes=omit)
-
- def use_cache(self, usecache, cache_file=None):
- self.usecache = usecache
- if cache_file and not self.cache:
- self.cache_default = cache_file
-
- def get_ready(self, parallel_mode=False):
- if self.usecache and not self.cache:
- self.cache = os.environ.get(self.cache_env, self.cache_default)
- if self.parallel_mode:
- self.cache += "." + gethostname() + "." + str(os.getpid())
- self.restore()
- self.analysis_cache = {}
-
- def start(self, parallel_mode=False):
- self.get_ready()
- if self.nesting == 0: #pragma: no cover
- sys.settrace(self.t)
- if hasattr(threading, 'settrace'):
- threading.settrace(self.t)
- self.nesting += 1
-
- def stop(self):
- self.nesting -= 1
- if self.nesting == 0: #pragma: no cover
- sys.settrace(None)
- if hasattr(threading, 'settrace'):
- threading.settrace(None)
-
- def erase(self):
- self.get_ready()
- self.c = {}
- self.analysis_cache = {}
- self.cexecuted = {}
- if self.cache and os.path.exists(self.cache):
- os.remove(self.cache)
-
- def exclude(self, re):
- if self.exclude_re:
- self.exclude_re += "|"
- self.exclude_re += "(" + re + ")"
-
- def begin_recursive(self):
- self.cstack.append(self.c)
- self.xstack.append(self.exclude_re)
-
- def end_recursive(self):
- self.c = self.cstack.pop()
- self.exclude_re = self.xstack.pop()
-
- # save(). Save coverage data to the coverage cache.
-
- def save(self):
- if self.usecache and self.cache:
- self.canonicalize_filenames()
- cache = open(self.cache, 'wb')
- import marshal
- marshal.dump(self.cexecuted, cache)
- cache.close()
-
- # restore(). Restore coverage data from the coverage cache (if it exists).
-
- def restore(self):
- self.c = {}
- self.cexecuted = {}
- assert self.usecache
- if os.path.exists(self.cache):
- self.cexecuted = self.restore_file(self.cache)
-
- def restore_file(self, file_name):
- try:
- cache = open(file_name, 'rb')
- import marshal
- cexecuted = marshal.load(cache)
- cache.close()
- if isinstance(cexecuted, types.DictType):
- return cexecuted
- else:
- return {}
- except:
- return {}
-
- # collect(). Collect data in multiple files produced by parallel mode
-
- def collect(self):
- cache_dir, local = os.path.split(self.cache)
- for f in os.listdir(cache_dir or '.'):
- if not f.startswith(local):
- continue
-
- full_path = os.path.join(cache_dir, f)
- cexecuted = self.restore_file(full_path)
- self.merge_data(cexecuted)
-
- def merge_data(self, new_data):
- for file_name, file_data in new_data.items():
- if self.cexecuted.has_key(file_name):
- self.merge_file_data(self.cexecuted[file_name], file_data)
- else:
- self.cexecuted[file_name] = file_data
-
- def merge_file_data(self, cache_data, new_data):
- for line_number in new_data.keys():
- if not cache_data.has_key(line_number):
- cache_data[line_number] = new_data[line_number]
-
- # canonical_filename(filename). Return a canonical filename for the
- # file (that is, an absolute path with no redundant components and
- # normalized case). See [GDR 2001-12-04b, 3.3].
-
- def canonical_filename(self, filename):
- if not self.canonical_filename_cache.has_key(filename):
- f = filename
- if os.path.isabs(f) and not os.path.exists(f):
- f = os.path.basename(f)
- if not os.path.isabs(f):
- for path in [os.curdir] + sys.path:
- g = os.path.join(path, f)
- if os.path.exists(g):
- f = g
- break
- cf = os.path.normcase(os.path.abspath(f))
- self.canonical_filename_cache[filename] = cf
- return self.canonical_filename_cache[filename]
-
- # canonicalize_filenames(). Copy results from "c" to "cexecuted",
- # canonicalizing filenames on the way. Clear the "c" map.
-
- def canonicalize_filenames(self):
- for filename, lineno in self.c.keys():
- if filename == '<string>':
- # Can't do anything useful with exec'd strings, so skip them.
- continue
- f = self.canonical_filename(filename)
- if not self.cexecuted.has_key(f):
- self.cexecuted[f] = {}
- self.cexecuted[f][lineno] = 1
- self.c = {}
-
- # morf_filename(morf). Return the filename for a module or file.
-
- def morf_filename(self, morf):
- if isinstance(morf, types.ModuleType):
- if not hasattr(morf, '__file__'):
- raise CoverageException, "Module has no __file__ attribute."
- f = morf.__file__
- else:
- f = morf
- return self.canonical_filename(f)
-
- # analyze_morf(morf). Analyze the module or filename passed as
- # the argument. If the source code can't be found, raise an error.
- # Otherwise, return a tuple of (1) the canonical filename of the
- # source code for the module, (2) a list of lines of statements
- # in the source code, (3) a list of lines of excluded statements,
- # and (4), a map of line numbers to multi-line line number ranges, for
- # statements that cross lines.
-
- def analyze_morf(self, morf):
- if self.analysis_cache.has_key(morf):
- return self.analysis_cache[morf]
- filename = self.morf_filename(morf)
- ext = os.path.splitext(filename)[1]
- if ext == '.pyc':
- if not os.path.exists(filename[0:-1]):
- raise CoverageException, ("No source for compiled code '%s'."
- % filename)
- filename = filename[0:-1]
- elif ext != '.py':
- raise CoverageException, "File '%s' not Python source." % filename
- source = open(filename, 'r')
- lines, excluded_lines, line_map = self.find_executable_statements(
- source.read(), exclude=self.exclude_re
- )
- source.close()
- result = filename, lines, excluded_lines, line_map
- self.analysis_cache[morf] = result
- return result
-
- def first_line_of_tree(self, tree):
- while True:
- if len(tree) == 3 and type(tree[2]) == type(1):
- return tree[2]
- tree = tree[1]
-
- def last_line_of_tree(self, tree):
- while True:
- if len(tree) == 3 and type(tree[2]) == type(1):
- return tree[2]
- tree = tree[-1]
-
- def find_docstring_pass_pair(self, tree, spots):
- for i in range(1, len(tree)):
- if self.is_string_constant(tree[i]) and self.is_pass_stmt(tree[i+1]):
- first_line = self.first_line_of_tree(tree[i])
- last_line = self.last_line_of_tree(tree[i+1])
- self.record_multiline(spots, first_line, last_line)
-
- def is_string_constant(self, tree):
- try:
- return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.expr_stmt
- except:
- return False
-
- def is_pass_stmt(self, tree):
- try:
- return tree[0] == symbol.stmt and tree[1][1][1][0] == symbol.pass_stmt
- except:
- return False
-
- def record_multiline(self, spots, i, j):
- for l in range(i, j+1):
- spots[l] = (i, j)
-
- def get_suite_spots(self, tree, spots):
- """ Analyze a parse tree to find suite introducers which span a number
- of lines.
- """
- for i in range(1, len(tree)):
- if type(tree[i]) == type(()):
- if tree[i][0] == symbol.suite:
- # Found a suite, look back for the colon and keyword.
- lineno_colon = lineno_word = None
- for j in range(i-1, 0, -1):
- if tree[j][0] == token.COLON:
- # Colons are never executed themselves: we want the
- # line number of the last token before the colon.
- lineno_colon = self.last_line_of_tree(tree[j-1])
- elif tree[j][0] == token.NAME:
- if tree[j][1] == 'elif':
- # Find the line number of the first non-terminal
- # after the keyword.
- t = tree[j+1]
- while t and token.ISNONTERMINAL(t[0]):
- t = t[1]
- if t:
- lineno_word = t[2]
- else:
- lineno_word = tree[j][2]
- break
- elif tree[j][0] == symbol.except_clause:
- # "except" clauses look like:
- # ('except_clause', ('NAME', 'except', lineno), ...)
- if tree[j][1][0] == token.NAME:
- lineno_word = tree[j][1][2]
- break
- if lineno_colon and lineno_word:
- # Found colon and keyword, mark all the lines
- # between the two with the two line numbers.
- self.record_multiline(spots, lineno_word, lineno_colon)
-
- # "pass" statements are tricky: different versions of Python
- # treat them differently, especially in the common case of a
- # function with a doc string and a single pass statement.
- self.find_docstring_pass_pair(tree[i], spots)
-
- elif tree[i][0] == symbol.simple_stmt:
- first_line = self.first_line_of_tree(tree[i])
- last_line = self.last_line_of_tree(tree[i])
- if first_line != last_line:
- self.record_multiline(spots, first_line, last_line)
- self.get_suite_spots(tree[i], spots)
-
- def find_executable_statements(self, text, exclude=None):
- # Find lines which match an exclusion pattern.
- excluded = {}
- suite_spots = {}
- if exclude:
- reExclude = re.compile(exclude)
- lines = text.split('\n')
- for i in range(len(lines)):
- if reExclude.search(lines[i]):
- excluded[i+1] = 1
-
- # Parse the code and analyze the parse tree to find out which statements
- # are multiline, and where suites begin and end.
- import parser
- tree = parser.suite(text+'\n\n').totuple(1)
- self.get_suite_spots(tree, suite_spots)
- #print "Suite spots:", suite_spots
-
- # Use the compiler module to parse the text and find the executable
- # statements. We add newlines to be impervious to final partial lines.
- statements = {}
- ast = compiler.parse(text+'\n\n')
- visitor = StatementFindingAstVisitor(statements, excluded, suite_spots)
- compiler.walk(ast, visitor, walker=visitor)
-
- lines = statements.keys()
- lines.sort()
- excluded_lines = excluded.keys()
- excluded_lines.sort()
- return lines, excluded_lines, suite_spots
-
- # format_lines(statements, lines). Format a list of line numbers
- # for printing by coalescing groups of lines as long as the lines
- # represent consecutive statements. This will coalesce even if
- # there are gaps between statements, so if statements =
- # [1,2,3,4,5,10,11,12,13,14] and lines = [1,2,5,10,11,13,14] then
- # format_lines will return "1-2, 5-11, 13-14".
-
- def format_lines(self, statements, lines):
- pairs = []
- i = 0
- j = 0
- start = None
- pairs = []
- while i < len(statements) and j < len(lines):
- if statements[i] == lines[j]:
- if start == None:
- start = lines[j]
- end = lines[j]
- j = j + 1
- elif start:
- pairs.append((start, end))
- start = None
- i = i + 1
- if start:
- pairs.append((start, end))
- def stringify(pair):
- start, end = pair
- if start == end:
- return "%d" % start
- else:
- return "%d-%d" % (start, end)
- ret = string.join(map(stringify, pairs), ", ")
- return ret
-
- # Backward compatibility with version 1.
- def analysis(self, morf):
- f, s, _, m, mf = self.analysis2(morf)
- return f, s, m, mf
-
- def analysis2(self, morf):
- filename, statements, excluded, line_map = self.analyze_morf(morf)
- self.canonicalize_filenames()
- if not self.cexecuted.has_key(filename):
- self.cexecuted[filename] = {}
- missing = []
- for line in statements:
- lines = line_map.get(line, [line, line])
- for l in range(lines[0], lines[1]+1):
- if self.cexecuted[filename].has_key(l):
- break
- else:
- missing.append(line)
- return (filename, statements, excluded, missing,
- self.format_lines(statements, missing))
-
- def relative_filename(self, filename):
- """ Convert filename to relative filename from self.relative_dir.
- """
- return filename.replace(self.relative_dir, "")
-
- def morf_name(self, morf):
- """ Return the name of morf as used in report.
- """
- if isinstance(morf, types.ModuleType):
- return morf.__name__
- else:
- return self.relative_filename(os.path.splitext(morf)[0])
-
- def filter_by_prefix(self, morfs, omit_prefixes):
- """ Return list of morfs where the morf name does not begin
- with any one of the omit_prefixes.
- """
- filtered_morfs = []
- for morf in morfs:
- for prefix in omit_prefixes:
- if self.morf_name(morf).startswith(prefix):
- break
- else:
- filtered_morfs.append(morf)
-
- return filtered_morfs
-
- def morf_name_compare(self, x, y):
- return cmp(self.morf_name(x), self.morf_name(y))
-
- def report(self, morfs, show_missing=1, ignore_errors=0, file=None, omit_prefixes=[]):
- if not isinstance(morfs, types.ListType):
- morfs = [morfs]
- # On windows, the shell doesn't expand wildcards. Do it here.
- globbed = []
- for morf in morfs:
- if isinstance(morf, basestring):
- globbed.extend(glob.glob(morf))
- else:
- globbed.append(morf)
- morfs = globbed
-
- morfs = self.filter_by_prefix(morfs, omit_prefixes)
- morfs.sort(self.morf_name_compare)
-
- max_name = max([5,] + map(len, map(self.morf_name, morfs)))
- fmt_name = "%%- %ds " % max_name
- fmt_err = fmt_name + "%s: %s"
- header = fmt_name % "Name" + " Stmts Exec Cover"
- fmt_coverage = fmt_name + "% 6d % 6d % 5d%%"
- if show_missing:
- header = header + " Missing"
- fmt_coverage = fmt_coverage + " %s"
- if not file:
- file = sys.stdout
- print >>file, header
- print >>file, "-" * len(header)
- total_statements = 0
- total_executed = 0
- for morf in morfs:
- name = self.morf_name(morf)
- try:
- _, statements, _, missing, readable = self.analysis2(morf)
- n = len(statements)
- m = n - len(missing)
- if n > 0:
- pc = 100.0 * m / n
- else:
- pc = 100.0
- args = (name, n, m, pc)
- if show_missing:
- args = args + (readable,)
- print >>file, fmt_coverage % args
- total_statements = total_statements + n
- total_executed = total_executed + m
- except KeyboardInterrupt: #pragma: no cover
- raise
- except:
- if not ignore_errors:
- typ, msg = sys.exc_info()[0:2]
- print >>file, fmt_err % (name, typ, msg)
- if len(morfs) > 1:
- print >>file, "-" * len(header)
- if total_statements > 0:
- pc = 100.0 * total_executed / total_statements
- else:
- pc = 100.0
- args = ("TOTAL", total_statements, total_executed, pc)
- if show_missing:
- args = args + ("",)
- print >>file, fmt_coverage % args
-
- # annotate(morfs, ignore_errors).
-
- blank_re = re.compile(r"\s*(#|$)")
- else_re = re.compile(r"\s*else\s*:\s*(#|$)")
-
- def annotate(self, morfs, directory=None, ignore_errors=0, omit_prefixes=[]):
- morfs = self.filter_by_prefix(morfs, omit_prefixes)
- for morf in morfs:
- try:
- filename, statements, excluded, missing, _ = self.analysis2(morf)
- self.annotate_file(filename, statements, excluded, missing, directory)
- except KeyboardInterrupt:
- raise
- except:
- if not ignore_errors:
- raise
-
- def annotate_file(self, filename, statements, excluded, missing, directory=None):
- source = open(filename, 'r')
- if directory:
- dest_file = os.path.join(directory,
- os.path.basename(filename)
- + ',cover')
- else:
- dest_file = filename + ',cover'
- dest = open(dest_file, 'w')
- lineno = 0
- i = 0
- j = 0
- covered = 1
- while 1:
- line = source.readline()
- if line == '':
- break
- lineno = lineno + 1
- while i < len(statements) and statements[i] < lineno:
- i = i + 1
- while j < len(missing) and missing[j] < lineno:
- j = j + 1
- if i < len(statements) and statements[i] == lineno:
- covered = j >= len(missing) or missing[j] > lineno
- if self.blank_re.match(line):
- dest.write(' ')
- elif self.else_re.match(line):
- # Special logic for lines containing only 'else:'.
- # See [GDR 2001-12-04b, 3.2].
- if i >= len(statements) and j >= len(missing):
- dest.write('! ')
- elif i >= len(statements) or j >= len(missing):
- dest.write('> ')
- elif statements[i] == missing[j]:
- dest.write('! ')
- else:
- dest.write('> ')
- elif lineno in excluded:
- dest.write('- ')
- elif covered:
- dest.write('> ')
- else:
- dest.write('! ')
- dest.write(line)
- source.close()
- dest.close()
-
-# Singleton object.
-the_coverage = coverage()
-
-# Module functions call methods in the singleton object.
-def use_cache(*args, **kw):
- return the_coverage.use_cache(*args, **kw)
-
-def start(*args, **kw):
- return the_coverage.start(*args, **kw)
-
-def stop(*args, **kw):
- return the_coverage.stop(*args, **kw)
-
-def erase(*args, **kw):
- return the_coverage.erase(*args, **kw)
-
-def begin_recursive(*args, **kw):
- return the_coverage.begin_recursive(*args, **kw)
-
-def end_recursive(*args, **kw):
- return the_coverage.end_recursive(*args, **kw)
-
-def exclude(*args, **kw):
- return the_coverage.exclude(*args, **kw)
-
-def analysis(*args, **kw):
- return the_coverage.analysis(*args, **kw)
-
-def analysis2(*args, **kw):
- return the_coverage.analysis2(*args, **kw)
-
-def report(*args, **kw):
- return the_coverage.report(*args, **kw)
-
-def annotate(*args, **kw):
- return the_coverage.annotate(*args, **kw)
-
-def annotate_file(*args, **kw):
- return the_coverage.annotate_file(*args, **kw)
-
-# Save coverage data when Python exits. (The atexit module wasn't
-# introduced until Python 2.0, so use sys.exitfunc when it's not
-# available.)
-try:
- import atexit
- atexit.register(the_coverage.save)
-except ImportError:
- sys.exitfunc = the_coverage.save
-
-# Command-line interface.
-if __name__ == '__main__':
- the_coverage.command_line(sys.argv[1:])
-
-
-# A. REFERENCES
-#
-# [GDR 2001-12-04a] "Statement coverage for Python"; Gareth Rees;
-# Ravenbrook Limited; 2001-12-04;
-# <http://www.nedbatchelder.com/code/modules/rees-coverage.html>.
-#
-# [GDR 2001-12-04b] "Statement coverage for Python: design and
-# analysis"; Gareth Rees; Ravenbrook Limited; 2001-12-04;
-# <http://www.nedbatchelder.com/code/modules/rees-design.html>.
-#
-# [van Rossum 2001-07-20a] "Python Reference Manual (releae 2.1.1)";
-# Guide van Rossum; 2001-07-20;
-# <http://www.python.org/doc/2.1.1/ref/ref.html>.
-#
-# [van Rossum 2001-07-20b] "Python Library Reference"; Guido van Rossum;
-# 2001-07-20; <http://www.python.org/doc/2.1.1/lib/lib.html>.
-#
-#
-# B. DOCUMENT HISTORY
-#
-# 2001-12-04 GDR Created.
-#
-# 2001-12-06 GDR Added command-line interface and source code
-# annotation.
-#
-# 2001-12-09 GDR Moved design and interface to separate documents.
-#
-# 2001-12-10 GDR Open cache file as binary on Windows. Allow
-# simultaneous -e and -x, or -a and -r.
-#
-# 2001-12-12 GDR Added command-line help. Cache analysis so that it
-# only needs to be done once when you specify -a and -r.
-#
-# 2001-12-13 GDR Improved speed while recording. Portable between
-# Python 1.5.2 and 2.1.1.
-#
-# 2002-01-03 GDR Module-level functions work correctly.
-#
-# 2002-01-07 GDR Update sys.path when running a file with the -x option,
-# so that it matches the value the program would get if it were run on
-# its own.
-#
-# 2004-12-12 NMB Significant code changes.
-# - Finding executable statements has been rewritten so that docstrings and
-# other quirks of Python execution aren't mistakenly identified as missing
-# lines.
-# - Lines can be excluded from consideration, even entire suites of lines.
-# - The filesystem cache of covered lines can be disabled programmatically.
-# - Modernized the code.
-#
-# 2004-12-14 NMB Minor tweaks. Return 'analysis' to its original behavior
-# and add 'analysis2'. Add a global for 'annotate', and factor it, adding
-# 'annotate_file'.
-#
-# 2004-12-31 NMB Allow for keyword arguments in the module global functions.
-# Thanks, Allen.
-#
-# 2005-12-02 NMB Call threading.settrace so that all threads are measured.
-# Thanks Martin Fuzzey. Add a file argument to report so that reports can be
-# captured to a different destination.
-#
-# 2005-12-03 NMB coverage.py can now measure itself.
-#
-# 2005-12-04 NMB Adapted Greg Rogers' patch for using relative filenames,
-# and sorting and omitting files to report on.
-#
-# 2006-07-23 NMB Applied Joseph Tate's patch for function decorators.
-#
-# 2006-08-21 NMB Applied Sigve Tjora and Mark van der Wal's fixes for argument
-# handling.
-#
-# 2006-08-22 NMB Applied Geoff Bache's parallel mode patch.
-#
-# 2006-08-23 NMB Refactorings to improve testability. Fixes to command-line
-# logic for parallel mode and collect.
-#
-# 2006-08-25 NMB "#pragma: nocover" is excluded by default.
-#
-# 2006-09-10 NMB Properly ignore docstrings and other constant expressions that
-# appear in the middle of a function, a problem reported by Tim Leslie.
-# Minor changes to avoid lint warnings.
-#
-# 2006-09-17 NMB coverage.erase() shouldn't clobber the exclude regex.
-# Change how parallel mode is invoked, and fix erase() so that it erases the
-# cache when called programmatically.
-#
-# 2007-07-21 NMB In reports, ignore code executed from strings, since we can't
-# do anything useful with it anyway.
-# Better file handling on Linux, thanks Guillaume Chazarain.
-# Better shell support on Windows, thanks Noel O'Boyle.
-# Python 2.2 support maintained, thanks Catherine Proulx.
-#
-# 2007-07-22 NMB Python 2.5 now fully supported. The method of dealing with
-# multi-line statements is now less sensitive to the exact line that Python
-# reports during execution. Pass statements are handled specially so that their
-# disappearance during execution won't throw off the measurement.
-
-# C. COPYRIGHT AND LICENCE
-#
-# Copyright 2001 Gareth Rees. All rights reserved.
-# Copyright 2004-2007 Ned Batchelder. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# 1. Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-#
-# 2. Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the
-# distribution.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# HOLDERS AND CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
-# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
-# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
-# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
-# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
-# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
-# DAMAGE.
-#
-# $Id: coverage.py 67 2007-07-21 19:51:07Z nedbat $
+++ /dev/null
-#!/usr/bin/env python
-'''
-unittest.py from Python 2.5.
-
-SQLAlchemy extends unittest internals to provide setUpAll()/tearDownAll()
-so we include a fixed version here to insulate from changes. 2.6 and
-3.0's unittest is incompatible with our changes.
-
-Approaches to removing this dependency are:
-
-* find a unittest-supported method of grouping UnitTest classes within
-a setUpAll()/tearDownAll() pair, such that all tests within a single
-UnitTest class are executed within a single execution of setUpAll()/
-tearDownAll(). It may be possible to create nested TestSuite objects
-to accomplish this but it's not clear.
-* migrate to a different system such as nose.
-
-Copyright (c) 1999-2003 Steve Purcell
-This module is free software, and you may redistribute it and/or modify
-it under the same terms as Python itself, so long as this copyright message
-and disclaimer are retained in their original form.
-
-IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
-SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF
-THIS CODE, EVEN IF THE AUTHOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
-DAMAGE.
-
-THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
-PARTICULAR PURPOSE. THE CODE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
-AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
-SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
-'''
-
-__author__ = "Steve Purcell"
-__email__ = "stephen_purcell at yahoo dot com"
-__version__ = "#Revision: 1.63 $"[11:-2]
-
-import time
-import sys
-import traceback
-import os
-import types
-from testlib.compat import callable
-
-##############################################################################
-# Exported classes and functions
-##############################################################################
-__all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner',
- 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader']
-
-# Expose obsolete functions for backwards compatibility
-__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
-
-
-##############################################################################
-# Test framework core
-##############################################################################
-
-# All classes defined herein are 'new-style' classes, allowing use of 'super()'
-__metaclass__ = type
-
-def _strclass(cls):
- return "%s.%s" % (cls.__module__, cls.__name__)
-
-__unittest = 1
-
-class TestResult:
- """Holder for test result information.
-
- Test results are automatically managed by the TestCase and TestSuite
- classes, and do not need to be explicitly manipulated by writers of tests.
-
- Each instance holds the total number of tests run, and collections of
- failures and errors that occurred among those test runs. The collections
- contain tuples of (testcase, exceptioninfo), where exceptioninfo is the
- formatted traceback of the error that occurred.
- """
- def __init__(self):
- self.failures = []
- self.errors = []
- self.testsRun = 0
- self.shouldStop = 0
-
- def startTest(self, test):
- "Called when the given test is about to be run"
- self.testsRun = self.testsRun + 1
-
- def stopTest(self, test):
- "Called when the given test has been run"
- pass
-
- def addError(self, test, err):
- """Called when an error has occurred. 'err' is a tuple of values as
- returned by sys.exc_info().
- """
- self.errors.append((test, self._exc_info_to_string(err, test)))
-
- def addFailure(self, test, err):
- """Called when an error has occurred. 'err' is a tuple of values as
- returned by sys.exc_info()."""
- self.failures.append((test, self._exc_info_to_string(err, test)))
-
- def addSuccess(self, test):
- "Called when a test has completed successfully"
- pass
-
- def wasSuccessful(self):
- "Tells whether or not this result was a success"
- return len(self.failures) == len(self.errors) == 0
-
- def stop(self):
- "Indicates that the tests should be aborted"
- self.shouldStop = True
-
- def _exc_info_to_string(self, err, test):
- """Converts a sys.exc_info()-style tuple of values into a string."""
- exctype, value, tb = err
- # Skip test runner traceback levels
- while tb and self._is_relevant_tb_level(tb):
- tb = tb.tb_next
- if exctype is test.failureException:
- # Skip assert*() traceback levels
- length = self._count_relevant_tb_levels(tb)
- return ''.join(traceback.format_exception(exctype, value, tb, length))
- return ''.join(traceback.format_exception(exctype, value, tb))
-
- def _is_relevant_tb_level(self, tb):
- return tb.tb_frame.f_globals.has_key('__unittest')
-
- def _count_relevant_tb_levels(self, tb):
- length = 0
- while tb and not self._is_relevant_tb_level(tb):
- length += 1
- tb = tb.tb_next
- return length
-
- def __repr__(self):
- return "<%s run=%i errors=%i failures=%i>" % \
- (_strclass(self.__class__), self.testsRun, len(self.errors),
- len(self.failures))
-
-class TestCase:
- """A class whose instances are single test cases.
-
- By default, the test code itself should be placed in a method named
- 'runTest'.
-
- If the fixture may be used for many test cases, create as
- many test methods as are needed. When instantiating such a TestCase
- subclass, specify in the constructor arguments the name of the test method
- that the instance is to execute.
-
- Test authors should subclass TestCase for their own tests. Construction
- and deconstruction of the test's environment ('fixture') can be
- implemented by overriding the 'setUp' and 'tearDown' methods respectively.
-
- If it is necessary to override the __init__ method, the base class
- __init__ method must always be called. It is important that subclasses
- should not change the signature of their __init__ method, since instances
- of the classes are instantiated automatically by parts of the framework
- in order to be run.
- """
-
- # This attribute determines which exception will be raised when
- # the instance's assertion methods fail; test methods raising this
- # exception will be deemed to have 'failed' rather than 'errored'
-
- failureException = AssertionError
-
- def __init__(self, methodName='runTest'):
- """Create an instance of the class that will use the named test
- method when executed. Raises a ValueError if the instance does
- not have a method with the specified name.
- """
- try:
- self._testMethodName = methodName
- testMethod = getattr(self, methodName)
- self._testMethodDoc = testMethod.__doc__
- except AttributeError:
- raise ValueError, "no such test method in %s: %s" % \
- (self.__class__, methodName)
-
- def setUp(self):
- "Hook method for setting up the test fixture before exercising it."
- pass
-
- def tearDown(self):
- "Hook method for deconstructing the test fixture after testing it."
- pass
-
- def countTestCases(self):
- return 1
-
- def defaultTestResult(self):
- return TestResult()
-
- def shortDescription(self):
- """Returns a one-line description of the test, or None if no
- description has been provided.
-
- The default implementation of this method returns the first line of
- the specified test method's docstring.
- """
- doc = self._testMethodDoc
- return doc and doc.split("\n")[0].strip() or None
-
- def id(self):
- return "%s.%s" % (_strclass(self.__class__), self._testMethodName)
-
- def __str__(self):
- return "%s (%s)" % (self._testMethodName, _strclass(self.__class__))
-
- def __repr__(self):
- return "<%s testMethod=%s>" % \
- (_strclass(self.__class__), self._testMethodName)
-
- def run(self, result=None):
- if result is None: result = self.defaultTestResult()
- result.startTest(self)
- testMethod = getattr(self, self._testMethodName)
- try:
- try:
- self.setUp()
- except KeyboardInterrupt:
- raise
- except:
- result.addError(self, self._exc_info())
- return
-
- ok = False
- try:
- testMethod()
- ok = True
- except self.failureException:
- result.addFailure(self, self._exc_info())
- except KeyboardInterrupt:
- raise
- except:
- result.addError(self, self._exc_info())
-
- try:
- self.tearDown()
- except KeyboardInterrupt:
- raise
- except:
- result.addError(self, self._exc_info())
- ok = False
- if ok: result.addSuccess(self)
- finally:
- result.stopTest(self)
-
- def __call__(self, *args, **kwds):
- return self.run(*args, **kwds)
-
- def debug(self):
- """Run the test without collecting errors in a TestResult"""
- self.setUp()
- getattr(self, self._testMethodName)()
- self.tearDown()
-
- def _exc_info(self):
- """Return a version of sys.exc_info() with the traceback frame
- minimised; usually the top level of the traceback frame is not
- needed.
- """
- exctype, excvalue, tb = sys.exc_info()
- if sys.platform[:4] == 'java': ## tracebacks look different in Jython
- return (exctype, excvalue, tb)
- return (exctype, excvalue, tb)
-
- def fail(self, msg=None):
- """Fail immediately, with the given message."""
- raise self.failureException, msg
-
- def failIf(self, expr, msg=None):
- "Fail the test if the expression is true."
- if expr: raise self.failureException, msg
-
- def failUnless(self, expr, msg=None):
- """Fail the test unless the expression is true."""
- if not expr: raise self.failureException, msg
-
- def failUnlessRaises(self, excClass, callableObj, *args, **kwargs):
- """Fail unless an exception of class excClass is thrown
- by callableObj when invoked with arguments args and keyword
- arguments kwargs. If a different type of exception is
- thrown, it will not be caught, and the test case will be
- deemed to have suffered an error, exactly as for an
- unexpected exception.
- """
- try:
- callableObj(*args, **kwargs)
- except excClass:
- return
- else:
- if hasattr(excClass,'__name__'): excName = excClass.__name__
- else: excName = str(excClass)
- raise self.failureException, "%s not raised" % excName
-
- def failUnlessEqual(self, first, second, msg=None):
- """Fail if the two objects are unequal as determined by the '=='
- operator.
- """
- if not first == second:
- raise self.failureException, \
- (msg or '%r != %r' % (first, second))
-
- def failIfEqual(self, first, second, msg=None):
- """Fail if the two objects are equal as determined by the '=='
- operator.
- """
- if first == second:
- raise self.failureException, \
- (msg or '%r == %r' % (first, second))
-
- def failUnlessAlmostEqual(self, first, second, places=7, msg=None):
- """Fail if the two objects are unequal as determined by their
- difference rounded to the given number of decimal places
- (default 7) and comparing to zero.
-
- Note that decimal places (from zero) are usually not the same
- as significant digits (measured from the most signficant digit).
- """
- if round(second-first, places) != 0:
- raise self.failureException, \
- (msg or '%r != %r within %r places' % (first, second, places))
-
- def failIfAlmostEqual(self, first, second, places=7, msg=None):
- """Fail if the two objects are equal as determined by their
- difference rounded to the given number of decimal places
- (default 7) and comparing to zero.
-
- Note that decimal places (from zero) are usually not the same
- as significant digits (measured from the most signficant digit).
- """
- if round(second-first, places) == 0:
- raise self.failureException, \
- (msg or '%r == %r within %r places' % (first, second, places))
-
- # Synonyms for assertion methods
-
- assertEqual = assertEquals = failUnlessEqual
-
- assertNotEqual = assertNotEquals = failIfEqual
-
- assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual
-
- assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual
-
- assertRaises = failUnlessRaises
-
- assert_ = assertTrue = failUnless
-
- assertFalse = failIf
-
-
-
-class TestSuite:
- """A test suite is a composite test consisting of a number of TestCases.
-
- For use, create an instance of TestSuite, then add test case instances.
- When all tests have been added, the suite can be passed to a test
- runner, such as TextTestRunner. It will run the individual test cases
- in the order in which they were added, aggregating the results. When
- subclassing, do not forget to call the base class constructor.
- """
- def __init__(self, tests=()):
- self._tests = []
- self.addTests(tests)
-
- def __repr__(self):
- return "<%s tests=%s>" % (_strclass(self.__class__), self._tests)
-
- __str__ = __repr__
-
- def __iter__(self):
- return iter(self._tests)
-
- def countTestCases(self):
- cases = 0
- for test in self._tests:
- cases += test.countTestCases()
- return cases
-
- def addTest(self, test):
- # sanity checks
- if not callable(test):
- raise TypeError("the test to add must be callable")
- if (isinstance(test, (type, types.ClassType)) and
- issubclass(test, (TestCase, TestSuite))):
- raise TypeError("TestCases and TestSuites must be instantiated "
- "before passing them to addTest()")
- self._tests.append(test)
-
- def addTests(self, tests):
- if isinstance(tests, basestring):
- raise TypeError("tests must be an iterable of tests, not a string")
- for test in tests:
- self.addTest(test)
-
- def run(self, result):
- for test in self._tests:
- if result.shouldStop:
- break
- test(result)
- return result
-
- def __call__(self, *args, **kwds):
- return self.run(*args, **kwds)
-
- def debug(self):
- """Run the tests without collecting errors in a TestResult"""
- for test in self._tests: test.debug()
-
-
-class FunctionTestCase(TestCase):
- """A test case that wraps a test function.
-
- This is useful for slipping pre-existing test functions into the
- PyUnit framework. Optionally, set-up and tidy-up functions can be
- supplied. As with TestCase, the tidy-up ('tearDown') function will
- always be called if the set-up ('setUp') function ran successfully.
- """
-
- def __init__(self, testFunc, setUp=None, tearDown=None,
- description=None):
- TestCase.__init__(self)
- self.__setUpFunc = setUp
- self.__tearDownFunc = tearDown
- self.__testFunc = testFunc
- self.__description = description
-
- def setUp(self):
- if self.__setUpFunc is not None:
- self.__setUpFunc()
-
- def tearDown(self):
- if self.__tearDownFunc is not None:
- self.__tearDownFunc()
-
- def runTest(self):
- self.__testFunc()
-
- def id(self):
- return self.__testFunc.__name__
-
- def __str__(self):
- return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__)
-
- def __repr__(self):
- return "<%s testFunc=%s>" % (_strclass(self.__class__), self.__testFunc)
-
- def shortDescription(self):
- if self.__description is not None: return self.__description
- doc = self.__testFunc.__doc__
- return doc and doc.split("\n")[0].strip() or None
-
-
-
-##############################################################################
-# Locating and loading tests
-##############################################################################
-
-class TestLoader:
- """This class is responsible for loading tests according to various
- criteria and returning them wrapped in a Test
- """
- testMethodPrefix = 'test'
- suiteClass = TestSuite
-
- def loadTestsFromTestCase(self, testCaseClass):
- """Return a suite of all tests cases contained in testCaseClass"""
- if issubclass(testCaseClass, TestSuite):
- raise TypeError("Test cases should not be derived from TestSuite. Maybe you meant to derive from TestCase?")
- testCaseNames = self.getTestCaseNames(testCaseClass)
- if not testCaseNames and hasattr(testCaseClass, 'runTest'):
- testCaseNames = ['runTest']
- return self.suiteClass(map(testCaseClass, testCaseNames))
-
- def loadTestsFromModule(self, module):
- """Return a suite of all tests cases contained in the given module"""
- tests = []
- for name in dir(module):
- obj = getattr(module, name)
- if (isinstance(obj, (type, types.ClassType)) and
- issubclass(obj, TestCase)):
- tests.append(self.loadTestsFromTestCase(obj))
- return self.suiteClass(tests)
-
- def loadTestsFromName(self, name, module=None):
- """Return a suite of all tests cases given a string specifier.
-
- The name may resolve either to a module, a test case class, a
- test method within a test case class, or a callable object which
- returns a TestCase or TestSuite instance.
-
- The method optionally resolves the names relative to a given module.
- """
- parts = name.split('.')
- if module is None:
- parts_copy = parts[:]
- while parts_copy:
- try:
- module = __import__('.'.join(parts_copy))
- break
- except ImportError:
- del parts_copy[-1]
- if not parts_copy: raise
- parts = parts[1:]
- obj = module
- for part in parts:
- parent, obj = obj, getattr(obj, part)
-
- if type(obj) == types.ModuleType:
- return self.loadTestsFromModule(obj)
- elif (isinstance(obj, (type, types.ClassType)) and
- issubclass(obj, TestCase)):
- return self.loadTestsFromTestCase(obj)
- elif type(obj) == types.UnboundMethodType:
- return parent(obj.__name__)
- elif isinstance(obj, TestSuite):
- return obj
- elif callable(obj):
- test = obj()
- if not isinstance(test, (TestCase, TestSuite)):
- raise ValueError, \
- "calling %s returned %s, not a test" % (obj,test)
- return test
- else:
- raise ValueError, "don't know how to make test from: %s" % obj
-
- def loadTestsFromNames(self, names, module=None):
- """Return a suite of all tests cases found using the given sequence
- of string specifiers. See 'loadTestsFromName()'.
- """
- suites = [self.loadTestsFromName(name, module) for name in names]
- return self.suiteClass(suites)
-
- def getTestCaseNames(self, testCaseClass):
- """Return a sorted sequence of method names found within testCaseClass
- """
-
- def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix):
- return attrname.startswith(prefix) and callable(getattr(testCaseClass, attrname))
- testFnNames = filter(isTestMethod, dir(testCaseClass))
- for baseclass in testCaseClass.__bases__:
- for testFnName in self.getTestCaseNames(baseclass):
- if testFnName not in testFnNames: # handle overridden methods
- testFnNames.append(testFnName)
- testFnNames.sort()
- return testFnNames
-
-
-
-defaultTestLoader = TestLoader()
-
-
-##############################################################################
-# Patches for old functions: these functions should be considered obsolete
-##############################################################################
-
-def _makeLoader(prefix, sortUsing, suiteClass=None):
- loader = TestLoader()
- loader.testMethodPrefix = prefix
- if suiteClass: loader.suiteClass = suiteClass
- return loader
-
-def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
- return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
-
-def makeSuite(testCaseClass, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
- return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
-
-def findTestCases(module, prefix='test', sortUsing=cmp, suiteClass=TestSuite):
- return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
-
-
-##############################################################################
-# Text UI
-##############################################################################
-
-class _WritelnDecorator:
- """Used to decorate file-like objects with a handy 'writeln' method"""
- def __init__(self,stream):
- self.stream = stream
-
- def __getattr__(self, attr):
- return getattr(self.stream,attr)
-
- def writeln(self, arg=None):
- if arg: self.write(arg)
- self.write('\n') # text-mode streams translate to \r\n if needed
-
-
-class _TextTestResult(TestResult):
- """A test result class that can print formatted text results to a stream.
-
- Used by TextTestRunner.
- """
- separator1 = '=' * 70
- separator2 = '-' * 70
-
- def __init__(self, stream, descriptions, verbosity):
- TestResult.__init__(self)
- self.stream = stream
- self.showAll = verbosity > 1
- self.dots = verbosity == 1
- self.descriptions = descriptions
-
- def getDescription(self, test):
- if self.descriptions:
- return test.shortDescription() or str(test)
- else:
- return str(test)
-
- def startTest(self, test):
- TestResult.startTest(self, test)
- if self.showAll:
- self.stream.write(self.getDescription(test))
- self.stream.write(" ... ")
-
- def addSuccess(self, test):
- TestResult.addSuccess(self, test)
- if self.showAll:
- self.stream.writeln("ok")
- elif self.dots:
- self.stream.write('.')
-
- def addError(self, test, err):
- TestResult.addError(self, test, err)
- if self.showAll:
- self.stream.writeln("ERROR")
- elif self.dots:
- self.stream.write('E')
-
- def addFailure(self, test, err):
- TestResult.addFailure(self, test, err)
- if self.showAll:
- self.stream.writeln("FAIL")
- elif self.dots:
- self.stream.write('F')
-
- def printErrors(self):
- if self.dots or self.showAll:
- self.stream.writeln()
- self.printErrorList('ERROR', self.errors)
- self.printErrorList('FAIL', self.failures)
-
- def printErrorList(self, flavour, errors):
- for test, err in errors:
- self.stream.writeln(self.separator1)
- self.stream.writeln("%s: %s" % (flavour,self.getDescription(test)))
- self.stream.writeln(self.separator2)
- self.stream.writeln("%s" % err)
-
-
-class TextTestRunner:
- """A test runner class that displays results in textual form.
-
- It prints out the names of tests as they are run, errors as they
- occur, and a summary of the results at the end of the test run.
- """
- def __init__(self, stream=sys.stderr, descriptions=1, verbosity=1):
- self.stream = _WritelnDecorator(stream)
- self.descriptions = descriptions
- self.verbosity = verbosity
-
- def _makeResult(self):
- return _TextTestResult(self.stream, self.descriptions, self.verbosity)
-
- def run(self, test):
- "Run the given test case or test suite."
- result = self._makeResult()
- startTime = time.time()
- test(result)
- stopTime = time.time()
- timeTaken = stopTime - startTime
- result.printErrors()
- self.stream.writeln(result.separator2)
- run = result.testsRun
- self.stream.writeln("Ran %d test%s in %.3fs" %
- (run, run != 1 and "s" or "", timeTaken))
- self.stream.writeln()
- if not result.wasSuccessful():
- self.stream.write("FAILED (")
- failed, errored = map(len, (result.failures, result.errors))
- if failed:
- self.stream.write("failures=%d" % failed)
- if errored:
- if failed: self.stream.write(", ")
- self.stream.write("errors=%d" % errored)
- self.stream.writeln(")")
- else:
- self.stream.writeln("OK")
- return result
-
-
-
-##############################################################################
-# Facilities for running tests from the command line
-##############################################################################
-
-class TestProgram:
- """A command-line program that runs a set of tests; this is primarily
- for making test modules conveniently executable.
- """
- USAGE = """\
-Usage: %(progName)s [options] [test] [...]
-
-Options:
- -h, --help Show this message
- -v, --verbose Verbose output
- -q, --quiet Minimal output
-
-Examples:
- %(progName)s - run default set of tests
- %(progName)s MyTestSuite - run suite 'MyTestSuite'
- %(progName)s MyTestCase.testSomething - run MyTestCase.testSomething
- %(progName)s MyTestCase - run all 'test*' test methods
- in MyTestCase
-"""
- def __init__(self, module='__main__', defaultTest=None,
- argv=None, testRunner=None, testLoader=defaultTestLoader):
- if type(module) == type(''):
- self.module = __import__(module)
- for part in module.split('.')[1:]:
- self.module = getattr(self.module, part)
- else:
- self.module = module
- if argv is None:
- argv = sys.argv
- self.verbosity = 1
- self.defaultTest = defaultTest
- self.testRunner = testRunner
- self.testLoader = testLoader
- self.progName = os.path.basename(argv[0])
- self.parseArgs(argv)
- self.runTests()
-
- def usageExit(self, msg=None):
- if msg: print msg
- print self.USAGE % self.__dict__
- sys.exit(2)
-
- def parseArgs(self, argv):
- import getopt
- try:
- options, args = getopt.getopt(argv[1:], 'hHvq',
- ['help','verbose','quiet'])
- for opt, value in options:
- if opt in ('-h','-H','--help'):
- self.usageExit()
- if opt in ('-q','--quiet'):
- self.verbosity = 0
- if opt in ('-v','--verbose'):
- self.verbosity = 2
- if len(args) == 0 and self.defaultTest is None:
- self.test = self.testLoader.loadTestsFromModule(self.module)
- return
- if len(args) > 0:
- self.testNames = args
- else:
- self.testNames = (self.defaultTest,)
- self.createTests()
- except getopt.error, msg:
- self.usageExit(msg)
-
- def createTests(self):
- self.test = self.testLoader.loadTestsFromNames(self.testNames,
- self.module)
-
- def runTests(self):
- if self.testRunner is None:
- self.testRunner = TextTestRunner(verbosity=self.verbosity)
- result = self.testRunner.run(self.test)
- sys.exit(not result.wasSuccessful())
-
-main = TestProgram
-
-
-##############################################################################
-# Executing this module from the command line
-##############################################################################
-
-if __name__ == "__main__":
- main(module=None)
+++ /dev/null
-import testenv; testenv.configure_for_tests()
-from testlib import sa_unittest as unittest
-
-def suite():
- modules_to_test = (
- 'zblog.tests',
- )
- alltests = unittest.TestSuite()
- for name in modules_to_test:
- mod = __import__(name)
- for token in name.split('.')[1:]:
- mod = getattr(mod, token)
- alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
- return alltests
-
-
-if __name__ == '__main__':
- testenv.main(suite())
"""mapper.py - defines mappers for domain objects, mapping operations"""
-import zblog.tables as tables
-import zblog.user as user
-from zblog.blog import *
+import tables, user
+from blog import *
from sqlalchemy import *
from sqlalchemy.orm import *
import sqlalchemy.util as util
"""application table metadata objects are described here."""
from sqlalchemy import *
-from testlib import *
metadata = MetaData()
-import testenv; testenv.configure_for_tests()
from sqlalchemy import *
from sqlalchemy.orm import *
-from testlib import *
-from zblog import mappers, tables
-from zblog.user import *
-from zblog.blog import *
+from sqlalchemy.test import *
+import mappers, tables
+from user import *
+from blog import *
class ZBlogTest(TestBase, AssertsExecutionResults):
- def create_tables(self):
+ @classmethod
+ def create_tables(cls):
tables.metadata.drop_all(bind=testing.db)
tables.metadata.create_all(bind=testing.db)
- def drop_tables(self):
+
+ @classmethod
+ def drop_tables(cls):
tables.metadata.drop_all(bind=testing.db)
- def setUpAll(self):
- self.create_tables()
- def tearDownAll(self):
- self.drop_tables()
- def tearDown(self):
+ @classmethod
+ def setup_class(cls):
+ cls.create_tables()
+ @classmethod
+ def teardown_class(cls):
+ cls.drop_tables()
+ def teardown(self):
pass
- def setUp(self):
+ def setup(self):
pass
class SavePostTest(ZBlogTest):
- def setUpAll(self):
- super(SavePostTest, self).setUpAll()
+ @classmethod
+ def setup_class(cls):
+ super(SavePostTest, cls).setup_class()
+
mappers.zblog_mappers()
global blog_id, user_id
s = create_session(bind=testing.db)
user_id = user.id
s.close()
- def tearDownAll(self):
+ @classmethod
+ def teardown_class(cls):
clear_mappers()
- super(SavePostTest, self).tearDownAll()
+ super(SavePostTest, cls).teardown_class()
def testattach(self):
"""test that a transient/pending instance has proper bi-directional behavior.
s.rollback()
-if __name__ == "__main__":
- testenv.main()