]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Introduce lambda combinations
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 5 Dec 2019 00:18:57 +0000 (19:18 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 5 Dec 2019 00:32:13 +0000 (19:32 -0500)
As the ORM's combinatoric tests mostly use entities and
table metadata that's defined in fixtures, we can't use
@testing.combinations directly as it takes place at the
module level.   Instead we use lambdas, but to reduce
verbosity we use a code replacement so that the namespace
of the lambda can be provided at runtime rather than
module import time.

Change-Id: Ia63a510f9c1d08b055eef62cf047f1f427f0450c
(cherry picked from commit 1ab483ac5481cb60e898f0bfdad54e5ca45bbb80)

lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
lib/sqlalchemy/testing/util.py
test/orm/test_query.py

index ed99b1eb2bb9d6a129dd8b6c79130c6e799c8066..9053af0a3199800eb33314ec9410734ca8115c79 100644 (file)
@@ -55,6 +55,7 @@ from .util import flag_combinations  # noqa
 from .util import force_drop_names  # noqa
 from .util import metadata_fixture  # noqa
 from .util import provide_metadata  # noqa
+from .util import resolve_lambda  # noqa
 from .util import rowset  # noqa
 from .util import run_as_contextmanager  # noqa
 from .util import teardown_events  # noqa
index dc83f1f51a19e6bd61f98ff13b4c56d388417b67..015fee22ba79254493f18e1b208bb6060f0bbf73 100644 (file)
@@ -365,6 +365,7 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
                 for idx, char in enumerate(id_)
                 if char in _combination_id_fns
             ]
+
             arg_sets = [
                 pytest.param(
                     *_arg_getter(_filter_exclusions(arg))[1:],
@@ -372,14 +373,21 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
                         comb_fn(getter(arg)) for getter, comb_fn in fns
                     )
                 )
-                for arg in arg_sets
+                for arg in [
+                    (arg,) if not isinstance(arg, tuple) else arg
+                    for arg in arg_sets
+                ]
             ]
         else:
             # ensure using pytest.param so that even a 1-arg paramset
             # still needs to be a tuple.  otherwise paramtrize tries to
             # interpret a single arg differently than tuple arg
             arg_sets = [
-                pytest.param(*_filter_exclusions(arg)) for arg in arg_sets
+                pytest.param(*_filter_exclusions(arg))
+                for arg in [
+                    (arg,) if not isinstance(arg, tuple) else arg
+                    for arg in arg_sets
+                ]
             ]
 
         def decorate(fn):
index 87c461fd2b31ef2574f900786d57b5234f83a100..dbe6a383de90a5f5f8366c449d329540f41048cf 100644 (file)
@@ -261,6 +261,21 @@ def flag_combinations(*combinations):
     )
 
 
+def resolve_lambda(__fn, **kw):
+    """Given a no-arg lambda and a namespace, return a new lambda that
+    has all the values filled in.
+
+    This is used so that we can have module-level fixtures that
+    refer to instance-level variables using lambdas.
+
+    """
+
+    glb = dict(__fn.__globals__)
+    glb.update(kw)
+    new_fn = types.FunctionType(__fn.__code__, glb)
+    return new_fn()
+
+
 def metadata_fixture(ddl="function"):
     """Provide MetaData for a pytest fixture."""
 
index 4e32d9c60fc37262090b446ed75ed314435b67ec..2791251bddb3ecceeec86daa8a1e20acb58c067d 100644 (file)
@@ -1,4 +1,5 @@
 import contextlib
+import functools
 
 import sqlalchemy as sa
 from sqlalchemy import and_
@@ -139,7 +140,169 @@ class RowTupleTest(QueryTest):
         assert row.id == 7
         assert row.uname == "jack"
 
