]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement ScalarValue
authorFederico Caselli <cfederico87@gmail.com>
Tue, 8 Nov 2022 21:12:47 +0000 (22:12 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Nov 2022 23:49:06 +0000 (18:49 -0500)
Added :class:`_expression.ScalarValues` that can be used as a column
element allowing using :class:`_expression.Values` inside IN clauses
or in conjunction with ``ANY`` or ``ALL`` collection aggregates.
This new class is generated using the method
:meth:`_expression.Values.scalar_values`.
The :class:`_expression.Values` instance is now coerced to a
:class:`_expression.ScalarValues` when used in a ``IN`` or ``NOT IN``
operation.

Fixes: #6289
Change-Id: Iac22487ccb01553684b908e54d01c0687fa739f1

doc/build/changelog/unreleased_20/6289.rst [new file with mode: 0644]
doc/build/core/selectable.rst
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/sql/test_compare.py
test/sql/test_compiler.py
test/sql/test_operators.py
test/sql/test_roles.py

diff --git a/doc/build/changelog/unreleased_20/6289.rst b/doc/build/changelog/unreleased_20/6289.rst
new file mode 100644 (file)
index 0000000..f7789a2
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: sql, usecase
+    :tickets: 6289
+
+    Added :class:`_expression.ScalarValues` that can be used as a column
+    element allowing using :class:`_expression.Values` inside IN clauses
+    or in conjunction with ``ANY`` or ``ALL`` collection aggregates.
+    This new class is generated using the method
+    :meth:`_expression.Values.scalar_values`.
+    The :class:`_expression.Values` instance is now coerced to a
+    :class:`_expression.ScalarValues` when used in a ``IN`` or ``NOT IN``
+    operation.
index 7537df7d56d034cacc030649f9df38cc3841f015..e81c88cc49432412b1a22db08bf2a2f53cea241b 100644 (file)
@@ -155,6 +155,9 @@ The classes here are generated using the constructors listed at
 .. autoclass:: Values
    :members:
 
+.. autoclass:: ScalarValues
+   :members:
+
 Label Style Constants
 ---------------------
 
index f48a3ccb007dd13ad54a98386775a783bbbaf1b4..9c3e7480ae950f91ec4b57cb11af205b57355268 100644 (file)
@@ -41,6 +41,7 @@ from .. import util
 from ..util.typing import Literal
 
 if typing.TYPE_CHECKING:
+    # elements lambdas schema selectable are set by __init__
     from . import elements
     from . import lambdas
     from . import schema
@@ -354,11 +355,7 @@ def expect(
 
     if not isinstance(
         element,
-        (
-            elements.CompilerElement,
-            schema.SchemaItem,
-            schema.FetchedValue,
-        ),
+        (elements.CompilerElement, schema.SchemaItem, schema.FetchedValue),
     ):
         resolved = None
 
@@ -773,10 +770,15 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl):
                 self._raise_for_expected(element, err=err)
 
     def _raise_for_expected(self, element, argname=None, resolved=None, **kw):
-        if isinstance(element, roles.AnonymizedFromClauseRole):
+        # select uses implicit coercion with warning instead of raising
+        if isinstance(element, selectable.Values):
             advice = (
-                "To create a "
-                "column expression from a FROM clause row "
+                "To create a column expression from a VALUES clause, "
+                "use the .scalar_values() method."
+            )
+        elif isinstance(element, roles.AnonymizedFromClauseRole):
+            advice = (
+                "To create a column expression from a FROM clause row "
                 "as a whole, use the .table_valued() method."
             )
         else:
@@ -886,6 +888,8 @@ class InElementImpl(RoleImpl):
             element.expand_op = operator
 
             return element
+        elif isinstance(element, selectable.Values):
+            return element.scalar_values()
         else:
             return element
 
index 17aafddadb6468ee1c34202526381cc63d1905fe..9e4422fbd07e61235dceed7e422ceee7912a2b54 100644 (file)
@@ -3612,9 +3612,9 @@ class SQLCompiler(Compiled):
 
         return text
 
-    def visit_values(self, element, asfrom=False, from_linter=None, **kw):
+    def _render_values(self, element, **kw):
         kw.setdefault("literal_binds", element.literal_binds)
-        v = "VALUES %s" % ", ".join(
+        tuples = ", ".join(
             self.process(
                 elements.Tuple(
                     types=element._column_types, *elem
@@ -3624,6 +3624,10 @@ class SQLCompiler(Compiled):
             for chunk in element._data
             for elem in chunk
         )
+        return f"VALUES {tuples}"
+
+    def visit_values(self, element, asfrom=False, from_linter=None, **kw):
+        v = self._render_values(element, **kw)
 
         if element._unnamed:
             name = None
@@ -3661,6 +3665,9 @@ class SQLCompiler(Compiled):
                 v = "%s(%s)" % (lateral, v)
         return v
 
+    def visit_scalar_values(self, element, **kw):
+        return f"({self._render_values(element, **kw)})"
+
     def get_render_as_alias_suffix(self, alias_name_text):
         return " AS " + alias_name_text
 
index 914d2b32634a8d66bc9fa8bfb6d67b5da1d50a77..6a9bd74caaafd397c8e892e4b8dba844a1790911 100644 (file)
@@ -97,7 +97,6 @@ if typing.TYPE_CHECKING:
     from .selectable import _SelectIterable
     from .selectable import FromClause
     from .selectable import NamedFromClause
-    from .selectable import Select
     from .sqltypes import TupleType
     from .type_api import TypeEngine
     from .visitors import _CloneCallableType
@@ -860,13 +859,17 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
 
         def in_(
             self,
-            other: Union[Sequence[Any], BindParameter[Any], Select[Any]],
+            other: Union[
+                Sequence[Any], BindParameter[Any], roles.InElementRole
+            ],
         ) -> BinaryExpression[bool]:
             ...
 
         def not_in(
             self,
-            other: Union[Sequence[Any], BindParameter[Any], Select[Any]],
+            other: Union[
+                Sequence[Any], BindParameter[Any], roles.InElementRole
+            ],
         ) -> BinaryExpression[bool]:
             ...
 
index fcffc324fb6b17026c068e0fc63a2830d164f0c6..97336d41663d3f845886c647ed60a4b053f12cee 100644 (file)
@@ -3127,7 +3127,7 @@ class ForUpdateArg(ClauseElement):
 SelfValues = typing.TypeVar("SelfValues", bound="Values")
 
 
-class Values(Generative, LateralFromClause):
+class Values(roles.InElementRole, Generative, LateralFromClause):
     """Represent a ``VALUES`` construct that can be used as a FROM element
     in a statement.
 
@@ -3228,8 +3228,7 @@ class Values(Generative, LateralFromClause):
     @_generative
     def data(self: SelfValues, values: List[Tuple[Any, ...]]) -> SelfValues:
         """Return a new :class:`_expression.Values` construct,
-        adding the given data
-        to the data list.
+        adding the given data to the data list.
 
         E.g.::
 
@@ -3244,6 +3243,15 @@ class Values(Generative, LateralFromClause):
         self._data += (values,)
         return self
 
+    def scalar_values(self) -> ScalarValues:
+        """Returns a scalar ``VALUES`` construct that can be used as a
+        COLUMN element in a statement.
+
+        .. versionadded:: 2.0.0b4
+
+        """
+        return ScalarValues(self._column_args, self._data, self.literal_binds)
+
     def _populate_column_collection(self) -> None:
         for c in self._column_args:
             self._columns.add(c)
@@ -3254,6 +3262,46 @@ class Values(Generative, LateralFromClause):
         return [self]
 
 
+class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]):
+    """Represent a scalar ``VALUES`` construct that can be used as a
+    COLUMN element in a statement.
+
+    The :class:`_expression.ScalarValues` object is created from the
+    :meth:`_expression.Values.scalar_values` method. It's also
+    automatically generated when a :class:`_expression.Values` is used in
+    an ``IN`` or ``NOT IN`` condition.
+
+    .. versionadded:: 2.0.0b4
+
+    """
+
+    __visit_name__ = "scalar_values"
+
+    _traverse_internals: _TraverseInternalsType = [
+        ("_column_args", InternalTraversal.dp_clauseelement_list),
+        ("_data", InternalTraversal.dp_dml_multi_values),
+        ("literal_binds", InternalTraversal.dp_boolean),
+    ]
+
+    def __init__(
+        self,
+        columns: Sequence[ColumnClause[Any]],
+        data: Tuple[List[Tuple[Any, ...]], ...],
+        literal_binds: bool,
+    ):
+        super().__init__()
+        self._column_args = columns
+        self._data = data
+        self.literal_binds = literal_binds
+
+    @property
+    def _column_types(self):
+        return [col.type for col in self._column_args]
+
+    def __clause_element__(self):
+        return self
+
+
 SelfSelectBase = TypeVar("SelfSelectBase", bound=Any)
 
 
index f18c79c7b22cc14c67da4508d15367896ebfb2de..87710fdd9ed40d984ee40b1acb9700504d60ad94 100644 (file)
@@ -621,6 +621,15 @@ class CoreFixtures:
             )
             .data([(1, "textA", 99), (2, "textB", 88)])
             ._annotate({"nocache": True}),
+            values(
+                column("mykey", Integer),
+                column("mytext", String),
+                column("myint", Integer),
+                name="myvalues",
+                literal_binds=True,
+            )
+            .data([(1, "textA", 99), (2, "textB", 88)])
+            ._annotate({"nocache": True}),
             values(
                 column("mykey", Integer),
                 column("mytext", String),
@@ -647,15 +656,62 @@ class CoreFixtures:
             ._annotate({"nocache": True}),
             # TODO: difference in type
             # values(
-            #    [
-            #        column("mykey", Integer),
-            #        column("mytext", Text),
-            #        column("myint", Integer),
-            #    ],
-            #    (1, "textA", 99),
-            #    (2, "textB", 88),
-            #    alias_name="myvalues",
-            # ),
+            #     column("mykey", Integer),
+            #     column("mytext", Text),
+            #     column("myint", Integer),
+            #     name="myvalues",
+            # )
+            # .data([(1, "textA", 99), (2, "textB", 88)])
+            # ._annotate({"nocache": True}),
+        ),
+        lambda: (
+            values(
+                column("mykey", Integer),
+                column("mytext", String),
+                column("myint", Integer),
+                name="myvalues",
+            )
+            .data([(1, "textA", 99), (2, "textB", 88)])
+            .scalar_values()
+            ._annotate({"nocache": True}),
+            values(
+                column("mykey", Integer),
+                column("mytext", String),
+                column("myint", Integer),
+                name="myvalues",
+                literal_binds=True,
+            )
+            .data([(1, "textA", 99), (2, "textB", 88)])
+            .scalar_values()
+            ._annotate({"nocache": True}),
+            values(
+                column("mykey", Integer),
+                column("mytext", String),
+                column("myint", Integer),
+                name="myvalues",
+            )
+            .data([(1, "textA", 89), (2, "textG", 88)])
+            .scalar_values()
+            ._annotate({"nocache": True}),
+            values(
+                column("mykey", Integer),
+                column("mynottext", String),
+                column("myint", Integer),
+                name="myvalues",
+            )
+            .data([(1, "textA", 99), (2, "textB", 88)])
+            .scalar_values()
+            ._annotate({"nocache": True}),
+            # TODO: difference in type
+            # values(
+            #     column("mykey", Integer),
+            #     column("mytext", Text),
+            #     column("myint", Integer),
+            #     name="myvalues",
+            # )
+            # .data([(1, "textA", 99), (2, "textB", 88)])
+            # .scalar_values()
+            # ._annotate({"nocache": True}),
         ),
         lambda: (
             select(table_a.c.a),
@@ -1304,9 +1360,8 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
                                 compare_annotations=True,
                                 compare_values=compare_values,
                             ),
-                            "%r != %r" % (case_a[a], case_b[b]),
+                            f"{case_a[a]!r} != {case_b[b]!r} (index {a} {b})",
                         )
