From: Mike Bayer Date: Thu, 5 Dec 2019 00:18:57 +0000 (-0500) Subject: Introduce lambda combinations X-Git-Tag: rel_1_3_12~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=11ed5b100f158d9022abf7cd340c01fa80fd021b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Introduce lambda combinations As the ORM's combinatoric tests mostly use entities and table metadata that's defined in fixtures, we can't use @testing.combinations directly as it takes place at the module level. Instead we use lambdas, but to reduce verbosity we use a code replacement so that the namespace of the lambda can be provided at runtime rather than module import time. Change-Id: Ia63a510f9c1d08b055eef62cf047f1f427f0450c (cherry picked from commit 1ab483ac5481cb60e898f0bfdad54e5ca45bbb80) --- diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index ed99b1eb2b..9053af0a31 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -55,6 +55,7 @@ from .util import flag_combinations # noqa from .util import force_drop_names # noqa from .util import metadata_fixture # noqa from .util import provide_metadata # noqa +from .util import resolve_lambda # noqa from .util import rowset # noqa from .util import run_as_contextmanager # noqa from .util import teardown_events # noqa diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index dc83f1f51a..015fee22ba 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -365,6 +365,7 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): for idx, char in enumerate(id_) if char in _combination_id_fns ] + arg_sets = [ pytest.param( *_arg_getter(_filter_exclusions(arg))[1:], @@ -372,14 +373,21 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): comb_fn(getter(arg)) for getter, comb_fn in fns ) ) - for arg in arg_sets + for arg in [ + (arg,) if not isinstance(arg, tuple) else arg + for arg in arg_sets + ] ] else: # ensure using pytest.param so that even a 1-arg paramset # still needs to be a tuple. otherwise paramtrize tries to # interpret a single arg differently than tuple arg arg_sets = [ - pytest.param(*_filter_exclusions(arg)) for arg in arg_sets + pytest.param(*_filter_exclusions(arg)) + for arg in [ + (arg,) if not isinstance(arg, tuple) else arg + for arg in arg_sets + ] ] def decorate(fn): diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 87c461fd2b..dbe6a383de 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -261,6 +261,21 @@ def flag_combinations(*combinations): ) +def resolve_lambda(__fn, **kw): + """Given a no-arg lambda and a namespace, return a new lambda that + has all the values filled in. + + This is used so that we can have module-level fixtures that + refer to instance-level variables using lambdas. + + """ + + glb = dict(__fn.__globals__) + glb.update(kw) + new_fn = types.FunctionType(__fn.__code__, glb) + return new_fn() + + def metadata_fixture(ddl="function"): """Provide MetaData for a pytest fixture.""" diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 4e32d9c60f..2791251bdd 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -1,4 +1,5 @@ import contextlib +import functools import sqlalchemy as sa from sqlalchemy import and_ @@ -139,7 +140,169 @@ class RowTupleTest(QueryTest): assert row.id == 7 assert row.uname == "jack" - def test_column_metadata(self): + @testing.combinations( + lambda: ( + sess.query(User), + [ + { + "name": "User", + "type": User, + "aliased": False, + "expr": User, + "entity": User, + } + ], + ), + lambda: ( + sess.query(User.id, User), + [ + { + "name": "id", + "type": users.c.id.type, + "aliased": False, + "expr": User.id, + "entity": User, + }, + { + "name": "User", + "type": User, + "aliased": False, + "expr": User, + "entity": User, + }, + ], + ), + lambda: ( + sess.query(User.id, user_alias), + [ + { + "name": "id", + "type": users.c.id.type, + "aliased": False, + "expr": User.id, + "entity": User, + }, + { + "name": None, + "type": User, + "aliased": True, + "expr": user_alias, + "entity": user_alias, + }, + ], + ), + lambda: ( + sess.query(user_alias.id), + [ + { + "name": "id", + "type": users.c.id.type, + "aliased": True, + "expr": user_alias.id, + "entity": user_alias, + } + ], + ), + lambda: ( + sess.query(user_alias_id_label), + [ + { + "name": "foo", + "type": users.c.id.type, + "aliased": True, + "expr": user_alias_id_label, + "entity": user_alias, + } + ], + ), + lambda: ( + sess.query(address_alias), + [ + { + "name": "aalias", + "type": Address, + "aliased": True, + "expr": address_alias, + "entity": address_alias, + } + ], + ), + lambda: ( + sess.query(name_label, fn), + [ + { + "name": "uname", + "type": users.c.name.type, + "aliased": False, + "expr": name_label, + "entity": User, + }, + { + "name": None, + "type": fn.type, + "aliased": False, + "expr": fn, + "entity": User, + }, + ], + ), + lambda: ( + sess.query(cte), + [ + { + "aliased": False, + "expr": cte.c.id, + "type": cte.c.id.type, + "name": "id", + "entity": None, + } + ], + ), + lambda: ( + sess.query(users), + [ + { + "aliased": False, + "expr": users.c.id, + "type": users.c.id.type, + "name": "id", + "entity": None, + }, + { + "aliased": False, + "expr": users.c.name, + "type": users.c.name.type, + "name": "name", + "entity": None, + }, + ], + ), + lambda: ( + sess.query(users.c.name), + [ + { + "name": "name", + "type": users.c.name.type, + "aliased": False, + "expr": users.c.name, + "entity": None, + } + ], + ), + lambda: ( + sess.query(bundle), + [ + { + "aliased": False, + "expr": bundle, + "type": Bundle, + "name": "b1", + "entity": User, + } + ], + ), + ) + def test_column_metadata(self, test_case): users, Address, addresses, User = ( self.tables.users, self.classes.Address, @@ -157,169 +320,10 @@ class RowTupleTest(QueryTest): name_label = User.name.label("uname") bundle = Bundle("b1", User.id, User.name) cte = sess.query(User.id).cte() - for q, asserted in [ - ( - sess.query(User), - [ - { - "name": "User", - "type": User, - "aliased": False, - "expr": User, - "entity": User, - } - ], - ), - ( - sess.query(User.id, User), - [ - { - "name": "id", - "type": users.c.id.type, - "aliased": False, - "expr": User.id, - "entity": User, - }, - { - "name": "User", - "type": User, - "aliased": False, - "expr": User, - "entity": User, - }, - ], - ), - ( - sess.query(User.id, user_alias), - [ - { - "name": "id", - "type": users.c.id.type, - "aliased": False, - "expr": User.id, - "entity": User, - }, - { - "name": None, - "type": User, - "aliased": True, - "expr": user_alias, - "entity": user_alias, - }, - ], - ), - ( - sess.query(user_alias.id), - [ - { - "name": "id", - "type": users.c.id.type, - "aliased": True, - "expr": user_alias.id, - "entity": user_alias, - } - ], - ), - ( - sess.query(user_alias_id_label), - [ - { - "name": "foo", - "type": users.c.id.type, - "aliased": True, - "expr": user_alias_id_label, - "entity": user_alias, - } - ], - ), - ( - sess.query(address_alias), - [ - { - "name": "aalias", - "type": Address, - "aliased": True, - "expr": address_alias, - "entity": address_alias, - } - ], - ), - ( - sess.query(name_label, fn), - [ - { - "name": "uname", - "type": users.c.name.type, - "aliased": False, - "expr": name_label, - "entity": User, - }, - { - "name": None, - "type": fn.type, - "aliased": False, - "expr": fn, - "entity": User, - }, - ], - ), - ( - sess.query(cte), - [ - { - "aliased": False, - "expr": cte.c.id, - "type": cte.c.id.type, - "name": "id", - "entity": None, - } - ], - ), - ( - sess.query(users), - [ - { - "aliased": False, - "expr": users.c.id, - "type": users.c.id.type, - "name": "id", - "entity": None, - }, - { - "aliased": False, - "expr": users.c.name, - "type": users.c.name.type, - "name": "name", - "entity": None, - }, - ], - ), - ( - sess.query(users.c.name), - [ - { - "name": "name", - "type": users.c.name.type, - "aliased": False, - "expr": users.c.name, - "entity": None, - } - ], - ), - ( - sess.query(bundle), - [ - { - "aliased": False, - "expr": bundle, - "type": Bundle, - "name": "b1", - "entity": User, - } - ], - ), - ]: - eq_(q.column_descriptions, asserted) + + q, asserted = testing.resolve_lambda(test_case, **locals()) + + eq_(q.column_descriptions, asserted) def test_unhashable_type(self): from sqlalchemy.types import TypeDecorator, Integer @@ -917,34 +921,34 @@ class GetTest(QueryTest): class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): - def test_no_limit_offset(self): + @testing.combinations( + lambda: s.query(User).limit(2), + lambda: s.query(User).filter(User.id == 1).offset(2), + lambda: s.query(User).limit(2).offset(2), + ) + def test_no_limit_offset(self, test_case): User = self.classes.User s = create_session() - for q in ( - s.query(User).limit(2), - s.query(User).offset(2), - s.query(User).limit(2).offset(2), - ): - assert_raises(sa_exc.InvalidRequestError, q.join, "addresses") + q = testing.resolve_lambda(test_case, User=User, s=s) - assert_raises( - sa_exc.InvalidRequestError, q.filter, User.name == "ed" - ) + assert_raises(sa_exc.InvalidRequestError, q.join, "addresses") - assert_raises(sa_exc.InvalidRequestError, q.filter_by, name="ed") + assert_raises(sa_exc.InvalidRequestError, q.filter, User.name == "ed") - assert_raises(sa_exc.InvalidRequestError, q.order_by, "foo") + assert_raises(sa_exc.InvalidRequestError, q.filter_by, name="ed") - assert_raises(sa_exc.InvalidRequestError, q.group_by, "foo") + assert_raises(sa_exc.InvalidRequestError, q.order_by, "foo") - assert_raises(sa_exc.InvalidRequestError, q.having, "foo") + assert_raises(sa_exc.InvalidRequestError, q.group_by, "foo") - q.enable_assertions(False).join("addresses") - q.enable_assertions(False).filter(User.name == "ed") - q.enable_assertions(False).order_by("foo") - q.enable_assertions(False).group_by("foo") + assert_raises(sa_exc.InvalidRequestError, q.having, "foo") + + q.enable_assertions(False).join("addresses") + q.enable_assertions(False).filter(User.name == "ed") + q.enable_assertions(False).order_by("foo") + q.enable_assertions(False).group_by("foo") def test_no_from(self): users, User = self.tables.users, self.classes.User @@ -1077,27 +1081,46 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): is_(q1._mapper_zero(), inspect(User)) is_(q1._entity_zero(), inspect(User)) - def test_from_statement(self): + @testing.combinations( + lambda: s.query(User).filter(User.id == 5), + lambda: s.query(User).filter_by(id=5), + lambda: s.query(User).limit(5), + lambda: s.query(User).group_by(User.name), + lambda: s.query(User).order_by(User.name), + ) + def test_from_statement(self, test_case): User = self.classes.User s = create_session() - for meth, arg, kw in [ - (Query.filter, (User.id == 5,), {}), - (Query.filter_by, (), {"id": 5}), - (Query.limit, (5,), {}), - (Query.group_by, (User.name,), {}), - (Query.order_by, (User.name,), {}), - ]: - q = s.query(User) - q = meth(q, *arg, **kw) - assert_raises( - sa_exc.InvalidRequestError, q.from_statement, text("x") - ) + q = testing.resolve_lambda(test_case, User=User, s=s) - q = s.query(User) - q = q.from_statement(text("x")) - assert_raises(sa_exc.InvalidRequestError, meth, q, *arg, **kw) + assert_raises(sa_exc.InvalidRequestError, q.from_statement, text("x")) + + @testing.combinations( + (Query.filter, lambda: meth(User.id == 5)), + (Query.filter_by, lambda: meth(id=5)), + (Query.limit, lambda: meth(5)), + (Query.group_by, lambda: meth(User.name)), + (Query.order_by, lambda: meth(User.name)), + ) + def test_from_statement_text(self, meth, test_case): + + User = self.classes.User + s = Session() + q = s.query(User) + + q = q.from_statement(text("x")) + m = functools.partial(meth, q) + + assert_raises( + sa_exc.InvalidRequestError, + testing.resolve_lambda, + test_case, + meth=m, + User=User, + s=s, + ) def test_illegal_coercions(self): User = self.classes.User @@ -1172,76 +1195,93 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): self.assert_compile(full, expected, checkparams=checkparams) - def test_arithmetic(self): - User = self.classes.User - + @testing.combinations( + (operators.add, "+"), + (operators.mul, "*"), + (operators.sub, "-"), + (operators.truediv, "/"), + (operators.div, "/"), + argnames="py_op, sql_op", + id_="ar", + ) + @testing.combinations( + (lambda: 5, lambda: User.id, ":id_1 %s users.id"), + (lambda: 5, lambda: literal(6), ":param_1 %s :param_2"), + (lambda: User.id, lambda: 5, "users.id %s :id_1"), + (lambda: User.id, lambda: literal("b"), "users.id %s :param_1"), + (lambda: User.id, lambda: User.id, "users.id %s users.id"), + (lambda: literal(5), lambda: "b", ":param_1 %s :param_2"), + (lambda: literal(5), lambda: User.id, ":param_1 %s users.id"), + (lambda: literal(5), lambda: literal(6), ":param_1 %s :param_2"), + argnames="lhs, rhs, res", + id_="aar", + ) + def test_arithmetic(self, py_op, sql_op, lhs, rhs, res): + User = self.classes.User + + lhs = testing.resolve_lambda(lhs, User=User) + rhs = testing.resolve_lambda(rhs, User=User) create_session().query(User) - for (py_op, sql_op) in ( - (operators.add, "+"), - (operators.mul, "*"), - (operators.sub, "-"), - (operators.truediv, "/"), - (operators.div, "/"), - ): - for (lhs, rhs, res) in ( - (5, User.id, ":id_1 %s users.id"), - (5, literal(6), ":param_1 %s :param_2"), - (User.id, 5, "users.id %s :id_1"), - (User.id, literal("b"), "users.id %s :param_1"), - (User.id, User.id, "users.id %s users.id"), - (literal(5), "b", ":param_1 %s :param_2"), - (literal(5), User.id, ":param_1 %s users.id"), - (literal(5), literal(6), ":param_1 %s :param_2"), - ): - self._test(py_op(lhs, rhs), res % sql_op) - - def test_comparison(self): + self._test(py_op(lhs, rhs), res % sql_op) + + @testing.combinations( + (operators.lt, "<", ">"), + (operators.gt, ">", "<"), + (operators.eq, "=", "="), + (operators.ne, "!=", "!="), + (operators.le, "<=", ">="), + (operators.ge, ">=", "<="), + id_="arr", + argnames="py_op, fwd_op, rev_op", + ) + @testing.combinations( + (lambda: "a", lambda: User.id, ":id_1", "users.id"), + ( + lambda: "a", + lambda: literal("b"), + ":param_2", + ":param_1", + ), # note swap! + (lambda: User.id, lambda: "b", "users.id", ":id_1"), + (lambda: User.id, lambda: literal("b"), "users.id", ":param_1"), + (lambda: User.id, lambda: User.id, "users.id", "users.id"), + (lambda: literal("a"), lambda: "b", ":param_1", ":param_2"), + (lambda: literal("a"), lambda: User.id, ":param_1", "users.id"), + (lambda: literal("a"), lambda: literal("b"), ":param_1", ":param_2"), + (lambda: ualias.id, lambda: literal("b"), "users_1.id", ":param_1"), + (lambda: User.id, lambda: ualias.name, "users.id", "users_1.name"), + (lambda: User.name, lambda: ualias.name, "users.name", "users_1.name"), + (lambda: ualias.name, lambda: User.name, "users_1.name", "users.name"), + argnames="lhs, rhs, l_sql, r_sql", + id_="aarr", + ) + def test_comparison(self, py_op, fwd_op, rev_op, lhs, rhs, l_sql, r_sql): User = self.classes.User create_session().query(User) ualias = aliased(User) - for (py_op, fwd_op, rev_op) in ( - (operators.lt, "<", ">"), - (operators.gt, ">", "<"), - (operators.eq, "=", "="), - (operators.ne, "!=", "!="), - (operators.le, "<=", ">="), - (operators.ge, ">=", "<="), - ): - for (lhs, rhs, l_sql, r_sql) in ( - ("a", User.id, ":id_1", "users.id"), - ("a", literal("b"), ":param_2", ":param_1"), # note swap! - (User.id, "b", "users.id", ":id_1"), - (User.id, literal("b"), "users.id", ":param_1"), - (User.id, User.id, "users.id", "users.id"), - (literal("a"), "b", ":param_1", ":param_2"), - (literal("a"), User.id, ":param_1", "users.id"), - (literal("a"), literal("b"), ":param_1", ":param_2"), - (ualias.id, literal("b"), "users_1.id", ":param_1"), - (User.id, ualias.name, "users.id", "users_1.name"), - (User.name, ualias.name, "users.name", "users_1.name"), - (ualias.name, User.name, "users_1.name", "users.name"), - ): - - # the compiled clause should match either (e.g.): - # 'a' < 'b' -or- 'b' > 'a'. - compiled = str( - py_op(lhs, rhs).compile(dialect=default.DefaultDialect()) - ) - fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql) - rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql) - - self.assert_( - compiled == fwd_sql or compiled == rev_sql, - "\n'" - + compiled - + "'\n does not match\n'" - + fwd_sql - + "'\n or\n'" - + rev_sql - + "'", - ) + lhs = testing.resolve_lambda(lhs, User=User, ualias=ualias) + rhs = testing.resolve_lambda(rhs, User=User, ualias=ualias) + + # the compiled clause should match either (e.g.): + # 'a' < 'b' -or- 'b' > 'a'. + compiled = str( + py_op(lhs, rhs).compile(dialect=default.DefaultDialect()) + ) + fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql) + rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql) + + self.assert_( + compiled == fwd_sql or compiled == rev_sql, + "\n'" + + compiled + + "'\n does not match\n'" + + fwd_sql + + "'\n or\n'" + + rev_sql + + "'", + ) def test_o2m_compare_to_null(self): User = self.classes.User