]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add `__rmatmul__` for symmetry
authorAramís Segovia <aramissegovia@gmail.com>
Mon, 12 May 2025 19:59:26 +0000 (15:59 -0400)
committerAramís Segovia <aramissegovia@gmail.com>
Mon, 12 May 2025 19:59:26 +0000 (15:59 -0400)
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/operators.py
test/sql/test_operators.py

index 75bf72f03363b8a95a9a266476e68e27f0ffee02..d4486f5e55fdb7d15dbb3e1d5c380525cb245f40 100644 (file)
@@ -926,6 +926,8 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
 
         def __matmul__(self, other: Any) -> ColumnElement[Any]: ...
 
+        def __rmatmul__(self, other: Any) -> ColumnElement[Any]: ...
+
         @overload
         def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: ...
 
index ed0582d309d71df57a5be2c797dd96635a3e2200..1a3e24711b247f3e7b31424c6da87f88e9ba2c66 100644 (file)
@@ -681,7 +681,7 @@ class ColumnOperators(Operators):
         return self.operate(rshift, other)
 
     def __matmul__(self, other: Any) -> ColumnOperators:
-        """Implement the @ operator.
+        """Implement the ``@`` operator.
 
         Not used by SQLAlchemy core, this is provided
         for custom operator systems which want to use
@@ -689,6 +689,15 @@ class ColumnOperators(Operators):
         """
         return self.operate(matmul, other)
 
+    def __rmatmul__(self, other: Any) -> ColumnOperators:
+        """Implement the ``@`` operator in reverse.
+
+        Not used by SQLAlchemy core, this is provided
+        for custom operator systems which want to use
+        @ as an extension point.
+        """
+        return self.reverse_operate(matmul, other)
+
     def concat(self, other: Any) -> ColumnOperators:
         """Implement the 'concat' operator.
 
index 753a2e3c1bde8882a375afe5f6220c8c89347367..be54cec8e5e49f836c539f68198e283c633b35d7 100644 (file)
@@ -987,6 +987,16 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(Column("x", MyType()) @ 5, "x -> :x_1")
 
+    def test_rmatmul(self):
+        class MyType(UserDefinedType):
+            cache_ok = True
+
+            class comparator_factory(UserDefinedType.Comparator):
+                def __rmatmul__(self, other):
+                    return self.op("->")(other)
+
+        self.assert_compile(5 @ Column("x", MyType()), "x -> :x_1")
+
 
 class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     def setup_test(self):