From: Federico Caselli Date: Fri, 4 Apr 2025 20:23:31 +0000 (+0200) Subject: Add pow operator support X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=571bb909320b6285fd3839fb52111c241a3ea8c4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add pow operator support Added support for the pow operator (``**``), with a default SQL implementation of the ``POW()`` function. On Oracle Database, PostgreSQL and MSSQL it renders as ``POWER()``. As part of this change, the operator routes through a new first class ``func`` member :class:`_functions.pow`, which renders on Oracle Database, PostgreSQL and MSSQL as ``POWER()``. Fixes: #8579 Closes: #8580 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8580 Pull-request-sha: 041b2ef474a291c6b6172e49cc6e0d548e28761a Change-Id: I371bd44ed3e58f2d55ef705aeec7d04710c97f23 --- diff --git a/doc/build/changelog/unreleased_21/8579.rst b/doc/build/changelog/unreleased_21/8579.rst new file mode 100644 index 0000000000..57fe7c91f2 --- /dev/null +++ b/doc/build/changelog/unreleased_21/8579.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: usecase, sql + :tickets: 8579 + + Added support for the pow operator (``**``), with a default SQL + implementation of the ``POW()`` function. On Oracle Database, PostgreSQL + and MSSQL it renders as ``POWER()``. As part of this change, the operator + routes through a new first class ``func`` member :class:`_functions.pow`, + which renders on Oracle Database, PostgreSQL and MSSQL as ``POWER()``. diff --git a/doc/build/core/functions.rst b/doc/build/core/functions.rst index 9771ffeedd..26c59a0bdd 100644 --- a/doc/build/core/functions.rst +++ b/doc/build/core/functions.rst @@ -124,6 +124,9 @@ return types are in use. .. autoclass:: percentile_disc :no-members: +.. autoclass:: pow + :no-members: + .. autoclass:: random :no-members: diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 24425fc817..8c8e7f9c47 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2040,6 +2040,9 @@ class MSSQLCompiler(compiler.SQLCompiler): delimeter = fn.clauses.clauses[1]._compiler_dispatch(self, **kw) return f"string_agg({expr}, {delimeter})" + def visit_pow_func(self, fn, **kw): + return f"POWER{self.function_argspec(fn)}" + def visit_concat_op_expression_clauselist( self, clauselist, operator, **kw ): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 69af577d56..c32dff2ea1 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1021,6 +1021,9 @@ class OracleCompiler(compiler.SQLCompiler): def visit_char_length_func(self, fn, **kw): return "LENGTH" + self.function_argspec(fn, **kw) + def visit_pow_func(self, fn, **kw): + return f"POWER{self.function_argspec(fn)}" + def visit_match_op_binary(self, binary, operator, **kw): return "CONTAINS (%s, %s)" % ( self.process(binary.left), diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 864445026b..32024f7d98 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2010,6 +2010,9 @@ class PGCompiler(compiler.SQLCompiler): def visit_aggregate_strings_func(self, fn, **kw): return "string_agg%s" % self.function_argspec(fn) + def visit_pow_func(self, fn, **kw): + return f"power{self.function_argspec(fn)}" + def visit_sequence(self, seq, **kw): return "nextval('%s')" % self.preparer.format_sequence(seq) diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 7fa5dafe9c..c1305be994 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -5,8 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Default implementation of SQL comparison operations. -""" +"""Default implementation of SQL comparison operations.""" from __future__ import annotations @@ -21,6 +20,7 @@ from typing import Type from typing import Union from . import coercions +from . import functions from . import operators from . import roles from . import type_api @@ -351,6 +351,19 @@ def _between_impl( ) +def _pow_impl( + expr: ColumnElement[Any], + op: OperatorType, + other: Any, + reverse: bool = False, + **kw: Any, +) -> ColumnElement[Any]: + if reverse: + return functions.pow(other, expr) + else: + return functions.pow(expr, other) + + def _collate_impl( expr: ColumnElement[str], op: OperatorType, collation: str, **kw: Any ) -> ColumnElement[str]: @@ -549,4 +562,5 @@ operator_lookup: Dict[ "regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT), "not_regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT), "regexp_replace_op": (_regexp_replace_impl, util.EMPTY_DICT), + "pow": (_pow_impl, util.EMPTY_DICT), } diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index cd1a20a708..050f94fd80 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1199,6 +1199,42 @@ class _FunctionGenerator: @property def percentile_disc(self) -> Type[percentile_disc[Any]]: ... + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 + + @overload + def pow( # noqa: A001 + self, + col: ColumnElement[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> pow[_T]: ... + + @overload + def pow( # noqa: A001 + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> pow[_T]: ... + + @overload + def pow( # noqa: A001 + self, + col: _T, + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> pow[_T]: ... + + def pow( # noqa: A001 + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> pow[_T]: ... + @property def random(self) -> Type[random]: ... @@ -1690,6 +1726,23 @@ class now(GenericFunction[datetime.datetime]): inherit_cache = True +class pow(ReturnTypeFromArgs[_T]): # noqa: A001 + """The SQL POW() function which performs the power operator. + + E.g.: + + .. sourcecode:: pycon+sql + + >>> print(select(func.pow(2, 8))) + {printsql}SELECT pow(:pow_2, :pow_3) AS pow_1 + + .. versionadded:: 2.1 + + """ + + inherit_cache = True + + class concat(GenericFunction[str]): """The SQL CONCAT() function, which concatenates strings. diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index f93864478f..635e5712ad 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -30,6 +30,7 @@ from operator import mul as _uncast_mul from operator import ne as _uncast_ne from operator import neg as _uncast_neg from operator import or_ as _uncast_or_ +from operator import pow as _uncast_pow from operator import rshift as _uncast_rshift from operator import sub as _uncast_sub from operator import truediv as _uncast_truediv @@ -114,6 +115,7 @@ mul = cast(OperatorType, _uncast_mul) ne = cast(OperatorType, _uncast_ne) neg = cast(OperatorType, _uncast_neg) or_ = cast(OperatorType, _uncast_or_) +pow_ = cast(OperatorType, _uncast_pow) rshift = cast(OperatorType, _uncast_rshift) sub = cast(OperatorType, _uncast_sub) truediv = cast(OperatorType, _uncast_truediv) @@ -1938,6 +1940,29 @@ class ColumnOperators(Operators): """ return self.reverse_operate(floordiv, other) + def __pow__(self, other: Any) -> ColumnOperators: + """Implement the ``**`` operator. + + In a column context, produces the clause ``pow(a, b)``, or a similar + dialect-specific expression. + + .. versionadded:: 2.1 + + """ + return self.operate(pow_, other) + + def __rpow__(self, other: Any) -> ColumnOperators: + """Implement the ``**`` operator in reverse. + + .. seealso:: + + :meth:`.ColumnOperators.__pow__`. + + .. versionadded:: 2.1 + + """ + return self.reverse_operate(pow_, other) + _commutative: Set[Any] = {eq, ne, add, mul} _comparison: Set[Any] = {eq, ne, lt, gt, ge, le} @@ -2541,6 +2566,7 @@ _PRECEDENCE: Dict[OperatorType, int] = { getitem: 15, json_getitem_op: 15, json_path_getitem_op: 15, + pow_: 15, mul: 8, truediv: 8, floordiv: 8, diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index eb4dba0a07..627738f713 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -32,9 +32,10 @@ from sqlalchemy.sql import table from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import eq_ignore_whitespace from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ -from sqlalchemy.testing.assertions import eq_ignore_whitespace +from sqlalchemy.testing import resolve_lambda from sqlalchemy.types import TypeEngine tbl = table("t", column("a")) @@ -1850,6 +1851,25 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): with testing.expect_raises_message(exc.CompileError, error): print(stmt.compile(dialect=self.__dialect__)) + @testing.combinations( + (lambda t: t.c.a**t.c.b, "POWER(t.a, t.b)", {}), + (lambda t: t.c.a**3, "POWER(t.a, :pow_1)", {"pow_1": 3}), + (lambda t: t.c.c.match(t.c.d), "CONTAINS (t.c, t.d)", {}), + (lambda t: t.c.c.match("w"), "CONTAINS (t.c, :c_1)", {"c_1": "w"}), + (lambda t: func.pow(t.c.a, 3), "POWER(t.a, :pow_1)", {"pow_1": 3}), + (lambda t: func.power(t.c.a, t.c.b), "power(t.a, t.b)", {}), + ) + def test_simple_compile(self, fn, string, params): + t = table( + "t", + column("a", Integer), + column("b", Integer), + column("c", String), + column("d", String), + ) + expr = resolve_lambda(fn, t=t) + self.assert_compile(expr, string, params) + class CompileIdentityTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = mssql.dialect() diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 0ab5052a1f..c7f4a0c492 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -43,6 +43,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import eq_ignore_whitespace from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table +from sqlalchemy.testing.util import resolve_lambda from sqlalchemy.types import TypeEngine @@ -1679,6 +1680,25 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): f"CREATE TABLE table1 (x INTEGER) {expected_sql}", ) + @testing.combinations( + (lambda t: t.c.a**t.c.b, "POWER(t.a, t.b)", {}), + (lambda t: t.c.a**3, "POWER(t.a, :pow_1)", {"pow_1": 3}), + (lambda t: t.c.c.match(t.c.d), "CONTAINS (t.c, t.d)", {}), + (lambda t: t.c.c.match("w"), "CONTAINS (t.c, :c_1)", {"c_1": "w"}), + (lambda t: func.pow(t.c.a, 3), "POWER(t.a, :pow_1)", {"pow_1": 3}), + (lambda t: func.power(t.c.a, t.c.b), "power(t.a, t.b)", {}), + ) + def test_simple_compile(self, fn, string, params): + t = table( + "t", + column("a", Integer), + column("b", Integer), + column("c", String), + column("d", String), + ) + expr = resolve_lambda(fn, t=t) + self.assert_compile(expr, string, params) + class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): def test_basic(self): diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index eda9f96662..f98ea9645b 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -79,6 +79,7 @@ from sqlalchemy.testing.assertions import eq_ignore_whitespace from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.testing.assertions import expect_warnings from sqlalchemy.testing.assertions import is_ +from sqlalchemy.testing.util import resolve_lambda from sqlalchemy.types import TypeEngine from sqlalchemy.util import OrderedDict @@ -2766,6 +2767,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): dialect=dialect, ) + @testing.combinations( + (lambda t: t.c.a**t.c.b, "power(t.a, t.b)", {}), + (lambda t: t.c.a**3, "power(t.a, %(pow_1)s)", {"pow_1": 3}), + (lambda t: func.pow(t.c.a, 3), "power(t.a, %(pow_1)s)", {"pow_1": 3}), + (lambda t: func.power(t.c.a, t.c.b), "power(t.a, t.b)", {}), + ) + def test_simple_compile(self, fn, string, params): + t = table("t", column("a", Integer), column("b", Integer)) + expr = resolve_lambda(fn, t=t) + self.assert_compile(expr, string, params) + class InsertOnConflictTest( fixtures.TablesTest, AssertsCompiledSQL, fixtures.CacheKeySuite diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 6ed2c76d75..099301707f 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -2646,6 +2646,14 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): expr = column("bar", Integer()) // column("foo", Integer) assert isinstance(expr.type, Integer) + def test_power_operator(self): + expr = column("bar", Integer()) ** column("foo", Integer) + self.assert_compile(expr, "pow(bar, foo)") + expr = column("bar", Integer()) ** 42 + self.assert_compile(expr, "pow(bar, :pow_1)", {"pow_1": 42}) + expr = 99 ** column("bar", Integer()) + self.assert_compile(expr, "pow(:pow_1, bar)", {"pow_1": 42}) + class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" diff --git a/test/typing/plain_files/sql/functions.py b/test/typing/plain_files/sql/functions.py index 800ed90a99..3660417887 100644 --- a/test/typing/plain_files/sql/functions.py +++ b/test/typing/plain_files/sql/functions.py @@ -127,35 +127,41 @@ stmt19 = select(func.percent_rank()) reveal_type(stmt19) -stmt20 = select(func.rank()) +stmt20 = select(func.pow(column("x", Integer))) # EXPECTED_RE_TYPE: .*Select\[.*int\] reveal_type(stmt20) -stmt21 = select(func.session_user()) +stmt21 = select(func.rank()) -# EXPECTED_RE_TYPE: .*Select\[.*str\] +# EXPECTED_RE_TYPE: .*Select\[.*int\] reveal_type(stmt21) -stmt22 = select(func.sum(column("x", Integer))) +stmt22 = select(func.session_user()) -# EXPECTED_RE_TYPE: .*Select\[.*int\] +# EXPECTED_RE_TYPE: .*Select\[.*str\] reveal_type(stmt22) -stmt23 = select(func.sysdate()) +stmt23 = select(func.sum(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[.*datetime\] +# EXPECTED_RE_TYPE: .*Select\[.*int\] reveal_type(stmt23) -stmt24 = select(func.user()) +stmt24 = select(func.sysdate()) -# EXPECTED_RE_TYPE: .*Select\[.*str\] +# EXPECTED_RE_TYPE: .*Select\[.*datetime\] reveal_type(stmt24) + +stmt25 = select(func.user()) + +# EXPECTED_RE_TYPE: .*Select\[.*str\] +reveal_type(stmt25) + # END GENERATED FUNCTION TYPING TESTS stmt_count: Select[int, int, int] = select(