]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support sql elements via standalone op functions
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Jul 2023 16:04:01 +0000 (12:04 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Jul 2023 18:00:59 +0000 (14:00 -0400)
Improved typing when using standalone operator functions from
``sqlalchemy.sql.operators`` such as ``sqlalchemy.sql.operators.eq``.

Fixes: #10054
Change-Id: I7e39cb3ccddb354a3f04f749ef65b18088e136e1

doc/build/changelog/unreleased_20/10054.rst [new file with mode: 0644]
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/operators.py
test/typing/plain_files/sql/operators.py

diff --git a/doc/build/changelog/unreleased_20/10054.rst b/doc/build/changelog/unreleased_20/10054.rst
new file mode 100644 (file)
index 0000000..1e72273
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: usecase, typing
+    :tickets: 10054
+
+    Improved typing when using standalone operator functions from
+    ``sqlalchemy.sql.operators`` such as ``sqlalchemy.sql.operators.eq``.
index b3129d4bf274aa344f849dd3181c4e49a3c3ccda..b82f1bc699ac31d0f531a1452f62780c901e8223 100644 (file)
@@ -447,12 +447,12 @@ class QueryableAttribute(
     def operate(
         self, op: OperatorType, *other: Any, **kwargs: Any
     ) -> ColumnElement[Any]:
-        return op(self.comparator, *other, **kwargs)  # type: ignore[return-value]  # noqa: E501
+        return op(self.comparator, *other, **kwargs)  # type: ignore[return-value,no-any-return]  # noqa: E501
 
     def reverse_operate(
         self, op: OperatorType, other: Any, **kwargs: Any
     ) -> ColumnElement[Any]:
-        return op(other, self.comparator, **kwargs)  # type: ignore[return-value]  # noqa: E501
+        return op(other, self.comparator, **kwargs)  # type: ignore[return-value,no-any-return]  # noqa: E501
 
     def hasparent(
         self, state: InstanceState[Any], optimistic: bool = False
index e215e061d5ee68cec4eeb192fa389efc82998a34..4df5175d07ec0cdcc2638aff45aadae358a015b2 100644 (file)
@@ -475,13 +475,13 @@ class ColumnProperty(
         def operate(
             self, op: OperatorType, *other: Any, **kwargs: Any
         ) -> ColumnElement[Any]:
-            return op(self.__clause_element__(), *other, **kwargs)  # type: ignore[return-value]  # noqa: E501
+            return op(self.__clause_element__(), *other, **kwargs)  # type: ignore[return-value,no-any-return]  # noqa: E501
 
         def reverse_operate(
             self, op: OperatorType, other: Any, **kwargs: Any
         ) -> ColumnElement[Any]:
             col = self.__clause_element__()
-            return op(col._bind_param(op, other), col, **kwargs)  # type: ignore[return-value]  # noqa: E501
+            return op(col._bind_param(op, other), col, **kwargs)  # type: ignore[return-value,no-any-return]  # noqa: E501
 
     def __str__(self) -> str:
         if not self.parent or not self.key:
@@ -639,13 +639,13 @@ class MappedColumn(
     def operate(
         self, op: OperatorType, *other: Any, **kwargs: Any
     ) -> ColumnElement[Any]:
-        return op(self.__clause_element__(), *other, **kwargs)  # type: ignore[return-value]  # noqa: E501
+        return op(self.__clause_element__(), *other, **kwargs)  # type: ignore[return-value,no-any-return]  # noqa: E501
 
     def reverse_operate(
         self, op: OperatorType, other: Any, **kwargs: Any
     ) -> ColumnElement[Any]:
         col = self.__clause_element__()
-        return op(col._bind_param(op, other), col, **kwargs)  # type: ignore[return-value]  # noqa: E501
+        return op(col._bind_param(op, other), col, **kwargs)  # type: ignore[return-value,no-any-return]  # noqa: E501
 
     def found_in_pep593_annotated(self) -> Any:
         # return a blank mapped_column().  This mapped_column()'s
index 54d876a8016001b946581fd3dd2a870a22270863..ba074db80c6ec11de4031f8f5ca0452883622387 100644 (file)
@@ -1535,12 +1535,12 @@ class ColumnElement(
         *other: Any,
         **kwargs: Any,
     ) -> ColumnElement[Any]:
-        return op(self.comparator, *other, **kwargs)  # type: ignore[return-value]  # noqa: E501
+        return op(self.comparator, *other, **kwargs)  # type: ignore[return-value,no-any-return]  # noqa: E501
 
     def reverse_operate(
         self, op: operators.OperatorType, other: Any, **kwargs: Any
     ) -> ColumnElement[Any]:
-        return op(other, self.comparator, **kwargs)  # type: ignore[return-value]  # noqa: E501
+        return op(other, self.comparator, **kwargs)  # type: ignore[return-value,no-any-return]  # noqa: E501
 
     def _bind_param(
         self,
index 352e5b62df0b07c5d8fa749e24a49b4cfe012baa..cff465972b9ba231fd7279d4a67c5314f7d0dff6 100644 (file)
@@ -40,6 +40,7 @@ from typing import cast
 from typing import Dict
 from typing import Generic
 from typing import Optional
+from typing import overload
 from typing import Set
 from typing import Tuple
 from typing import Type
@@ -53,7 +54,9 @@ from ..util.typing import Literal
 from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
+    from ._typing import ColumnExpressionArgument
     from .cache_key import CacheConst
+    from .elements import ColumnElement
     from .type_api import TypeEngine
 
 _T = TypeVar("_T", bound=Any)
@@ -67,6 +70,17 @@ class OperatorType(Protocol):
 
     __name__: str
 
+    @overload
+    def __call__(
+        self,
+        left: ColumnExpressionArgument[Any],
+        right: Optional[Any] = None,
+        *other: Any,
+        **kwargs: Any,
+    ) -> ColumnElement[Any]:
+        ...
+
+    @overload
     def __call__(
         self,
         left: Operators,
@@ -76,6 +90,15 @@ class OperatorType(Protocol):
     ) -> Operators:
         ...
 
+    def __call__(
+        self,
+        left: Any,
+        right: Optional[Any] = None,
+        *other: Any,
+        **kwargs: Any,
+    ) -> Operators:
+        ...
+
 
 add = cast(OperatorType, _uncast_add)
 and_ = cast(OperatorType, _uncast_and_)
@@ -436,15 +459,35 @@ class custom_op(OperatorType, Generic[_T]):
             self.return_type._static_cache_key if self.return_type else None,
         )
 
+    @overload
+    def __call__(
+        self,
+        left: ColumnExpressionArgument[Any],
+        right: Optional[Any] = None,
+        *other: Any,
+        **kwargs: Any,
+    ) -> ColumnElement[Any]:
+        ...
+
+    @overload
     def __call__(
         self,
         left: Operators,
         right: Optional[Any] = None,
         *other: Any,
         **kwargs: Any,
+    ) -> Operators:
+        ...
+
+    def __call__(
+        self,
+        left: Any,
+        right: Optional[Any] = None,
+        *other: Any,
+        **kwargs: Any,
     ) -> Operators:
         if hasattr(left, "__sa_operate__"):
-            return left.operate(self, right, *other, **kwargs)
+            return left.operate(self, right, *other, **kwargs)  # type: ignore
         elif self.python_impl:
             return self.python_impl(left, right, *other, **kwargs)  # type: ignore  # noqa: E501
         else:
index fea8daa6a26da8a67bce9d6620cacaf79eb7bd52..8258ec65b1fe10cb4cedfe9282ff6845f251f894 100644 (file)
@@ -7,10 +7,12 @@ from sqlalchemy import BigInteger
 from sqlalchemy import column
 from sqlalchemy import ColumnElement
 from sqlalchemy import Integer
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql import operators
 
 
 class Base(DeclarativeBase):
@@ -136,3 +138,8 @@ op_b: "ColumnElement[int]" = col.op("&", return_type=Integer)(1)
 op_c: "ColumnElement[str]" = col.op("&", return_type=String)("1")
 op_d: "ColumnElement[int]" = col.op("&", return_type=BigInteger)("1")
 op_e: "ColumnElement[bool]" = col.bool_op("&")("1")
+
+
+# op functions
+t1 = operators.eq(A.id, 1)
+select().where(t1)