-
                     else:
                         is_false(
                             case_a[a].compare(
@@ -1314,7 +1369,7 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
                                 compare_annotations=True,
                                 compare_values=compare_values,
                             ),
-                            "%r == %r" % (case_a[a], case_b[b]),
+                            f"{case_a[a]!r} == {case_b[b]!r} (index {a} {b})",
                         )
 
     def test_compare_col_identity(self):
index e5a149c49e7354099456f61aa8f211701d7111f1..c71cfd61f0a765077db500127712846afe97c100 100644 (file)
@@ -4526,10 +4526,16 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
             "((myothertable.otherid, myothertable.othername))",
         )
 
+    @testing.variation("scalar_subquery", [True, False])
+    def test_select_in(self, scalar_subquery):
+
+        stmt = select(table2.c.otherid, table2.c.othername)
+
+        if scalar_subquery:
+            stmt = stmt.scalar_subquery()
+
         self.assert_compile(
-            tuple_(table1.c.myid, table1.c.name).in_(
-                select(table2.c.otherid, table2.c.othername)
-            ),
+            tuple_(table1.c.myid, table1.c.name).in_(stmt),
             "(mytable.myid, mytable.name) IN (SELECT "
             "myothertable.otherid, myothertable.othername FROM myothertable)",
         )
