]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement facade for pytest parametrize, fixtures, classlevel
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 17 Oct 2019 17:09:24 +0000 (13:09 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 21 Oct 2019 00:49:03 +0000 (20:49 -0400)
Add factilities to implement pytest.mark.parametrize and
pytest.fixtures patterns, which largely resemble things we are
already doing.

Ensure a facade is used, so that the test suite remains independent
of py.test, but also tailors the functions to the more limited
scope in which we are using them.

Additionally, create a class-based version that works from the
same facade.

Several old polymorphic tests as well as two of the sql test
are refactored to use the new features.

Change-Id: I6ef8af1dafff92534313016944d447f9439856cf
References: #4896

12 files changed:
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
test/aaa_profiling/test_memusage.py
test/orm/inheritance/test_abc_polymorphic.py
test/orm/inheritance/test_assorted_poly.py
test/orm/inheritance/test_magazine.py
test/orm/inheritance/test_poly_persistence.py
test/orm/test_descriptor.py
test/sql/test_operators.py
test/sql/test_types.py

index 2b8158fbb1bd84ad0870ef18b20d1fc5d2e574a0..4f28461e37e737b503aa2b2dbea01c46a4e9d4fe 100644 (file)
@@ -32,7 +32,9 @@ from .assertions import ne_  # noqa
 from .assertions import not_in_  # noqa
 from .assertions import startswith_  # noqa
 from .assertions import uses_deprecated  # noqa
+from .config import combinations  # noqa
 from .config import db  # noqa
+from .config import fixture  # noqa
 from .config import requirements as requires  # noqa
 from .exclusions import _is_excluded  # noqa
 from .exclusions import _server_version  # noqa
index f94c5b3086b6aa6e7021eeb61143608f79fbf91b..87bbc6a0f27845d0871e27e9d2c63b721cedd710 100644 (file)
@@ -6,7 +6,6 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 import collections
-from unittest import SkipTest as _skip_test_exception
 
 requirements = None
 db = None
@@ -17,6 +16,75 @@ test_schema = None
 test_schema_2 = None
 _current = None
 
+_fixture_functions = None  # installed by plugin_base
+
+
+def combinations(*comb, **kw):
+    r"""Deliver multiple versions of a test based on positional combinations.
+
+    This is a facade over pytest.mark.parametrize.
+
+
+    :param \*comb: argument combinations.  These are tuples that will be passed
+     positionally to the decorated function.
+
+    :param argnames: optional list of argument names.   These are the names
+     of the arguments in the test function that correspond to the entries
+     in each argument tuple.   pytest.mark.parametrize requires this, however
+     the combinations function will derive it automatically if not present
+     by using ``inspect.getfullargspec(fn).args[1:]``.  Note this assumes the
+     first argument is "self" which is discarded.
+
+    :param id\_: optional id template.  This is a string template that
+     describes how the "id" for each parameter set should be defined, if any.
+     The number of characters in the template should match the number of
+     entries in each argument tuple.   Each character describes how the
+     corresponding entry in the argument tuple should be handled, as far as
+     whether or not it is included in the arguments passed to the function, as
+     well as if it is included in the tokens used to create the id of the
+     parameter set.
+
+     If omitted, the argment combinations are passed to parametrize as is.  If
+     passed, each argument combination is turned into a pytest.param() object,
+     mapping the elements of the argument tuple to produce an id based on a
+     character value in the same position within the string template using the
+     following scheme::
+
+        i - the given argument is a string that is part of the id only, don't
+            pass it as an argument
+
+        n - the given argument should be passed and it should be added to the
+            id by calling the .__name__ attribute
+
+        r - the given argument should be passed and it should be added to the
+            id by calling repr()
+
+        s-  the given argument should be passed and it should be added to the
+            id by calling str()
+
+     e.g.::
+
+        @testing.combinations(
+            (operator.eq, "eq"),
+            (operator.ne, "ne"),
+            (operator.gt, "gt"),
+            (operator.lt, "lt"),
+            id_="na"
+        )
+        def test_operator(self, opfunc, name):
+            pass
+
+    The above combination will call ``.__name__`` on the first member of
+    each tuple and use that as the "id" to pytest.param().
+
+
+    """
+    return _fixture_functions.combinations(*comb, **kw)
+
+
+def fixture(*arg, **kw):
+    return _fixture_functions.fixture(*arg, **kw)
+
 
 class Config(object):
     def __init__(self, db, db_opts, options, file_config):
@@ -94,4 +162,4 @@ class Config(object):
 
 
 def skip_test(msg):
-    raise _skip_test_exception(msg)
+    raise _fixture_functions.skip_test_exception(msg)
index 859d1d7799c2c64fbce1d68d3749b1624ada4600..a2f969a66ead422ce7955445520a199611a8008c 100644 (file)
@@ -16,6 +16,7 @@ is py.test.
 
 from __future__ import absolute_import
 
+import abc
 import re
 import sys
 
@@ -24,8 +25,15 @@ py3k = sys.version_info >= (3, 0)
 
 if py3k:
     import configparser
+
+    ABC = abc.ABC
 else:
     import ConfigParser as configparser
+    import collections as collections_abc  # noqa
+
+    class ABC(object):
+        __metaclass__ = abc.ABCMeta
+
 
 # late imports
 fixtures = None
@@ -238,14 +246,6 @@ def set_coverage_flag(value):
     options.has_coverage = value
 
 
-_skip_test_exception = None
-
-
-def set_skip_test(exc):
-    global _skip_test_exception
-    _skip_test_exception = exc
-
-
 def post_begin():
     """things to set up later, once we know coverage is running."""
     # Lazy setup of other options (post coverage)
@@ -331,10 +331,10 @@ def _monkeypatch_cdecimal(options, file_config):
 
 
 @post
-def _init_skiptest(options, file_config):
+def _init_symbols(options, file_config):
     from sqlalchemy.testing import config
 
-    config._skip_test_exception = _skip_test_exception
+    config._fixture_functions = _fixture_fn_class()
 
 
 @post
@@ -486,10 +486,10 @@ def _setup_profiling(options, file_config):
     )
 
 
-def want_class(cls):
+def want_class(name, cls):
     if not issubclass(cls, fixtures.TestBase):
         return False
-    elif cls.__name__.startswith("_"):
+    elif name.startswith("_"):
         return False
     elif (
         config.options.backend_only
@@ -711,3 +711,29 @@ def _do_skips(cls):
 
 def _setup_config(config_obj, ctx):
     config._current.push(config_obj, testing)
+
+
+class FixtureFunctions(ABC):
+    @abc.abstractmethod
+    def skip_test_exception(self, *arg, **kw):
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def combinations(self, *args, **kw):
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def param_ident(self, *args, **kw):
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def fixture(self, fn):
+        raise NotImplementedError()
+
+
+_fixture_fn_class = None
+
+
+def set_fixture_functions(fixture_fn_class):
+    global _fixture_fn_class
+    _fixture_fn_class = fixture_fn_class
index e0335c1357931f14b4fc67be3ef01a81363e2079..5d91db5d7028f02a9d55bf1bdf1dd0ab2974541d 100644 (file)
@@ -8,7 +8,11 @@ except ImportError:
 import argparse
 import collections
 import inspect
+import itertools
+import operator
 import os
+import re
+import sys
 
 import pytest
 
@@ -87,7 +91,7 @@ def pytest_configure(config):
         bool(getattr(config.option, "cov_source", False))
     )
 
-    plugin_base.set_skip_test(pytest.skip.Exception)
+    plugin_base.set_fixture_functions(PytestFixtureFunctions)
 
 
 def pytest_sessionstart(session):
@@ -132,6 +136,7 @@ def pytest_collection_modifyitems(session, config, items):
     rebuilt_items = collections.defaultdict(
         lambda: collections.defaultdict(list)
     )
