]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add pow operator support
authorFederico Caselli <cfederico87@gmail.com>
Fri, 4 Apr 2025 20:23:31 +0000 (22:23 +0200)
committerMichael Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Apr 2025 15:54:50 +0000 (15:54 +0000)
Added support for the pow operator (``**``), with a default SQL
implementation of the ``POW()`` function.   On Oracle Database, PostgreSQL
and MSSQL it renders as ``POWER()``.   As part of this change, the operator
routes through a new first class ``func`` member :class:`_functions.pow`,
which renders on Oracle Database, PostgreSQL and MSSQL as ``POWER()``.

Fixes: #8579
Closes: #8580
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8580
Pull-request-sha: 041b2ef474a291c6b6172e49cc6e0d548e28761a

Change-Id: I371bd44ed3e58f2d55ef705aeec7d04710c97f23

13 files changed:
doc/build/changelog/unreleased_21/8579.rst [new file with mode: 0644]
doc/build/core/functions.rst
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/operators.py
test/dialect/mssql/test_compiler.py
test/dialect/oracle/test_compiler.py
test/dialect/postgresql/test_compiler.py
test/sql/test_operators.py
test/typing/plain_files/sql/functions.py

diff --git a/doc/build/changelog/unreleased_21/8579.rst b/doc/build/changelog/unreleased_21/8579.rst
new file mode 100644 (file)
index 0000000..57fe7c9
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: usecase, sql
+    :tickets: 8579
+
+    Added support for the pow operator (``**``), with a default SQL
+    implementation of the ``POW()`` function.   On Oracle Database, PostgreSQL
+    and MSSQL it renders as ``POWER()``.   As part of this change, the operator
+    routes through a new first class ``func`` member :class:`_functions.pow`,
+    which renders on Oracle Database, PostgreSQL and MSSQL as ``POWER()``.
index 9771ffeedd9be20c76f037d32f7aadeb5a7f5d06..26c59a0bdda1ab1b0e2370af227e9bd6b3fdc699 100644 (file)
@@ -124,6 +124,9 @@ return types are in use.
 .. autoclass:: percentile_disc
     :no-members:
 
+.. autoclass:: pow
+    :no-members:
+
 .. autoclass:: random
     :no-members:
 
index 24425fc817092ce8f2332095264c9164cd689ca3..8c8e7f9c47c6ae093233dd707828059cfbac1964 100644 (file)
@@ -2040,6 +2040,9 @@ class MSSQLCompiler(compiler.SQLCompiler):
         delimeter = fn.clauses.clauses[1]._compiler_dispatch(self, **kw)
         return f"string_agg({expr}, {delimeter})"
 
+    def visit_pow_func(self, fn, **kw):
+        return f"POWER{self.function_argspec(fn)}"
+
     def visit_concat_op_expression_clauselist(
         self, clauselist, operator, **kw
     ):
index 69af577d560be1abafe520a3f7adb0eb185ae0c7..c32dff2ea10691be49541aeb78216ef1b9d52645 100644 (file)
@@ -1021,6 +1021,9 @@ class OracleCompiler(compiler.SQLCompiler):
     def visit_char_length_func(self, fn, **kw):
         return "LENGTH" + self.function_argspec(fn, **kw)
 