-    def test_column_metadata(self):
+    @testing.combinations(
+        lambda: (
+            sess.query(User),
+            [
+                {
+                    "name": "User",
+                    "type": User,
+                    "aliased": False,
+                    "expr": User,
+                    "entity": User,
+                }
+            ],
+        ),
+        lambda: (
+            sess.query(User.id, User),
+            [
+                {
+                    "name": "id",
+                    "type": users.c.id.type,
+                    "aliased": False,
+                    "expr": User.id,
+                    "entity": User,
+                },
+                {
+                    "name": "User",
+                    "type": User,
+                    "aliased": False,
+                    "expr": User,
+                    "entity": User,
+                },
+            ],
+        ),
+        lambda: (
+            sess.query(User.id, user_alias),
+            [
+                {
+                    "name": "id",
+                    "type": users.c.id.type,
+                    "aliased": False,
+                    "expr": User.id,
+                    "entity": User,
+                },
+                {
+                    "name": None,
+                    "type": User,
+                    "aliased": True,
+                    "expr": user_alias,
+                    "entity": user_alias,
+                },
+            ],
+        ),
+        lambda: (
+            sess.query(user_alias.id),
+            [
+                {
+                    "name": "id",
+                    "type": users.c.id.type,
+                    "aliased": True,
+                    "expr": user_alias.id,
+                    "entity": user_alias,
+                }
+            ],
+        ),
+        lambda: (
+            sess.query(user_alias_id_label),
+            [
+                {
+                    "name": "foo",
+                    "type": users.c.id.type,
+                    "aliased": True,
+                    "expr": user_alias_id_label,
+                    "entity": user_alias,
+                }
+            ],
+        ),
+        lambda: (
+            sess.query(address_alias),
+            [
+                {
+                    "name": "aalias",
+                    "type": Address,
+                    "aliased": True,
+                    "expr": address_alias,
+                    "entity": address_alias,
+                }
+            ],
+        ),
+        lambda: (
+            sess.query(name_label, fn),
+            [
+                {
+                    "name": "uname",
+                    "type": users.c.name.type,
+                    "aliased": False,
+                    "expr": name_label,
+                    "entity": User,
+                },
+                {
+                    "name": None,
+                    "type": fn.type,
+                    "aliased": False,
+                    "expr": fn,
+                    "entity": User,
+                },
+            ],
+        ),
+        lambda: (
+            sess.query(cte),
+            [
+                {
+                    "aliased": False,
+                    "expr": cte.c.id,
+                    "type": cte.c.id.type,
+                    "name": "id",
+                    "entity": None,
+                }
+            ],
+        ),
+        lambda: (
+            sess.query(users),
+            [
+                {
+                    "aliased": False,
+                    "expr": users.c.id,
+                    "type": users.c.id.type,
+                    "name": "id",
+                    "entity": None,
+                },
+                {
+                    "aliased": False,
+                    "expr": users.c.name,
+                    "type": users.c.name.type,
+                    "name": "name",
+                    "entity": None,
+                },
+            ],
+        ),
+        lambda: (
+            sess.query(users.c.name),
+            [
+                {
+                    "name": "name",
+                    "type": users.c.name.type,
+                    "aliased": False,
+                    "expr": users.c.name,
+                    "entity": None,
+                }
+            ],
+        ),
+        lambda: (
+            sess.query(bundle),
+            [
+                {
+                    "aliased": False,
+                    "expr": bundle,
+                    "type": Bundle,
+                    "name": "b1",
+                    "entity": User,
+                }
+            ],
+        ),
+    )
+    def test_column_metadata(self, test_case):
         users, Address, addresses, User = (
             self.tables.users,
             self.classes.Address,
@@ -157,169 +320,10 @@ class RowTupleTest(QueryTest):
         name_label = User.name.label("uname")
         bundle = Bundle("b1", User.id, User.name)
         cte = sess.query(User.id).cte()
-        for q, asserted in [
-            (
-                sess.query(User),
-                [
-                    {
-                        "name": "User",
-                        "type": User,
-                        "aliased": False,
-                        "expr": User,
-                        "entity": User,
-                    }
-                ],
-            ),
-            (
-                sess.query(User.id, User),
-                [
-                    {
-                        "name": "id",
-                        "type": users.c.id.type,
-                        "aliased": False,
-                        "expr": User.id,
-                        "entity": User,
-                    },
-                    {
-                        "name": "User",
-                        "type": User,
-                        "aliased": False,
-                        "expr": User,
-                        "entity": User,
-                    },
-                ],
-            ),
-            (
-                sess.query(User.id, user_alias),
-                [
-                    {
-                        "name": "id",
-                        "type": users.c.id.type,
-                        "aliased": False,
-                        "expr": User.id,
-                        "entity": User,
-                    },
-                    {
-                        "name": None,
-                        "type": User,
-                        "aliased": True,
-                        "expr": user_alias,
-                        "entity": user_alias,
-                    },
-                ],
-            ),
-            (
-                sess.query(user_alias.id),
-                [
-                    {
-                        "name": "id",
-                        "type": users.c.id.type,
-                        "aliased": True,
-                        "expr": user_alias.id,
-                        "entity": user_alias,
-                    }
-                ],
-            ),
-            (
-                sess.query(user_alias_id_label),
-                [
-                    {
-                        "name": "foo",
-                        "type": users.c.id.type,
-                        "aliased": True,
-                        "expr": user_alias_id_label,
-                        "entity": user_alias,
-                    }
-                ],
-            ),
-            (
-                sess.query(address_alias),
-                [
-                    {
-                        "name": "aalias",
-                        "type": Address,
-                        "aliased": True,
-                        "expr": address_alias,
-                        "entity": address_alias,
-                    }
-                ],
-            ),
-            (
-                sess.query(name_label, fn),
-                [
-                    {
-                        "name": "uname",
-                        "type": users.c.name.type,
-                        "aliased": False,
-                        "expr": name_label,
-                        "entity": User,
-                    },
-                    {
-                        "name": None,
-                        "type": fn.type,
-                        "aliased": False,
-                        "expr": fn,
-                        "entity": User,
-                    },
-                ],
-            ),
-            (
-                sess.query(cte),
-                [
-                    {
-                        "aliased": False,
-                        "expr": cte.c.id,
-                        "type": cte.c.id.type,
-                        "name": "id",
-                        "entity": None,
-                    }
-                ],
-            ),
-            (
-                sess.query(users),
-                [
-                    {
-                        "aliased": False,
-                        "expr": users.c.id,
-                        "type": users.c.id.type,
-                        "name": "id",
-                        "entity": None,
-                    },
-                    {
-                        "aliased": False,
-                        "expr": users.c.name,
-                        "type": users.c.name.type,
-                        "name": "name",
-                        "entity": None,
-                    },
-                ],
-            ),
-            (
-                sess.query(users.c.name),
-                [
-                    {
-                        "name": "name",
-                        "type": users.c.name.type,
-                        "aliased": False,
-                        "expr": users.c.name,
-                        "entity": None,
-                    }
-                ],
-            ),
-            (
-                sess.query(bundle),
-                [
-                    {
-                        "aliased": False,
-                        "expr": bundle,
-                        "type": Bundle,
-                        "name": "b1",
-                        "entity": User,
-                    }
-                ],
-            ),
-        ]:
-            eq_(q.column_descriptions, asserted)
+
+        q, asserted = testing.resolve_lambda(test_case, **locals())
+
+        eq_(q.column_descriptions, asserted)
 
     def test_unhashable_type(self):
         from sqlalchemy.types import TypeDecorator, Integer
@@ -917,34 +921,34 @@ class GetTest(QueryTest):
 
 
 class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
-    def test_no_limit_offset(self):
+    @testing.combinations(
+        lambda: s.query(User).limit(2),
+        lambda: s.query(User).filter(User.id == 1).offset(2),
+        lambda: s.query(User).limit(2).offset(2),
+    )
+    def test_no_limit_offset(self, test_case):
         User = self.classes.User
 
         s = create_session()
 
-        for q in (
-            s.query(User).limit(2),
-            s.query(User).offset(2),
-            s.query(User).limit(2).offset(2),
-        ):
-            assert_raises(sa_exc.InvalidRequestError, q.join, "addresses")
+        q = testing.resolve_lambda(test_case, User=User, s=s)
 
-            assert_raises(
-                sa_exc.InvalidRequestError, q.filter, User.name == "ed"
-            )
+        assert_raises(sa_exc.InvalidRequestError, q.join, "addresses")
 
-            assert_raises(sa_exc.InvalidRequestError, q.filter_by, name="ed")
+        assert_raises(sa_exc.InvalidRequestError, q.filter, User.name == "ed")
 
-            assert_raises(sa_exc.InvalidRequestError, q.order_by, "foo")
+        assert_raises(sa_exc.InvalidRequestError, q.filter_by, name="ed")
 
-            assert_raises(sa_exc.InvalidRequestError, q.group_by, "foo")
+        assert_raises(sa_exc.InvalidRequestError, q.order_by, "foo")
 
-            assert_raises(sa_exc.InvalidRequestError, q.having, "foo")
+        assert_raises(sa_exc.InvalidRequestError, q.group_by, "foo")
 
-            q.enable_assertions(False).join("addresses")
-            q.enable_assertions(False).filter(User.name == "ed")
-            q.enable_assertions(False).order_by("foo")
-            q.enable_assertions(False).group_by("foo")
+        assert_raises(sa_exc.InvalidRequestError, q.having, "foo")
+
+        q.enable_assertions(False).join("addresses")
+        q.enable_assertions(False).filter(User.name == "ed")
+        q.enable_assertions(False).order_by("foo")
+        q.enable_assertions(False).group_by("foo")
 
     def test_no_from(self):
         users, User = self.tables.users, self.classes.User
@@ -1077,27 +1081,46 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
         is_(q1._mapper_zero(), inspect(User))
         is_(q1._entity_zero(), inspect(User))
 
-    def test_from_statement(self):
+    @testing.combinations(
+        lambda: s.query(User).filter(User.id == 5),
+        lambda: s.query(User).filter_by(id=5),
+        lambda: s.query(User).limit(5),
+        lambda: s.query(User).group_by(User.name),
+        lambda: s.query(User).order_by(User.name),
+    )
+    def test_from_statement(self, test_case):
         User = self.classes.User
 
         s = create_session()
 
-        for meth, arg, kw in [
-            (Query.filter, (User.id == 5,), {}),
-            (Query.filter_by, (), {"id": 5}),
-            (Query.limit, (5,), {}),
-            (Query.group_by, (User.name,), {}),
-            (Query.order_by, (User.name,), {}),
-        ]:
-            q = s.query(User)
-            q = meth(q, *arg, **kw)
-            assert_raises(
-                sa_exc.InvalidRequestError, q.from_statement, text("x")
-            )
+        q = testing.resolve_lambda(test_case, User=User, s=s)
 
-            q = s.query(User)
-            q = q.from_statement(text("x"))
-            assert_raises(sa_exc.InvalidRequestError, meth, q, *arg, **kw)
+        assert_raises(sa_exc.InvalidRequestError, q.from_statement, text("x"))
+
+    @testing.combinations(
+        (Query.filter, lambda: meth(User.id == 5)),
+        (Query.filter_by, lambda: meth(id=5)),
+        (Query.limit, lambda: meth(5)),
+        (Query.group_by, lambda: meth(User.name)),
+        (Query.order_by, lambda: meth(User.name)),
+    )
+    def test_from_statement_text(self, meth, test_case):
+
+        User = self.classes.User
+        s = Session()
+        q = s.query(User)
+
+        q = q.from_statement(text("x"))
+        m = functools.partial(meth, q)
+
+        assert_raises(
+            sa_exc.InvalidRequestError,
+            testing.resolve_lambda,
+            test_case,
+            meth=m,
+            User=User,
+            s=s,
+        )
 
     def test_illegal_coercions(self):
         User = self.classes.User
@@ -1172,76 +1195,93 @@ class OperatorTest(QueryTest, AssertsCompiledSQL):
 
         self.assert_compile(full, expected, checkparams=checkparams)
 
-    def test_arithmetic(self):
-        User = self.classes.User
-
+    @testing.combinations(
+        (operators.add, "+"),
+        (operators.mul, "*"),
+        (operators.sub, "-"),
+        (operators.truediv, "/"),
+        (operators.div, "/"),
+        argnames="py_op, sql_op",
+        id_="ar",
+    )
+    @testing.combinations(
+        (lambda: 5, lambda: User.id, ":id_1 %s users.id"),
+        (lambda: 5, lambda: literal(6), ":param_1 %s :param_2"),
+        (lambda: User.id, lambda: 5, "users.id %s :id_1"),
+        (lambda: User.id, lambda: literal("b"), "users.id %s :param_1"),
+        (lambda: User.id, lambda: User.id, "users.id %s users.id"),
+        (lambda: literal(5), lambda: "b", ":param_1 %s :param_2"),
+        (lambda: literal(5), lambda: User.id, ":param_1 %s users.id"),
+        (lambda: literal(5), lambda: literal(6), ":param_1 %s :param_2"),
+        argnames="lhs, rhs, res",
+        id_="aar",
+    )
+    def test_arithmetic(self, py_op, sql_op, lhs, rhs, res):
+        User = self.classes.User
+
+        lhs = testing.resolve_lambda(lhs, User=User)
+        rhs = testing.resolve_lambda(rhs, User=User)
         create_session().query(User)
-        for (py_op, sql_op) in (
-            (operators.add, "+"),
-            (operators.mul, "*"),
-            (operators.sub, "-"),
-            (operators.truediv, "/"),
-            (operators.div, "/"),
-        ):
-            for (lhs, rhs, res) in (
-                (5, User.id, ":id_1 %s users.id"),
-                (5, literal(6), ":param_1 %s :param_2"),
-                (User.id, 5, "users.id %s :id_1"),
-                (User.id, literal("b"), "users.id %s :param_1"),
-                (User.id, User.id, "users.id %s users.id"),
-                (literal(5), "b", ":param_1 %s :param_2"),
-                (literal(5), User.id, ":param_1 %s users.id"),
-                (literal(5), literal(6), ":param_1 %s :param_2"),
-            ):
-                self._test(py_op(lhs, rhs), res % sql_op)
-
-    def test_comparison(self):
+        self._test(py_op(lhs, rhs), res % sql_op)
+
+    @testing.combinations(
+        (operators.lt, "<", ">"),
+        (operators.gt, ">", "<"),
+        (operators.eq, "=", "="),
+        (operators.ne, "!=", "!="),
+        (operators.le, "<=", ">="),
+        (operators.ge, ">=", "<="),
+        id_="arr",
+        argnames="py_op, fwd_op, rev_op",
+    )
+    @testing.combinations(
+        (lambda: "a", lambda: User.id, ":id_1", "users.id"),
+        (
+            lambda: "a",
+            lambda: literal("b"),
+            ":param_2",
+            ":param_1",
+        ),  # note swap!
+        (lambda: User.id, lambda: "b", "users.id", ":id_1"),
+        (lambda: User.id, lambda: literal("b"), "users.id", ":param_1"),
+        (lambda: User.id, lambda: User.id, "users.id", "users.id"),
+        (lambda: literal("a"), lambda: "b", ":param_1", ":param_2"),
+        (lambda: literal("a"), lambda: User.id, ":param_1", "users.id"),
+        (lambda: literal("a"), lambda: literal("b"), ":param_1", ":param_2"),
+        (lambda: ualias.id, lambda: literal("b"), "users_1.id", ":param_1"),
+        (lambda: User.id, lambda: ualias.name, "users.id", "users_1.name"),
+        (lambda: User.name, lambda: ualias.name, "users.name", "users_1.name"),
+        (lambda: ualias.name, lambda: User.name, "users_1.name", "users.name"),
+        argnames="lhs, rhs, l_sql, r_sql",
+        id_="aarr",
+    )
+    def test_comparison(self, py_op, fwd_op, rev_op, lhs, rhs, l_sql, r_sql):
         User = self.classes.User
 
         create_session().query(User)
         ualias = aliased(User)
 
-        for (py_op, fwd_op, rev_op) in (
-            (operators.lt, "<", ">"),
-            (operators.gt, ">", "<"),
-            (operators.eq, "=", "="),
-            (operators.ne, "!=", "!="),
-            (operators.le, "<=", ">="),
-            (operators.ge, ">=", "<="),
-        ):
-            for (lhs, rhs, l_sql, r_sql) in (
-                ("a", User.id, ":id_1", "users.id"),
-                ("a", literal("b"), ":param_2", ":param_1"),  # note swap!
-                (User.id, "b", "users.id", ":id_1"),
-                (User.id, literal("b"), "users.id", ":param_1"),
-                (User.id, User.id, "users.id", "users.id"),
-                (literal("a"), "b", ":param_1", ":param_2"),
-                (literal("a"), User.id, ":param_1", "users.id"),
-                (literal("a"), literal("b"), ":param_1", ":param_2"),
-                (ualias.id, literal("b"), "users_1.id", ":param_1"),
-                (User.id, ualias.name, "users.id", "users_1.name"),
-                (User.name, ualias.name, "users.name", "users_1.name"),
-                (ualias.name, User.name, "users_1.name", "users.name"),
-            ):
-
-                # the compiled clause should match either (e.g.):
-                # 'a' < 'b' -or- 'b' > 'a'.
-                compiled = str(
-                    py_op(lhs, rhs).compile(dialect=default.DefaultDialect())
-                )
-                fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql)
-                rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql)
-
-                self.assert_(
-                    compiled == fwd_sql or compiled == rev_sql,
-                    "\n'"
-                    + compiled
-                    + "'\n does not match\n'"
-                    + fwd_sql
-                    + "'\n or\n'"
-                    + rev_sql
-                    + "'",
-                )
+        lhs = testing.resolve_lambda(lhs, User=User, ualias=ualias)
+        rhs = testing.resolve_lambda(rhs, User=User, ualias=ualias)
+
+        # the compiled clause should match either (e.g.):
+        # 'a' < 'b' -or- 'b' > 'a'.
+        compiled = str(
+            py_op(lhs, rhs).compile(dialect=default.DefaultDialect())
+        )
+        fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql)
+        rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql)
+
+        self.assert_(
+            compiled == fwd_sql or compiled == rev_sql,
+            "\n'"
+            + compiled
+            + "'\n does not match\n'"
+            + fwd_sql
+            + "'\n or\n'"
+            + rev_sql
+            + "'",
+        )
 
     def test_o2m_compare_to_null(self):
         User = self.classes.User