+
     items[:] = [
         item
         for item in items
@@ -173,21 +178,63 @@ def pytest_collection_modifyitems(session, config, items):
 
 
 def pytest_pycollect_makeitem(collector, name, obj):
-    if inspect.isclass(obj) and plugin_base.want_class(obj):
-        return pytest.Class(name, parent=collector)
+
+    if inspect.isclass(obj) and plugin_base.want_class(name, obj):
+        return [
+            pytest.Class(parametrize_cls.__name__, parent=collector)
+            for parametrize_cls in _parametrize_cls(collector.module, obj)
+        ]
     elif (
         inspect.isfunction(obj)
         and isinstance(collector, pytest.Instance)
         and plugin_base.want_method(collector.cls, obj)
     ):
-        return pytest.Function(name, parent=collector)
+        # None means, fall back to default logic, which includes
+        # method-level parametrize
+        return None
     else:
+        # empty list means skip this item
         return []
 
 
 _current_class = None
 
 
+def _parametrize_cls(module, cls):
+    """implement a class-based version of pytest parametrize."""
+
+    if "_sa_parametrize" not in cls.__dict__:
+        return [cls]
+
+    _sa_parametrize = cls._sa_parametrize
+    classes = []
+    for full_param_set in itertools.product(
+        *[params for argname, params in _sa_parametrize]
+    ):
+        cls_variables = {}
+
+        for argname, param in zip(
+            [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
+        ):
+            if not argname:
+                raise TypeError("need argnames for class-based combinations")
+            argname_split = re.split(r",\s*", argname)
+            for arg, val in zip(argname_split, param.values):
+                cls_variables[arg] = val
+        parametrized_name = "_".join(
+            # token is a string, but in py2k py.test is giving us a unicode,
+            # so call str() on it.
+            str(re.sub(r"\W", "", token))
+            for param in full_param_set
+            for token in param.id.split("-")
+        )
+        name = "%s_%s" % (cls.__name__, parametrized_name)
+        newcls = type.__new__(type, name, (cls,), cls_variables)
+        setattr(module, name, newcls)
+        classes.append(newcls)
+    return classes
+
+
 def pytest_runtest_setup(item):
     # here we seem to get called only based on what we collected
     # in pytest_collection_modifyitems.   So to do class-based stuff
@@ -239,3 +286,99 @@ def class_setup(item):
 
 def class_teardown(item):
     plugin_base.stop_test_class(item.cls)
+
+
+def getargspec(fn):
+    if sys.version_info.major == 3:
+        return inspect.getfullargspec(fn)
+    else:
+        return inspect.getargspec(fn)
+
+
+class PytestFixtureFunctions(plugin_base.FixtureFunctions):
+    def skip_test_exception(self, *arg, **kw):
+        return pytest.skip.Exception(*arg, **kw)
+
+    _combination_id_fns = {
+        "i": lambda obj: obj,
+        "r": repr,
+        "s": str,
+        "n": operator.attrgetter("__name__"),
+    }
+
+    def combinations(self, *arg_sets, **kw):
+        """facade for pytest.mark.paramtrize.
+
+        Automatically derives argument names from the callable which in our
+        case is always a method on a class with positional arguments.
+
+        ids for parameter sets are derived using an optional template.
+
+        """
+
+        if sys.version_info.major == 3:
+            if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
+                arg_sets = list(arg_sets[0])
+        else:
+            if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
+                arg_sets = list(arg_sets[0])
+
+        argnames = kw.pop("argnames", None)
+
+        id_ = kw.pop("id_", None)
+
+        if id_:
+            _combination_id_fns = self._combination_id_fns
+
+            # because itemgetter is not consistent for one argument vs.
+            # multiple, make it multiple in all cases and use a slice
+            # to omit the first argument
+            _arg_getter = operator.itemgetter(
+                0,
+                *[
+                    idx
+                    for idx, char in enumerate(id_)
+                    if char in ("n", "r", "s", "a")
+                ]
+            )
+            fns = [
+                (operator.itemgetter(idx), _combination_id_fns[char])
+                for idx, char in enumerate(id_)
+                if char in _combination_id_fns
+            ]
+            arg_sets = [
+                pytest.param(
+                    *_arg_getter(arg)[1:],
+                    id="-".join(
+                        comb_fn(getter(arg)) for getter, comb_fn in fns
+                    )
+                )
+                for arg in arg_sets
+            ]
+        else:
+            # ensure using pytest.param so that even a 1-arg paramset
+            # still needs to be a tuple.  otherwise paramtrize tries to
+            # interpret a single arg differently than tuple arg
+            arg_sets = [pytest.param(*arg) for arg in arg_sets]
+
+        def decorate(fn):
+            if inspect.isclass(fn):
+                if "_sa_parametrize" not in fn.__dict__:
+                    fn._sa_parametrize = []
+                fn._sa_parametrize.append((argnames, arg_sets))
+                return fn
+            else:
+                if argnames is None:
+                    _argnames = getargspec(fn).args[1:]
+                else:
+                    _argnames = argnames
+                return pytest.mark.parametrize(_argnames, arg_sets)(fn)
+
+        return decorate
+
+    def param_ident(self, *parameters):
+        ident = parameters[0]
+        return pytest.param(*parameters[1:], id=ident)
+
+    def fixture(self, fn):
+        return pytest.fixture(fn)
index cbfbc63ee5778e6efc87823a0943ca40f794390b..431e53b1ba5697ce122901be657f8ea896e5ee23 100644 (file)
@@ -921,7 +921,7 @@ class MemUsageWBackendTest(EnsureZeroed):
             metadata.drop_all()
         assert_no_mappers()
 
-    @testing.expect_deprecated
+    @testing.uses_deprecated()
     @testing.provide_metadata
     def test_key_fallback_result(self):
         e = self.engine
index f430e761f8e345d2d5b4fedb26df0d723e060a46..cf06c9e263c8ae781326cab30e68699a6ae56cfc 100644 (file)
@@ -1,13 +1,13 @@
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
+from sqlalchemy import testing
 from sqlalchemy.orm import create_session
 from sqlalchemy.orm import mapper
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
-from sqlalchemy.testing.util import function_named
 
 
 class ABCTest(fixtures.MappedTest):
@@ -36,91 +36,85 @@ class ABCTest(fixtures.MappedTest):
             Column("cdata", String(30)),
         )
 
-    def _make_test(fetchtype):
-        def test_roundtrip(self):
-            class A(fixtures.ComparableEntity):
-                pass
+    @testing.combinations(("union",), ("none",))
+    def test_abc_poly_roundtrip(self, fetchtype):
+        class A(fixtures.ComparableEntity):
+            pass
 
-            class B(A):
-                pass
+        class B(A):
+            pass
 
-            class C(B):
-                pass
+        class C(B):
+            pass
 
-            if fetchtype == "union":
-                abc = a.outerjoin(b).outerjoin(c)
-                bc = a.join(b).outerjoin(c)
-            else:
-                abc = bc = None
+        if fetchtype == "union":
+            abc = a.outerjoin(b).outerjoin(c)
+            bc = a.join(b).outerjoin(c)
+        else:
+            abc = bc = None
 
-            mapper(
-                A,
-                a,
-                with_polymorphic=("*", abc),
-                polymorphic_on=a.c.type,
-                polymorphic_identity="a",
-            )
-            mapper(
-                B,
-                b,
-                with_polymorphic=("*", bc),
-                inherits=A,
-                polymorphic_identity="b",
-            )
-            mapper(C, c, inherits=B, polymorphic_identity="c")
-
-            a1 = A(adata="a1")
-            b1 = B(bdata="b1", adata="b1")
-            b2 = B(bdata="b2", adata="b2")
-            b3 = B(bdata="b3", adata="b3")
-            c1 = C(cdata="c1", bdata="c1", adata="c1")
-            c2 = C(cdata="c2", bdata="c2", adata="c2")
-            c3 = C(cdata="c2", bdata="c2", adata="c2")
-
-            sess = create_session()
-            for x in (a1, b1, b2, b3, c1, c2, c3):
-                sess.add(x)
-            sess.flush()
-            sess.expunge_all()
+        mapper(
+            A,
+            a,
+            with_polymorphic=("*", abc),
+            polymorphic_on=a.c.type,
+            polymorphic_identity="a",
+        )
+        mapper(
+            B,
+            b,
+            with_polymorphic=("*", bc),
+            inherits=A,
+            polymorphic_identity="b",
+        )
+        mapper(C, c, inherits=B, polymorphic_identity="c")
 
-            # for obj in sess.query(A).all():
-            #    print obj
-            eq_(
-                [
-                    A(adata="a1"),
-                    B(bdata="b1", adata="b1"),
-                    B(bdata="b2", adata="b2"),
-                    B(bdata="b3", adata="b3"),
-                    C(cdata="c1", bdata="c1", adata="c1"),
-                    C(cdata="c2", bdata="c2", adata="c2"),
-                    C(cdata="c2", bdata="c2", adata="c2"),
-                ],
-                sess.query(A).order_by(A.id).all(),
-            )
+        a1 = A(adata="a1")
+        b1 = B(bdata="b1", adata="b1")
+        b2 = B(bdata="b2", adata="b2")
+        b3 = B(bdata="b3", adata="b3")
+        c1 = C(cdata="c1", bdata="c1", adata="c1")
+        c2 = C(cdata="c2", bdata="c2", adata="c2")
+        c3 = C(cdata="c2", bdata="c2", adata="c2")
 
-            eq_(
-                [
-                    B(bdata="b1", adata="b1"),
-                    B(bdata="b2", adata="b2"),
-                    B(bdata="b3", adata="b3"),
-                    C(cdata="c1", bdata="c1", adata="c1"),
-                    C(cdata="c2", bdata="c2", adata="c2"),
-                    C(cdata="c2", bdata="c2", adata="c2"),
-                ],
-                sess.query(B).order_by(A.id).all(),
-            )
+        sess = create_session()
+        for x in (a1, b1, b2, b3, c1, c2, c3):
+            sess.add(x)
+        sess.flush()
+        sess.expunge_all()
 
-            eq_(
-                [
-                    C(cdata="c1", bdata="c1", adata="c1"),
-                    C(cdata="c2", bdata="c2", adata="c2"),
-                    C(cdata="c2", bdata="c2", adata="c2"),
-                ],
-                sess.query(C).order_by(A.id).all(),
-            )
+        # for obj in sess.query(A).all():
+        #    print obj
+        eq_(
+            [
+                A(adata="a1"),
+                B(bdata="b1", adata="b1"),
+                B(bdata="b2", adata="b2"),
+                B(bdata="b3", adata="b3"),
+                C(cdata="c1", bdata="c1", adata="c1"),
+                C(cdata="c2", bdata="c2", adata="c2"),
+                C(cdata="c2", bdata="c2", adata="c2"),
+            ],
+            sess.query(A).order_by(A.id).all(),
+        )
 
-        test_roundtrip = function_named(test_roundtrip, "test_%s" % fetchtype)
-        return test_roundtrip
+        eq_(
+            [
+                B(bdata="b1", adata="b1"),
+                B(bdata="b2", adata="b2"),
+                B(bdata="b3", adata="b3"),
+                C(cdata="c1", bdata="c1", adata="c1"),
+                C(cdata="c2", bdata="c2", adata="c2"),
+                C(cdata="c2", bdata="c2", adata="c2"),
+            ],
+            sess.query(B).order_by(A.id).all(),
+        )
 
-    test_union = _make_test("union")
-    test_none = _make_test("none")
+        eq_(
+            [
+                C(cdata="c1", bdata="c1", adata="c1"),
+                C(cdata="c2", bdata="c2", adata="c2"),
+                C(cdata="c2", bdata="c2", adata="c2"),
+            ],
+            sess.query(C).order_by(A.id).all(),
+        )
index 2f8677f8bf41cc59a590f63108a8686bc9e0ad28..ecab0a497d0fbf14d0274e91cb1327f9643b3590 100644 (file)
@@ -7,7 +7,6 @@ from sqlalchemy import exists
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
-from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import Sequence
 from sqlalchemy import String
@@ -15,7 +14,6 @@ from sqlalchemy import testing
 from sqlalchemy import Unicode
 from sqlalchemy import util
 from sqlalchemy.orm import class_mapper
-from sqlalchemy.orm import clear_mappers
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm import create_session
@@ -23,7 +21,6 @@ from sqlalchemy.orm import join
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import polymorphic_union
-from sqlalchemy.orm import Query
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import with_polymorphic
@@ -33,15 +30,6 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
-from sqlalchemy.testing.util import function_named
-
-
-class AttrSettable(object):
-    def __init__(self, **kwargs):
-        [setattr(self, k, v) for k, v in kwargs.items()]
-
-    def __repr__(self):
-        return self.__class__.__name__ + "(%s)" % (hex(id(self)))
 
 
 class RelationshipTest1(fixtures.MappedTest):
@@ -84,17 +72,17 @@ class RelationshipTest1(fixtures.MappedTest):
             Column("manager_name", String(50)),
         )
 
-    def teardown(self):
-        people.update(values={people.c.manager_id: None}).execute()
-        super(RelationshipTest1, self).teardown()
-
-    def test_parent_refs_descendant(self):
-        class Person(AttrSettable):
+    @classmethod
+    def setup_classes(cls):
+        class Person(cls.Comparable):
             pass
 
         class Manager(Person):
             pass
 