+    def visit_pow_func(self, fn, **kw):
+        return f"POWER{self.function_argspec(fn)}"
+
     def visit_match_op_binary(self, binary, operator, **kw):
         return "CONTAINS (%s, %s)" % (
             self.process(binary.left),
index 864445026baabfdbb254278c07d1599ee7554444..32024f7d986cf25c91fc418bb3015d1fcfb25574 100644 (file)
@@ -2010,6 +2010,9 @@ class PGCompiler(compiler.SQLCompiler):
     def visit_aggregate_strings_func(self, fn, **kw):
         return "string_agg%s" % self.function_argspec(fn)
 
+    def visit_pow_func(self, fn, **kw):
+        return f"power{self.function_argspec(fn)}"
+
     def visit_sequence(self, seq, **kw):
         return "nextval('%s')" % self.preparer.format_sequence(seq)
 
index 7fa5dafe9ce4875e58fe79fedfc34cc34e3157fe..c1305be99470650eda61bd053f438ae48feb594b 100644 (file)
@@ -5,8 +5,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
-"""Default implementation of SQL comparison operations.
-"""
+"""Default implementation of SQL comparison operations."""
 
 from __future__ import annotations
 
@@ -21,6 +20,7 @@ from typing import Type
 from typing import Union
 
 from . import coercions
+from . import functions
 from . import operators
 from . import roles
 from . import type_api
@@ -351,6 +351,19 @@ def _between_impl(
     )
 
 
+def _pow_impl(
+    expr: ColumnElement[Any],
+    op: OperatorType,
+    other: Any,
+    reverse: bool = False,
+    **kw: Any,
+) -> ColumnElement[Any]:
+    if reverse:
+        return functions.pow(other, expr)
+    else:
+        return functions.pow(expr, other)
+
+
 def _collate_impl(
     expr: ColumnElement[str], op: OperatorType, collation: str, **kw: Any
 ) -> ColumnElement[str]:
@@ -549,4 +562,5 @@ operator_lookup: Dict[
     "regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT),
     "not_regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT),
     "regexp_replace_op": (_regexp_replace_impl, util.EMPTY_DICT),
+    "pow": (_pow_impl, util.EMPTY_DICT),
 }
index cd1a20a708e3415eeb19597e91c674e30bdafa75..050f94fd8087685aed5f7b0329bd3d2619ea6b88 100644 (file)
@@ -1199,6 +1199,42 @@ class _FunctionGenerator:
         @property
         def percentile_disc(self) -> Type[percentile_disc[Any]]: ...
 
+        # set ColumnElement[_T] as a separate overload, to appease
+        # mypy which seems to not want to accept _T from
+        # _ColumnExpressionArgument. Seems somewhat related to the covariant
+        # _HasClauseElement as of mypy 1.15
+
+        @overload
+        def pow(  # noqa: A001
+            self,
+            col: ColumnElement[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> pow[_T]: ...
+
+        @overload
+        def pow(  # noqa: A001
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> pow[_T]: ...
+
+        @overload
+        def pow(  # noqa: A001
+            self,
+            col: _T,
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> pow[_T]: ...
+
+        def pow(  # noqa: A001
+            self,
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> pow[_T]: ...
+
         @property
         def random(self) -> Type[random]: ...
 
@@ -1690,6 +1726,23 @@ class now(GenericFunction[datetime.datetime]):
     inherit_cache = True
 
 
+class pow(ReturnTypeFromArgs[_T]):  # noqa: A001
+    """The SQL POW() function which performs the power operator.
+
+    E.g.:
+
+    .. sourcecode:: pycon+sql
+
+        >>> print(select(func.pow(2, 8)))
+        {printsql}SELECT pow(:pow_2, :pow_3) AS pow_1
+
+    .. versionadded:: 2.1
+
+    """
+
+    inherit_cache = True
+
+
 class concat(GenericFunction[str]):
     """The SQL CONCAT() function, which concatenates strings.
 
index f93864478f8abb081de05a78ca0920fa3c55d4e4..635e5712ad57b2d6abcb12e825f5f20944ca9181 100644 (file)
@@ -30,6 +30,7 @@ from operator import mul as _uncast_mul
 from operator import ne as _uncast_ne
 from operator import neg as _uncast_neg
 from operator import or_ as _uncast_or_
+from operator import pow as _uncast_pow
 from operator import rshift as _uncast_rshift
 from operator import sub as _uncast_sub
 from operator import truediv as _uncast_truediv
@@ -114,6 +115,7 @@ mul = cast(OperatorType, _uncast_mul)
 ne = cast(OperatorType, _uncast_ne)
 neg = cast(OperatorType, _uncast_neg)
 or_ = cast(OperatorType, _uncast_or_)
+pow_ = cast(OperatorType, _uncast_pow)
 rshift = cast(OperatorType, _uncast_rshift)
 sub = cast(OperatorType, _uncast_sub)
 truediv = cast(OperatorType, _uncast_truediv)
@@ -1938,6 +1940,29 @@ class ColumnOperators(Operators):
         """
         return self.reverse_operate(floordiv, other)
 
+    def __pow__(self, other: Any) -> ColumnOperators:
+        """Implement the ``**`` operator.
+
+        In a column context, produces the clause ``pow(a, b)``, or a similar
+        dialect-specific expression.
+
+        .. versionadded:: 2.1
+
+        """
+        return self.operate(pow_, other)
+
+    def __rpow__(self, other: Any) -> ColumnOperators:
+        """Implement the ``**`` operator in reverse.
+
+        .. seealso::
+
+            :meth:`.ColumnOperators.__pow__`.
+
+        .. versionadded:: 2.1
+
+        """
+        return self.reverse_operate(pow_, other)
+
 
 _commutative: Set[Any] = {eq, ne, add, mul}
 _comparison: Set[Any] = {eq, ne, lt, gt, ge, le}
@@ -2541,6 +2566,7 @@ _PRECEDENCE: Dict[OperatorType, int] = {
     getitem: 15,
     json_getitem_op: 15,
     json_path_getitem_op: 15,
+    pow_: 15,
     mul: 8,
     truediv: 8,
     floordiv: 8,
index eb4dba0a079683ea06ed3f7a20fc7e395adf2aff..627738f71357d809d1ea3bc505a9fe42fee1c713 100644 (file)
@@ -32,9 +32,10 @@ from sqlalchemy.sql import table
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_ignore_whitespace
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
-from sqlalchemy.testing.assertions import eq_ignore_whitespace
+from sqlalchemy.testing import resolve_lambda
 from sqlalchemy.types import TypeEngine
 
 tbl = table("t", column("a"))
@@ -1850,6 +1851,25 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         with testing.expect_raises_message(exc.CompileError, error):
             print(stmt.compile(dialect=self.__dialect__))
 
+    @testing.combinations(
+        (lambda t: t.c.a**t.c.b, "POWER(t.a, t.b)", {}),
+        (lambda t: t.c.a**3, "POWER(t.a, :pow_1)", {"pow_1": 3}),
+        (lambda t: t.c.c.match(t.c.d), "CONTAINS (t.c, t.d)", {}),
+        (lambda t: t.c.c.match("w"), "CONTAINS (t.c, :c_1)", {"c_1": "w"}),
+        (lambda t: func.pow(t.c.a, 3), "POWER(t.a, :pow_1)", {"pow_1": 3}),
+        (lambda t: func.power(t.c.a, t.c.b), "power(t.a, t.b)", {}),
+    )
+    def test_simple_compile(self, fn, string, params):
+        t = table(
+            "t",
+            column("a", Integer),
+            column("b", Integer),
+            column("c", String),
+            column("d", String),
+        )
+        expr = resolve_lambda(fn, t=t)
+        self.assert_compile(expr, string, params)
+
 
 class CompileIdentityTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = mssql.dialect()
index 0ab5052a1fe2406f4a77f45fe8ace6032afe4f55..c7f4a0c492b50623628b12111b1dc963281be525 100644 (file)
@@ -43,6 +43,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import eq_ignore_whitespace
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
+from sqlalchemy.testing.util import resolve_lambda
 from sqlalchemy.types import TypeEngine
 
 
@@ -1679,6 +1680,25 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             f"CREATE TABLE table1 (x INTEGER) {expected_sql}",
         )
 
+    @testing.combinations(
+        (lambda t: t.c.a**t.c.b, "POWER(t.a, t.b)", {}),
+        (lambda t: t.c.a**3, "POWER(t.a, :pow_1)", {"pow_1": 3}),
+        (lambda t: t.c.c.match(t.c.d), "CONTAINS (t.c, t.d)", {}),
+        (lambda t: t.c.c.match("w"), "CONTAINS (t.c, :c_1)", {"c_1": "w"}),
+        (lambda t: func.pow(t.c.a, 3), "POWER(t.a, :pow_1)", {"pow_1": 3}),
+        (lambda t: func.power(t.c.a, t.c.b), "power(t.a, t.b)", {}),
+    )
+    def test_simple_compile(self, fn, string, params):
+        t = table(
+            "t",
+            column("a", Integer),
+            column("b", Integer),
+            column("c", String),
+            column("d", String),
+        )
+        expr = resolve_lambda(fn, t=t)
+        self.assert_compile(expr, string, params)
+
 
 class SequenceTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_basic(self):
index eda9f96662e1d90d8f533c565901d3eb94501766..f98ea9645b043e40d90bd7289c67ca0654752094 100644 (file)
@@ -79,6 +79,7 @@ from sqlalchemy.testing.assertions import eq_ignore_whitespace
 from sqlalchemy.testing.assertions import expect_deprecated
 from sqlalchemy.testing.assertions import expect_warnings
 from sqlalchemy.testing.assertions import is_
+from sqlalchemy.testing.util import resolve_lambda
 from sqlalchemy.types import TypeEngine
 from sqlalchemy.util import OrderedDict
 
@@ -2766,6 +2767,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=dialect,
         )
 
+    @testing.combinations(
+        (lambda t: t.c.a**t.c.b, "power(t.a, t.b)", {}),
+        (lambda t: t.c.a**3, "power(t.a, %(pow_1)s)", {"pow_1": 3}),
+        (lambda t: func.pow(t.c.a, 3), "power(t.a, %(pow_1)s)", {"pow_1": 3}),
+        (lambda t: func.power(t.c.a, t.c.b), "power(t.a, t.b)", {}),
+    )
+    def test_simple_compile(self, fn, string, params):
+        t = table("t", column("a", Integer), column("b", Integer))
+        expr = resolve_lambda(fn, t=t)
+        self.assert_compile(expr, string, params)
+
 
 class InsertOnConflictTest(
     fixtures.TablesTest, AssertsCompiledSQL, fixtures.CacheKeySuite
index 6ed2c76d75050bdc24a5553924fba91ded711098..099301707fcd20f1d868c95cdabc476172dfb9d1 100644 (file)
@@ -2646,6 +2646,14 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         expr = column("bar", Integer()) // column("foo", Integer)
         assert isinstance(expr.type, Integer)
 
+    def test_power_operator(self):
+        expr = column("bar", Integer()) ** column("foo", Integer)
+        self.assert_compile(expr, "pow(bar, foo)")
+        expr = column("bar", Integer()) ** 42
+        self.assert_compile(expr, "pow(bar, :pow_1)", {"pow_1": 42})
+        expr = 99 ** column("bar", Integer())
+        self.assert_compile(expr, "pow(:pow_1, bar)", {"pow_1": 42})
+
 
 class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
index 800ed90a99060c89284268fcfc11fdee48405fd4..3660417887915a932c08df7c7ddf5698685c8653 100644 (file)
@@ -127,35 +127,41 @@ stmt19 = select(func.percent_rank())
 reveal_type(stmt19)
 
 
-stmt20 = select(func.rank())
+stmt20 = select(func.pow(column("x", Integer)))
 
 # EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt20)
 
 
-stmt21 = select(func.session_user())
+stmt21 = select(func.rank())
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
+# EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt21)
 
 
-stmt22 = select(func.sum(column("x", Integer)))
+stmt22 = select(func.session_user())
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
+# EXPECTED_RE_TYPE: .*Select\[.*str\]
 reveal_type(stmt22)
 
 
-stmt23 = select(func.sysdate())
+stmt23 = select(func.sum(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
+# EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt23)
 
 
-stmt24 = select(func.user())
+stmt24 = select(func.sysdate())
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
+# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
 reveal_type(stmt24)
 
+
+stmt25 = select(func.user())
+
+# EXPECTED_RE_TYPE: .*Select\[.*str\]
+reveal_type(stmt25)
+
 # END GENERATED FUNCTION TYPING TESTS
 
 stmt_count: Select[int, int, int] = select(