index e00cacad899d45072483a0b817154f6a7fa48e1d..103520f1faa98a7007f9c43695e3c3ae778cfc40 100644 (file)
@@ -43,6 +43,7 @@ from sqlalchemy.sql import roles
 from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql import table
 from sqlalchemy.sql import true
+from sqlalchemy.sql import values
 from sqlalchemy.sql.elements import BindParameter
 from sqlalchemy.sql.elements import BooleanClauseList
 from sqlalchemy.sql.elements import Label
@@ -2390,6 +2391,23 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             dialect="default_enhanced",
         )
 
+    @testing.combinations(lambda v: v, lambda v: v.scalar_values())
+    def test_in_values(self, scalar):
+        t1, t2 = self.table1, self.table2
+        v = scalar(values(t2.c.otherid).data([(1,), (42,)]))
+        self.assert_compile(
+            select(t1.c.myid.in_(v)),
+            "SELECT mytable.myid IN (VALUES (:param_1), (:param_2)) "
+            "AS anon_1 FROM mytable",
+            params={"param_1": 1, "param_2": 42},
+        )
+        self.assert_compile(
+            select(t1.c.myid.not_in(v)),
+            "SELECT (mytable.myid NOT IN (VALUES (:param_1), (:param_2))) "
+            "AS anon_1 FROM mytable",
+            params={"param_1": 1, "param_2": 42},
+        )
+
 
 class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