+    def test_parent_refs_descendant(self):
+        Person, Manager = self.classes("Person", "Manager")
+
         mapper(
             Person,
             people,
@@ -132,11 +120,7 @@ class RelationshipTest1(fixtures.MappedTest):
         assert p.manager is m
 
     def test_descendant_refs_parent(self):
-        class Person(AttrSettable):
-            pass
-
-        class Manager(Person):
-            pass
+        Person, Manager = self.classes("Person", "Manager")
 
         mapper(Person, people)
         mapper(
@@ -212,31 +196,22 @@ class RelationshipTest2(fixtures.MappedTest):
             Column("data", String(30)),
         )
 
-    def test_relationshiponsubclass_j1_nodata(self):
-        self._do_test("join1", False)
-
-    def test_relationshiponsubclass_j2_nodata(self):
-        self._do_test("join2", False)
-
-    def test_relationshiponsubclass_j1_data(self):
-        self._do_test("join1", True)
-
-    def test_relationshiponsubclass_j2_data(self):
-        self._do_test("join2", True)
-
-    def test_relationshiponsubclass_j3_nodata(self):
-        self._do_test("join3", False)
-
-    def test_relationshiponsubclass_j3_data(self):
-        self._do_test("join3", True)
-
-    def _do_test(self, jointype="join1", usedata=False):
-        class Person(AttrSettable):
+    @classmethod
+    def setup_classes(cls):
+        class Person(cls.Comparable):
             pass
 
         class Manager(Person):
             pass
 
+    @testing.combinations(
+        ("join1",), ("join2",), ("join3",), argnames="jointype"
+    )
+    @testing.combinations(
+        ("usedata", True), ("nodata", False), id_="ia", argnames="usedata"
+    )
+    def test_relationshiponsubclass(self, jointype, usedata):
+        Person, Manager = self.classes("Person", "Manager")
         if jointype == "join1":
             poly_union = polymorphic_union(
                 {
@@ -382,21 +357,20 @@ class RelationshipTest3(fixtures.MappedTest):
             Column("data", String(30)),
         )
 
-
-def _generate_test(jointype="join1", usedata=False):
-    def _do_test(self):
-        class Person(AttrSettable):
+    @classmethod
+    def setup_classes(cls):
+        class Person(cls.Comparable):
             pass
 
         class Manager(Person):
             pass
 
-        if usedata:
-
-            class Data(object):
-                def __init__(self, data):
-                    self.data = data
+        class Data(cls.Comparable):
+            def __init__(self, data):
+                self.data = data
 
+    def _setup_mappings(self, jointype, usedata):
+        Person, Manager, Data = self.classes("Person", "Manager", "Data")
         if jointype == "join1":
             poly_union = polymorphic_union(
                 {
@@ -427,6 +401,8 @@ def _generate_test(jointype="join1", usedata=False):
             poly_union = people.outerjoin(managers)
         elif jointype == "join4":
             poly_union = None
+        else:
+            assert False
 
         if usedata:
             mapper(Data, data)
@@ -475,6 +451,16 @@ def _generate_test(jointype="join1", usedata=False):
             polymorphic_identity="manager",
         )
 
+    @testing.combinations(
+        ("join1",), ("join2",), ("join3",), ("join4",), argnames="jointype"
+    )
+    @testing.combinations(
+        ("usedata", True), ("nodata", False), id_="ia", argnames="usedata"
+    )
+    def test_relationship_on_base_class(self, jointype, usedata):
+        self._setup_mappings(jointype, usedata)
+        Person, Manager, Data = self.classes("Person", "Manager", "Data")
+
         sess = create_session()
         p = Person(name="person1")
         p2 = Person(name="person2")
@@ -502,20 +488,6 @@ def _generate_test(jointype="join1", usedata=False):
             assert p.data.data == "ps data"
             assert m.data.data == "ms data"
 
-    do_test = function_named(
-        _do_test,
-        "test_relationship_on_base_class_%s_%s"
-        % (jointype, data and "nodata" or "data"),
-    )
-    return do_test
-
-
-for jointype in ["join1", "join2", "join3", "join4"]:
-    for data in (True, False):
-        _fn = _generate_test(jointype, data)
-        setattr(RelationshipTest3, _fn.__name__, _fn)
-del _fn
-
 
 class RelationshipTest4(fixtures.MappedTest):
     @classmethod
@@ -853,13 +825,17 @@ class RelationshipTest6(fixtures.MappedTest):
             Column("status", String(30)),
         )
 
-    def test_basic(self):
-        class Person(AttrSettable):
+    @classmethod
+    def setup_classes(cls):
+        class Person(cls.Comparable):
             pass
 
         class Manager(Person):
             pass
 
+    def test_basic(self):
+        Person, Manager = self.classes("Person", "Manager")
+
         mapper(Person, people)
 
         mapper(
@@ -1128,9 +1104,9 @@ class RelationshipTest8(fixtures.MappedTest):
         )
 
 
-class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
+class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults):
     @classmethod
-    def setup_class(cls):
+    def define_tables(cls, metadata):
         #  cars---owned by---  people (abstract) --- has a --- status
         #   |                  ^    ^                            |
         #   |                  |    |                            |
@@ -1138,10 +1114,8 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
         #   |                                                    |
         #   +--------------------------------------- has a ------+
 
-        global metadata, status, people, engineers, managers, cars
-        metadata = MetaData(testing.db)
         # table definitions
-        status = Table(
+        Table(
             "status",
             metadata,
             Column(
@@ -1153,7 +1127,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
             Column("name", String(20)),
         )
 
-        people = Table(
+        Table(
             "people",
             metadata,
             Column(
@@ -1171,7 +1145,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
             Column("name", String(50)),
         )
 
-        engineers = Table(
+        Table(
             "engineers",
             metadata,
             Column(
@@ -1183,7 +1157,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
             Column("field", String(30)),
         )
 
-        managers = Table(
+        Table(
             "managers",
             metadata,
             Column(
@@ -1195,7 +1169,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
             Column("category", String(70)),
         )
 
-        cars = Table(
+        Table(
             "cars",
             metadata,
             Column(
@@ -1218,52 +1192,31 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
             ),
         )
 
-        metadata.create_all()
-
     @classmethod
-    def teardown_class(cls):
-        metadata.drop_all()
-
-    def teardown(self):
-        clear_mappers()
-        for t in reversed(metadata.sorted_tables):
-            t.delete().execute()
-
-    def test_join_to(self):
-        # class definitions
-        class PersistentObject(object):
-            def __init__(self, **kwargs):
-                for key, value in kwargs.items():
-                    setattr(self, key, value)
-
-        class Status(PersistentObject):
-            def __repr__(self):
-                return "Status %s" % self.name
+    def setup_classes(cls):
+        class Status(cls.Comparable):
+            pass
 
-        class Person(PersistentObject):
-            def __repr__(self):
-                return "Ordinary person %s" % self.name
+        class Person(cls.Comparable):
+            pass
 
         class Engineer(Person):
-            def __repr__(self):
-                return "Engineer %s, field %s, status %s" % (
-                    self.name,
-                    self.field,
-                    self.status,
-                )
+            pass
 
         class Manager(Person):
-            def __repr__(self):
-                return "Manager %s, category %s, status %s" % (
-                    self.name,
-                    self.category,
-                    self.status,
-                )
+            pass
 
-        class Car(PersistentObject):
-            def __repr__(self):
-                return "Car number %d" % self.car_id
+        class Car(cls.Comparable):
+            pass
 
+    @classmethod
+    def setup_mappers(cls):
+        status, people, engineers, managers, cars = cls.tables(
+            "status", "people", "engineers", "managers", "cars"
+        )
+        Status, Person, Engineer, Manager, Car = cls.classes(
+            "Status", "Person", "Engineer", "Manager", "Car"
+        )
         # create a union that represents both types of joins.
         employee_join = polymorphic_union(
             {
@@ -1283,7 +1236,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
             polymorphic_identity="person",
             properties={"status": relationship(status_mapper)},
         )
-        engineer_mapper = mapper(
+        mapper(
             Engineer,
             engineers,
             inherits=person_mapper,
@@ -1304,6 +1257,11 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
             },
         )
 
+    @classmethod
+    def insert_data(cls):
+        Status, Person, Engineer, Manager, Car = cls.classes(
+            "Status", "Person", "Engineer", "Manager", "Car"
+        )
         session = create_session()
 
         active = Status(name="active")
@@ -1332,7 +1290,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
         session.flush()
 
         # get E4
-        engineer4 = session.query(engineer_mapper).filter_by(name="E4").one()
+        engineer4 = session.query(Engineer).filter_by(name="E4").one()
 
         # create 2 cars for E4, one active and one dead
         car1 = Car(employee=engineer4, status=active)
@@ -1341,9 +1299,11 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
         session.add(car2)
         session.flush()
 
-        # this particular adapt used to cause a recursion overflow;
-        # added here for testing
-        Query(Person)._adapt_clause(employee_join, False, False)
+    def test_join_to_q_person(self):
+        Status, Person, Engineer, Manager, Car = self.classes(
+            "Status", "Person", "Engineer", "Manager", "Car"
+        )
+        session = create_session()
 
         r = (
             session.query(Person)
@@ -1353,31 +1313,52 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults):
             .order_by(Person.person_id)
         )
         eq_(
-            str(list(r)),
-            "[Manager M2, category YYYYYYYYY, status "
-            "Status active, Engineer E2, field X, "
-            "status Status active]",
+            list(r),
+            [
+                Manager(
+                    name="M2",
+                    category="YYYYYYYYY",
+                    status=Status(name="active"),
+                ),
+                Engineer(name="E2", field="X", status=Status(name="active")),
+            ],
+        )
+
+    def test_join_to_q_engineer(self):
+        Status, Person, Engineer, Manager, Car = self.classes(
+            "Status", "Person", "Engineer", "Manager", "Car"
         )
+        session = create_session()
         r = (
             session.query(Engineer)
             .join("status")
             .filter(
                 Person.name.in_(["E2", "E3", "E4", "M4", "M2", "M1"])
-                & (status.c.name == "active")
+                & (Status.name == "active")
             )
             .order_by(Person.name)
         )
         eq_(
-            str(list(r)),
-            "[Engineer E2, field X, status Status "
-            "active, Engineer E3, field X, status "
-            "Status active]",
+            list(r),
+            [
+                Engineer(name="E2", field="X", status=Status(name="active")),
+                Engineer(name="E3", field="X", status=Status(name="active")),
+            ],
         )
 
+    def test_join_to_q_person_car(self):
+        Status, Person, Engineer, Manager, Car = self.classes(
+            "Status", "Person", "Engineer", "Manager", "Car"
+        )
+        session = create_session()
         r = session.query(Person).filter(
             exists([1], Car.owner == Person.person_id)
         )
-        eq_(str(list(r)), "[Engineer E4, field X, status Status dead]")
+
+        eq_(
+            list(r),
+            [Engineer(name="E4", field="X", status=Status(name="dead"))],
+        )
 
 
 class MultiLevelTest(fixtures.MappedTest):
index 1abfb90322660a59a4cf2935a2f44393c6d1a307..228cb1273e9bf784f6d72b24d0c831d9fe70ac99 100644 (file)
+"""A legacy test for a particular somewhat complicated mapping."""
+
 from sqlalchemy import CHAR
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
+from sqlalchemy import testing
 from sqlalchemy import Text
 from sqlalchemy.orm import backref
-from sqlalchemy.orm import create_session
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import polymorphic_union
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
-from sqlalchemy.testing.util import function_named
-
-
-class BaseObject(object):
-    def __init__(self, *args, **kwargs):
-        for key, value in kwargs.items():
-            setattr(self, key, value)
-
-
-class Publication(BaseObject):
-    pass
-
-
-class Issue(BaseObject):
-    pass
-
-
-class Location(BaseObject):
-    def __repr__(self):
-        return "%s(%s, %s)" % (
-            self.__class__.__name__,
-            str(getattr(self, "issue_id", None)),
-            repr(str(self._name.name)),
-        )
-
-    def _get_name(self):
-        return self._name
-
-    def _set_name(self, name):
-        session = create_session()
-        s = (
-            session.query(LocationName)
-            .filter(LocationName.name == name)
-            .first()
-        )
-        session.expunge_all()
-        if s is not None:
-            self._name = s
-
-            return
-
-        found = False
 
-        for i in session.new:
-            if isinstance(i, LocationName) and i.name == name:
-                self._name = i
-                found = True
 
-                break
-
-        if found is False:
-            self._name = LocationName(name=name)
-
-    name = property(_get_name, _set_name)
-
-
-class LocationName(BaseObject):
-    def __repr__(self):
-        return "%s()" % (self.__class__.__name__)
-
-
-class PageSize(BaseObject):
-    def __repr__(self):
-        return "%s(%sx%s, %s)" % (
-            self.__class__.__name__,
-            self.width,
-            self.height,
-            self.name,
-        )
+class MagazineTest(fixtures.MappedTest):
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.Comparable
 
+        class Publication(Base):
+            pass
 
-class Magazine(BaseObject):
-    def __repr__(self):
-        return "%s(%s, %s)" % (
-            self.__class__.__name__,
-            repr(self.location),
-            repr(self.size),
-        )
+        class Issue(Base):
+            pass
 
+        class Location(Base):
+            pass
 
-class Page(BaseObject):
-    def __repr__(self):
-        return "%s(%s)" % (self.__class__.__name__, str(self.page_no))
+        class LocationName(Base):
+            pass
 
+        class PageSize(Base):
+            pass
 
-class MagazinePage(Page):
-    def __repr__(self):
-        return "%s(%s, %s)" % (
-            self.__class__.__name__,
-            str(self.page_no),
-            repr(self.magazine),
-        )
+        class Magazine(Base):
+            pass
 
+        class Page(Base):
+            pass
 
-class ClassifiedPage(MagazinePage):
-    pass
+        class MagazinePage(Page):
+            pass
 
+        class ClassifiedPage(MagazinePage):
+            pass
 
-class MagazineTest(fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -198,9 +137,65 @@ class MagazineTest(fixtures.MappedTest):
             Column("name", String(45), default=""),
         )
 
+    def _generate_data(self):
+        (
+            Publication,
+            Issue,
+            Location,
+            LocationName,
+            PageSize,
+            Magazine,
+            Page,
+            MagazinePage,
+            ClassifiedPage,
+        ) = self.classes(
+            "Publication",
+            "Issue",
+            "Location",
+            "LocationName",
+            "PageSize",
+            "Magazine",
+            "Page",
+            "MagazinePage",
+            "ClassifiedPage",
+        )
+        london = LocationName(name="London")
+        pub = Publication(name="Test")
+        issue = Issue(issue=46, publication=pub)
+        location = Location(ref="ABC", name=london, issue=issue)
+
+        page_size = PageSize(name="A4", width=210, height=297)
 
-def _generate_round_trip_test(use_unions=False, use_joins=False):
-    def test_roundtrip(self):
+        magazine = Magazine(location=location, size=page_size)
+
+        ClassifiedPage(magazine=magazine, page_no=1)
+        MagazinePage(magazine=magazine, page_no=2)
+        ClassifiedPage(magazine=magazine, page_no=3)
+
+        return pub
+
+    def _setup_mapping(self, use_unions, use_joins):
+        (
+            Publication,
+            Issue,
+            Location,
+            LocationName,
+            PageSize,
+            Magazine,
+            Page,
+            MagazinePage,
+            ClassifiedPage,
+        ) = self.classes(
+            "Publication",
+            "Issue",
+            "Location",
+            "LocationName",
+            "PageSize",
+            "Magazine",
+            "Page",
+            "MagazinePage",
+            "ClassifiedPage",
+        )
         mapper(Publication, self.tables.publication)
 
         mapper(
@@ -228,7 +223,7 @@ def _generate_round_trip_test(use_unions=False, use_joins=False):
                         cascade="all, delete-orphan",
                     ),
                 ),
-                "_name": relationship(LocationName),
+                "name": relationship(LocationName),
             },
         )
 
@@ -354,42 +349,29 @@ def _generate_round_trip_test(use_unions=False, use_joins=False):
             primary_key=[self.tables.page.c.id],
         )
 
-        session = create_session()
-
-        pub = Publication(name="Test")
-        issue = Issue(issue=46, publication=pub)
-        location = Location(ref="ABC", name="London", issue=issue)
+    @testing.combinations(
+        ("unions", True, False),
+        ("joins", False, True),
+        ("plain", False, False),
+        id_="iaa",
+    )
+    def test_magazine_round_trip(self, use_unions, use_joins):
+        self._setup_mapping(use_unions, use_joins)
 
-        page_size = PageSize(name="A4", width=210, height=297)
+        Publication = self.classes.Publication
 
-        magazine = Magazine(location=location, size=page_size)
+        session = Session()
 
-        page = ClassifiedPage(magazine=magazine, page_no=1)
-        page2 = MagazinePage(magazine=magazine, page_no=2)
-        page3 = ClassifiedPage(magazine=magazine, page_no=3)
+        pub = self._generate_data()
         session.add(pub)
+        session.commit()
+        session.close()
 
-        session.flush()
-        print([x for x in session])
-        session.expunge_all()
-
-        session.flush()
-        session.expunge_all()
         p = session.query(Publication).filter(Publication.name == "Test").one()
 
-        print(p.issues[0].locations[0].magazine.pages)
-        print([page, page2, page3])
-        assert repr(p.issues[0].locations[0].magazine.pages) == repr(
-            [page, page2, page3]
-        ), repr(p.issues[0].locations[0].magazine.pages)
-
-    test_roundtrip = function_named(
-        test_roundtrip,
-        "test_%s"
-        % (not use_union and (use_joins and "joins" or "select") or "unions"),
-    )
-    setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip)
-
-
-for (use_union, use_join) in [(True, False), (False, True), (False, False)]:
-    _generate_round_trip_test(use_union, use_join)
+        test_pub = self._generate_data()
+        eq_(p, test_pub)
+        eq_(
+            p.issues[0].locations[0].magazine.pages,
+            test_pub.issues[0].locations[0].magazine.pages,
+        )
index 1cef654cd36f93925ecc05caf9bd1bb5aae1ed92..508cb99657b5f43d6e76d8ce8c25fa5a29c109b2 100644 (file)
@@ -2,9 +2,7 @@
 
 from sqlalchemy import exc as sa_exc
 from sqlalchemy import ForeignKey
-from sqlalchemy import func
 from sqlalchemy import Integer
-from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
@@ -12,12 +10,12 @@ from sqlalchemy.orm import create_session
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import polymorphic_union
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing.schema import Column
-from sqlalchemy.testing.util import function_named
 
 
 class Person(fixtures.ComparableEntity):
@@ -115,8 +113,6 @@ class PolymorphTest(fixtures.MappedTest):
             Column("golf_swing", String(30)),
         )
 
-        metadata.create_all()
-
 
 class InsertOrderTest(PolymorphTest):
     def test_insert_order(self):
@@ -198,28 +194,41 @@ class InsertOrderTest(PolymorphTest):
         eq_(session.query(Company).get(c.company_id), c)
 
 
+@testing.combinations(
+    ("lazy", True), ("nonlazy", False), argnames="lazy_relationship", id_="ia"
+)
+@testing.combinations(
+    ("redefine", True),
+    ("noredefine", False),
+    argnames="redefine_colprop",
+    id_="ia",
+)
+@testing.combinations(
+    ("unions", True),
+    ("unions", False),
+    ("joins", False),
+    ("auto", False),
+    ("none", False),
+    argnames="with_polymorphic,include_base",
+    id_="rr",
+)
 class RoundTripTest(PolymorphTest):
-    pass
-
-
-def _generate_round_trip_test(
-    include_base, lazy_relationship, redefine_colprop, with_polymorphic
-):
-    """generates a round trip test.
-
-    include_base - whether or not to include the base 'person' type in
-    the union.
+    lazy_relationship = None
+    include_base = None
+    redefine_colprop = None
+    with_polymorphic = None
 
-    lazy_relationship - whether or not the Company relationship to
-    People is lazy or eager.
+    run_inserts = "once"
+    run_deletes = None
+    run_setup_mappers = "once"
 
-    redefine_colprop - if we redefine the 'name' column to be
-    'people_name' on the base Person class
-
-    use_literal_join - primary join condition is explicitly specified
-    """
+    @classmethod
+    def setup_mappers(cls):
+        include_base = cls.include_base
+        lazy_relationship = cls.lazy_relationship
+        redefine_colprop = cls.redefine_colprop
+        with_polymorphic = cls.with_polymorphic
 
-    def test_roundtrip(self):
         if with_polymorphic == "unions":
             if include_base:
                 person_join = polymorphic_union(
@@ -308,6 +317,11 @@ def _generate_round_trip_test(
             },
         )
 
+    @classmethod
+    def insert_data(cls):
+        redefine_colprop = cls.redefine_colprop
+        include_base = cls.include_base
+
         if redefine_colprop:
             person_attribute_name = "person_name"
         else:
@@ -342,15 +356,48 @@ def _generate_round_trip_test(
             ),
         ]
 
-        dilbert = employees[1]
-
-        session = create_session()
+        session = Session()
         c = Company(name="company1")
         c.employees = employees
         session.add(c)
 
-        session.flush()
-        session.expunge_all()
+        session.commit()
+
+    @testing.fixture
+    def get_dilbert(self):
+        def run(session):
+            if self.redefine_colprop:
+                person_attribute_name = "person_name"
+            else:
+                person_attribute_name = "name"
+
+            dilbert = (
+                session.query(Engineer)
+                .filter_by(**{person_attribute_name: "dilbert"})
+                .one()
+            )
+            return dilbert
+
+        return run
+
+    def test_lazy_load(self):
+        lazy_relationship = self.lazy_relationship
+        with_polymorphic = self.with_polymorphic
+
+        if self.redefine_colprop:
+            person_attribute_name = "person_name"
+        else:
+            person_attribute_name = "name"
+
+        session = create_session()
+
+        dilbert = (
+            session.query(Engineer)
+            .filter_by(**{person_attribute_name: "dilbert"})
+            .one()
+        )
+        employees = session.query(Person).order_by(Person.person_id).all()
+        company = session.query(Company).first()
 
         eq_(session.query(Person).get(dilbert.person_id), dilbert)
         session.expunge_all()
@@ -364,20 +411,29 @@ def _generate_round_trip_test(
         session.expunge_all()
 
         def go():
-            cc = session.query(Company).get(c.company_id)
+            cc = session.query(Company).get(company.company_id)
             eq_(cc.employees, employees)
 
         if not lazy_relationship:
             if with_polymorphic != "none":
                 self.assert_sql_count(testing.db, go, 1)
             else:
-                self.assert_sql_count(testing.db, go, 5)
+                self.assert_sql_count(testing.db, go, 2)
 
         else:
             if with_polymorphic != "none":
                 self.assert_sql_count(testing.db, go, 2)
             else:
-                self.assert_sql_count(testing.db, go, 6)
+                self.assert_sql_count(testing.db, go, 3)
+
+    def test_baseclass_lookup(self, get_dilbert):
+        session = Session()
+        dilbert = get_dilbert(session)
+
+        if self.redefine_colprop:
+            person_attribute_name = "person_name"
+        else:
+            person_attribute_name = "name"
 
         # test selecting from the query, using the base
         # mapped table (people) as the selection criterion.
@@ -390,12 +446,14 @@ def _generate_round_trip_test(
             dilbert,
         )
 
-        assert (
-            session.query(Person)
-            .filter(getattr(Person, person_attribute_name) == "dilbert")
-            .first()
-            .person_id
-        )
+    def test_subclass_lookup(self, get_dilbert):
+        session = Session()
+        dilbert = get_dilbert(session)
+
+        if self.redefine_colprop:
+            person_attribute_name = "person_name"
+        else:
+            person_attribute_name = "name"
 
         eq_(
             session.query(Engineer)
@@ -404,6 +462,10 @@ def _generate_round_trip_test(
             dilbert,
         )
 
+    def test_baseclass_base_alias_filter(self, get_dilbert):
+        session = Session()
+        dilbert = get_dilbert(session)
+
         # test selecting from the query, joining against
         # an alias of the base "people" table.  test that
         # the "palias" alias does *not* get sucked up
@@ -419,6 +481,13 @@ def _generate_round_trip_test(
             )
             .first(),
         )
+
+    def test_subclass_base_alias_filter(self, get_dilbert):
+        session = Session()
+        dilbert = get_dilbert(session)
+
+        palias = people.alias("palias")
+
         is_(
             dilbert,
             session.query(Engineer)
@@ -428,6 +497,11 @@ def _generate_round_trip_test(
             )
             .first(),
         )
+
+    def test_baseclass_sub_table_filter(self, get_dilbert):
+        session = Session()
+        dilbert = get_dilbert(session)
+
         is_(
             dilbert,
             session.query(Person)
@@ -437,6 +511,11 @@ def _generate_round_trip_test(
             )
             .first(),
         )
+
+    def test_subclass_getitem(self, get_dilbert):
+        session = Session()
+        dilbert = get_dilbert(session)
+
         is_(
             dilbert,
             session.query(Engineer).filter(
@@ -444,17 +523,16 @@ def _generate_round_trip_test(
             )[0],
         )
 
-        session.flush()
-        session.expunge_all()
+    def test_primary_table_only_for_requery(self):
 
-        def go():
-            session.query(Person).filter(
-                getattr(Person, person_attribute_name) == "dilbert"
-            ).first()
+        session = Session()
 
-        self.assert_sql_count(testing.db, go, 1)
-        session.expunge_all()
-        dilbert = (
+        if self.redefine_colprop:
+            person_attribute_name = "person_name"
+        else:
+            person_attribute_name = "name"
+
+        dilbert = (  # noqa
             session.query(Person)
             .filter(getattr(Person, person_attribute_name) == "dilbert")
             .first()
@@ -471,7 +549,14 @@ def _generate_round_trip_test(
 
         self.assert_sql_count(testing.db, go, 1)
 
-        # test standalone orphans
+    def test_standalone_orphans(self):
+        if self.redefine_colprop:
+            person_attribute_name = "person_name"
+        else:
+            person_attribute_name = "name"
+
+        session = Session()
+
         daboss = Boss(
             status="BBB",
             manager_name="boss",
@@ -480,52 +565,3 @@ def _generate_round_trip_test(
         )
         session.add(daboss)
         assert_raises(sa_exc.DBAPIError, session.flush)
-
-        c = session.query(Company).first()
-        daboss.company = c
-        manager_list = [e for e in c.employees if isinstance(e, Manager)]
-        session.flush()
-        session.expunge_all()
-
-        eq_(
-            session.query(Manager).order_by(Manager.person_id).all(),
-            manager_list,
-        )
-        c = session.query(Company).first()
-
-        session.delete(c)
-        session.flush()
-
-        eq_(select([func.count("*")]).select_from(people).scalar(), 0)
-
-    test_roundtrip = function_named(
-        test_roundtrip,
-        "test_%s%s%s_%s"
-        % (
-            (lazy_relationship and "lazy" or "eager"),
-            (include_base and "_inclbase" or ""),
-            (redefine_colprop and "_redefcol" or ""),
-            with_polymorphic,
-        ),
-    )
-    setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip)
-
-
-for lazy_relationship in [True, False]:
-    for redefine_colprop in [True, False]:
-        for with_polymorphic_ in ["unions", "joins", "auto", "none"]:
-            if with_polymorphic_ == "unions":
-                for include_base in [True, False]:
-                    _generate_round_trip_test(
-                        include_base,
-                        lazy_relationship,
-                        redefine_colprop,
-                        with_polymorphic_,
-                    )
-            else:
-                _generate_round_trip_test(
-                    False,
-                    lazy_relationship,
-                    redefine_colprop,
-                    with_polymorphic_,
-                )
index 1baa82d3d277d431908629c83e9cc327550a7d16..7b530b928100cad528a5f51762631e87f33eb856 100644 (file)
@@ -13,7 +13,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.util import partial
 
 
-class TestDescriptor(descriptor_props.DescriptorProperty):
+class MockDescriptor(descriptor_props.DescriptorProperty):
     def __init__(
         self, cls, key, descriptor=None, doc=None, comparator_factory=None
     ):
@@ -40,7 +40,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest):
     def test_fixture(self):
         Foo = self._fixture()
 
-        d = TestDescriptor(Foo, "foo")
+        d = MockDescriptor(Foo, "foo")
         d.instrument_class(Foo.__mapper__)
 
         assert Foo.foo
@@ -50,7 +50,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest):
         prop = property(lambda self: None)
         Foo.foo = prop
 
-        d = TestDescriptor(Foo, "foo")
+        d = MockDescriptor(Foo, "foo")
         d.instrument_class(Foo.__mapper__)
 
         assert Foo().foo is None
@@ -68,7 +68,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest):
         prop = myprop(lambda self: None)
         Foo.foo = prop
 
-        d = TestDescriptor(Foo, "foo")
+        d = MockDescriptor(Foo, "foo")
         d.instrument_class(Foo.__mapper__)
 
         assert Foo().foo is None
@@ -95,7 +95,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest):
                 return column("foo") == func.upper(other)
 
         Foo = self._fixture()
-        d = TestDescriptor(Foo, "foo", comparator_factory=Comparator)
+        d = MockDescriptor(Foo, "foo", comparator_factory=Comparator)
         d.instrument_class(Foo.__mapper__)
         eq_(Foo.foo.method1(), "method1")
         eq_(Foo.foo.method2("x"), "method2")
@@ -119,7 +119,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest):
             prop = mapper._props["_name"]
             return Comparator(prop, mapper)
 
-        d = TestDescriptor(Foo, "foo", comparator_factory=comparator_factory)
+        d = MockDescriptor(Foo, "foo", comparator_factory=comparator_factory)
         d.instrument_class(Foo.__mapper__)
 
         eq_(str(Foo.foo == "ed"), "foobar(foo.name) = foobar(:foobar_1)")
index 66fe1859837ae2f2e8f6ea00726e10ac149c5c1d..637f1f8a5a4a3a032979f3bec84782e7c26b3f24 100644 (file)
@@ -73,12 +73,41 @@ class LoopOperate(operators.ColumnOperators):
 
 
 class DefaultColumnComparatorTest(fixtures.TestBase):
-    def _do_scalar_test(self, operator, compare_to):
+    @testing.combinations((operators.desc_op, desc), (operators.asc_op, asc))
+    def test_scalar(self, operator, compare_to):
         left = column("left")
         assert left.comparator.operate(operator).compare(compare_to(left))
         self._loop_test(operator)
 
-    def _do_operate_test(self, operator, right=column("right")):
+    right_column = column("right")
+
+    @testing.combinations(
+        (operators.add, right_column),
+        (operators.is_, None),
+        (operators.isnot, None),
+        (operators.is_, null()),
+        (operators.is_, true()),
+        (operators.is_, false()),
+        (operators.eq, True),
+        (operators.ne, True),
+        (operators.is_distinct_from, True),
+        (operators.is_distinct_from, False),
+        (operators.is_distinct_from, None),
+        (operators.isnot_distinct_from, True),
+        (operators.is_, True),
+        (operators.isnot, True),
+        (operators.is_, False),
+        (operators.isnot, False),
+        (operators.like_op, right_column),
+        (operators.notlike_op, right_column),
+        (operators.ilike_op, right_column),
+        (operators.notilike_op, right_column),
+        (operators.is_, right_column),
+        (operators.isnot, right_column),
+        (operators.concat_op, right_column),
+        id_="ns",
+    )
+    def test_operate(self, operator, right):
         left = column("left")
 
         assert left.comparator.operate(operator, right).compare(
@@ -109,84 +138,13 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         loop = LoopOperate()
         is_(operator(loop, *arg), operator)
 
-    def test_desc(self):
-        self._do_scalar_test(operators.desc_op, desc)
-
-    def test_asc(self):
-        self._do_scalar_test(operators.asc_op, asc)
-
-    def test_plus(self):
-        self._do_operate_test(operators.add)
-
-    def test_is_null(self):
-        self._do_operate_test(operators.is_, None)
-
-    def test_isnot_null(self):
-        self._do_operate_test(operators.isnot, None)
-
-    def test_is_null_const(self):
-        self._do_operate_test(operators.is_, null())
-
-    def test_is_true_const(self):
-        self._do_operate_test(operators.is_, true())
-
-    def test_is_false_const(self):
-        self._do_operate_test(operators.is_, false())
-
-    def test_equals_true(self):
-        self._do_operate_test(operators.eq, True)
-
-    def test_notequals_true(self):
-        self._do_operate_test(operators.ne, True)
-
-    def test_is_distinct_from_true(self):
-        self._do_operate_test(operators.is_distinct_from, True)
-
-    def test_is_distinct_from_false(self):
-        self._do_operate_test(operators.is_distinct_from, False)
-
-    def test_is_distinct_from_null(self):
-        self._do_operate_test(operators.is_distinct_from, None)
-
-    def test_isnot_distinct_from_true(self):
-        self._do_operate_test(operators.isnot_distinct_from, True)
-
-    def test_is_true(self):
-        self._do_operate_test(operators.is_, True)
-
-    def test_isnot_true(self):
-        self._do_operate_test(operators.isnot, True)
-
-    def test_is_false(self):
-        self._do_operate_test(operators.is_, False)
-
-    def test_isnot_false(self):
-        self._do_operate_test(operators.isnot, False)
-
-    def test_like(self):
-        self._do_operate_test(operators.like_op)
-
-    def test_notlike(self):
-        self._do_operate_test(operators.notlike_op)
-
-    def test_ilike(self):
-        self._do_operate_test(operators.ilike_op)
-
-    def test_notilike(self):
-        self._do_operate_test(operators.notilike_op)
-
-    def test_is(self):
-        self._do_operate_test(operators.is_)
-
-    def test_isnot(self):
-        self._do_operate_test(operators.isnot)
-
     def test_no_getitem(self):
         assert_raises_message(
             NotImplementedError,
             "Operator 'getitem' is not supported on this expression",
-            self._do_operate_test,
+            self.test_operate,
             operators.getitem,
+            column("right"),
         )
         assert_raises_message(
             NotImplementedError,
@@ -274,9 +232,6 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
             collate(left, right)
         )
 
-    def test_concat(self):
-        self._do_operate_test(operators.concat_op)
-
     def test_default_adapt(self):
         class TypeOne(TypeEngine):
             pass
@@ -329,7 +284,8 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
 class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
 
-    def _factorial_fixture(self):
+    @testing.fixture
+    def factorial(self):
         class MyInteger(Integer):
             class comparator_factory(Integer.Comparator):
                 def factorial(self):
@@ -355,24 +311,24 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         return MyInteger
 
-    def test_factorial(self):
-        col = column("somecol", self._factorial_fixture())
+    def test_factorial(self, factorial):
+        col = column("somecol", factorial())
         self.assert_compile(col.factorial(), "somecol !")
 
-    def test_double_factorial(self):
-        col = column("somecol", self._factorial_fixture())
+    def test_double_factorial(self, factorial):
+        col = column("somecol", factorial())
         self.assert_compile(col.factorial().factorial(), "somecol ! !")
 
-    def test_factorial_prefix(self):
-        col = column("somecol", self._factorial_fixture())
+    def test_factorial_prefix(self, factorial):
+        col = column("somecol", factorial())
         self.assert_compile(col.factorial_prefix(), "!! somecol")
 
-    def test_factorial_invert(self):
-        col = column("somecol", self._factorial_fixture())
+    def test_factorial_invert(self, factorial):
+        col = column("somecol", factorial())
         self.assert_compile(~col, "!!! somecol")
 
-    def test_double_factorial_invert(self):
-        col = column("somecol", self._factorial_fixture())
+    def test_double_factorial_invert(self, factorial):
+        col = column("somecol", factorial())
         self.assert_compile(~(~col), "!!! (!!! somecol)")
 
     def test_unary_no_ops(self):
@@ -1845,7 +1801,15 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
     table1 = table("mytable", column("myid", Integer))
 
-    def _test_math_op(self, py_op, sql_op):
+    @testing.combinations(
+        ("add", operator.add, "+"),
+        ("mul", operator.mul, "*"),
+        ("sub", operator.sub, "-"),
+        ("div", operator.truediv if util.py3k else operator.div, "/"),
+        ("mod", operator.mod, "%"),
+        id_="iaa",
+    )
+    def test_math_op(self, py_op, sql_op):
         for (lhs, rhs, res) in (
             (5, self.table1.c.myid, ":myid_1 %s mytable.myid"),
             (5, literal(5), ":param_1 %s :param_2"),
@@ -1862,24 +1826,6 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         ):
             self.assert_compile(py_op(lhs, rhs), res % sql_op)
 
-    def test_math_op_add(self):
-        self._test_math_op(operator.add, "+")
-
-    def test_math_op_mul(self):
-        self._test_math_op(operator.mul, "*")
-
-    def test_math_op_sub(self):
-        self._test_math_op(operator.sub, "-")
-
-    def test_math_op_div(self):
-        if util.py3k:
-            self._test_math_op(operator.truediv, "/")
-        else:
-            self._test_math_op(operator.div, "/")
-
-    def test_math_op_mod(self):
-        self._test_math_op(operator.mod, "%")
-
 
 class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
@@ -1898,7 +1844,16 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         clause = tuple_(1, 2, 3)
         eq_(str(clause), str(util.pickle.loads(util.pickle.dumps(clause))))
 
-    def _test_comparison_op(self, py_op, fwd_op, rev_op):
+    @testing.combinations(
+        (operator.lt, "<", ">"),
+        (operator.gt, ">", "<"),
+        (operator.eq, "=", "="),
+        (operator.ne, "!=", "!="),
+        (operator.le, "<=", ">="),
+        (operator.ge, ">=", "<="),
+        id_="naa",
+    )
+    def test_comparison_op(self, py_op, fwd_op, rev_op):
         dt = datetime.datetime(2012, 5, 10, 15, 27, 18)
         for (lhs, rhs, l_sql, r_sql) in (
             ("a", self.table1.c.myid, ":myid_1", "mytable.myid"),
@@ -1935,24 +1890,6 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 + "'",
             )
 
-    def test_comparison_operators_lt(self):
-        self._test_comparison_op(operator.lt, "<", ">"),
-
-    def test_comparison_operators_gt(self):
-        self._test_comparison_op(operator.gt, ">", "<")
-
-    def test_comparison_operators_eq(self):
-        self._test_comparison_op(operator.eq, "=", "=")
-
-    def test_comparison_operators_ne(self):
-        self._test_comparison_op(operator.ne, "!=", "!=")
-
-    def test_comparison_operators_le(self):
-        self._test_comparison_op(operator.le, "<=", ">=")
-
-    def test_comparison_operators_ge(self):
-        self._test_comparison_op(operator.ge, ">=", "<=")
-
 
 class NonZeroTest(fixtures.TestBase):
     def _raises(self, expr):
@@ -2690,38 +2627,39 @@ class CustomOpTest(fixtures.TestBase):
         assert operators.is_comparison(op1)
         assert not operators.is_comparison(op2)
 
-    def test_return_types(self):
+    @testing.combinations(
+        (sqltypes.NULLTYPE,),
+        (Integer(),),
+        (ARRAY(String),),
+        (String(50),),
+        (Boolean(),),
+        (DateTime(),),
+        (sqltypes.JSON(),),
+        (postgresql.ARRAY(Integer),),
+        (sqltypes.Numeric(5, 2),),
+        id_="r",
+    )
+    def test_return_types(self, typ):
         some_return_type = sqltypes.DECIMAL()
 
-        for typ in [
-            sqltypes.NULLTYPE,
-            Integer(),
-            ARRAY(String),
-            String(50),
-            Boolean(),
-            DateTime(),
-            sqltypes.JSON(),
-            postgresql.ARRAY(Integer),
-            sqltypes.Numeric(5, 2),
-        ]:
-            c = column("x", typ)
-            expr = c.op("$", is_comparison=True)(None)
-            is_(expr.type, sqltypes.BOOLEANTYPE)
+        c = column("x", typ)
+        expr = c.op("$", is_comparison=True)(None)
+        is_(expr.type, sqltypes.BOOLEANTYPE)
 
-            c = column("x", typ)
-            expr = c.bool_op("$")(None)
-            is_(expr.type, sqltypes.BOOLEANTYPE)
+        c = column("x", typ)
+        expr = c.bool_op("$")(None)
+        is_(expr.type, sqltypes.BOOLEANTYPE)
 
-            expr = c.op("$")(None)
-            is_(expr.type, typ)
+        expr = c.op("$")(None)
+        is_(expr.type, typ)
 
-            expr = c.op("$", return_type=some_return_type)(None)
-            is_(expr.type, some_return_type)
+        expr = c.op("$", return_type=some_return_type)(None)
+        is_(expr.type, some_return_type)
 
-            expr = c.op("$", is_comparison=True, return_type=some_return_type)(
-                None
-            )
-            is_(expr.type, some_return_type)
+        expr = c.op("$", is_comparison=True, return_type=some_return_type)(
+            None
+        )
+        is_(expr.type, some_return_type)
 
 
 class TupleTypingTest(fixtures.TestBase):
@@ -2756,7 +2694,8 @@ class TupleTypingTest(fixtures.TestBase):
 class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
 
-    def _fixture(self):
+    @testing.fixture
+    def t_fixture(self):
         m = MetaData()
 
         t = Table(
@@ -2767,8 +2706,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
         return t
 
-    def test_any_array(self):
-        t = self._fixture()
+    def test_any_array(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5 == any_(t.c.arrval),
@@ -2776,8 +2715,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"param_1": 5},
         )
 
-    def test_any_array_method(self):
-        t = self._fixture()
+    def test_any_array_method(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5 == t.c.arrval.any_(),
@@ -2785,8 +2724,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"param_1": 5},
         )
 
-    def test_all_array(self):
-        t = self._fixture()
+    def test_all_array(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5 == all_(t.c.arrval),
@@ -2794,8 +2733,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"param_1": 5},
         )
 
-    def test_all_array_method(self):
-        t = self._fixture()
+    def test_all_array_method(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5 == t.c.arrval.all_(),
@@ -2803,8 +2742,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"param_1": 5},
         )
 
-    def test_any_comparator_array(self):
-        t = self._fixture()
+    def test_any_comparator_array(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5 > any_(t.c.arrval),
@@ -2812,8 +2751,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"param_1": 5},
         )
 
