From: Mike Bayer Date: Thu, 17 Oct 2019 17:09:24 +0000 (-0400) Subject: Implement facade for pytest parametrize, fixtures, classlevel X-Git-Tag: rel_1_4_0b1~662 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ed553fffd65a063d6dbdb3770d1fa0124bd55e23;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement facade for pytest parametrize, fixtures, classlevel Add factilities to implement pytest.mark.parametrize and pytest.fixtures patterns, which largely resemble things we are already doing. Ensure a facade is used, so that the test suite remains independent of py.test, but also tailors the functions to the more limited scope in which we are using them. Additionally, create a class-based version that works from the same facade. Several old polymorphic tests as well as two of the sql test are refactored to use the new features. Change-Id: I6ef8af1dafff92534313016944d447f9439856cf References: #4896 --- diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 2b8158fbb1..4f28461e37 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -32,7 +32,9 @@ from .assertions import ne_ # noqa from .assertions import not_in_ # noqa from .assertions import startswith_ # noqa from .assertions import uses_deprecated # noqa +from .config import combinations # noqa from .config import db # noqa +from .config import fixture # noqa from .config import requirements as requires # noqa from .exclusions import _is_excluded # noqa from .exclusions import _server_version # noqa diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index f94c5b3086..87bbc6a0f2 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -6,7 +6,6 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php import collections -from unittest import SkipTest as _skip_test_exception requirements = None db = None @@ -17,6 +16,75 @@ test_schema = None test_schema_2 = None _current = None +_fixture_functions = None # installed by plugin_base + + +def combinations(*comb, **kw): + r"""Deliver multiple versions of a test based on positional combinations. + + This is a facade over pytest.mark.parametrize. + + + :param \*comb: argument combinations. These are tuples that will be passed + positionally to the decorated function. + + :param argnames: optional list of argument names. These are the names + of the arguments in the test function that correspond to the entries + in each argument tuple. pytest.mark.parametrize requires this, however + the combinations function will derive it automatically if not present + by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the + first argument is "self" which is discarded. + + :param id\_: optional id template. This is a string template that + describes how the "id" for each parameter set should be defined, if any. + The number of characters in the template should match the number of + entries in each argument tuple. Each character describes how the + corresponding entry in the argument tuple should be handled, as far as + whether or not it is included in the arguments passed to the function, as + well as if it is included in the tokens used to create the id of the + parameter set. + + If omitted, the argment combinations are passed to parametrize as is. If + passed, each argument combination is turned into a pytest.param() object, + mapping the elements of the argument tuple to produce an id based on a + character value in the same position within the string template using the + following scheme:: + + i - the given argument is a string that is part of the id only, don't + pass it as an argument + + n - the given argument should be passed and it should be added to the + id by calling the .__name__ attribute + + r - the given argument should be passed and it should be added to the + id by calling repr() + + s- the given argument should be passed and it should be added to the + id by calling str() + + e.g.:: + + @testing.combinations( + (operator.eq, "eq"), + (operator.ne, "ne"), + (operator.gt, "gt"), + (operator.lt, "lt"), + id_="na" + ) + def test_operator(self, opfunc, name): + pass + + The above combination will call ``.__name__`` on the first member of + each tuple and use that as the "id" to pytest.param(). + + + """ + return _fixture_functions.combinations(*comb, **kw) + + +def fixture(*arg, **kw): + return _fixture_functions.fixture(*arg, **kw) + class Config(object): def __init__(self, db, db_opts, options, file_config): @@ -94,4 +162,4 @@ class Config(object): def skip_test(msg): - raise _skip_test_exception(msg) + raise _fixture_functions.skip_test_exception(msg) diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 859d1d7799..a2f969a66e 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -16,6 +16,7 @@ is py.test. from __future__ import absolute_import +import abc import re import sys @@ -24,8 +25,15 @@ py3k = sys.version_info >= (3, 0) if py3k: import configparser + + ABC = abc.ABC else: import ConfigParser as configparser + import collections as collections_abc # noqa + + class ABC(object): + __metaclass__ = abc.ABCMeta + # late imports fixtures = None @@ -238,14 +246,6 @@ def set_coverage_flag(value): options.has_coverage = value -_skip_test_exception = None - - -def set_skip_test(exc): - global _skip_test_exception - _skip_test_exception = exc - - def post_begin(): """things to set up later, once we know coverage is running.""" # Lazy setup of other options (post coverage) @@ -331,10 +331,10 @@ def _monkeypatch_cdecimal(options, file_config): @post -def _init_skiptest(options, file_config): +def _init_symbols(options, file_config): from sqlalchemy.testing import config - config._skip_test_exception = _skip_test_exception + config._fixture_functions = _fixture_fn_class() @post @@ -486,10 +486,10 @@ def _setup_profiling(options, file_config): ) -def want_class(cls): +def want_class(name, cls): if not issubclass(cls, fixtures.TestBase): return False - elif cls.__name__.startswith("_"): + elif name.startswith("_"): return False elif ( config.options.backend_only @@ -711,3 +711,29 @@ def _do_skips(cls): def _setup_config(config_obj, ctx): config._current.push(config_obj, testing) + + +class FixtureFunctions(ABC): + @abc.abstractmethod + def skip_test_exception(self, *arg, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def combinations(self, *args, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def param_ident(self, *args, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def fixture(self, fn): + raise NotImplementedError() + + +_fixture_fn_class = None + + +def set_fixture_functions(fixture_fn_class): + global _fixture_fn_class + _fixture_fn_class = fixture_fn_class diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index e0335c1357..5d91db5d70 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -8,7 +8,11 @@ except ImportError: import argparse import collections import inspect +import itertools +import operator import os +import re +import sys import pytest @@ -87,7 +91,7 @@ def pytest_configure(config): bool(getattr(config.option, "cov_source", False)) ) - plugin_base.set_skip_test(pytest.skip.Exception) + plugin_base.set_fixture_functions(PytestFixtureFunctions) def pytest_sessionstart(session): @@ -132,6 +136,7 @@ def pytest_collection_modifyitems(session, config, items): rebuilt_items = collections.defaultdict( lambda: collections.defaultdict(list) ) + items[:] = [ item for item in items @@ -173,21 +178,63 @@ def pytest_collection_modifyitems(session, config, items): def pytest_pycollect_makeitem(collector, name, obj): - if inspect.isclass(obj) and plugin_base.want_class(obj): - return pytest.Class(name, parent=collector) + + if inspect.isclass(obj) and plugin_base.want_class(name, obj): + return [ + pytest.Class(parametrize_cls.__name__, parent=collector) + for parametrize_cls in _parametrize_cls(collector.module, obj) + ] elif ( inspect.isfunction(obj) and isinstance(collector, pytest.Instance) and plugin_base.want_method(collector.cls, obj) ): - return pytest.Function(name, parent=collector) + # None means, fall back to default logic, which includes + # method-level parametrize + return None else: + # empty list means skip this item return [] _current_class = None +def _parametrize_cls(module, cls): + """implement a class-based version of pytest parametrize.""" + + if "_sa_parametrize" not in cls.__dict__: + return [cls] + + _sa_parametrize = cls._sa_parametrize + classes = [] + for full_param_set in itertools.product( + *[params for argname, params in _sa_parametrize] + ): + cls_variables = {} + + for argname, param in zip( + [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set + ): + if not argname: + raise TypeError("need argnames for class-based combinations") + argname_split = re.split(r",\s*", argname) + for arg, val in zip(argname_split, param.values): + cls_variables[arg] = val + parametrized_name = "_".join( + # token is a string, but in py2k py.test is giving us a unicode, + # so call str() on it. + str(re.sub(r"\W", "", token)) + for param in full_param_set + for token in param.id.split("-") + ) + name = "%s_%s" % (cls.__name__, parametrized_name) + newcls = type.__new__(type, name, (cls,), cls_variables) + setattr(module, name, newcls) + classes.append(newcls) + return classes + + def pytest_runtest_setup(item): # here we seem to get called only based on what we collected # in pytest_collection_modifyitems. So to do class-based stuff @@ -239,3 +286,99 @@ def class_setup(item): def class_teardown(item): plugin_base.stop_test_class(item.cls) + + +def getargspec(fn): + if sys.version_info.major == 3: + return inspect.getfullargspec(fn) + else: + return inspect.getargspec(fn) + + +class PytestFixtureFunctions(plugin_base.FixtureFunctions): + def skip_test_exception(self, *arg, **kw): + return pytest.skip.Exception(*arg, **kw) + + _combination_id_fns = { + "i": lambda obj: obj, + "r": repr, + "s": str, + "n": operator.attrgetter("__name__"), + } + + def combinations(self, *arg_sets, **kw): + """facade for pytest.mark.paramtrize. + + Automatically derives argument names from the callable which in our + case is always a method on a class with positional arguments. + + ids for parameter sets are derived using an optional template. + + """ + + if sys.version_info.major == 3: + if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"): + arg_sets = list(arg_sets[0]) + else: + if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"): + arg_sets = list(arg_sets[0]) + + argnames = kw.pop("argnames", None) + + id_ = kw.pop("id_", None) + + if id_: + _combination_id_fns = self._combination_id_fns + + # because itemgetter is not consistent for one argument vs. + # multiple, make it multiple in all cases and use a slice + # to omit the first argument + _arg_getter = operator.itemgetter( + 0, + *[ + idx + for idx, char in enumerate(id_) + if char in ("n", "r", "s", "a") + ] + ) + fns = [ + (operator.itemgetter(idx), _combination_id_fns[char]) + for idx, char in enumerate(id_) + if char in _combination_id_fns + ] + arg_sets = [ + pytest.param( + *_arg_getter(arg)[1:], + id="-".join( + comb_fn(getter(arg)) for getter, comb_fn in fns + ) + ) + for arg in arg_sets + ] + else: + # ensure using pytest.param so that even a 1-arg paramset + # still needs to be a tuple. otherwise paramtrize tries to + # interpret a single arg differently than tuple arg + arg_sets = [pytest.param(*arg) for arg in arg_sets] + + def decorate(fn): + if inspect.isclass(fn): + if "_sa_parametrize" not in fn.__dict__: + fn._sa_parametrize = [] + fn._sa_parametrize.append((argnames, arg_sets)) + return fn + else: + if argnames is None: + _argnames = getargspec(fn).args[1:] + else: + _argnames = argnames + return pytest.mark.parametrize(_argnames, arg_sets)(fn) + + return decorate + + def param_ident(self, *parameters): + ident = parameters[0] + return pytest.param(*parameters[1:], id=ident) + + def fixture(self, fn): + return pytest.fixture(fn) diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index cbfbc63ee5..431e53b1ba 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -921,7 +921,7 @@ class MemUsageWBackendTest(EnsureZeroed): metadata.drop_all() assert_no_mappers() - @testing.expect_deprecated + @testing.uses_deprecated() @testing.provide_metadata def test_key_fallback_result(self): e = self.engine diff --git a/test/orm/inheritance/test_abc_polymorphic.py b/test/orm/inheritance/test_abc_polymorphic.py index f430e761f8..cf06c9e263 100644 --- a/test/orm/inheritance/test_abc_polymorphic.py +++ b/test/orm/inheritance/test_abc_polymorphic.py @@ -1,13 +1,13 @@ from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String +from sqlalchemy import testing from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table -from sqlalchemy.testing.util import function_named class ABCTest(fixtures.MappedTest): @@ -36,91 +36,85 @@ class ABCTest(fixtures.MappedTest): Column("cdata", String(30)), ) - def _make_test(fetchtype): - def test_roundtrip(self): - class A(fixtures.ComparableEntity): - pass + @testing.combinations(("union",), ("none",)) + def test_abc_poly_roundtrip(self, fetchtype): + class A(fixtures.ComparableEntity): + pass - class B(A): - pass + class B(A): + pass - class C(B): - pass + class C(B): + pass - if fetchtype == "union": - abc = a.outerjoin(b).outerjoin(c) - bc = a.join(b).outerjoin(c) - else: - abc = bc = None + if fetchtype == "union": + abc = a.outerjoin(b).outerjoin(c) + bc = a.join(b).outerjoin(c) + else: + abc = bc = None - mapper( - A, - a, - with_polymorphic=("*", abc), - polymorphic_on=a.c.type, - polymorphic_identity="a", - ) - mapper( - B, - b, - with_polymorphic=("*", bc), - inherits=A, - polymorphic_identity="b", - ) - mapper(C, c, inherits=B, polymorphic_identity="c") - - a1 = A(adata="a1") - b1 = B(bdata="b1", adata="b1") - b2 = B(bdata="b2", adata="b2") - b3 = B(bdata="b3", adata="b3") - c1 = C(cdata="c1", bdata="c1", adata="c1") - c2 = C(cdata="c2", bdata="c2", adata="c2") - c3 = C(cdata="c2", bdata="c2", adata="c2") - - sess = create_session() - for x in (a1, b1, b2, b3, c1, c2, c3): - sess.add(x) - sess.flush() - sess.expunge_all() + mapper( + A, + a, + with_polymorphic=("*", abc), + polymorphic_on=a.c.type, + polymorphic_identity="a", + ) + mapper( + B, + b, + with_polymorphic=("*", bc), + inherits=A, + polymorphic_identity="b", + ) + mapper(C, c, inherits=B, polymorphic_identity="c") - # for obj in sess.query(A).all(): - # print obj - eq_( - [ - A(adata="a1"), - B(bdata="b1", adata="b1"), - B(bdata="b2", adata="b2"), - B(bdata="b3", adata="b3"), - C(cdata="c1", bdata="c1", adata="c1"), - C(cdata="c2", bdata="c2", adata="c2"), - C(cdata="c2", bdata="c2", adata="c2"), - ], - sess.query(A).order_by(A.id).all(), - ) + a1 = A(adata="a1") + b1 = B(bdata="b1", adata="b1") + b2 = B(bdata="b2", adata="b2") + b3 = B(bdata="b3", adata="b3") + c1 = C(cdata="c1", bdata="c1", adata="c1") + c2 = C(cdata="c2", bdata="c2", adata="c2") + c3 = C(cdata="c2", bdata="c2", adata="c2") - eq_( - [ - B(bdata="b1", adata="b1"), - B(bdata="b2", adata="b2"), - B(bdata="b3", adata="b3"), - C(cdata="c1", bdata="c1", adata="c1"), - C(cdata="c2", bdata="c2", adata="c2"), - C(cdata="c2", bdata="c2", adata="c2"), - ], - sess.query(B).order_by(A.id).all(), - ) + sess = create_session() + for x in (a1, b1, b2, b3, c1, c2, c3): + sess.add(x) + sess.flush() + sess.expunge_all() - eq_( - [ - C(cdata="c1", bdata="c1", adata="c1"), - C(cdata="c2", bdata="c2", adata="c2"), - C(cdata="c2", bdata="c2", adata="c2"), - ], - sess.query(C).order_by(A.id).all(), - ) + # for obj in sess.query(A).all(): + # print obj + eq_( + [ + A(adata="a1"), + B(bdata="b1", adata="b1"), + B(bdata="b2", adata="b2"), + B(bdata="b3", adata="b3"), + C(cdata="c1", bdata="c1", adata="c1"), + C(cdata="c2", bdata="c2", adata="c2"), + C(cdata="c2", bdata="c2", adata="c2"), + ], + sess.query(A).order_by(A.id).all(), + ) - test_roundtrip = function_named(test_roundtrip, "test_%s" % fetchtype) - return test_roundtrip + eq_( + [ + B(bdata="b1", adata="b1"), + B(bdata="b2", adata="b2"), + B(bdata="b3", adata="b3"), + C(cdata="c1", bdata="c1", adata="c1"), + C(cdata="c2", bdata="c2", adata="c2"), + C(cdata="c2", bdata="c2", adata="c2"), + ], + sess.query(B).order_by(A.id).all(), + ) - test_union = _make_test("union") - test_none = _make_test("none") + eq_( + [ + C(cdata="c1", bdata="c1", adata="c1"), + C(cdata="c2", bdata="c2", adata="c2"), + C(cdata="c2", bdata="c2", adata="c2"), + ], + sess.query(C).order_by(A.id).all(), + ) diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index 2f8677f8bf..ecab0a497d 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -7,7 +7,6 @@ from sqlalchemy import exists from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer -from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import String @@ -15,7 +14,6 @@ from sqlalchemy import testing from sqlalchemy import Unicode from sqlalchemy import util from sqlalchemy.orm import class_mapper -from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import column_property from sqlalchemy.orm import contains_eager from sqlalchemy.orm import create_session @@ -23,7 +21,6 @@ from sqlalchemy.orm import join from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import polymorphic_union -from sqlalchemy.orm import Query from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import with_polymorphic @@ -33,15 +30,6 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table -from sqlalchemy.testing.util import function_named - - -class AttrSettable(object): - def __init__(self, **kwargs): - [setattr(self, k, v) for k, v in kwargs.items()] - - def __repr__(self): - return self.__class__.__name__ + "(%s)" % (hex(id(self))) class RelationshipTest1(fixtures.MappedTest): @@ -84,17 +72,17 @@ class RelationshipTest1(fixtures.MappedTest): Column("manager_name", String(50)), ) - def teardown(self): - people.update(values={people.c.manager_id: None}).execute() - super(RelationshipTest1, self).teardown() - - def test_parent_refs_descendant(self): - class Person(AttrSettable): + @classmethod + def setup_classes(cls): + class Person(cls.Comparable): pass class Manager(Person): pass + def test_parent_refs_descendant(self): + Person, Manager = self.classes("Person", "Manager") + mapper( Person, people, @@ -132,11 +120,7 @@ class RelationshipTest1(fixtures.MappedTest): assert p.manager is m def test_descendant_refs_parent(self): - class Person(AttrSettable): - pass - - class Manager(Person): - pass + Person, Manager = self.classes("Person", "Manager") mapper(Person, people) mapper( @@ -212,31 +196,22 @@ class RelationshipTest2(fixtures.MappedTest): Column("data", String(30)), ) - def test_relationshiponsubclass_j1_nodata(self): - self._do_test("join1", False) - - def test_relationshiponsubclass_j2_nodata(self): - self._do_test("join2", False) - - def test_relationshiponsubclass_j1_data(self): - self._do_test("join1", True) - - def test_relationshiponsubclass_j2_data(self): - self._do_test("join2", True) - - def test_relationshiponsubclass_j3_nodata(self): - self._do_test("join3", False) - - def test_relationshiponsubclass_j3_data(self): - self._do_test("join3", True) - - def _do_test(self, jointype="join1", usedata=False): - class Person(AttrSettable): + @classmethod + def setup_classes(cls): + class Person(cls.Comparable): pass class Manager(Person): pass + @testing.combinations( + ("join1",), ("join2",), ("join3",), argnames="jointype" + ) + @testing.combinations( + ("usedata", True), ("nodata", False), id_="ia", argnames="usedata" + ) + def test_relationshiponsubclass(self, jointype, usedata): + Person, Manager = self.classes("Person", "Manager") if jointype == "join1": poly_union = polymorphic_union( { @@ -382,21 +357,20 @@ class RelationshipTest3(fixtures.MappedTest): Column("data", String(30)), ) - -def _generate_test(jointype="join1", usedata=False): - def _do_test(self): - class Person(AttrSettable): + @classmethod + def setup_classes(cls): + class Person(cls.Comparable): pass class Manager(Person): pass - if usedata: - - class Data(object): - def __init__(self, data): - self.data = data + class Data(cls.Comparable): + def __init__(self, data): + self.data = data + def _setup_mappings(self, jointype, usedata): + Person, Manager, Data = self.classes("Person", "Manager", "Data") if jointype == "join1": poly_union = polymorphic_union( { @@ -427,6 +401,8 @@ def _generate_test(jointype="join1", usedata=False): poly_union = people.outerjoin(managers) elif jointype == "join4": poly_union = None + else: + assert False if usedata: mapper(Data, data) @@ -475,6 +451,16 @@ def _generate_test(jointype="join1", usedata=False): polymorphic_identity="manager", ) + @testing.combinations( + ("join1",), ("join2",), ("join3",), ("join4",), argnames="jointype" + ) + @testing.combinations( + ("usedata", True), ("nodata", False), id_="ia", argnames="usedata" + ) + def test_relationship_on_base_class(self, jointype, usedata): + self._setup_mappings(jointype, usedata) + Person, Manager, Data = self.classes("Person", "Manager", "Data") + sess = create_session() p = Person(name="person1") p2 = Person(name="person2") @@ -502,20 +488,6 @@ def _generate_test(jointype="join1", usedata=False): assert p.data.data == "ps data" assert m.data.data == "ms data" - do_test = function_named( - _do_test, - "test_relationship_on_base_class_%s_%s" - % (jointype, data and "nodata" or "data"), - ) - return do_test - - -for jointype in ["join1", "join2", "join3", "join4"]: - for data in (True, False): - _fn = _generate_test(jointype, data) - setattr(RelationshipTest3, _fn.__name__, _fn) -del _fn - class RelationshipTest4(fixtures.MappedTest): @classmethod @@ -853,13 +825,17 @@ class RelationshipTest6(fixtures.MappedTest): Column("status", String(30)), ) - def test_basic(self): - class Person(AttrSettable): + @classmethod + def setup_classes(cls): + class Person(cls.Comparable): pass class Manager(Person): pass + def test_basic(self): + Person, Manager = self.classes("Person", "Manager") + mapper(Person, people) mapper( @@ -1128,9 +1104,9 @@ class RelationshipTest8(fixtures.MappedTest): ) -class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): +class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults): @classmethod - def setup_class(cls): + def define_tables(cls, metadata): # cars---owned by--- people (abstract) --- has a --- status # | ^ ^ | # | | | | @@ -1138,10 +1114,8 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): # | | # +--------------------------------------- has a ------+ - global metadata, status, people, engineers, managers, cars - metadata = MetaData(testing.db) # table definitions - status = Table( + Table( "status", metadata, Column( @@ -1153,7 +1127,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): Column("name", String(20)), ) - people = Table( + Table( "people", metadata, Column( @@ -1171,7 +1145,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): Column("name", String(50)), ) - engineers = Table( + Table( "engineers", metadata, Column( @@ -1183,7 +1157,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): Column("field", String(30)), ) - managers = Table( + Table( "managers", metadata, Column( @@ -1195,7 +1169,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): Column("category", String(70)), ) - cars = Table( + Table( "cars", metadata, Column( @@ -1218,52 +1192,31 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): ), ) - metadata.create_all() - @classmethod - def teardown_class(cls): - metadata.drop_all() - - def teardown(self): - clear_mappers() - for t in reversed(metadata.sorted_tables): - t.delete().execute() - - def test_join_to(self): - # class definitions - class PersistentObject(object): - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - class Status(PersistentObject): - def __repr__(self): - return "Status %s" % self.name + def setup_classes(cls): + class Status(cls.Comparable): + pass - class Person(PersistentObject): - def __repr__(self): - return "Ordinary person %s" % self.name + class Person(cls.Comparable): + pass class Engineer(Person): - def __repr__(self): - return "Engineer %s, field %s, status %s" % ( - self.name, - self.field, - self.status, - ) + pass class Manager(Person): - def __repr__(self): - return "Manager %s, category %s, status %s" % ( - self.name, - self.category, - self.status, - ) + pass - class Car(PersistentObject): - def __repr__(self): - return "Car number %d" % self.car_id + class Car(cls.Comparable): + pass + @classmethod + def setup_mappers(cls): + status, people, engineers, managers, cars = cls.tables( + "status", "people", "engineers", "managers", "cars" + ) + Status, Person, Engineer, Manager, Car = cls.classes( + "Status", "Person", "Engineer", "Manager", "Car" + ) # create a union that represents both types of joins. employee_join = polymorphic_union( { @@ -1283,7 +1236,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): polymorphic_identity="person", properties={"status": relationship(status_mapper)}, ) - engineer_mapper = mapper( + mapper( Engineer, engineers, inherits=person_mapper, @@ -1304,6 +1257,11 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): }, ) + @classmethod + def insert_data(cls): + Status, Person, Engineer, Manager, Car = cls.classes( + "Status", "Person", "Engineer", "Manager", "Car" + ) session = create_session() active = Status(name="active") @@ -1332,7 +1290,7 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): session.flush() # get E4 - engineer4 = session.query(engineer_mapper).filter_by(name="E4").one() + engineer4 = session.query(Engineer).filter_by(name="E4").one() # create 2 cars for E4, one active and one dead car1 = Car(employee=engineer4, status=active) @@ -1341,9 +1299,11 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): session.add(car2) session.flush() - # this particular adapt used to cause a recursion overflow; - # added here for testing - Query(Person)._adapt_clause(employee_join, False, False) + def test_join_to_q_person(self): + Status, Person, Engineer, Manager, Car = self.classes( + "Status", "Person", "Engineer", "Manager", "Car" + ) + session = create_session() r = ( session.query(Person) @@ -1353,31 +1313,52 @@ class GenerativeTest(fixtures.TestBase, AssertsExecutionResults): .order_by(Person.person_id) ) eq_( - str(list(r)), - "[Manager M2, category YYYYYYYYY, status " - "Status active, Engineer E2, field X, " - "status Status active]", + list(r), + [ + Manager( + name="M2", + category="YYYYYYYYY", + status=Status(name="active"), + ), + Engineer(name="E2", field="X", status=Status(name="active")), + ], + ) + + def test_join_to_q_engineer(self): + Status, Person, Engineer, Manager, Car = self.classes( + "Status", "Person", "Engineer", "Manager", "Car" ) + session = create_session() r = ( session.query(Engineer) .join("status") .filter( Person.name.in_(["E2", "E3", "E4", "M4", "M2", "M1"]) - & (status.c.name == "active") + & (Status.name == "active") ) .order_by(Person.name) ) eq_( - str(list(r)), - "[Engineer E2, field X, status Status " - "active, Engineer E3, field X, status " - "Status active]", + list(r), + [ + Engineer(name="E2", field="X", status=Status(name="active")), + Engineer(name="E3", field="X", status=Status(name="active")), + ], ) + def test_join_to_q_person_car(self): + Status, Person, Engineer, Manager, Car = self.classes( + "Status", "Person", "Engineer", "Manager", "Car" + ) + session = create_session() r = session.query(Person).filter( exists([1], Car.owner == Person.person_id) ) - eq_(str(list(r)), "[Engineer E4, field X, status Status dead]") + + eq_( + list(r), + [Engineer(name="E4", field="X", status=Status(name="dead"))], + ) class MultiLevelTest(fixtures.MappedTest): diff --git a/test/orm/inheritance/test_magazine.py b/test/orm/inheritance/test_magazine.py index 1abfb90322..228cb1273e 100644 --- a/test/orm/inheritance/test_magazine.py +++ b/test/orm/inheritance/test_magazine.py @@ -1,115 +1,54 @@ +"""A legacy test for a particular somewhat complicated mapping.""" + from sqlalchemy import CHAR from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String +from sqlalchemy import testing from sqlalchemy import Text from sqlalchemy.orm import backref -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session +from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table -from sqlalchemy.testing.util import function_named - - -class BaseObject(object): - def __init__(self, *args, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - -class Publication(BaseObject): - pass - - -class Issue(BaseObject): - pass - - -class Location(BaseObject): - def __repr__(self): - return "%s(%s, %s)" % ( - self.__class__.__name__, - str(getattr(self, "issue_id", None)), - repr(str(self._name.name)), - ) - - def _get_name(self): - return self._name - - def _set_name(self, name): - session = create_session() - s = ( - session.query(LocationName) - .filter(LocationName.name == name) - .first() - ) - session.expunge_all() - if s is not None: - self._name = s - - return - - found = False - for i in session.new: - if isinstance(i, LocationName) and i.name == name: - self._name = i - found = True - break - - if found is False: - self._name = LocationName(name=name) - - name = property(_get_name, _set_name) - - -class LocationName(BaseObject): - def __repr__(self): - return "%s()" % (self.__class__.__name__) - - -class PageSize(BaseObject): - def __repr__(self): - return "%s(%sx%s, %s)" % ( - self.__class__.__name__, - self.width, - self.height, - self.name, - ) +class MagazineTest(fixtures.MappedTest): + @classmethod + def setup_classes(cls): + Base = cls.Comparable + class Publication(Base): + pass -class Magazine(BaseObject): - def __repr__(self): - return "%s(%s, %s)" % ( - self.__class__.__name__, - repr(self.location), - repr(self.size), - ) + class Issue(Base): + pass + class Location(Base): + pass -class Page(BaseObject): - def __repr__(self): - return "%s(%s)" % (self.__class__.__name__, str(self.page_no)) + class LocationName(Base): + pass + class PageSize(Base): + pass -class MagazinePage(Page): - def __repr__(self): - return "%s(%s, %s)" % ( - self.__class__.__name__, - str(self.page_no), - repr(self.magazine), - ) + class Magazine(Base): + pass + class Page(Base): + pass -class ClassifiedPage(MagazinePage): - pass + class MagazinePage(Page): + pass + class ClassifiedPage(MagazinePage): + pass -class MagazineTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): Table( @@ -198,9 +137,65 @@ class MagazineTest(fixtures.MappedTest): Column("name", String(45), default=""), ) + def _generate_data(self): + ( + Publication, + Issue, + Location, + LocationName, + PageSize, + Magazine, + Page, + MagazinePage, + ClassifiedPage, + ) = self.classes( + "Publication", + "Issue", + "Location", + "LocationName", + "PageSize", + "Magazine", + "Page", + "MagazinePage", + "ClassifiedPage", + ) + london = LocationName(name="London") + pub = Publication(name="Test") + issue = Issue(issue=46, publication=pub) + location = Location(ref="ABC", name=london, issue=issue) + + page_size = PageSize(name="A4", width=210, height=297) -def _generate_round_trip_test(use_unions=False, use_joins=False): - def test_roundtrip(self): + magazine = Magazine(location=location, size=page_size) + + ClassifiedPage(magazine=magazine, page_no=1) + MagazinePage(magazine=magazine, page_no=2) + ClassifiedPage(magazine=magazine, page_no=3) + + return pub + + def _setup_mapping(self, use_unions, use_joins): + ( + Publication, + Issue, + Location, + LocationName, + PageSize, + Magazine, + Page, + MagazinePage, + ClassifiedPage, + ) = self.classes( + "Publication", + "Issue", + "Location", + "LocationName", + "PageSize", + "Magazine", + "Page", + "MagazinePage", + "ClassifiedPage", + ) mapper(Publication, self.tables.publication) mapper( @@ -228,7 +223,7 @@ def _generate_round_trip_test(use_unions=False, use_joins=False): cascade="all, delete-orphan", ), ), - "_name": relationship(LocationName), + "name": relationship(LocationName), }, ) @@ -354,42 +349,29 @@ def _generate_round_trip_test(use_unions=False, use_joins=False): primary_key=[self.tables.page.c.id], ) - session = create_session() - - pub = Publication(name="Test") - issue = Issue(issue=46, publication=pub) - location = Location(ref="ABC", name="London", issue=issue) + @testing.combinations( + ("unions", True, False), + ("joins", False, True), + ("plain", False, False), + id_="iaa", + ) + def test_magazine_round_trip(self, use_unions, use_joins): + self._setup_mapping(use_unions, use_joins) - page_size = PageSize(name="A4", width=210, height=297) + Publication = self.classes.Publication - magazine = Magazine(location=location, size=page_size) + session = Session() - page = ClassifiedPage(magazine=magazine, page_no=1) - page2 = MagazinePage(magazine=magazine, page_no=2) - page3 = ClassifiedPage(magazine=magazine, page_no=3) + pub = self._generate_data() session.add(pub) + session.commit() + session.close() - session.flush() - print([x for x in session]) - session.expunge_all() - - session.flush() - session.expunge_all() p = session.query(Publication).filter(Publication.name == "Test").one() - print(p.issues[0].locations[0].magazine.pages) - print([page, page2, page3]) - assert repr(p.issues[0].locations[0].magazine.pages) == repr( - [page, page2, page3] - ), repr(p.issues[0].locations[0].magazine.pages) - - test_roundtrip = function_named( - test_roundtrip, - "test_%s" - % (not use_union and (use_joins and "joins" or "select") or "unions"), - ) - setattr(MagazineTest, test_roundtrip.__name__, test_roundtrip) - - -for (use_union, use_join) in [(True, False), (False, True), (False, False)]: - _generate_round_trip_test(use_union, use_join) + test_pub = self._generate_data() + eq_(p, test_pub) + eq_( + p.issues[0].locations[0].magazine.pages, + test_pub.issues[0].locations[0].magazine.pages, + ) diff --git a/test/orm/inheritance/test_poly_persistence.py b/test/orm/inheritance/test_poly_persistence.py index 1cef654cd3..508cb99657 100644 --- a/test/orm/inheritance/test_poly_persistence.py +++ b/test/orm/inheritance/test_poly_persistence.py @@ -2,9 +2,7 @@ from sqlalchemy import exc as sa_exc from sqlalchemy import ForeignKey -from sqlalchemy import func from sqlalchemy import Integer -from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing @@ -12,12 +10,12 @@ from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing.schema import Column -from sqlalchemy.testing.util import function_named class Person(fixtures.ComparableEntity): @@ -115,8 +113,6 @@ class PolymorphTest(fixtures.MappedTest): Column("golf_swing", String(30)), ) - metadata.create_all() - class InsertOrderTest(PolymorphTest): def test_insert_order(self): @@ -198,28 +194,41 @@ class InsertOrderTest(PolymorphTest): eq_(session.query(Company).get(c.company_id), c) +@testing.combinations( + ("lazy", True), ("nonlazy", False), argnames="lazy_relationship", id_="ia" +) +@testing.combinations( + ("redefine", True), + ("noredefine", False), + argnames="redefine_colprop", + id_="ia", +) +@testing.combinations( + ("unions", True), + ("unions", False), + ("joins", False), + ("auto", False), + ("none", False), + argnames="with_polymorphic,include_base", + id_="rr", +) class RoundTripTest(PolymorphTest): - pass - - -def _generate_round_trip_test( - include_base, lazy_relationship, redefine_colprop, with_polymorphic -): - """generates a round trip test. - - include_base - whether or not to include the base 'person' type in - the union. + lazy_relationship = None + include_base = None + redefine_colprop = None + with_polymorphic = None - lazy_relationship - whether or not the Company relationship to - People is lazy or eager. + run_inserts = "once" + run_deletes = None + run_setup_mappers = "once" - redefine_colprop - if we redefine the 'name' column to be - 'people_name' on the base Person class - - use_literal_join - primary join condition is explicitly specified - """ + @classmethod + def setup_mappers(cls): + include_base = cls.include_base + lazy_relationship = cls.lazy_relationship + redefine_colprop = cls.redefine_colprop + with_polymorphic = cls.with_polymorphic - def test_roundtrip(self): if with_polymorphic == "unions": if include_base: person_join = polymorphic_union( @@ -308,6 +317,11 @@ def _generate_round_trip_test( }, ) + @classmethod + def insert_data(cls): + redefine_colprop = cls.redefine_colprop + include_base = cls.include_base + if redefine_colprop: person_attribute_name = "person_name" else: @@ -342,15 +356,48 @@ def _generate_round_trip_test( ), ] - dilbert = employees[1] - - session = create_session() + session = Session() c = Company(name="company1") c.employees = employees session.add(c) - session.flush() - session.expunge_all() + session.commit() + + @testing.fixture + def get_dilbert(self): + def run(session): + if self.redefine_colprop: + person_attribute_name = "person_name" + else: + person_attribute_name = "name" + + dilbert = ( + session.query(Engineer) + .filter_by(**{person_attribute_name: "dilbert"}) + .one() + ) + return dilbert + + return run + + def test_lazy_load(self): + lazy_relationship = self.lazy_relationship + with_polymorphic = self.with_polymorphic + + if self.redefine_colprop: + person_attribute_name = "person_name" + else: + person_attribute_name = "name" + + session = create_session() + + dilbert = ( + session.query(Engineer) + .filter_by(**{person_attribute_name: "dilbert"}) + .one() + ) + employees = session.query(Person).order_by(Person.person_id).all() + company = session.query(Company).first() eq_(session.query(Person).get(dilbert.person_id), dilbert) session.expunge_all() @@ -364,20 +411,29 @@ def _generate_round_trip_test( session.expunge_all() def go(): - cc = session.query(Company).get(c.company_id) + cc = session.query(Company).get(company.company_id) eq_(cc.employees, employees) if not lazy_relationship: if with_polymorphic != "none": self.assert_sql_count(testing.db, go, 1) else: - self.assert_sql_count(testing.db, go, 5) + self.assert_sql_count(testing.db, go, 2) else: if with_polymorphic != "none": self.assert_sql_count(testing.db, go, 2) else: - self.assert_sql_count(testing.db, go, 6) + self.assert_sql_count(testing.db, go, 3) + + def test_baseclass_lookup(self, get_dilbert): + session = Session() + dilbert = get_dilbert(session) + + if self.redefine_colprop: + person_attribute_name = "person_name" + else: + person_attribute_name = "name" # test selecting from the query, using the base # mapped table (people) as the selection criterion. @@ -390,12 +446,14 @@ def _generate_round_trip_test( dilbert, ) - assert ( - session.query(Person) - .filter(getattr(Person, person_attribute_name) == "dilbert") - .first() - .person_id - ) + def test_subclass_lookup(self, get_dilbert): + session = Session() + dilbert = get_dilbert(session) + + if self.redefine_colprop: + person_attribute_name = "person_name" + else: + person_attribute_name = "name" eq_( session.query(Engineer) @@ -404,6 +462,10 @@ def _generate_round_trip_test( dilbert, ) + def test_baseclass_base_alias_filter(self, get_dilbert): + session = Session() + dilbert = get_dilbert(session) + # test selecting from the query, joining against # an alias of the base "people" table. test that # the "palias" alias does *not* get sucked up @@ -419,6 +481,13 @@ def _generate_round_trip_test( ) .first(), ) + + def test_subclass_base_alias_filter(self, get_dilbert): + session = Session() + dilbert = get_dilbert(session) + + palias = people.alias("palias") + is_( dilbert, session.query(Engineer) @@ -428,6 +497,11 @@ def _generate_round_trip_test( ) .first(), ) + + def test_baseclass_sub_table_filter(self, get_dilbert): + session = Session() + dilbert = get_dilbert(session) + is_( dilbert, session.query(Person) @@ -437,6 +511,11 @@ def _generate_round_trip_test( ) .first(), ) + + def test_subclass_getitem(self, get_dilbert): + session = Session() + dilbert = get_dilbert(session) + is_( dilbert, session.query(Engineer).filter( @@ -444,17 +523,16 @@ def _generate_round_trip_test( )[0], ) - session.flush() - session.expunge_all() + def test_primary_table_only_for_requery(self): - def go(): - session.query(Person).filter( - getattr(Person, person_attribute_name) == "dilbert" - ).first() + session = Session() - self.assert_sql_count(testing.db, go, 1) - session.expunge_all() - dilbert = ( + if self.redefine_colprop: + person_attribute_name = "person_name" + else: + person_attribute_name = "name" + + dilbert = ( # noqa session.query(Person) .filter(getattr(Person, person_attribute_name) == "dilbert") .first() @@ -471,7 +549,14 @@ def _generate_round_trip_test( self.assert_sql_count(testing.db, go, 1) - # test standalone orphans + def test_standalone_orphans(self): + if self.redefine_colprop: + person_attribute_name = "person_name" + else: + person_attribute_name = "name" + + session = Session() + daboss = Boss( status="BBB", manager_name="boss", @@ -480,52 +565,3 @@ def _generate_round_trip_test( ) session.add(daboss) assert_raises(sa_exc.DBAPIError, session.flush) - - c = session.query(Company).first() - daboss.company = c - manager_list = [e for e in c.employees if isinstance(e, Manager)] - session.flush() - session.expunge_all() - - eq_( - session.query(Manager).order_by(Manager.person_id).all(), - manager_list, - ) - c = session.query(Company).first() - - session.delete(c) - session.flush() - - eq_(select([func.count("*")]).select_from(people).scalar(), 0) - - test_roundtrip = function_named( - test_roundtrip, - "test_%s%s%s_%s" - % ( - (lazy_relationship and "lazy" or "eager"), - (include_base and "_inclbase" or ""), - (redefine_colprop and "_redefcol" or ""), - with_polymorphic, - ), - ) - setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip) - - -for lazy_relationship in [True, False]: - for redefine_colprop in [True, False]: - for with_polymorphic_ in ["unions", "joins", "auto", "none"]: - if with_polymorphic_ == "unions": - for include_base in [True, False]: - _generate_round_trip_test( - include_base, - lazy_relationship, - redefine_colprop, - with_polymorphic_, - ) - else: - _generate_round_trip_test( - False, - lazy_relationship, - redefine_colprop, - with_polymorphic_, - ) diff --git a/test/orm/test_descriptor.py b/test/orm/test_descriptor.py index 1baa82d3d2..7b530b9281 100644 --- a/test/orm/test_descriptor.py +++ b/test/orm/test_descriptor.py @@ -13,7 +13,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.util import partial -class TestDescriptor(descriptor_props.DescriptorProperty): +class MockDescriptor(descriptor_props.DescriptorProperty): def __init__( self, cls, key, descriptor=None, doc=None, comparator_factory=None ): @@ -40,7 +40,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): def test_fixture(self): Foo = self._fixture() - d = TestDescriptor(Foo, "foo") + d = MockDescriptor(Foo, "foo") d.instrument_class(Foo.__mapper__) assert Foo.foo @@ -50,7 +50,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): prop = property(lambda self: None) Foo.foo = prop - d = TestDescriptor(Foo, "foo") + d = MockDescriptor(Foo, "foo") d.instrument_class(Foo.__mapper__) assert Foo().foo is None @@ -68,7 +68,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): prop = myprop(lambda self: None) Foo.foo = prop - d = TestDescriptor(Foo, "foo") + d = MockDescriptor(Foo, "foo") d.instrument_class(Foo.__mapper__) assert Foo().foo is None @@ -95,7 +95,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): return column("foo") == func.upper(other) Foo = self._fixture() - d = TestDescriptor(Foo, "foo", comparator_factory=Comparator) + d = MockDescriptor(Foo, "foo", comparator_factory=Comparator) d.instrument_class(Foo.__mapper__) eq_(Foo.foo.method1(), "method1") eq_(Foo.foo.method2("x"), "method2") @@ -119,7 +119,7 @@ class DescriptorInstrumentationTest(fixtures.ORMTest): prop = mapper._props["_name"] return Comparator(prop, mapper) - d = TestDescriptor(Foo, "foo", comparator_factory=comparator_factory) + d = MockDescriptor(Foo, "foo", comparator_factory=comparator_factory) d.instrument_class(Foo.__mapper__) eq_(str(Foo.foo == "ed"), "foobar(foo.name) = foobar(:foobar_1)") diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 66fe185983..637f1f8a5a 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -73,12 +73,41 @@ class LoopOperate(operators.ColumnOperators): class DefaultColumnComparatorTest(fixtures.TestBase): - def _do_scalar_test(self, operator, compare_to): + @testing.combinations((operators.desc_op, desc), (operators.asc_op, asc)) + def test_scalar(self, operator, compare_to): left = column("left") assert left.comparator.operate(operator).compare(compare_to(left)) self._loop_test(operator) - def _do_operate_test(self, operator, right=column("right")): + right_column = column("right") + + @testing.combinations( + (operators.add, right_column), + (operators.is_, None), + (operators.isnot, None), + (operators.is_, null()), + (operators.is_, true()), + (operators.is_, false()), + (operators.eq, True), + (operators.ne, True), + (operators.is_distinct_from, True), + (operators.is_distinct_from, False), + (operators.is_distinct_from, None), + (operators.isnot_distinct_from, True), + (operators.is_, True), + (operators.isnot, True), + (operators.is_, False), + (operators.isnot, False), + (operators.like_op, right_column), + (operators.notlike_op, right_column), + (operators.ilike_op, right_column), + (operators.notilike_op, right_column), + (operators.is_, right_column), + (operators.isnot, right_column), + (operators.concat_op, right_column), + id_="ns", + ) + def test_operate(self, operator, right): left = column("left") assert left.comparator.operate(operator, right).compare( @@ -109,84 +138,13 @@ class DefaultColumnComparatorTest(fixtures.TestBase): loop = LoopOperate() is_(operator(loop, *arg), operator) - def test_desc(self): - self._do_scalar_test(operators.desc_op, desc) - - def test_asc(self): - self._do_scalar_test(operators.asc_op, asc) - - def test_plus(self): - self._do_operate_test(operators.add) - - def test_is_null(self): - self._do_operate_test(operators.is_, None) - - def test_isnot_null(self): - self._do_operate_test(operators.isnot, None) - - def test_is_null_const(self): - self._do_operate_test(operators.is_, null()) - - def test_is_true_const(self): - self._do_operate_test(operators.is_, true()) - - def test_is_false_const(self): - self._do_operate_test(operators.is_, false()) - - def test_equals_true(self): - self._do_operate_test(operators.eq, True) - - def test_notequals_true(self): - self._do_operate_test(operators.ne, True) - - def test_is_distinct_from_true(self): - self._do_operate_test(operators.is_distinct_from, True) - - def test_is_distinct_from_false(self): - self._do_operate_test(operators.is_distinct_from, False) - - def test_is_distinct_from_null(self): - self._do_operate_test(operators.is_distinct_from, None) - - def test_isnot_distinct_from_true(self): - self._do_operate_test(operators.isnot_distinct_from, True) - - def test_is_true(self): - self._do_operate_test(operators.is_, True) - - def test_isnot_true(self): - self._do_operate_test(operators.isnot, True) - - def test_is_false(self): - self._do_operate_test(operators.is_, False) - - def test_isnot_false(self): - self._do_operate_test(operators.isnot, False) - - def test_like(self): - self._do_operate_test(operators.like_op) - - def test_notlike(self): - self._do_operate_test(operators.notlike_op) - - def test_ilike(self): - self._do_operate_test(operators.ilike_op) - - def test_notilike(self): - self._do_operate_test(operators.notilike_op) - - def test_is(self): - self._do_operate_test(operators.is_) - - def test_isnot(self): - self._do_operate_test(operators.isnot) - def test_no_getitem(self): assert_raises_message( NotImplementedError, "Operator 'getitem' is not supported on this expression", - self._do_operate_test, + self.test_operate, operators.getitem, + column("right"), ) assert_raises_message( NotImplementedError, @@ -274,9 +232,6 @@ class DefaultColumnComparatorTest(fixtures.TestBase): collate(left, right) ) - def test_concat(self): - self._do_operate_test(operators.concat_op) - def test_default_adapt(self): class TypeOne(TypeEngine): pass @@ -329,7 +284,8 @@ class DefaultColumnComparatorTest(fixtures.TestBase): class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" - def _factorial_fixture(self): + @testing.fixture + def factorial(self): class MyInteger(Integer): class comparator_factory(Integer.Comparator): def factorial(self): @@ -355,24 +311,24 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): return MyInteger - def test_factorial(self): - col = column("somecol", self._factorial_fixture()) + def test_factorial(self, factorial): + col = column("somecol", factorial()) self.assert_compile(col.factorial(), "somecol !") - def test_double_factorial(self): - col = column("somecol", self._factorial_fixture()) + def test_double_factorial(self, factorial): + col = column("somecol", factorial()) self.assert_compile(col.factorial().factorial(), "somecol ! !") - def test_factorial_prefix(self): - col = column("somecol", self._factorial_fixture()) + def test_factorial_prefix(self, factorial): + col = column("somecol", factorial()) self.assert_compile(col.factorial_prefix(), "!! somecol") - def test_factorial_invert(self): - col = column("somecol", self._factorial_fixture()) + def test_factorial_invert(self, factorial): + col = column("somecol", factorial()) self.assert_compile(~col, "!!! somecol") - def test_double_factorial_invert(self): - col = column("somecol", self._factorial_fixture()) + def test_double_factorial_invert(self, factorial): + col = column("somecol", factorial()) self.assert_compile(~(~col), "!!! (!!! somecol)") def test_unary_no_ops(self): @@ -1845,7 +1801,15 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): table1 = table("mytable", column("myid", Integer)) - def _test_math_op(self, py_op, sql_op): + @testing.combinations( + ("add", operator.add, "+"), + ("mul", operator.mul, "*"), + ("sub", operator.sub, "-"), + ("div", operator.truediv if util.py3k else operator.div, "/"), + ("mod", operator.mod, "%"), + id_="iaa", + ) + def test_math_op(self, py_op, sql_op): for (lhs, rhs, res) in ( (5, self.table1.c.myid, ":myid_1 %s mytable.myid"), (5, literal(5), ":param_1 %s :param_2"), @@ -1862,24 +1826,6 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): ): self.assert_compile(py_op(lhs, rhs), res % sql_op) - def test_math_op_add(self): - self._test_math_op(operator.add, "+") - - def test_math_op_mul(self): - self._test_math_op(operator.mul, "*") - - def test_math_op_sub(self): - self._test_math_op(operator.sub, "-") - - def test_math_op_div(self): - if util.py3k: - self._test_math_op(operator.truediv, "/") - else: - self._test_math_op(operator.div, "/") - - def test_math_op_mod(self): - self._test_math_op(operator.mod, "%") - class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -1898,7 +1844,16 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): clause = tuple_(1, 2, 3) eq_(str(clause), str(util.pickle.loads(util.pickle.dumps(clause)))) - def _test_comparison_op(self, py_op, fwd_op, rev_op): + @testing.combinations( + (operator.lt, "<", ">"), + (operator.gt, ">", "<"), + (operator.eq, "=", "="), + (operator.ne, "!=", "!="), + (operator.le, "<=", ">="), + (operator.ge, ">=", "<="), + id_="naa", + ) + def test_comparison_op(self, py_op, fwd_op, rev_op): dt = datetime.datetime(2012, 5, 10, 15, 27, 18) for (lhs, rhs, l_sql, r_sql) in ( ("a", self.table1.c.myid, ":myid_1", "mytable.myid"), @@ -1935,24 +1890,6 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): + "'", ) - def test_comparison_operators_lt(self): - self._test_comparison_op(operator.lt, "<", ">"), - - def test_comparison_operators_gt(self): - self._test_comparison_op(operator.gt, ">", "<") - - def test_comparison_operators_eq(self): - self._test_comparison_op(operator.eq, "=", "=") - - def test_comparison_operators_ne(self): - self._test_comparison_op(operator.ne, "!=", "!=") - - def test_comparison_operators_le(self): - self._test_comparison_op(operator.le, "<=", ">=") - - def test_comparison_operators_ge(self): - self._test_comparison_op(operator.ge, ">=", "<=") - class NonZeroTest(fixtures.TestBase): def _raises(self, expr): @@ -2690,38 +2627,39 @@ class CustomOpTest(fixtures.TestBase): assert operators.is_comparison(op1) assert not operators.is_comparison(op2) - def test_return_types(self): + @testing.combinations( + (sqltypes.NULLTYPE,), + (Integer(),), + (ARRAY(String),), + (String(50),), + (Boolean(),), + (DateTime(),), + (sqltypes.JSON(),), + (postgresql.ARRAY(Integer),), + (sqltypes.Numeric(5, 2),), + id_="r", + ) + def test_return_types(self, typ): some_return_type = sqltypes.DECIMAL() - for typ in [ - sqltypes.NULLTYPE, - Integer(), - ARRAY(String), - String(50), - Boolean(), - DateTime(), - sqltypes.JSON(), - postgresql.ARRAY(Integer), - sqltypes.Numeric(5, 2), - ]: - c = column("x", typ) - expr = c.op("$", is_comparison=True)(None) - is_(expr.type, sqltypes.BOOLEANTYPE) + c = column("x", typ) + expr = c.op("$", is_comparison=True)(None) + is_(expr.type, sqltypes.BOOLEANTYPE) - c = column("x", typ) - expr = c.bool_op("$")(None) - is_(expr.type, sqltypes.BOOLEANTYPE) + c = column("x", typ) + expr = c.bool_op("$")(None) + is_(expr.type, sqltypes.BOOLEANTYPE) - expr = c.op("$")(None) - is_(expr.type, typ) + expr = c.op("$")(None) + is_(expr.type, typ) - expr = c.op("$", return_type=some_return_type)(None) - is_(expr.type, some_return_type) + expr = c.op("$", return_type=some_return_type)(None) + is_(expr.type, some_return_type) - expr = c.op("$", is_comparison=True, return_type=some_return_type)( - None - ) - is_(expr.type, some_return_type) + expr = c.op("$", is_comparison=True, return_type=some_return_type)( + None + ) + is_(expr.type, some_return_type) class TupleTypingTest(fixtures.TestBase): @@ -2756,7 +2694,8 @@ class TupleTypingTest(fixtures.TestBase): class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" - def _fixture(self): + @testing.fixture + def t_fixture(self): m = MetaData() t = Table( @@ -2767,8 +2706,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) return t - def test_any_array(self): - t = self._fixture() + def test_any_array(self, t_fixture): + t = t_fixture self.assert_compile( 5 == any_(t.c.arrval), @@ -2776,8 +2715,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"param_1": 5}, ) - def test_any_array_method(self): - t = self._fixture() + def test_any_array_method(self, t_fixture): + t = t_fixture self.assert_compile( 5 == t.c.arrval.any_(), @@ -2785,8 +2724,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"param_1": 5}, ) - def test_all_array(self): - t = self._fixture() + def test_all_array(self, t_fixture): + t = t_fixture self.assert_compile( 5 == all_(t.c.arrval), @@ -2794,8 +2733,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"param_1": 5}, ) - def test_all_array_method(self): - t = self._fixture() + def test_all_array_method(self, t_fixture): + t = t_fixture self.assert_compile( 5 == t.c.arrval.all_(), @@ -2803,8 +2742,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"param_1": 5}, ) - def test_any_comparator_array(self): - t = self._fixture() + def test_any_comparator_array(self, t_fixture): + t = t_fixture self.assert_compile( 5 > any_(t.c.arrval), @@ -2812,8 +2751,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"param_1": 5}, ) - def test_all_comparator_array(self): - t = self._fixture() + def test_all_comparator_array(self, t_fixture): + t = t_fixture self.assert_compile( 5 > all_(t.c.arrval), @@ -2821,8 +2760,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"param_1": 5}, ) - def test_any_comparator_array_wexpr(self): - t = self._fixture() + def test_any_comparator_array_wexpr(self, t_fixture): + t = t_fixture self.assert_compile( t.c.data > any_(t.c.arrval), @@ -2830,8 +2769,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={}, ) - def test_all_comparator_array_wexpr(self): - t = self._fixture() + def test_all_comparator_array_wexpr(self, t_fixture): + t = t_fixture self.assert_compile( t.c.data > all_(t.c.arrval), @@ -2839,8 +2778,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={}, ) - def test_illegal_ops(self): - t = self._fixture() + def test_illegal_ops(self, t_fixture): + t = t_fixture assert_raises_message( exc.ArgumentError, @@ -2856,8 +2795,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): t.c.data + all_(t.c.arrval), "tab1.data + ALL (tab1.arrval)" ) - def test_any_array_comparator_accessor(self): - t = self._fixture() + def test_any_array_comparator_accessor(self, t_fixture): + t = t_fixture self.assert_compile( t.c.arrval.any(5, operator.gt), @@ -2865,8 +2804,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"param_1": 5}, ) - def test_all_array_comparator_accessor(self): - t = self._fixture() + def test_all_array_comparator_accessor(self, t_fixture): + t = t_fixture self.assert_compile( t.c.arrval.all(5, operator.gt), @@ -2874,8 +2813,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"param_1": 5}, ) - def test_any_array_expression(self): - t = self._fixture() + def test_any_array_expression(self, t_fixture): + t = t_fixture self.assert_compile( 5 == any_(t.c.arrval[5:6] + postgresql.array([3, 4])), @@ -2891,8 +2830,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): dialect="postgresql", ) - def test_all_array_expression(self): - t = self._fixture() + def test_all_array_expression(self, t_fixture): + t = t_fixture self.assert_compile( 5 == all_(t.c.arrval[5:6] + postgresql.array([3, 4])), @@ -2908,8 +2847,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): dialect="postgresql", ) - def test_any_subq(self): - t = self._fixture() + def test_any_subq(self, t_fixture): + t = t_fixture self.assert_compile( 5 @@ -2919,8 +2858,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"data_1": 10, "param_1": 5}, ) - def test_any_subq_method(self): - t = self._fixture() + def test_any_subq_method(self, t_fixture): + t = t_fixture self.assert_compile( 5 @@ -2933,8 +2872,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"data_1": 10, "param_1": 5}, ) - def test_all_subq(self): - t = self._fixture() + def test_all_subq(self, t_fixture): + t = t_fixture self.assert_compile( 5 @@ -2944,8 +2883,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"data_1": 10, "param_1": 5}, ) - def test_all_subq_method(self): - t = self._fixture() + def test_all_subq_method(self, t_fixture): + t = t_fixture self.assert_compile( 5 diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 7bf83b461d..2ffdd83b74 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -5,6 +5,7 @@ import importlib import operator import os +import sqlalchemy as sa from sqlalchemy import and_ from sqlalchemy import ARRAY from sqlalchemy import BigInteger @@ -87,42 +88,83 @@ from sqlalchemy.testing.util import round_decimal from sqlalchemy.util import OrderedDict -class AdaptTest(fixtures.TestBase): - def _all_dialect_modules(self): - return [ - importlib.import_module("sqlalchemy.dialects.%s" % d) - for d in dialects.__all__ - if not d.startswith("_") - ] +def _all_dialect_modules(): + return [ + importlib.import_module("sqlalchemy.dialects.%s" % d) + for d in dialects.__all__ + if not d.startswith("_") + ] - def _all_dialects(self): - return [d.base.dialect() for d in self._all_dialect_modules()] - def _types_for_mod(self, mod): - for key in dir(mod): - typ = getattr(mod, key) - if not isinstance(typ, type) or not issubclass( - typ, types.TypeEngine - ): - continue - yield typ +def _all_dialects(): + return [d.base.dialect() for d in _all_dialect_modules()] - def _all_types(self): - for typ in self._types_for_mod(types): - yield typ - for dialect in self._all_dialect_modules(): - for typ in self._types_for_mod(dialect): - yield typ - def test_uppercase_importable(self): - import sqlalchemy as sa +def _types_for_mod(mod): + for key in dir(mod): + typ = getattr(mod, key) + if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine): + continue + yield typ - for typ in self._types_for_mod(types): - if typ.__name__ == typ.__name__.upper(): - assert getattr(sa, typ.__name__) is typ - assert typ.__name__ in types.__all__ - def test_uppercase_rendering(self): +def _all_types(omit_special_types=False): + seen = set() + for typ in _types_for_mod(types): + if omit_special_types and typ in ( + types.TypeDecorator, + types.TypeEngine, + types.Variant, + ): + continue + + if typ in seen: + continue + seen.add(typ) + yield typ + for dialect in _all_dialect_modules(): + for typ in _types_for_mod(dialect): + if typ in seen: + continue + seen.add(typ) + yield typ + + +class AdaptTest(fixtures.TestBase): + @testing.combinations(((t,) for t in _types_for_mod(types)), id_="n") + def test_uppercase_importable(self, typ): + if typ.__name__ == typ.__name__.upper(): + assert getattr(sa, typ.__name__) is typ + assert typ.__name__ in types.__all__ + + @testing.combinations( + ((d.name, d) for d in _all_dialects()), argnames="dialect", id_="ia" + ) + @testing.combinations( + (REAL(), "REAL"), + (FLOAT(), "FLOAT"), + (NUMERIC(), "NUMERIC"), + (DECIMAL(), "DECIMAL"), + (INTEGER(), "INTEGER"), + (SMALLINT(), "SMALLINT"), + (TIMESTAMP(), ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE")), + (DATETIME(), "DATETIME"), + (DATE(), "DATE"), + (TIME(), ("TIME", "TIME WITHOUT TIME ZONE")), + (CLOB(), "CLOB"), + (VARCHAR(10), ("VARCHAR(10)", "VARCHAR(10 CHAR)")), + ( + NVARCHAR(10), + ("NVARCHAR(10)", "NATIONAL VARCHAR(10)", "NVARCHAR2(10)"), + ), + (CHAR(), "CHAR"), + (NCHAR(), ("NCHAR", "NATIONAL CHAR")), + (BLOB(), ("BLOB", "BLOB SUB_TYPE 0")), + (BOOLEAN(), ("BOOLEAN", "BOOL", "INTEGER")), + argnames="type_, expected", + id_="ra", + ) + def test_uppercase_rendering(self, dialect, type_, expected): """Test that uppercase types from types.py always render as their type. @@ -133,51 +175,48 @@ class AdaptTest(fixtures.TestBase): """ - for dialect in self._all_dialects(): - for type_, expected in ( - (REAL, "REAL"), - (FLOAT, "FLOAT"), - (NUMERIC, "NUMERIC"), - (DECIMAL, "DECIMAL"), - (INTEGER, "INTEGER"), - (SMALLINT, "SMALLINT"), - (TIMESTAMP, ("TIMESTAMP", "TIMESTAMP WITHOUT TIME ZONE")), - (DATETIME, "DATETIME"), - (DATE, "DATE"), - (TIME, ("TIME", "TIME WITHOUT TIME ZONE")), - (CLOB, "CLOB"), - (VARCHAR(10), ("VARCHAR(10)", "VARCHAR(10 CHAR)")), - ( - NVARCHAR(10), - ("NVARCHAR(10)", "NATIONAL VARCHAR(10)", "NVARCHAR2(10)"), - ), - (CHAR, "CHAR"), - (NCHAR, ("NCHAR", "NATIONAL CHAR")), - (BLOB, ("BLOB", "BLOB SUB_TYPE 0")), - (BOOLEAN, ("BOOLEAN", "BOOL", "INTEGER")), - ): - if isinstance(expected, str): - expected = (expected,) + if isinstance(expected, str): + expected = (expected,) - try: - compiled = types.to_instance(type_).compile( - dialect=dialect - ) - except NotImplementedError: - continue + try: + compiled = type_.compile(dialect=dialect) + except NotImplementedError: + return - assert compiled in expected, ( - "%r matches none of %r for dialect %s" - % (compiled, expected, dialect.name) - ) + assert compiled in expected, "%r matches none of %r for dialect %s" % ( + compiled, + expected, + dialect.name, + ) - assert str(types.to_instance(type_)) in expected, ( - "default str() of type %r not expected, %r" - % (type_, expected) - ) + assert ( + str(types.to_instance(type_)) in expected + ), "default str() of type %r not expected, %r" % (type_, expected) + + def _adaptions(): + for typ in _all_types(omit_special_types=True): + + # up adapt from LowerCase to UPPERCASE, + # as well as to all non-sqltypes + up_adaptions = [typ] + typ.__subclasses__() + yield "%s.%s" % ( + typ.__module__, + typ.__name__, + ), False, typ, up_adaptions + for subcl in typ.__subclasses__(): + if ( + subcl is not typ + and typ is not TypeDecorator + and "sqlalchemy" in subcl.__module__ + ): + yield "%s.%s" % ( + subcl.__module__, + subcl.__name__, + ), True, subcl, [typ] @testing.uses_deprecated(".*Binary.*") - def test_adapt_method(self): + @testing.combinations(_adaptions(), id_="iaaa") + def test_adapt_method(self, is_down_adaption, typ, target_adaptions): """ensure all types have a working adapt() method, which creates a distinct copy. @@ -190,67 +229,44 @@ class AdaptTest(fixtures.TestBase): """ - def adaptions(): - for typ in self._all_types(): - # up adapt from LowerCase to UPPERCASE, - # as well as to all non-sqltypes - up_adaptions = [typ] + typ.__subclasses__() - yield False, typ, up_adaptions - for subcl in typ.__subclasses__(): - if ( - subcl is not typ - and typ is not TypeDecorator - and "sqlalchemy" in subcl.__module__ - ): - yield True, subcl, [typ] - - for is_down_adaption, typ, target_adaptions in adaptions(): - if typ in (types.TypeDecorator, types.TypeEngine, types.Variant): + if issubclass(typ, ARRAY): + t1 = typ(String) + else: + t1 = typ() + for cls in target_adaptions: + if (is_down_adaption and issubclass(typ, sqltypes.Emulated)) or ( + not is_down_adaption and issubclass(cls, sqltypes.Emulated) + ): continue - elif issubclass(typ, ARRAY): - t1 = typ(String) - else: - t1 = typ() - for cls in target_adaptions: - if ( - is_down_adaption and issubclass(typ, sqltypes.Emulated) - ) or ( - not is_down_adaption and issubclass(cls, sqltypes.Emulated) - ): - continue - if cls.__module__.startswith("test"): + # print("ADAPT %s -> %s" % (t1.__class__, cls)) + t2 = t1.adapt(cls) + assert t1 is not t2 + + if is_down_adaption: + t2, t1 = t1, t2 + + for k in t1.__dict__: + if k in ( + "impl", + "_is_oracle_number", + "_create_events", + "create_constraint", + "inherit_schema", + "schema", + "metadata", + "name", + ): continue + # assert each value was copied, or that + # the adapted type has a more specific + # value than the original (i.e. SQL Server + # applies precision=24 for REAL) + assert ( + getattr(t2, k) == t1.__dict__[k] or t1.__dict__[k] is None + ) - # print("ADAPT %s -> %s" % (t1.__class__, cls)) - t2 = t1.adapt(cls) - assert t1 is not t2 - - if is_down_adaption: - t2, t1 = t1, t2 - - for k in t1.__dict__: - if k in ( - "impl", - "_is_oracle_number", - "_create_events", - "create_constraint", - "inherit_schema", - "schema", - "metadata", - "name", - ): - continue - # assert each value was copied, or that - # the adapted type has a more specific - # value than the original (i.e. SQL Server - # applies precision=24 for REAL) - assert ( - getattr(t2, k) == t1.__dict__[k] - or t1.__dict__[k] is None - ) - - eq_(t1.evaluates_none().should_evaluate_none, True) + eq_(t1.evaluates_none().should_evaluate_none, True) def test_python_type(self): eq_(types.Integer().python_type, int) @@ -270,15 +286,13 @@ class AdaptTest(fixtures.TestBase): ) @testing.uses_deprecated() - def test_repr(self): - for typ in self._all_types(): - if typ in (types.TypeDecorator, types.TypeEngine, types.Variant): - continue - elif issubclass(typ, ARRAY): - t1 = typ(String) - else: - t1 = typ() - repr(t1) + @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)]) + def test_repr(self, typ): + if issubclass(typ, ARRAY): + t1 = typ(String) + else: + t1 = typ() + repr(t1) def test_adapt_constructor_copy_override_kw(self): """test that adapt() can accept kw args that override @@ -299,27 +313,30 @@ class AdaptTest(fixtures.TestBase): class TypeAffinityTest(fixtures.TestBase): - def test_type_affinity(self): - for type_, affin in [ - (String(), String), - (VARCHAR(), String), - (Date(), Date), - (LargeBinary(), types._Binary), - ]: - eq_(type_._type_affinity, affin) - - for t1, t2, comp in [ - (Integer(), SmallInteger(), True), - (Integer(), String(), False), - (Integer(), Integer(), True), - (Text(), String(), True), - (Text(), Unicode(), True), - (LargeBinary(), Integer(), False), - (LargeBinary(), PickleType(), True), - (PickleType(), LargeBinary(), True), - (PickleType(), PickleType(), True), - ]: - eq_(t1._compare_type_affinity(t2), comp, "%s %s" % (t1, t2)) + @testing.combinations( + (String(), String), + (VARCHAR(), String), + (Date(), Date), + (LargeBinary(), types._Binary), + id_="rn", + ) + def test_type_affinity(self, type_, affin): + eq_(type_._type_affinity, affin) + + @testing.combinations( + (Integer(), SmallInteger(), True), + (Integer(), String(), False), + (Integer(), Integer(), True), + (Text(), String(), True), + (Text(), Unicode(), True), + (LargeBinary(), Integer(), False), + (LargeBinary(), PickleType(), True), + (PickleType(), LargeBinary(), True), + (PickleType(), PickleType(), True), + id_="rra", + ) + def test_compare_type_affinity(self, t1, t2, comp): + eq_(t1._compare_type_affinity(t2), comp, "%s %s" % (t1, t2)) def test_decorator_doesnt_cache(self): from sqlalchemy.dialects import postgresql @@ -340,30 +357,32 @@ class TypeAffinityTest(fixtures.TestBase): class PickleTypesTest(fixtures.TestBase): - def test_pickle_types(self): + @testing.combinations( + ("Boo", Boolean()), + ("Str", String()), + ("Tex", Text()), + ("Uni", Unicode()), + ("Int", Integer()), + ("Sma", SmallInteger()), + ("Big", BigInteger()), + ("Num", Numeric()), + ("Flo", Float()), + ("Dat", DateTime()), + ("Dat", Date()), + ("Tim", Time()), + ("Lar", LargeBinary()), + ("Pic", PickleType()), + ("Int", Interval()), + id_="ar", + ) + def test_pickle_types(self, name, type_): + column_type = Column(name, type_) + meta = MetaData() + Table("foo", meta, column_type) + for loads, dumps in picklers(): - column_types = [ - Column("Boo", Boolean()), - Column("Str", String()), - Column("Tex", Text()), - Column("Uni", Unicode()), - Column("Int", Integer()), - Column("Sma", SmallInteger()), - Column("Big", BigInteger()), - Column("Num", Numeric()), - Column("Flo", Float()), - Column("Dat", DateTime()), - Column("Dat", Date()), - Column("Tim", Time()), - Column("Lar", LargeBinary()), - Column("Pic", PickleType()), - Column("Int", Interval()), - ] - for column_type in column_types: - meta = MetaData() - Table("foo", meta, column_type) - loads(dumps(column_type)) - loads(dumps(meta)) + loads(dumps(column_type)) + loads(dumps(meta)) class _UserDefinedTypeFixture(object): @@ -2414,19 +2433,19 @@ class ExpressionTest( expr = column("foo", CHAR) == "asdf" eq_(expr.right.type.__class__, CHAR) - def test_actual_literal_adapters(self): - for data, expected in [ - (5, Integer), - (2.65, Float), - (True, Boolean), - (decimal.Decimal("2.65"), Numeric), - (datetime.date(2015, 7, 20), Date), - (datetime.time(10, 15, 20), Time), - (datetime.datetime(2015, 7, 20, 10, 15, 20), DateTime), - (datetime.timedelta(seconds=5), Interval), - (None, types.NullType), - ]: - is_(literal(data).type.__class__, expected) + @testing.combinations( + (5, Integer), + (2.65, Float), + (True, Boolean), + (decimal.Decimal("2.65"), Numeric), + (datetime.date(2015, 7, 20), Date), + (datetime.time(10, 15, 20), Time), + (datetime.datetime(2015, 7, 20, 10, 15, 20), DateTime), + (datetime.timedelta(seconds=5), Interval), + (None, types.NullType), + ) + def test_actual_literal_adapters(self, data, expected): + is_(literal(data).type.__class__, expected) def test_typedec_operator_adapt(self): expr = test_table.c.bvalue + "hi" @@ -2592,18 +2611,22 @@ class ExpressionTest( expr = column("bar", types.Interval) * column("foo", types.Numeric) eq_(expr.type._type_affinity, types.Interval) - def test_numerics_coercion(self): - - for op in (operator.add, operator.mul, operator.truediv, operator.sub): - for other in (Numeric(10, 2), Integer): - expr = op( - column("bar", types.Numeric(10, 2)), column("foo", other) - ) - assert isinstance(expr.type, types.Numeric) - expr = op( - column("foo", other), column("bar", types.Numeric(10, 2)) - ) - assert isinstance(expr.type, types.Numeric) + @testing.combinations( + (operator.add,), + (operator.mul,), + (operator.truediv,), + (operator.sub,), + argnames="op", + id_="n", + ) + @testing.combinations( + (Numeric(10, 2),), (Integer(),), argnames="other", id_="r" + ) + def test_numerics_coercion(self, op, other): + expr = op(column("bar", types.Numeric(10, 2)), column("foo", other)) + assert isinstance(expr.type, types.Numeric) + expr = op(column("foo", other), column("bar", types.Numeric(10, 2))) + assert isinstance(expr.type, types.Numeric) def test_asdecimal_int_to_numeric(self): expr = column("a", Integer) * column("b", Numeric(asdecimal=False))