@@ -2708,6 +2726,25 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         assert not text("x = y")._is_implicitly_boolean
         assert not literal_column("x = y")._is_implicitly_boolean
 
+    def test_scalar_select(self):
+        t = self.table1
+        expr = select(t.c.myid).where(t.c.myid > 5).scalar_subquery()
+        self.assert_compile(
+            not_(expr),
+            "NOT (SELECT mytable.myid FROM mytable "
+            "WHERE mytable.myid > :myid_1)",
+            params={"myid_1": 5},
+        )
+
+    def test_scalar_values(self):
+        t = self.table1
+        expr = values(t.c.myid).data([(7,), (42,)]).scalar_values()
+        self.assert_compile(
+            not_(expr),
+            "NOT (VALUES (:param_1), (:param_2))",
+            params={"param_1": 7, "param_2": 42},
+        )
+
 
 class LikeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
@@ -4324,85 +4361,61 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(expr(col), "NULL = ANY (tab1.%s)" % col.name)
 
-    def test_any_array(self, t_fixture):
-        t = t_fixture
-
-        self.assert_compile(
-            5 == any_(t.c.arrval),
-            ":param_1 = ANY (tab1.arrval)",
-            checkparams={"param_1": 5},
-        )
-
-    def test_any_array_method(self, t_fixture):
-        t = t_fixture
-
-        self.assert_compile(
-            5 == t.c.arrval.any_(),
-            ":param_1 = ANY (tab1.arrval)",
-            checkparams={"param_1": 5},
-        )
-
-    def test_all_array(self, t_fixture):
-        t = t_fixture
-
-        self.assert_compile(
-            5 == all_(t.c.arrval),
-            ":param_1 = ALL (tab1.arrval)",
-            checkparams={"param_1": 5},
-        )
-
-    def test_all_array_method(self, t_fixture):
-        t = t_fixture
-
-        self.assert_compile(
-            5 == t.c.arrval.all_(),
-            ":param_1 = ALL (tab1.arrval)",
-            checkparams={"param_1": 5},
-        )
+    @testing.fixture(
+        params=[
+            ("ANY", any_),
+            ("ANY", lambda x: x.any_()),
+            ("ALL", all_),
+            ("ALL", lambda x: x.all_()),
+        ]
+    )
+    def operator(self, request):
+        return request.param
+
+    @testing.fixture(
+        params=[
+            ("ANY", lambda x, *o: x.any(*o)),
+            ("ALL", lambda x, *o: x.all(*o)),
+        ]
+    )
+    def array_op(self, request):
+        return request.param
 
