From: Aramís Segovia Date: Tue, 13 May 2025 20:18:11 +0000 (-0400) Subject: Support `matmul` (@) as an optional operator. X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=8bd314378c1d477761346433c441c4a0c8a5abde;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support `matmul` (@) as an optional operator. Allow custom operator systems to use the @ Python operator (#12479). ### Description Add a dummy implementation for the `__matmul__` operator rasing `NotImplementedError` by default. ### Checklist This pull request is: - [ ] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [X] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #12583 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12583 Pull-request-sha: 7e69d23610f39468b24c0a9a1ffdbdab20ae34fb Change-Id: Ia0d565decd437b940efd3b97478c16d7a0377bc6 --- diff --git a/doc/build/changelog/unreleased_21/12479.rst b/doc/build/changelog/unreleased_21/12479.rst new file mode 100644 index 0000000000..4cced479b1 --- /dev/null +++ b/doc/build/changelog/unreleased_21/12479.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: core, feature, sql + :tickets: 12479 + + The Core operator system now includes the `matmul` operator, i.e. the + @ operator in Python as an optional operator. diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index c1305be994..eba769f892 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -558,6 +558,7 @@ operator_lookup: Dict[ "getitem": (_getitem_impl, util.EMPTY_DICT), "lshift": (_unsupported_impl, util.EMPTY_DICT), "rshift": (_unsupported_impl, util.EMPTY_DICT), + "matmul": (_unsupported_impl, util.EMPTY_DICT), "contains": (_unsupported_impl, util.EMPTY_DICT), "regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT), "not_regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT), diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 42dfe61106..737d67b6b5 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -916,6 +916,14 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): def __lshift__(self, other: Any) -> ColumnElement[Any]: ... + @overload + def __rlshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... + + @overload + def __rlshift__(self, other: Any) -> ColumnElement[Any]: ... + + def __rlshift__(self, other: Any) -> ColumnElement[Any]: ... + @overload def __rshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... @@ -924,6 +932,18 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): def __rshift__(self, other: Any) -> ColumnElement[Any]: ... + @overload + def __rrshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... + + @overload + def __rrshift__(self, other: Any) -> ColumnElement[Any]: ... + + def __rrshift__(self, other: Any) -> ColumnElement[Any]: ... + + def __matmul__(self, other: Any) -> ColumnElement[Any]: ... + + def __rmatmul__(self, other: Any) -> ColumnElement[Any]: ... + @overload def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: ... diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 635e5712ad..7e751e13d0 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -25,6 +25,7 @@ from operator import inv as _uncast_inv from operator import le as _uncast_le from operator import lshift as _uncast_lshift from operator import lt as _uncast_lt +from operator import matmul as _uncast_matmul from operator import mod as _uncast_mod from operator import mul as _uncast_mul from operator import ne as _uncast_ne @@ -110,6 +111,7 @@ inv = cast(OperatorType, _uncast_inv) le = cast(OperatorType, _uncast_le) lshift = cast(OperatorType, _uncast_lshift) lt = cast(OperatorType, _uncast_lt) +matmul = cast(OperatorType, _uncast_matmul) mod = cast(OperatorType, _uncast_mod) mul = cast(OperatorType, _uncast_mul) ne = cast(OperatorType, _uncast_ne) @@ -661,7 +663,7 @@ class ColumnOperators(Operators): return self.operate(getitem, index) def __lshift__(self, other: Any) -> ColumnOperators: - """implement the << operator. + """Implement the ``<<`` operator. Not used by SQLAlchemy core, this is provided for custom operator systems which want to use @@ -669,8 +671,17 @@ class ColumnOperators(Operators): """ return self.operate(lshift, other) + def __rlshift__(self, other: Any) -> ColumnOperators: + """Implement the ``<<`` operator in reverse. + + Not used by SQLAlchemy core, this is provided + for custom operator systems which want to use + << as an extension point. + """ + return self.reverse_operate(lshift, other) + def __rshift__(self, other: Any) -> ColumnOperators: - """implement the >> operator. + """Implement the ``>>`` operator. Not used by SQLAlchemy core, this is provided for custom operator systems which want to use @@ -678,6 +689,33 @@ class ColumnOperators(Operators): """ return self.operate(rshift, other) + def __rrshift__(self, other: Any) -> ColumnOperators: + """Implement the ``>>`` operator in reverse. + + Not used by SQLAlchemy core, this is provided + for custom operator systems which want to use + >> as an extension point. + """ + return self.reverse_operate(rshift, other) + + def __matmul__(self, other: Any) -> ColumnOperators: + """Implement the ``@`` operator. + + Not used by SQLAlchemy core, this is provided + for custom operator systems which want to use + @ as an extension point. + """ + return self.operate(matmul, other) + + def __rmatmul__(self, other: Any) -> ColumnOperators: + """Implement the ``@`` operator in reverse. + + Not used by SQLAlchemy core, this is provided + for custom operator systems which want to use + @ as an extension point. + """ + return self.reverse_operate(matmul, other) + def concat(self, other: Any) -> ColumnOperators: """Implement the 'concat' operator. diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 099301707f..b78b3ac1f7 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -967,6 +967,16 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile(Column("x", MyType()) << 5, "x -> :x_1") + def test_rlshift(self): + class MyType(UserDefinedType): + cache_ok = True + + class comparator_factory(UserDefinedType.Comparator): + def __rlshift__(self, other): + return self.op("->")(other) + + self.assert_compile(5 << Column("x", MyType()), "x -> :x_1") + def test_rshift(self): class MyType(UserDefinedType): cache_ok = True @@ -977,6 +987,36 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile(Column("x", MyType()) >> 5, "x -> :x_1") + def test_rrshift(self): + class MyType(UserDefinedType): + cache_ok = True + + class comparator_factory(UserDefinedType.Comparator): + def __rrshift__(self, other): + return self.op("->")(other) + + self.assert_compile(5 >> Column("x", MyType()), "x -> :x_1") + + def test_matmul(self): + class MyType(UserDefinedType): + cache_ok = True + + class comparator_factory(UserDefinedType.Comparator): + def __matmul__(self, other): + return self.op("->")(other) + + self.assert_compile(Column("x", MyType()) @ 5, "x -> :x_1") + + def test_rmatmul(self): + class MyType(UserDefinedType): + cache_ok = True + + class comparator_factory(UserDefinedType.Comparator): + def __rmatmul__(self, other): + return self.op("->")(other) + + self.assert_compile(5 @ Column("x", MyType()), "x -> :x_1") + class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): def setup_test(self):