-    def test_all_comparator_array(self):
-        t = self._fixture()
+    def test_all_comparator_array(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5 > all_(t.c.arrval),
@@ -2821,8 +2760,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"param_1": 5},
         )
 
-    def test_any_comparator_array_wexpr(self):
-        t = self._fixture()
+    def test_any_comparator_array_wexpr(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             t.c.data > any_(t.c.arrval),
@@ -2830,8 +2769,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={},
         )
 
-    def test_all_comparator_array_wexpr(self):
-        t = self._fixture()
+    def test_all_comparator_array_wexpr(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             t.c.data > all_(t.c.arrval),
@@ -2839,8 +2778,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={},
         )
 
-    def test_illegal_ops(self):
-        t = self._fixture()
+    def test_illegal_ops(self, t_fixture):
+        t = t_fixture
 
         assert_raises_message(
             exc.ArgumentError,
@@ -2856,8 +2795,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             t.c.data + all_(t.c.arrval), "tab1.data + ALL (tab1.arrval)"
         )
 
-    def test_any_array_comparator_accessor(self):
-        t = self._fixture()
+    def test_any_array_comparator_accessor(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             t.c.arrval.any(5, operator.gt),
@@ -2865,8 +2804,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"param_1": 5},
         )
 
-    def test_all_array_comparator_accessor(self):
-        t = self._fixture()
+    def test_all_array_comparator_accessor(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             t.c.arrval.all(5, operator.gt),
@@ -2874,8 +2813,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"param_1": 5},
         )
 