-    def test_any_comparator_array(self, t_fixture):
+    def test_array(self, t_fixture, operator):
         t = t_fixture
-
+        op, fn = operator
         self.assert_compile(
-            5 > any_(t.c.arrval),
-            ":param_1 > ANY (tab1.arrval)",
+            5 == fn(t.c.arrval),
+            f":param_1 = {op} (tab1.arrval)",
             checkparams={"param_1": 5},
         )
 
-    def test_all_comparator_array(self, t_fixture):
+    def test_comparator_array(self, t_fixture, operator):
         t = t_fixture
-
+        op, fn = operator
         self.assert_compile(
-            5 > all_(t.c.arrval),
-            ":param_1 > ALL (tab1.arrval)",
+            5 > fn(t.c.arrval),
+            f":param_1 > {op} (tab1.arrval)",
             checkparams={"param_1": 5},
         )
 
-    def test_any_comparator_array_wexpr(self, t_fixture):
+    def test_comparator_array_wexpr(self, t_fixture, operator):
         t = t_fixture
-
-        self.assert_compile(
-            t.c.data > any_(t.c.arrval),
-            "tab1.data > ANY (tab1.arrval)",
-            checkparams={},
-        )
-
-    def test_all_comparator_array_wexpr(self, t_fixture):
-        t = t_fixture
-
+        op, fn = operator
         self.assert_compile(
-            t.c.data > all_(t.c.arrval),
-            "tab1.data > ALL (tab1.arrval)",
+            t.c.data > fn(t.c.arrval),
+            f"tab1.data > {op} (tab1.arrval)",
             checkparams={},
         )
 
-    def test_illegal_ops(self, t_fixture):
+    def test_illegal_ops(self, t_fixture, operator):
         t = t_fixture
+        op, fn = operator
 
         assert_raises_message(
             exc.ArgumentError,
             "Only comparison operators may be used with ANY/ALL",
-            lambda: 5 + all_(t.c.arrval),
+            lambda: 5 + fn(t.c.arrval),
         )
 
         # TODO:
@@ -4410,86 +4423,47 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         # as the left-hand side just does its thing.  Types
         # would need to reject their right-hand side.
         self.assert_compile(
-            t.c.data + all_(t.c.arrval), "tab1.data + ALL (tab1.arrval)"
+            t.c.data + fn(t.c.arrval), f"tab1.data + {op} (tab1.arrval)"
         )
 
