]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
rework wraps_column_expression logic to be purely compile time checking
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 16 Jun 2025 23:53:30 +0000 (19:53 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Jun 2025 14:31:31 +0000 (10:31 -0400)
Fixed issue where :func:`.select` of a free-standing, unnamed scalar expression that
has a unary operator applied, such as negation, would not apply result
processors to the selected column even though the correct type remains in
place for the unary expression.

This change opened up a typing rabbithole where we were led to also
improve and harden the typing for the Exists element, in particular
in that the Exists now always refers to a ScalarSelect object, and
no longer a SelectStatementGrouping within the _regroup() cases; there
did not seem to be any reason for this inconsistency.

Fixes: #12681
Change-Id: If9131807941030c627ab31ede4ccbd86e44e707f

doc/build/changelog/unreleased_20/12681.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/testing/assertions.py
test/sql/test_labels.py
test/sql/test_operators.py
test/sql/test_selectable.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/12681.rst b/doc/build/changelog/unreleased_20/12681.rst
new file mode 100644 (file)
index 0000000..72e7e1e
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 12681
+
+    Fixed issue where :func:`.select` of a free-standing scalar expression that
+    has a unary operator applied, such as negation, would not apply result
+    processors to the selected column even though the correct type remains in
+    place for the unary expression.
+
index 5e874b3799672bb9ad865bc93867eec40b7bda96..5b992269a593c369f792f71c34fd982110b5e537 100644 (file)
@@ -4562,7 +4562,52 @@ class SQLCompiler(Compiled):
             elif isinstance(column, elements.TextClause):
                 render_with_label = False
             elif isinstance(column, elements.UnaryExpression):
-                render_with_label = column.wraps_column_expression or asfrom
+                # unary expression.  notes added as of #12681
+                #
+                # By convention, the visit_unary() method
+                # itself does not add an entry to the result map, and relies
+                # upon either the inner expression creating a result map
+                # entry, or if not, by creating a label here that produces
+                # the result map entry.  Where that happens is based on whether
+                # or not the element immediately inside the unary is a
+                # NamedColumn subclass or not.
+                #
+                # Now, this also impacts how the SELECT is written; if
+                # we decide to generate a label here, we get the usual
+                # "~(x+y) AS anon_1" thing in the columns clause.   If we
+                # don't, we don't get an AS at all, we get like
+                # "~table.column".
+                #
+                # But here is the important thing as of modernish (like 1.4)
+                # versions of SQLAlchemy - **whether or not the AS <label>
+                # is present in the statement is not actually important**.
+                # We target result columns **positionally** for a fully
+                # compiled ``Select()`` object; before 1.4 we needed those
+                # labels to match in cursor.description etc etc but now it
+                # really doesn't matter.
+                # So really, we could set render_with_label True in all cases.
+                # Or we could just have visit_unary() populate the result map
+                # in all cases.
+                #
+                # What we're doing here is strictly trying to not rock the
+                # boat too much with when we do/don't render "AS label";
+                # labels being present helps in the edge cases that we
+                # "fall back" to named cursor.description matching, labels
+                # not being present for columns keeps us from having awkward
+                # phrases like "SELECT DISTINCT table.x AS x".
+                render_with_label = (
+                    (
+                        # exception case to detect if we render "not boolean"
+                        # as "not <col>" for native boolean or "<col> = 1"
+                        # for non-native boolean.   this is controlled by
+                        # visit_is_<true|false>_unary_operator
+                        column.operator
+                        in (operators.is_false, operators.is_true)
+                        and not self.dialect.supports_native_boolean
+                    )
+                    or column._wraps_unnamed_column()
+                    or asfrom
+                )
             elif (
                 # general class of expressions that don't have a SQL-column
                 # addressible name.  includes scalar selects, bind parameters,
index 4c75936b5804a1181229480f992b38b8fe2dbad5..84f813be5f5ba99f6886433e58c4716432deeeb5 100644 (file)
@@ -755,6 +755,8 @@ class ClauseElement(
             return self._negate()
 
     def _negate(self) -> ClauseElement:
+        # TODO: this code is uncovered and in all likelihood is not included
+        # in any codepath.  So this should raise NotImplementedError in 2.1
         grouped = self.self_group(against=operators.inv)
         assert isinstance(grouped, ColumnElement)
         return UnaryExpression(grouped, operator=operators.inv)
@@ -1466,6 +1468,10 @@ class ColumnElement(
 
     _alt_names: Sequence[str] = ()
 
+    if TYPE_CHECKING:
+
+        def _ungroup(self) -> ColumnElement[_T]: ...
+
     @overload
     def self_group(self, against: None = None) -> ColumnElement[_T]: ...
 
@@ -1500,7 +1506,8 @@ class ColumnElement(
             grouped = self.self_group(against=operators.inv)
             assert isinstance(grouped, ColumnElement)
             return UnaryExpression(
-                grouped, operator=operators.inv, wraps_column_expression=True
+                grouped,
+                operator=operators.inv,
             )
 
     type: TypeEngine[_T]
@@ -3045,9 +3052,7 @@ class ExpressionClauseList(OperatorExpression[_T]):
     def _negate(self) -> Any:
         grouped = self.self_group(against=operators.inv)
         assert isinstance(grouped, ColumnElement)
-        return UnaryExpression(
-            grouped, operator=operators.inv, wraps_column_expression=True
-        )
+        return UnaryExpression(grouped, operator=operators.inv)
 
 
 class BooleanClauseList(ExpressionClauseList[bool]):
@@ -3666,15 +3671,18 @@ class UnaryExpression(ColumnElement[_T]):
         ("modifier", InternalTraversal.dp_operator),
     ]
 
-    element: ClauseElement
+    element: ColumnElement[Any]
+    operator: Optional[OperatorType]
+    modifier: Optional[OperatorType]
 
     def __init__(
         self,
         element: ColumnElement[Any],
+        *,
         operator: Optional[OperatorType] = None,
         modifier: Optional[OperatorType] = None,
         type_: Optional[_TypeEngineArgument[_T]] = None,
-        wraps_column_expression: bool = False,
+        wraps_column_expression: bool = False,  # legacy, not used as of 2.0.42
     ):
         self.operator = operator
         self.modifier = modifier
@@ -3687,7 +3695,12 @@ class UnaryExpression(ColumnElement[_T]):
         # know how to get the overloads to express that correctly
         self.type = type_api.to_instance(type_)  # type: ignore
 
-        self.wraps_column_expression = wraps_column_expression
+    def _wraps_unnamed_column(self):
+        ungrouped = self.element._ungroup()
+        return (
+            not isinstance(ungrouped, NamedColumn)
+            or ungrouped._non_anon_label is None
+        )
 
     @classmethod
     def _create_nulls_first(
@@ -3697,7 +3710,6 @@ class UnaryExpression(ColumnElement[_T]):
         return UnaryExpression(
             coercions.expect(roles.ByOfRole, column),
             modifier=operators.nulls_first_op,
-            wraps_column_expression=False,
         )
 
     @classmethod
@@ -3708,7 +3720,6 @@ class UnaryExpression(ColumnElement[_T]):
         return UnaryExpression(
             coercions.expect(roles.ByOfRole, column),
             modifier=operators.nulls_last_op,
-            wraps_column_expression=False,
         )
 
     @classmethod
@@ -3718,7 +3729,6 @@ class UnaryExpression(ColumnElement[_T]):
         return UnaryExpression(
             coercions.expect(roles.ByOfRole, column),
             modifier=operators.desc_op,
-            wraps_column_expression=False,
         )
 
     @classmethod
@@ -3729,7 +3739,6 @@ class UnaryExpression(ColumnElement[_T]):
         return UnaryExpression(
             coercions.expect(roles.ByOfRole, column),
             modifier=operators.asc_op,
-            wraps_column_expression=False,
         )
 
     @classmethod
@@ -3744,7 +3753,6 @@ class UnaryExpression(ColumnElement[_T]):
             col_expr,
             operator=operators.distinct_op,
             type_=col_expr.type,
-            wraps_column_expression=False,
         )
 
     @classmethod
@@ -3759,7 +3767,6 @@ class UnaryExpression(ColumnElement[_T]):
             col_expr,
             operator=operators.bitwise_not_op,
             type_=col_expr.type,
-            wraps_column_expression=False,
         )
 
     @property
@@ -3773,16 +3780,15 @@ class UnaryExpression(ColumnElement[_T]):
     def _from_objects(self) -> List[FromClause]:
         return self.element._from_objects
 
-    def _negate(self):
+    def _negate(self) -> ColumnElement[Any]:
         if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
             return UnaryExpression(
                 self.self_group(against=operators.inv),
                 operator=operators.inv,
                 type_=type_api.BOOLEANTYPE,
-                wraps_column_expression=self.wraps_column_expression,
             )
         else:
-            return ClauseElement._negate(self)
+            return ColumnElement._negate(self)
 
     def self_group(
         self, against: Optional[OperatorType] = None
@@ -3819,7 +3825,6 @@ class CollectionAggregate(UnaryExpression[_T]):
             col_expr,
             operator=operators.any_op,
             type_=type_api.BOOLEANTYPE,
-            wraps_column_expression=False,
         )
 
     @classmethod
@@ -3835,7 +3840,6 @@ class CollectionAggregate(UnaryExpression[_T]):
             col_expr,
             operator=operators.all_op,
             type_=type_api.BOOLEANTYPE,
-            wraps_column_expression=False,
         )
 
     # operate and reverse_operate are hardwired to
@@ -3870,7 +3874,6 @@ class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]):
         self.operator = operator
         self.negate = negate
         self.modifier = None