-    def test_any_array_expression(self):
-        t = self._fixture()
+    def test_any_array_expression(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5 == any_(t.c.arrval[5:6] + postgresql.array([3, 4])),
@@ -2891,8 +2830,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             dialect="postgresql",
         )
 
-    def test_all_array_expression(self):
-        t = self._fixture()
+    def test_all_array_expression(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5 == all_(t.c.arrval[5:6] + postgresql.array([3, 4])),
@@ -2908,8 +2847,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             dialect="postgresql",
         )
 
-    def test_any_subq(self):
-        t = self._fixture()
+    def test_any_subq(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5
@@ -2919,8 +2858,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"data_1": 10, "param_1": 5},
         )
 
-    def test_any_subq_method(self):
-        t = self._fixture()
+    def test_any_subq_method(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5
@@ -2933,8 +2872,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"data_1": 10, "param_1": 5},
         )
 
-    def test_all_subq(self):
-        t = self._fixture()
+    def test_all_subq(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5
@@ -2944,8 +2883,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"data_1": 10, "param_1": 5},
         )
 
-    def test_all_subq_method(self):
-        t = self._fixture()
+    def test_all_subq_method(self, t_fixture):
+        t = t_fixture
 
         self.assert_compile(
             5
index 7bf83b461d3d3b255ec14ce5b3e1596cb11d2833..2ffdd83b74547220378c1ea5b2cfe79ad1617877 100644 (file)
@@ -5,6 +5,7 @@ import importlib
 import operator
 import os
 
+import sqlalchemy as sa
 from sqlalchemy import and_
 from sqlalchemy import ARRAY
 from sqlalchemy import BigInteger
@@ -87,42 +88,83 @@ from sqlalchemy.testing.util import round_decimal
 from sqlalchemy.util import OrderedDict
 
 
-class AdaptTest(fixtures.TestBase):
-    def _all_dialect_modules(self):
-        return [
-            importlib.import_module("sqlalchemy.dialects.%s" % d)
-            for d in dialects.__all__
-            if not d.startswith("_")
-        ]
+def _all_dialect_modules():
+    return [
+        importlib.import_module("sqlalchemy.dialects.%s" % d)
+        for d in dialects.__all__
+        if not d.startswith("_")
+    ]
 
-    def _all_dialects(self):
-        return [d.base.dialect() for d in self._all_dialect_modules()]
 
-    def _types_for_mod(self, mod):
-        for key in dir(mod):
-            typ = getattr(mod, key)
-            if not isinstance(typ, type) or not issubclass(
-                typ, types.TypeEngine
-            ):
-                continue
-            yield typ
+def _all_dialects():
+    return [d.base.dialect() for d in _all_dialect_modules()]
 
-    def _all_types(self):
-        for typ in self._types_for_mod(types):
-            yield typ
-        for dialect in self._all_dialect_modules():
-            for typ in self._types_for_mod(dialect):
-                yield typ
 
-    def test_uppercase_importable(self):
-        import sqlalchemy as sa
+def _types_for_mod(mod):
+    for key in dir(mod):
+        typ = getattr(mod, key)
+        if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine):
+            continue
+        yield typ
 