-    @testing.combinations("all", "any", argnames="op")
-    def test_any_all_bindparam_coercion(self, t_fixture, op):
+    def test_bindparam_coercion(self, t_fixture, array_op):
         """test #7979"""
         t = t_fixture
+        op, fn = array_op
 
-        if op == "all":
-            expr = t.c.arrval.all(bindparam("param"))
-            expected = "%(param)s = ALL (tab1.arrval)"
-        elif op == "any":
-            expr = t.c.arrval.any(bindparam("param"))
-            expected = "%(param)s = ANY (tab1.arrval)"
-        else:
-            assert False
-
+        expr = fn(t.c.arrval, bindparam("param"))
+        expected = f"%(param)s = {op} (tab1.arrval)"
         is_(expr.left.type._type_affinity, Integer)
 
         self.assert_compile(expr, expected, dialect="postgresql")
 
-    def test_any_array_comparator_accessor(self, t_fixture):
-        t = t_fixture
-
-        self.assert_compile(
-            t.c.arrval.any(5, operator.gt),
-            ":arrval_1 > ANY (tab1.arrval)",
-            checkparams={"arrval_1": 5},
-        )
-
-    def test_any_array_comparator_negate_accessor(self, t_fixture):
-        t = t_fixture
-
-        self.assert_compile(
-            ~t.c.arrval.any(5, operator.gt),
-            "NOT (:arrval_1 > ANY (tab1.arrval))",
-            checkparams={"arrval_1": 5},
-        )
-
-    def test_all_array_comparator_accessor(self, t_fixture):
+    def test_array_comparator_accessor(self, t_fixture, array_op):
         t = t_fixture
+        op, fn = array_op
 
         self.assert_compile(
-            t.c.arrval.all(5, operator.gt),
-            ":arrval_1 > ALL (tab1.arrval)",
+            fn(t.c.arrval, 5, operator.gt),
+            f":arrval_1 > {op} (tab1.arrval)",
             checkparams={"arrval_1": 5},
         )
 
-    def test_all_array_comparator_negate_accessor(self, t_fixture):
+    def test_array_comparator_negate_accessor(self, t_fixture, array_op):
         t = t_fixture
+        op, fn = array_op
 
         self.assert_compile(
-            ~t.c.arrval.all(5, operator.gt),
-            "NOT (:arrval_1 > ALL (tab1.arrval))",
+            ~fn(t.c.arrval, 5, operator.gt),
+            f"NOT (:arrval_1 > {op} (tab1.arrval))",
             checkparams={"arrval_1": 5},
         )
 
-    def test_any_array_expression(self, t_fixture):
-        t = t_fixture
-
-        self.assert_compile(
-            5 == any_(t.c.arrval[5:6] + postgresql.array([3, 4])),
-            "%(param_1)s = ANY (tab1.arrval[%(arrval_1)s:%(arrval_2)s] || "
-            "ARRAY[%(param_2)s, %(param_3)s])",
-            checkparams={
-                "arrval_2": 6,
-                "param_1": 5,
-                "param_3": 4,
-                "arrval_1": 5,
-                "param_2": 3,
-            },
-            dialect="postgresql",
-        )
-
-    def test_all_array_expression(self, t_fixture):
+    def test_array_expression(self, t_fixture, operator):
         t = t_fixture