-        self.wraps_column_expression = True
         self._is_implicitly_boolean = element._is_implicitly_boolean
 
     @property
@@ -4093,13 +4096,11 @@ class GroupedElement(DQLDMLClauseElement):
 
     __visit_name__ = "grouping"
 
-    element: ClauseElement
-
     def self_group(self, against: Optional[OperatorType] = None) -> Self:
         return self
 
-    def _ungroup(self):
-        return self.element._ungroup()
+    def _ungroup(self) -> ClauseElement:
+        raise NotImplementedError()
 
 
 class Grouping(GroupedElement, ColumnElement[_T]):
@@ -4128,6 +4129,10 @@ class Grouping(GroupedElement, ColumnElement[_T]):
     def _with_binary_element_type(self, type_):
         return self.__class__(self.element._with_binary_element_type(type_))
 
+    def _ungroup(self) -> ColumnElement[_T]:
+        assert isinstance(self.element, ColumnElement)
+        return self.element._ungroup()
+
     @util.memoized_property
     def _is_implicitly_boolean(self):
         return self.element._is_implicitly_boolean
index 462d96b27acdd2dcaffefd1a01a122094f571888..c7ca0ba795b46a2238a7ffa1dd6fc1652247875c 100644 (file)
@@ -3492,6 +3492,8 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]):
             self, against: Optional[OperatorType] = None
         ) -> Self: ...
 
