]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support `matmul` (@) as an optional operator.
authorAramís Segovia <aramissegovia@gmail.com>
Tue, 13 May 2025 20:18:11 +0000 (16:18 -0400)
committersqla-tester <sqla-tester@sqlalchemy.org>
Tue, 13 May 2025 20:18:11 +0000 (16:18 -0400)
Allow custom operator systems to use the @ Python operator (#12479).

### Description
Add a dummy implementation for the  `__matmul__` operator rasing `NotImplementedError` by default.

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [X] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #12583
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12583
Pull-request-sha: 7e69d23610f39468b24c0a9a1ffdbdab20ae34fb

Change-Id: Ia0d565decd437b940efd3b97478c16d7a0377bc6

doc/build/changelog/unreleased_21/12479.rst [new file with mode: 0644]
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/operators.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_21/12479.rst b/doc/build/changelog/unreleased_21/12479.rst
new file mode 100644 (file)
index 0000000..4cced47
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: core, feature, sql
+    :tickets: 12479
+
+      The Core operator system now includes the `matmul` operator, i.e. the
+      @ operator in Python as an optional operator.
index c1305be99470650eda61bd053f438ae48feb594b..eba769f892af446d9293fbd39b8eb17755eef5b1 100644 (file)
@@ -558,6 +558,7 @@ operator_lookup: Dict[
     "getitem": (_getitem_impl, util.EMPTY_DICT),
     "lshift": (_unsupported_impl, util.EMPTY_DICT),
     "rshift": (_unsupported_impl, util.EMPTY_DICT),
+    "matmul": (_unsupported_impl, util.EMPTY_DICT),
     "contains": (_unsupported_impl, util.EMPTY_DICT),
     "regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT),
     "not_regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT),
index 42dfe6110647f1636bbf3649a9ddc6429c69d422..737d67b6b5b830f0423059d8ac9098c292adeb10 100644 (file)
@@ -916,6 +916,14 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
 
         def __lshift__(self, other: Any) -> ColumnElement[Any]: ...
 
+        @overload
+        def __rlshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ...
+
+        @overload
+        def __rlshift__(self, other: Any) -> ColumnElement[Any]: ...
+
+        def __rlshift__(self, other: Any) -> ColumnElement[Any]: ...
+
         @overload
         def __rshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ...
 
@@ -924,6 +932,18 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
 
         def __rshift__(self, other: Any) -> ColumnElement[Any]: ...
 
+        @overload
+        def __rrshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ...
+
+        @overload
+        def __rrshift__(self, other: Any) -> ColumnElement[Any]: ...
+
+        def __rrshift__(self, other: Any) -> ColumnElement[Any]: ...
+
+        def __matmul__(self, other: Any) -> ColumnElement[Any]: ...
+
+        def __rmatmul__(self, other: Any) -> ColumnElement[Any]: ...
+
         @overload
         def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: ...
 
index 635e5712ad57b2d6abcb12e825f5f20944ca9181..7e751e13d08d5aa7f41292867ce15af6ad81b986 100644 (file)
@@ -25,6 +25,7 @@ from operator import inv as _uncast_inv
 from operator import le as _uncast_le
 from operator import lshift as _uncast_lshift
 from operator import lt as _uncast_lt
+from operator import matmul as _uncast_matmul
 from operator import mod as _uncast_mod
 from operator import mul as _uncast_mul
 from operator import ne as _uncast_ne
@@ -110,6 +111,7 @@ inv = cast(OperatorType, _uncast_inv)
 le = cast(OperatorType, _uncast_le)
 lshift = cast(OperatorType, _uncast_lshift)
 lt = cast(OperatorType, _uncast_lt)
+matmul = cast(OperatorType, _uncast_matmul)
 mod = cast(OperatorType, _uncast_mod)
 mul = cast(OperatorType, _uncast_mul)
 ne = cast(OperatorType, _uncast_ne)
@@ -661,7 +663,7 @@ class ColumnOperators(Operators):
         return self.operate(getitem, index)
 
     def __lshift__(self, other: Any) -> ColumnOperators:
-        """implement the << operator.
+        """Implement the ``<<`` operator.
 
         Not used by SQLAlchemy core, this is provided
         for custom operator systems which want to use
@@ -669,8 +671,17 @@ class ColumnOperators(Operators):
         """
         return self.operate(lshift, other)
 
+    def __rlshift__(self, other: Any) -> ColumnOperators:
+        """Implement the ``<<`` operator in reverse.
+
+        Not used by SQLAlchemy core, this is provided
+        for custom operator systems which want to use
+        << as an extension point.
+        """
+        return self.reverse_operate(lshift, other)
+
     def __rshift__(self, other: Any) -> ColumnOperators:
-        """implement the >> operator.
+        """Implement the ``>>`` operator.
 
         Not used by SQLAlchemy core, this is provided
         for custom operator systems which want to use
@@ -678,6 +689,33 @@ class ColumnOperators(Operators):
         """
         return self.operate(rshift, other)
 
+    def __rrshift__(self, other: Any) -> ColumnOperators:
+        """Implement the ``>>`` operator in reverse.
+
+        Not used by SQLAlchemy core, this is provided
+        for custom operator systems which want to use
+        >> as an extension point.
+        """
+        return self.reverse_operate(rshift, other)
+
+    def __matmul__(self, other: Any) -> ColumnOperators:
+        """Implement the ``@`` operator.
+
+        Not used by SQLAlchemy core, this is provided
+        for custom operator systems which want to use
+        @ as an extension point.
+        """
+        return self.operate(matmul, other)
+
+    def __rmatmul__(self, other: Any) -> ColumnOperators:
+        """Implement the ``@`` operator in reverse.
+
+        Not used by SQLAlchemy core, this is provided
+        for custom operator systems which want to use
+        @ as an extension point.
+        """
+        return self.reverse_operate(matmul, other)
+
     def concat(self, other: Any) -> ColumnOperators:
         """Implement the 'concat' operator.
 
index 099301707fcd20f1d868c95cdabc476172dfb9d1..b78b3ac1f7626b6d0084b39096887908fb292b46 100644 (file)
@@ -967,6 +967,16 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(Column("x", MyType()) << 5, "x -> :x_1")
 
+    def test_rlshift(self):
+        class MyType(UserDefinedType):
+            cache_ok = True
+
+            class comparator_factory(UserDefinedType.Comparator):
+                def __rlshift__(self, other):
+                    return self.op("->")(other)
+
+        self.assert_compile(5 << Column("x", MyType()), "x -> :x_1")
+
     def test_rshift(self):
         class MyType(UserDefinedType):
             cache_ok = True
@@ -977,6 +987,36 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(Column("x", MyType()) >> 5, "x -> :x_1")
 
+    def test_rrshift(self):
+        class MyType(UserDefinedType):
+            cache_ok = True
+
+            class comparator_factory(UserDefinedType.Comparator):
+                def __rrshift__(self, other):
+                    return self.op("->")(other)
+
+        self.assert_compile(5 >> Column("x", MyType()), "x -> :x_1")
+
+    def test_matmul(self):
+        class MyType(UserDefinedType):
+            cache_ok = True
+
+            class comparator_factory(UserDefinedType.Comparator):
+                def __matmul__(self, other):
+                    return self.op("->")(other)
+
+        self.assert_compile(Column("x", MyType()) @ 5, "x -> :x_1")
+
+    def test_rmatmul(self):
+        class MyType(UserDefinedType):
+            cache_ok = True
+
+            class comparator_factory(UserDefinedType.Comparator):
+                def __rmatmul__(self, other):
+                    return self.op("->")(other)
+
+        self.assert_compile(5 @ Column("x", MyType()), "x -> :x_1")
+
 
 class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     def setup_test(self):