+        op, fn = operator
 
         self.assert_compile(
-            5 == all_(t.c.arrval[5:6] + postgresql.array([3, 4])),
-            "%(param_1)s = ALL (tab1.arrval[%(arrval_1)s:%(arrval_2)s] || "
+            5 == fn(t.c.arrval[5:6] + postgresql.array([3, 4])),
+            f"%(param_1)s = {op} (tab1.arrval[%(arrval_1)s:%(arrval_2)s] || "
             "ARRAY[%(param_2)s, %(param_3)s])",
             checkparams={
                 "arrval_2": 6,
@@ -4501,44 +4475,35 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             dialect="postgresql",
         )
 
-    def test_any_subq(self, t_fixture):
-        t = t_fixture
-
-        self.assert_compile(
-            5 == any_(select(t.c.data).where(t.c.data < 10).scalar_subquery()),
-            ":param_1 = ANY (SELECT tab1.data "
-            "FROM tab1 WHERE tab1.data < :data_1)",
-            checkparams={"data_1": 10, "param_1": 5},
-        )
-
-    def test_any_subq_method(self, t_fixture):
+    def test_subq(self, t_fixture, operator):
         t = t_fixture
+        op, fn = operator
 
         self.assert_compile(
-            5
-            == select(t.c.data).where(t.c.data < 10).scalar_subquery().any_(),
-            ":param_1 = ANY (SELECT tab1.data "
+            5 == fn(select(t.c.data).where(t.c.data < 10).scalar_subquery()),
+            f":param_1 = {op} (SELECT tab1.data "
             "FROM tab1 WHERE tab1.data < :data_1)",
             checkparams={"data_1": 10, "param_1": 5},
         )
 
-    def test_all_subq(self, t_fixture):
+    def test_scalar_values(self, t_fixture, operator):
         t = t_fixture
+        op, fn = operator
 
         self.assert_compile(
-            5 == all_(select(t.c.data).where(t.c.data < 10).scalar_subquery()),
-            ":param_1 = ALL (SELECT tab1.data "
-            "FROM tab1 WHERE tab1.data < :data_1)",
-            checkparams={"data_1": 10, "param_1": 5},
+            5 == fn(values(t.c.data).data([(1,), (42,)]).scalar_values()),
+            f":param_1 = {op} (VALUES (:param_2), (:param_3))",
+            checkparams={"param_1": 5, "param_2": 1, "param_3": 42},
         )
 
-    def test_all_subq_method(self, t_fixture):
+    @testing.combinations(any_, all_, argnames="fn")
+    def test_values_illegal(self, t_fixture, fn):
         t = t_fixture
 
-        self.assert_compile(
-            5
-            == select(t.c.data).where(t.c.data < 10).scalar_subquery().all_(),
-            ":param_1 = ALL (SELECT tab1.data "
-            "FROM tab1 WHERE tab1.data < :data_1)",
-            checkparams={"data_1": 10, "param_1": 5},
-        )
+        with expect_raises_message(
+            exc.ArgumentError,
+            "SQL expression element expected, got .* "
+            "To create a column expression from a VALUES clause, "
+            r"use the .scalar_values\(\) method.",
+        ):
+            fn(values(t.c.data).data([(1,), (42,)]))
index 5c9ed3588a57ab90bd203308a3b9a6b59968515d..d181e0d1ac78d00bbb368b5653aa0e196a6f33f5 100644 (file)
@@ -12,6 +12,7 @@ from sqlalchemy import table
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import update
+from sqlalchemy import values
 from sqlalchemy.schema import DDL
 from sqlalchemy.schema import Sequence
 from sqlalchemy.sql import ClauseElement
@@ -189,6 +190,22 @@ class RoleTest(fixtures.TestBase):
                 select(column("q")).alias(),
             )
 
+    def test_values_advice(self):
+        value_expr = values(
+            column("id", Integer), column("name", String), name="my_values"
+        ).data([(1, "name1"), (2, "name2"), (3, "name3")])
+
+        assert_raises_message(
+            exc.ArgumentError,
+            r"SQL expression element expected, got <.*Values.*my_values>. To "
+            r"create a "
+            r"column expression from a VALUES clause, "
+            r"use the .scalar_values\(\) method.",
+            expect,
+            roles.ExpressionElementRole,
+            value_expr,
+        )
+
     def test_table_valued_advice(self):
         msg = (
             r"SQL expression element expected, got %s. To create a "