+        def _ungroup(self) -> ColumnElement[Any]: ...
+
 
 class SelectBase(
     roles.SelectStatementRole,
@@ -6848,9 +6850,8 @@ class ScalarSelect(
     def self_group(self, against: Optional[OperatorType] = None) -> Self:
         return self
 
-    if TYPE_CHECKING:
-
-        def _ungroup(self) -> Select[Unpack[TupleAny]]: ...
+    def _ungroup(self) -> Self:
+        return self
 
     @_generative
     def correlate(
@@ -6938,10 +6939,6 @@ class Exists(UnaryExpression[bool]):
     """
 
     inherit_cache = True
-    element: Union[
-        SelectStatementGrouping[Select[Unpack[TupleAny]]],
-        ScalarSelect[Any],
-    ]
 
     def __init__(
         self,
@@ -6968,7 +6965,6 @@ class Exists(UnaryExpression[bool]):
             s,
             operator=operators.exists,
             type_=type_api.BOOLEANTYPE,
-            wraps_column_expression=True,
         )
 
     @util.ro_non_memoized_property
@@ -6978,12 +6974,17 @@ class Exists(UnaryExpression[bool]):
     def _regroup(
         self,
         fn: Callable[[Select[Unpack[TupleAny]]], Select[Unpack[TupleAny]]],
-    ) -> SelectStatementGrouping[Select[Unpack[TupleAny]]]:
-        element = self.element._ungroup()
+    ) -> ScalarSelect[Any]:
+
+        assert isinstance(self.element, ScalarSelect)
+        element = self.element.element
+        if not isinstance(element, Select):
+            raise exc.InvalidRequestError(
+                "Can only apply this operation to a plain SELECT construct"
+            )
         new_element = fn(element)
 
-        return_value = new_element.self_group(against=operators.exists)
-        assert isinstance(return_value, SelectStatementGrouping)
+        return_value = new_element.scalar_subquery()
         return return_value
 
     def select(self) -> Select[bool]:
@@ -7038,7 +7039,6 @@ class Exists(UnaryExpression[bool]):
             :meth:`_sql.ScalarSelect.correlate_except`
 
         """
-
         e = self._clone()
         e.element = self._regroup(
             lambda element: element.correlate_except(*fromclauses)
index a22da65a625bc2ce6cd24908ef39490ebcc2e039..431c2b7e98a1954b529e466332ee03bb486b13d0 100644 (file)
@@ -519,6 +519,7 @@ class AssertsCompiledSQL:
         use_default_dialect=False,
         allow_dialect_select=False,
         supports_default_values=True,
+        supports_native_boolean=False,
         supports_default_metavalue=True,
         literal_binds=False,
         render_postcompile=False,
@@ -533,6 +534,7 @@ class AssertsCompiledSQL:
             dialect = default.DefaultDialect()
             dialect.supports_default_values = supports_default_values
             dialect.supports_default_metavalue = supports_default_metavalue
+            dialect.supports_native_boolean = supports_native_boolean
         elif allow_dialect_select:
             dialect = None
         else:
index 40ae2a65c8a5166b93da8b5e2714501976198430..b4a857771a00aca12a2b63c8c07c0fb1b52e00af 100644 (file)
@@ -1080,15 +1080,24 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL):
             "some_table.name AS name_1 FROM some_table",
         )
 
-    def test_boolean_auto_label(self):
+    @testing.variation("native_boolean", [True, False])
+    def test_boolean_auto_label(self, native_boolean):
         col = column("value", Boolean)
 
-        self.assert_compile(
-            select(~col, col),
-            # not sure if this SQL is right but this is what it was
-            # before the new labeling, just different label name
-            "SELECT value = 0 AS value, value",
-        )
+        if native_boolean:
+            self.assert_compile(
+                select(~col, col),
+                "SELECT NOT value, value",
+                supports_native_boolean=True,
+                use_default_dialect=True,
+            )
+        else:
+            self.assert_compile(
+                select(~col, col),
+                # not sure if this SQL is right but this is what it was
+                # before the new labeling, just different label name
+                "SELECT value = 0 AS value, value",
+            )
 
     def test_label_auto_label_use_labels(self):
         expr = self._fixture()
index b78b3ac1f7626b6d0084b39096887908fb292b46..48a5b6acb8641627644aaa74dbb7766c340bdd9a 100644 (file)
@@ -2833,6 +2833,15 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
     table1 = table("mytable", column("myid", Integer), column("name", String))
 
+    @testing.combinations(
+        (~literal(5), "NOT :param_1"), (~-literal(5), "NOT -:param_1")
+    )
+    def test_nonsensical_negates(self, expr, expected):
+        """exercise codepaths in the UnaryExpression._negate() method where the
+        type is not BOOLEAN"""
+
+        self.assert_compile(expr, expected)
+
     def test_negate_operators_1(self):
         for py_op, op in ((operator.neg, "-"), (operator.inv, "NOT ")):
             for expr, expected in (
@@ -4835,6 +4844,50 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 class BitOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
 
+    @testing.combinations(
+        ("neg", operators.neg, "-"),
+        ("inv", operators.inv, "NOT "),
+        ("not", operators.bitwise_not_op, "~"),
+        ("distinct", operators.distinct_op, "DISTINCT "),
+        id_="iaa",
+        argnames="py_op, sql_op",
+    )
+    @testing.variation("named", ["column", "unnamed", "label"])
+    def test_wraps_named_column_heuristic(self, py_op, sql_op, named):
+        """test for #12681"""
+
+        if named.column:
+            expr = py_op(column("q", String))
+            assert isinstance(expr, UnaryExpression)
+
+            self.assert_compile(
+                select(expr),
+                f"SELECT {sql_op}q",
+            )
+
+        elif named.unnamed:
+            expr = py_op(literal("x", String))
+            assert isinstance(expr, UnaryExpression)
+
+            self.assert_compile(
+                select(expr),
+                f"SELECT {sql_op}:param_1 AS anon_1",
+            )
+        elif named.label:
+            expr = py_op(literal("x", String).label("z"))
+            if py_op is operators.inv:
+                # special case for operators.inv due to Label._negate()
+                # not sure if this should be changed but still works out in the
+                # end
+                assert isinstance(expr.element, UnaryExpression)
+            else:
+                assert isinstance(expr, UnaryExpression)
+
+            self.assert_compile(
+                select(expr),
+                f"SELECT {sql_op}:param_1 AS z",
+            )
+
     def test_compile_not_column_lvl(self):
         c = column("c", Integer)
 
index fc3039fada73c66cf98acc919beff32a77955b9c..91eacd55bba6aa4c250d9c6ff86fb2a03b05f6c3 100644 (file)
@@ -3943,6 +3943,40 @@ class ResultMapTest(fixtures.TestBase):
             [Boolean],
         )
 
+    @testing.combinations(
+        lambda e, t: e.correlate(t),
+        lambda e, t: e.correlate_except(t),
+        lambda e, t: e.select_from(t),
+        lambda e, t: e.where(t.c.y == 5),
+        argnames="testcase",
+    )
+    @testing.variation("inner_select", ["select", "compound"])
+    def test_exists_regroup_modifiers(
+        self, testcase, inner_select: testing.Variation
+    ):
+        a = table("a", column("x"), column("y"))
+        b = table("b", column("x"), column("y"))
+        if inner_select.compound:
+            stmt = select(a.c.x).union_all(select(b.c.x))
+        elif inner_select.select:
+            stmt = select(a.c.x)
+        else:
+            inner_select.fail()
+
+        exists = stmt.exists()
+
+        if inner_select.compound:
+            with expect_raises_message(
+                exc.InvalidRequestError,
+                "Can only apply this operation to a plain SELECT construct",
+            ):
+                testcase(exists, b)
+        else:
+            regrouped = testcase(exists, b)
+            assert regrouped.element.compare(
+                testcase(exists.element, b).scalar_subquery()
+            )
+
     def test_column_subquery_plain(self):
         t = self._fixture()
         s1 = select(t.c.x).where(t.c.x > 5).scalar_subquery()
index 1a173f89d1f58115ab8b04e42b205f77523aff6e..bd147a415ae2e4e33e081d3814055d754541d715 100644 (file)
@@ -900,6 +900,35 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
             ],
         )
 
+    def test_unary_operator(self, connection):
+        users = self.tables.users
+        self._data_fixture(connection)
+
+        eq_(
+            connection.scalar(
+                select(-users.c.goofy8).order_by(users.c.user_id)
+            ),
+            -1200,
+        )
+
+    def test_unary_operator_standalone(self, connection):
+        """test #12681"""
+
+        class MyNewIntType(types.TypeDecorator):
+            impl = Integer
+            cache_ok = True
+
+            def process_bind_param(self, value, dialect):
+                if value is None:
+                    value = 29
+                return value * 10
+
+            def process_result_value(self, value, dialect):
+                return value * 10
+
+        eq_(connection.scalar(select(literal(12, MyNewIntType))), 1200)
+        eq_(connection.scalar(select(-literal(12, MyNewIntType))), -1200)
+
     def test_plain_in_typedec(self, connection):
         users = self.tables.users
         self._data_fixture(connection)