]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add `__rlshift__` and `__rrshift__` for symmetry
authorAramís Segovia <aramissegovia@gmail.com>
Mon, 12 May 2025 20:43:47 +0000 (16:43 -0400)
committerAramís Segovia <aramissegovia@gmail.com>
Mon, 12 May 2025 20:43:47 +0000 (16:43 -0400)
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/operators.py
test/sql/test_operators.py

index d4486f5e55fdb7d15dbb3e1d5c380525cb245f40..737d67b6b5b830f0423059d8ac9098c292adeb10 100644 (file)
@@ -916,6 +916,14 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
 
         def __lshift__(self, other: Any) -> ColumnElement[Any]: ...
 
+        @overload
+        def __rlshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ...
+
+        @overload
+        def __rlshift__(self, other: Any) -> ColumnElement[Any]: ...
+
+        def __rlshift__(self, other: Any) -> ColumnElement[Any]: ...
+
         @overload
         def __rshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ...
 
@@ -924,6 +932,14 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
 
         def __rshift__(self, other: Any) -> ColumnElement[Any]: ...
 
+        @overload
+        def __rrshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ...
+
+        @overload
+        def __rrshift__(self, other: Any) -> ColumnElement[Any]: ...
+
+        def __rrshift__(self, other: Any) -> ColumnElement[Any]: ...
+
         def __matmul__(self, other: Any) -> ColumnElement[Any]: ...
 
         def __rmatmul__(self, other: Any) -> ColumnElement[Any]: ...
index 1a3e24711b247f3e7b31424c6da87f88e9ba2c66..7c36bdc46963441b2ed46e5f65b77f5e39d8b06c 100644 (file)
@@ -663,7 +663,7 @@ class ColumnOperators(Operators):
         return self.operate(getitem, index)
 
     def __lshift__(self, other: Any) -> ColumnOperators:
-        """implement the << operator.
+        """Implement the ``<<`` operator.
 
         Not used by SQLAlchemy core, this is provided
         for custom operator systems which want to use
@@ -671,8 +671,17 @@ class ColumnOperators(Operators):
         """
         return self.operate(lshift, other)
 
+    def __rlshift__(self, other: Any) -> ColumnOperators:
+        """Implement the ``<<`` operator.
+
+        Not used by SQLAlchemy core, this is provided
+        for custom operator systems which want to use
+        << as an extension point.
+        """
+        return self.reverse_operate(lshift, other)
+
     def __rshift__(self, other: Any) -> ColumnOperators:
-        """implement the >> operator.
+        """Implement the ``>>`` operator.
 
         Not used by SQLAlchemy core, this is provided
         for custom operator systems which want to use
@@ -680,6 +689,15 @@ class ColumnOperators(Operators):
         """
         return self.operate(rshift, other)
 
+    def __rrshift__(self, other: Any) -> ColumnOperators:
+        """Implement the ``>>`` operator in reverse.
+
+        Not used by SQLAlchemy core, this is provided
+        for custom operator systems which want to use
+        >> as an extension point.
+        """
+        return self.reverse_operate(rshift, other)
+
     def __matmul__(self, other: Any) -> ColumnOperators:
         """Implement the ``@`` operator.
 
index be54cec8e5e49f836c539f68198e283c633b35d7..b78b3ac1f7626b6d0084b39096887908fb292b46 100644 (file)
@@ -967,6 +967,16 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(Column("x", MyType()) << 5, "x -> :x_1")
 
+    def test_rlshift(self):
+        class MyType(UserDefinedType):
+            cache_ok = True
+
+            class comparator_factory(UserDefinedType.Comparator):
+                def __rlshift__(self, other):
+                    return self.op("->")(other)
+
+        self.assert_compile(5 << Column("x", MyType()), "x -> :x_1")
+
     def test_rshift(self):
         class MyType(UserDefinedType):
             cache_ok = True
@@ -977,6 +987,16 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(Column("x", MyType()) >> 5, "x -> :x_1")
 
+    def test_rrshift(self):
+        class MyType(UserDefinedType):
+            cache_ok = True
+
+            class comparator_factory(UserDefinedType.Comparator):
+                def __rrshift__(self, other):
+                    return self.op("->")(other)
+
+        self.assert_compile(5 >> Column("x", MyType()), "x -> :x_1")
+
     def test_matmul(self):
         class MyType(UserDefinedType):
             cache_ok = True