-        for typ in self._types_for_mod(types):
-            if typ.__name__ == typ.__name__.upper():
-                assert getattr(sa, typ.__name__) is typ
-                assert typ.__name__ in types.__all__
 
-    def test_uppercase_rendering(self):
+def _all_types(omit_special_types=False):
+    seen = set()
+    for typ in _types_for_mod(types):
+        if omit_special_types and typ in (
+            types.TypeDecorator,
+            types.TypeEngine,
+            types.Variant,
+        ):
+            continue
+
+        if typ in seen:
+            continue
+        seen.add(typ)
+        yield typ
+    for dialect in _all_dialect_modules():
+        for typ in _types_for_mod(dialect):
+            if typ in seen:
+                continue
+            seen.add(typ)
+            yield typ
+
+
+class AdaptTest(fixtures.TestBase):
+    @testing.combinations(((t,) for t in _types_for_mod(types)), id_="n")
+    def test_uppercase_importable(self, typ):
+        if typ.__name__ == typ.__name__.upper():
+            assert getattr(sa, typ.__name__) is typ
+            assert typ.__name__ in types.__all__
+
+    @testing.combinations(
+        ((d.name, d) for d in _all_dialects()), argnames="dialect", id_="ia"
+    )
+    @testing.combinations(
+        (REAL(), "REAL"),
+        (FLOAT(), "FLOAT"),
+        (NUMERIC(), "NUMERIC"),
+        (DECIMAL(), "DECIMAL"),
+        (INTEGER(), "INTEGER"),
+        (SMALLINT(), "SMALLINT"),
+        (TIMESTAMP(), ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE")),
+        (DATETIME(), "DATETIME"),
+        (DATE(), "DATE"),
+        (TIME(), ("TIME", "TIME WITHOUT TIME ZONE")),
+        (CLOB(), "CLOB"),
+        (VARCHAR(10), ("VARCHAR(10)", "VARCHAR(10 CHAR)")),
+        (
+            NVARCHAR(10),
+            ("NVARCHAR(10)", "NATIONAL VARCHAR(10)", "NVARCHAR2(10)"),
+        ),
+        (CHAR(), "CHAR"),
+        (NCHAR(), ("NCHAR", "NATIONAL CHAR")),
+        (BLOB(), ("BLOB", "BLOB SUB_TYPE 0")),
+        (BOOLEAN(), ("BOOLEAN", "BOOL", "INTEGER")),
+        argnames="type_, expected",
+        id_="ra",
+    )
+    def test_uppercase_rendering(self, dialect, type_, expected):
         """Test that uppercase types from types.py always render as their
         type.
 
@@ -133,51 +175,48 @@ class AdaptTest(fixtures.TestBase):
 
         """
 
-        for dialect in self._all_dialects():
-            for type_, expected in (
-                (REAL, "REAL"),
-                (FLOAT, "FLOAT"),
-                (NUMERIC, "NUMERIC"),
-                (DECIMAL, "DECIMAL"),
-                (INTEGER, "INTEGER"),
-                (SMALLINT, "SMALLINT"),
-                (TIMESTAMP, ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE")),
-                (DATETIME, "DATETIME"),
-                (DATE, "DATE"),
-                (TIME, ("TIME", "TIME WITHOUT TIME ZONE")),
-                (CLOB, "CLOB"),
-                (VARCHAR(10), ("VARCHAR(10)", "VARCHAR(10 CHAR)")),
-                (
-                    NVARCHAR(10),
-                    ("NVARCHAR(10)", "NATIONAL VARCHAR(10)", "NVARCHAR2(10)"),
-                ),
-                (CHAR, "CHAR"),
-                (NCHAR, ("NCHAR", "NATIONAL CHAR")),
-                (BLOB, ("BLOB", "BLOB SUB_TYPE 0")),
-                (BOOLEAN, ("BOOLEAN", "BOOL", "INTEGER")),
-            ):
-                if isinstance(expected, str):
-                    expected = (expected,)
+        if isinstance(expected, str):
+            expected = (expected,)
 
-                try:
-                    compiled = types.to_instance(type_).compile(
-                        dialect=dialect
-                    )
-                except NotImplementedError:
-                    continue
+        try:
+            compiled = type_.compile(dialect=dialect)
+        except NotImplementedError:
+            return
 
-                assert compiled in expected, (
-                    "%r matches none of %r for dialect %s"
-                    % (compiled, expected, dialect.name)
-                )
+        assert compiled in expected, "%r matches none of %r for dialect %s" % (
+            compiled,
+            expected,
+            dialect.name,
+        )
 
-                assert str(types.to_instance(type_)) in expected, (
-                    "default str() of type %r not expected, %r"
-                    % (type_, expected)
-                )
+        assert (
+            str(types.to_instance(type_)) in expected
+        ), "default str() of type %r not expected, %r" % (type_, expected)
+
+    def _adaptions():
+        for typ in _all_types(omit_special_types=True):
+
+            # up adapt from LowerCase to UPPERCASE,
+            # as well as to all non-sqltypes
+            up_adaptions = [typ] + typ.__subclasses__()
+            yield "%s.%s" % (
+                typ.__module__,
+                typ.__name__,
+            ), False, typ, up_adaptions
+            for subcl in typ.__subclasses__():
+                if (
+                    subcl is not typ
+                    and typ is not TypeDecorator
+                    and "sqlalchemy" in subcl.__module__
+                ):
+                    yield "%s.%s" % (
+                        subcl.__module__,
+                        subcl.__name__,
+                    ), True, subcl, [typ]
 
     @testing.uses_deprecated(".*Binary.*")
-    def test_adapt_method(self):
+    @testing.combinations(_adaptions(), id_="iaaa")
+    def test_adapt_method(self, is_down_adaption, typ, target_adaptions):
         """ensure all types have a working adapt() method,
         which creates a distinct copy.
 
@@ -190,67 +229,44 @@ class AdaptTest(fixtures.TestBase):
 
         """
 
-        def adaptions():
-            for typ in self._all_types():
-                # up adapt from LowerCase to UPPERCASE,
-                # as well as to all non-sqltypes
-                up_adaptions = [typ] + typ.__subclasses__()
-                yield False, typ, up_adaptions
-                for subcl in typ.__subclasses__():
-                    if (
-                        subcl is not typ
-                        and typ is not TypeDecorator
-                        and "sqlalchemy" in subcl.__module__
-                    ):
-                        yield True, subcl, [typ]
-
-        for is_down_adaption, typ, target_adaptions in adaptions():
-            if typ in (types.TypeDecorator, types.TypeEngine, types.Variant):
+        if issubclass(typ, ARRAY):
+            t1 = typ(String)
+        else:
+            t1 = typ()
+        for cls in target_adaptions:
+            if (is_down_adaption and issubclass(typ, sqltypes.Emulated)) or (
+                not is_down_adaption and issubclass(cls, sqltypes.Emulated)
+            ):
                 continue
-            elif issubclass(typ, ARRAY):
-                t1 = typ(String)
-            else:
-                t1 = typ()
-            for cls in target_adaptions:
-                if (
-                    is_down_adaption and issubclass(typ, sqltypes.Emulated)
-                ) or (
-                    not is_down_adaption and issubclass(cls, sqltypes.Emulated)
-                ):
-                    continue
 
-                if cls.__module__.startswith("test"):
+            # print("ADAPT %s -> %s" % (t1.__class__, cls))
+            t2 = t1.adapt(cls)
+            assert t1 is not t2
+
+            if is_down_adaption:
+                t2, t1 = t1, t2
+
+            for k in t1.__dict__:
+                if k in (
+                    "impl",
+                    "_is_oracle_number",
+                    "_create_events",
+                    "create_constraint",
+                    "inherit_schema",
+                    "schema",
+                    "metadata",
+                    "name",
+                ):
                     continue
+                # assert each value was copied, or that
+                # the adapted type has a more specific
+                # value than the original (i.e. SQL Server
+                # applies precision=24 for REAL)
+                assert (
+                    getattr(t2, k) == t1.__dict__[k] or t1.__dict__[k] is None
+                )
 
-                # print("ADAPT %s -> %s" % (t1.__class__, cls))
-                t2 = t1.adapt(cls)
-                assert t1 is not t2
-
-                if is_down_adaption:
-                    t2, t1 = t1, t2
-
-                for k in t1.__dict__:
-                    if k in (
-                        "impl",
-                        "_is_oracle_number",
-                        "_create_events",
-                        "create_constraint",
-                        "inherit_schema",
-                        "schema",
-                        "metadata",
-                        "name",
-                    ):
-                        continue
-                    # assert each value was copied, or that
-                    # the adapted type has a more specific
-                    # value than the original (i.e. SQL Server
-                    # applies precision=24 for REAL)
-                    assert (
-                        getattr(t2, k) == t1.__dict__[k]
-                        or t1.__dict__[k] is None
-                    )
-
-            eq_(t1.evaluates_none().should_evaluate_none, True)
+        eq_(t1.evaluates_none().should_evaluate_none, True)
 
     def test_python_type(self):
         eq_(types.Integer().python_type, int)
@@ -270,15 +286,13 @@ class AdaptTest(fixtures.TestBase):
         )
 
     @testing.uses_deprecated()
-    def test_repr(self):
-        for typ in self._all_types():
-            if typ in (types.TypeDecorator, types.TypeEngine, types.Variant):
-                continue
-            elif issubclass(typ, ARRAY):
-                t1 = typ(String)
-            else:
-                t1 = typ()
-            repr(t1)
+    @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
+    def test_repr(self, typ):
+        if issubclass(typ, ARRAY):
+            t1 = typ(String)
+        else:
+            t1 = typ()
+        repr(t1)
 
     def test_adapt_constructor_copy_override_kw(self):
         """test that adapt() can accept kw args that override
@@ -299,27 +313,30 @@ class AdaptTest(fixtures.TestBase):
 
 
 class TypeAffinityTest(fixtures.TestBase):
-    def test_type_affinity(self):
-        for type_, affin in [
-            (String(), String),
-            (VARCHAR(), String),
-            (Date(), Date),
-            (LargeBinary(), types._Binary),
-        ]:
-            eq_(type_._type_affinity, affin)
-
-        for t1, t2, comp in [
-            (Integer(), SmallInteger(), True),
-            (Integer(), String(), False),
-            (Integer(), Integer(), True),
-            (Text(), String(), True),
-            (Text(), Unicode(), True),
-            (LargeBinary(), Integer(), False),
-            (LargeBinary(), PickleType(), True),
-            (PickleType(), LargeBinary(), True),
-            (PickleType(), PickleType(), True),
-        ]:
-            eq_(t1._compare_type_affinity(t2), comp, "%s %s" % (t1, t2))
+    @testing.combinations(
+        (String(), String),
+        (VARCHAR(), String),
+        (Date(), Date),
+        (LargeBinary(), types._Binary),
+        id_="rn",
+    )
+    def test_type_affinity(self, type_, affin):
+        eq_(type_._type_affinity, affin)
+
+    @testing.combinations(
+        (Integer(), SmallInteger(), True),
+        (Integer(), String(), False),
+        (Integer(), Integer(), True),
+        (Text(), String(), True),
+        (Text(), Unicode(), True),
+        (LargeBinary(), Integer(), False),
+        (LargeBinary(), PickleType(), True),
+        (PickleType(), LargeBinary(), True),
+        (PickleType(), PickleType(), True),
+        id_="rra",
+    )
+    def test_compare_type_affinity(self, t1, t2, comp):
+        eq_(t1._compare_type_affinity(t2), comp, "%s %s" % (t1, t2))
 
     def test_decorator_doesnt_cache(self):
         from sqlalchemy.dialects import postgresql
@@ -340,30 +357,32 @@ class TypeAffinityTest(fixtures.TestBase):
 
 
 class PickleTypesTest(fixtures.TestBase):
-    def test_pickle_types(self):
+    @testing.combinations(
+        ("Boo", Boolean()),
+        ("Str", String()),
+        ("Tex", Text()),
+        ("Uni", Unicode()),
+        ("Int", Integer()),
+        ("Sma", SmallInteger()),
+        ("Big", BigInteger()),
+        ("Num", Numeric()),
+        ("Flo", Float()),
+        ("Dat", DateTime()),
+        ("Dat", Date()),
+        ("Tim", Time()),
+        ("Lar", LargeBinary()),
+        ("Pic", PickleType()),
+        ("Int", Interval()),
+        id_="ar",
+    )
+    def test_pickle_types(self, name, type_):
+        column_type = Column(name, type_)
+        meta = MetaData()
+        Table("foo", meta, column_type)
+
         for loads, dumps in picklers():
-            column_types = [
-                Column("Boo", Boolean()),
-                Column("Str", String()),
-                Column("Tex", Text()),
-                Column("Uni", Unicode()),
-                Column("Int", Integer()),
-                Column("Sma", SmallInteger()),
-                Column("Big", BigInteger()),
-                Column("Num", Numeric()),
-                Column("Flo", Float()),
-                Column("Dat", DateTime()),
-                Column("Dat", Date()),
-                Column("Tim", Time()),
-                Column("Lar", LargeBinary()),
-                Column("Pic", PickleType()),
-                Column("Int", Interval()),
-            ]
-            for column_type in column_types:
-                meta = MetaData()
-                Table("foo", meta, column_type)
-                loads(dumps(column_type))
-                loads(dumps(meta))
+            loads(dumps(column_type))
+            loads(dumps(meta))
 
 
 class _UserDefinedTypeFixture(object):
@@ -2414,19 +2433,19 @@ class ExpressionTest(
         expr = column("foo", CHAR) == "asdf"
         eq_(expr.right.type.__class__, CHAR)
 
-    def test_actual_literal_adapters(self):
-        for data, expected in [
-            (5, Integer),
-            (2.65, Float),
-            (True, Boolean),
-            (decimal.Decimal("2.65"), Numeric),
-            (datetime.date(2015, 7, 20), Date),
-            (datetime.time(10, 15, 20), Time),
-            (datetime.datetime(2015, 7, 20, 10, 15, 20), DateTime),
-            (datetime.timedelta(seconds=5), Interval),
-            (None, types.NullType),
-        ]:
-            is_(literal(data).type.__class__, expected)
+    @testing.combinations(
+        (5, Integer),
+        (2.65, Float),
+        (True, Boolean),
+        (decimal.Decimal("2.65"), Numeric),
+        (datetime.date(2015, 7, 20), Date),
+        (datetime.time(10, 15, 20), Time),
+        (datetime.datetime(2015, 7, 20, 10, 15, 20), DateTime),
+        (datetime.timedelta(seconds=5), Interval),
+        (None, types.NullType),
+    )
+    def test_actual_literal_adapters(self, data, expected):
+        is_(literal(data).type.__class__, expected)
 
     def test_typedec_operator_adapt(self):
         expr = test_table.c.bvalue + "hi"
@@ -2592,18 +2611,22 @@ class ExpressionTest(
         expr = column("bar", types.Interval) * column("foo", types.Numeric)
         eq_(expr.type._type_affinity, types.Interval)
 
-    def test_numerics_coercion(self):
-
-        for op in (operator.add, operator.mul, operator.truediv, operator.sub):
-            for other in (Numeric(10, 2), Integer):
-                expr = op(
-                    column("bar", types.Numeric(10, 2)), column("foo", other)
-                )
-                assert isinstance(expr.type, types.Numeric)
-                expr = op(
-                    column("foo", other), column("bar", types.Numeric(10, 2))
-                )
-                assert isinstance(expr.type, types.Numeric)
+    @testing.combinations(
+        (operator.add,),
+        (operator.mul,),
+        (operator.truediv,),
+        (operator.sub,),
+        argnames="op",
+        id_="n",
+    )
+    @testing.combinations(
+        (Numeric(10, 2),), (Integer(),), argnames="other", id_="r"
+    )
+    def test_numerics_coercion(self, op, other):
+        expr = op(column("bar", types.Numeric(10, 2)), column("foo", other))
+        assert isinstance(expr.type, types.Numeric)
+        expr = op(column("foo", other), column("bar", types.Numeric(10, 2)))
+        assert isinstance(expr.type, types.Numeric)
 
     def test_asdecimal_int_to_numeric(self):
         expr = column("a", Integer) * column("b", Numeric(asdecimal=False))