import contextlib
+import functools
import sqlalchemy as sa
from sqlalchemy import and_
assert row.id == 7
assert row.uname == "jack"
- def test_column_metadata(self):
+ @testing.combinations(
+ lambda: (
+ sess.query(User),
+ [
+ {
+ "name": "User",
+ "type": User,
+ "aliased": False,
+ "expr": User,
+ "entity": User,
+ }
+ ],
+ ),
+ lambda: (
+ sess.query(User.id, User),
+ [
+ {
+ "name": "id",
+ "type": users.c.id.type,
+ "aliased": False,
+ "expr": User.id,
+ "entity": User,
+ },
+ {
+ "name": "User",
+ "type": User,
+ "aliased": False,
+ "expr": User,
+ "entity": User,
+ },
+ ],
+ ),
+ lambda: (
+ sess.query(User.id, user_alias),
+ [
+ {
+ "name": "id",
+ "type": users.c.id.type,
+ "aliased": False,
+ "expr": User.id,
+ "entity": User,
+ },
+ {
+ "name": None,
+ "type": User,
+ "aliased": True,
+ "expr": user_alias,
+ "entity": user_alias,
+ },
+ ],
+ ),
+ lambda: (
+ sess.query(user_alias.id),
+ [
+ {
+ "name": "id",
+ "type": users.c.id.type,
+ "aliased": True,
+ "expr": user_alias.id,
+ "entity": user_alias,
+ }
+ ],
+ ),
+ lambda: (
+ sess.query(user_alias_id_label),
+ [
+ {
+ "name": "foo",
+ "type": users.c.id.type,
+ "aliased": True,
+ "expr": user_alias_id_label,
+ "entity": user_alias,
+ }
+ ],
+ ),
+ lambda: (
+ sess.query(address_alias),
+ [
+ {
+ "name": "aalias",
+ "type": Address,
+ "aliased": True,
+ "expr": address_alias,
+ "entity": address_alias,
+ }
+ ],
+ ),
+ lambda: (
+ sess.query(name_label, fn),
+ [
+ {
+ "name": "uname",
+ "type": users.c.name.type,
+ "aliased": False,
+ "expr": name_label,
+ "entity": User,
+ },
+ {
+ "name": None,
+ "type": fn.type,
+ "aliased": False,
+ "expr": fn,
+ "entity": User,
+ },
+ ],
+ ),
+ lambda: (
+ sess.query(cte),
+ [
+ {
+ "aliased": False,
+ "expr": cte.c.id,
+ "type": cte.c.id.type,
+ "name": "id",
+ "entity": None,
+ }
+ ],
+ ),
+ lambda: (
+ sess.query(users),
+ [
+ {
+ "aliased": False,
+ "expr": users.c.id,
+ "type": users.c.id.type,
+ "name": "id",
+ "entity": None,
+ },
+ {
+ "aliased": False,
+ "expr": users.c.name,
+ "type": users.c.name.type,
+ "name": "name",
+ "entity": None,
+ },
+ ],
+ ),
+ lambda: (
+ sess.query(users.c.name),
+ [
+ {
+ "name": "name",
+ "type": users.c.name.type,
+ "aliased": False,
+ "expr": users.c.name,
+ "entity": None,
+ }
+ ],
+ ),
+ lambda: (
+ sess.query(bundle),
+ [
+ {
+ "aliased": False,
+ "expr": bundle,
+ "type": Bundle,
+ "name": "b1",
+ "entity": User,
+ }
+ ],
+ ),
+ )
+ def test_column_metadata(self, test_case):
users, Address, addresses, User = (
self.tables.users,
self.classes.Address,
name_label = User.name.label("uname")
bundle = Bundle("b1", User.id, User.name)
cte = sess.query(User.id).cte()
- for q, asserted in [
- (
- sess.query(User),
- [
- {
- "name": "User",
- "type": User,
- "aliased": False,
- "expr": User,
- "entity": User,
- }
- ],
- ),
- (
- sess.query(User.id, User),
- [
- {
- "name": "id",
- "type": users.c.id.type,
- "aliased": False,
- "expr": User.id,
- "entity": User,
- },
- {
- "name": "User",
- "type": User,
- "aliased": False,
- "expr": User,
- "entity": User,
- },
- ],
- ),
- (
- sess.query(User.id, user_alias),
- [
- {
- "name": "id",
- "type": users.c.id.type,
- "aliased": False,
- "expr": User.id,
- "entity": User,
- },
- {
- "name": None,
- "type": User,
- "aliased": True,
- "expr": user_alias,
- "entity": user_alias,
- },
- ],
- ),
- (
- sess.query(user_alias.id),
- [
- {
- "name": "id",
- "type": users.c.id.type,
- "aliased": True,
- "expr": user_alias.id,
- "entity": user_alias,
- }
- ],
- ),
- (
- sess.query(user_alias_id_label),
- [
- {
- "name": "foo",
- "type": users.c.id.type,
- "aliased": True,
- "expr": user_alias_id_label,
- "entity": user_alias,
- }
- ],
- ),
- (
- sess.query(address_alias),
- [
- {
- "name": "aalias",
- "type": Address,
- "aliased": True,
- "expr": address_alias,
- "entity": address_alias,
- }
- ],
- ),
- (
- sess.query(name_label, fn),
- [
- {
- "name": "uname",
- "type": users.c.name.type,
- "aliased": False,
- "expr": name_label,
- "entity": User,
- },
- {
- "name": None,
- "type": fn.type,
- "aliased": False,
- "expr": fn,
- "entity": User,
- },
- ],
- ),
- (
- sess.query(cte),
- [
- {
- "aliased": False,
- "expr": cte.c.id,
- "type": cte.c.id.type,
- "name": "id",
- "entity": None,
- }
- ],
- ),
- (
- sess.query(users),
- [
- {
- "aliased": False,
- "expr": users.c.id,
- "type": users.c.id.type,
- "name": "id",
- "entity": None,
- },
- {
- "aliased": False,
- "expr": users.c.name,
- "type": users.c.name.type,
- "name": "name",
- "entity": None,
- },
- ],
- ),
- (
- sess.query(users.c.name),
- [
- {
- "name": "name",
- "type": users.c.name.type,
- "aliased": False,
- "expr": users.c.name,
- "entity": None,
- }
- ],
- ),
- (
- sess.query(bundle),
- [
- {
- "aliased": False,
- "expr": bundle,
- "type": Bundle,
- "name": "b1",
- "entity": User,
- }
- ],
- ),
- ]:
- eq_(q.column_descriptions, asserted)
+
+ q, asserted = testing.resolve_lambda(test_case, **locals())
+
+ eq_(q.column_descriptions, asserted)
def test_unhashable_type(self):
from sqlalchemy.types import TypeDecorator, Integer
class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
- def test_no_limit_offset(self):
+ @testing.combinations(
+ lambda: s.query(User).limit(2),
+ lambda: s.query(User).filter(User.id == 1).offset(2),
+ lambda: s.query(User).limit(2).offset(2),
+ )
+ def test_no_limit_offset(self, test_case):
User = self.classes.User
s = create_session()
- for q in (
- s.query(User).limit(2),
- s.query(User).offset(2),
- s.query(User).limit(2).offset(2),
- ):
- assert_raises(sa_exc.InvalidRequestError, q.join, "addresses")
+ q = testing.resolve_lambda(test_case, User=User, s=s)
- assert_raises(
- sa_exc.InvalidRequestError, q.filter, User.name == "ed"
- )
+ assert_raises(sa_exc.InvalidRequestError, q.join, "addresses")
- assert_raises(sa_exc.InvalidRequestError, q.filter_by, name="ed")
+ assert_raises(sa_exc.InvalidRequestError, q.filter, User.name == "ed")
- assert_raises(sa_exc.InvalidRequestError, q.order_by, "foo")
+ assert_raises(sa_exc.InvalidRequestError, q.filter_by, name="ed")
- assert_raises(sa_exc.InvalidRequestError, q.group_by, "foo")
+ assert_raises(sa_exc.InvalidRequestError, q.order_by, "foo")
- assert_raises(sa_exc.InvalidRequestError, q.having, "foo")
+ assert_raises(sa_exc.InvalidRequestError, q.group_by, "foo")
- q.enable_assertions(False).join("addresses")
- q.enable_assertions(False).filter(User.name == "ed")
- q.enable_assertions(False).order_by("foo")
- q.enable_assertions(False).group_by("foo")
+ assert_raises(sa_exc.InvalidRequestError, q.having, "foo")
+
+ q.enable_assertions(False).join("addresses")
+ q.enable_assertions(False).filter(User.name == "ed")
+ q.enable_assertions(False).order_by("foo")
+ q.enable_assertions(False).group_by("foo")
def test_no_from(self):
users, User = self.tables.users, self.classes.User
is_(q1._mapper_zero(), inspect(User))
is_(q1._entity_zero(), inspect(User))
- def test_from_statement(self):
+ @testing.combinations(
+ lambda: s.query(User).filter(User.id == 5),
+ lambda: s.query(User).filter_by(id=5),
+ lambda: s.query(User).limit(5),
+ lambda: s.query(User).group_by(User.name),
+ lambda: s.query(User).order_by(User.name),
+ )
+ def test_from_statement(self, test_case):
User = self.classes.User
s = create_session()
- for meth, arg, kw in [
- (Query.filter, (User.id == 5,), {}),
- (Query.filter_by, (), {"id": 5}),
- (Query.limit, (5,), {}),
- (Query.group_by, (User.name,), {}),
- (Query.order_by, (User.name,), {}),
- ]:
- q = s.query(User)
- q = meth(q, *arg, **kw)
- assert_raises(
- sa_exc.InvalidRequestError, q.from_statement, text("x")
- )
+ q = testing.resolve_lambda(test_case, User=User, s=s)
- q = s.query(User)
- q = q.from_statement(text("x"))
- assert_raises(sa_exc.InvalidRequestError, meth, q, *arg, **kw)
+ assert_raises(sa_exc.InvalidRequestError, q.from_statement, text("x"))
+
+ @testing.combinations(
+ (Query.filter, lambda: meth(User.id == 5)),
+ (Query.filter_by, lambda: meth(id=5)),
+ (Query.limit, lambda: meth(5)),
+ (Query.group_by, lambda: meth(User.name)),
+ (Query.order_by, lambda: meth(User.name)),
+ )
+ def test_from_statement_text(self, meth, test_case):
+
+ User = self.classes.User
+ s = Session()
+ q = s.query(User)
+
+ q = q.from_statement(text("x"))
+ m = functools.partial(meth, q)
+
+ assert_raises(
+ sa_exc.InvalidRequestError,
+ testing.resolve_lambda,
+ test_case,
+ meth=m,
+ User=User,
+ s=s,
+ )
def test_illegal_coercions(self):
User = self.classes.User
self.assert_compile(full, expected, checkparams=checkparams)
- def test_arithmetic(self):
- User = self.classes.User
-
+ @testing.combinations(
+ (operators.add, "+"),
+ (operators.mul, "*"),
+ (operators.sub, "-"),
+ (operators.truediv, "/"),
+ (operators.div, "/"),
+ argnames="py_op, sql_op",
+ id_="ar",
+ )
+ @testing.combinations(
+ (lambda: 5, lambda: User.id, ":id_1 %s users.id"),
+ (lambda: 5, lambda: literal(6), ":param_1 %s :param_2"),
+ (lambda: User.id, lambda: 5, "users.id %s :id_1"),
+ (lambda: User.id, lambda: literal("b"), "users.id %s :param_1"),
+ (lambda: User.id, lambda: User.id, "users.id %s users.id"),
+ (lambda: literal(5), lambda: "b", ":param_1 %s :param_2"),
+ (lambda: literal(5), lambda: User.id, ":param_1 %s users.id"),
+ (lambda: literal(5), lambda: literal(6), ":param_1 %s :param_2"),
+ argnames="lhs, rhs, res",
+ id_="aar",
+ )
+ def test_arithmetic(self, py_op, sql_op, lhs, rhs, res):
+ User = self.classes.User
+
+ lhs = testing.resolve_lambda(lhs, User=User)
+ rhs = testing.resolve_lambda(rhs, User=User)
create_session().query(User)
- for (py_op, sql_op) in (
- (operators.add, "+"),
- (operators.mul, "*"),
- (operators.sub, "-"),
- (operators.truediv, "/"),
- (operators.div, "/"),
- ):
- for (lhs, rhs, res) in (
- (5, User.id, ":id_1 %s users.id"),
- (5, literal(6), ":param_1 %s :param_2"),
- (User.id, 5, "users.id %s :id_1"),
- (User.id, literal("b"), "users.id %s :param_1"),
- (User.id, User.id, "users.id %s users.id"),
- (literal(5), "b", ":param_1 %s :param_2"),
- (literal(5), User.id, ":param_1 %s users.id"),
- (literal(5), literal(6), ":param_1 %s :param_2"),
- ):
- self._test(py_op(lhs, rhs), res % sql_op)
-
- def test_comparison(self):
+ self._test(py_op(lhs, rhs), res % sql_op)
+
+ @testing.combinations(
+ (operators.lt, "<", ">"),
+ (operators.gt, ">", "<"),
+ (operators.eq, "=", "="),
+ (operators.ne, "!=", "!="),
+ (operators.le, "<=", ">="),
+ (operators.ge, ">=", "<="),
+ id_="arr",
+ argnames="py_op, fwd_op, rev_op",
+ )
+ @testing.combinations(
+ (lambda: "a", lambda: User.id, ":id_1", "users.id"),
+ (
+ lambda: "a",
+ lambda: literal("b"),
+ ":param_2",
+ ":param_1",
+ ), # note swap!
+ (lambda: User.id, lambda: "b", "users.id", ":id_1"),
+ (lambda: User.id, lambda: literal("b"), "users.id", ":param_1"),
+ (lambda: User.id, lambda: User.id, "users.id", "users.id"),
+ (lambda: literal("a"), lambda: "b", ":param_1", ":param_2"),
+ (lambda: literal("a"), lambda: User.id, ":param_1", "users.id"),
+ (lambda: literal("a"), lambda: literal("b"), ":param_1", ":param_2"),
+ (lambda: ualias.id, lambda: literal("b"), "users_1.id", ":param_1"),
+ (lambda: User.id, lambda: ualias.name, "users.id", "users_1.name"),
+ (lambda: User.name, lambda: ualias.name, "users.name", "users_1.name"),
+ (lambda: ualias.name, lambda: User.name, "users_1.name", "users.name"),
+ argnames="lhs, rhs, l_sql, r_sql",
+ id_="aarr",
+ )
+ def test_comparison(self, py_op, fwd_op, rev_op, lhs, rhs, l_sql, r_sql):
User = self.classes.User
create_session().query(User)
ualias = aliased(User)
- for (py_op, fwd_op, rev_op) in (
- (operators.lt, "<", ">"),
- (operators.gt, ">", "<"),
- (operators.eq, "=", "="),
- (operators.ne, "!=", "!="),
- (operators.le, "<=", ">="),
- (operators.ge, ">=", "<="),
- ):
- for (lhs, rhs, l_sql, r_sql) in (
- ("a", User.id, ":id_1", "users.id"),
- ("a", literal("b"), ":param_2", ":param_1"), # note swap!
- (User.id, "b", "users.id", ":id_1"),
- (User.id, literal("b"), "users.id", ":param_1"),
- (User.id, User.id, "users.id", "users.id"),
- (literal("a"), "b", ":param_1", ":param_2"),
- (literal("a"), User.id, ":param_1", "users.id"),
- (literal("a"), literal("b"), ":param_1", ":param_2"),
- (ualias.id, literal("b"), "users_1.id", ":param_1"),
- (User.id, ualias.name, "users.id", "users_1.name"),
- (User.name, ualias.name, "users.name", "users_1.name"),
- (ualias.name, User.name, "users_1.name", "users.name"),
- ):
-
- # the compiled clause should match either (e.g.):
- # 'a' < 'b' -or- 'b' > 'a'.
- compiled = str(
- py_op(lhs, rhs).compile(dialect=default.DefaultDialect())
- )
- fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql)
- rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql)
-
- self.assert_(
- compiled == fwd_sql or compiled == rev_sql,
- "\n'"
- + compiled
- + "'\n does not match\n'"
- + fwd_sql
- + "'\n or\n'"
- + rev_sql
- + "'",
- )
+ lhs = testing.resolve_lambda(lhs, User=User, ualias=ualias)
+ rhs = testing.resolve_lambda(rhs, User=User, ualias=ualias)
+
+ # the compiled clause should match either (e.g.):
+ # 'a' < 'b' -or- 'b' > 'a'.
+ compiled = str(
+ py_op(lhs, rhs).compile(dialect=default.DefaultDialect())
+ )
+ fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql)
+ rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql)
+
+ self.assert_(
+ compiled == fwd_sql or compiled == rev_sql,
+ "\n'"
+ + compiled
+ + "'\n does not match\n'"
+ + fwd_sql
+ + "'\n or\n'"
+ + rev_sql
+ + "'",
+ )
def test_o2m_compare_to_null(self):
User = self.classes.User