From: Mike Bayer Date: Mon, 3 Jul 2023 16:04:01 +0000 (-0400) Subject: support sql elements via standalone op functions X-Git-Tag: rel_2_0_18~13^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=784874d74b94bfe977212b6bfcabacaf12942ebe;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git support sql elements via standalone op functions Improved typing when using standalone operator functions from ``sqlalchemy.sql.operators`` such as ``sqlalchemy.sql.operators.eq``. Fixes: #10054 Change-Id: I7e39cb3ccddb354a3f04f749ef65b18088e136e1 --- diff --git a/doc/build/changelog/unreleased_20/10054.rst b/doc/build/changelog/unreleased_20/10054.rst new file mode 100644 index 0000000000..1e72273ecf --- /dev/null +++ b/doc/build/changelog/unreleased_20/10054.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase, typing + :tickets: 10054 + + Improved typing when using standalone operator functions from + ``sqlalchemy.sql.operators`` such as ``sqlalchemy.sql.operators.eq``. diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index b3129d4bf2..b82f1bc699 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -447,12 +447,12 @@ class QueryableAttribute( def operate( self, op: OperatorType, *other: Any, **kwargs: Any ) -> ColumnElement[Any]: - return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501 + return op(self.comparator, *other, **kwargs) # type: ignore[return-value,no-any-return] # noqa: E501 def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any ) -> ColumnElement[Any]: - return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501 + return op(other, self.comparator, **kwargs) # type: ignore[return-value,no-any-return] # noqa: E501 def hasparent( self, state: InstanceState[Any], optimistic: bool = False diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index e215e061d5..4df5175d07 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -475,13 +475,13 @@ class ColumnProperty( def operate( self, op: OperatorType, *other: Any, **kwargs: Any ) -> ColumnElement[Any]: - return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501 + return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value,no-any-return] # noqa: E501 def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any ) -> ColumnElement[Any]: col = self.__clause_element__() - return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501 + return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value,no-any-return] # noqa: E501 def __str__(self) -> str: if not self.parent or not self.key: @@ -639,13 +639,13 @@ class MappedColumn( def operate( self, op: OperatorType, *other: Any, **kwargs: Any ) -> ColumnElement[Any]: - return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value] # noqa: E501 + return op(self.__clause_element__(), *other, **kwargs) # type: ignore[return-value,no-any-return] # noqa: E501 def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any ) -> ColumnElement[Any]: col = self.__clause_element__() - return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501 + return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value,no-any-return] # noqa: E501 def found_in_pep593_annotated(self) -> Any: # return a blank mapped_column(). This mapped_column()'s diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 54d876a801..ba074db80c 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1535,12 +1535,12 @@ class ColumnElement( *other: Any, **kwargs: Any, ) -> ColumnElement[Any]: - return op(self.comparator, *other, **kwargs) # type: ignore[return-value] # noqa: E501 + return op(self.comparator, *other, **kwargs) # type: ignore[return-value,no-any-return] # noqa: E501 def reverse_operate( self, op: operators.OperatorType, other: Any, **kwargs: Any ) -> ColumnElement[Any]: - return op(other, self.comparator, **kwargs) # type: ignore[return-value] # noqa: E501 + return op(other, self.comparator, **kwargs) # type: ignore[return-value,no-any-return] # noqa: E501 def _bind_param( self, diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 352e5b62df..cff465972b 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -40,6 +40,7 @@ from typing import cast from typing import Dict from typing import Generic from typing import Optional +from typing import overload from typing import Set from typing import Tuple from typing import Type @@ -53,7 +54,9 @@ from ..util.typing import Literal from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import ColumnExpressionArgument from .cache_key import CacheConst + from .elements import ColumnElement from .type_api import TypeEngine _T = TypeVar("_T", bound=Any) @@ -67,6 +70,17 @@ class OperatorType(Protocol): __name__: str + @overload + def __call__( + self, + left: ColumnExpressionArgument[Any], + right: Optional[Any] = None, + *other: Any, + **kwargs: Any, + ) -> ColumnElement[Any]: + ... + + @overload def __call__( self, left: Operators, @@ -76,6 +90,15 @@ class OperatorType(Protocol): ) -> Operators: ... + def __call__( + self, + left: Any, + right: Optional[Any] = None, + *other: Any, + **kwargs: Any, + ) -> Operators: + ... + add = cast(OperatorType, _uncast_add) and_ = cast(OperatorType, _uncast_and_) @@ -436,15 +459,35 @@ class custom_op(OperatorType, Generic[_T]): self.return_type._static_cache_key if self.return_type else None, ) + @overload + def __call__( + self, + left: ColumnExpressionArgument[Any], + right: Optional[Any] = None, + *other: Any, + **kwargs: Any, + ) -> ColumnElement[Any]: + ... + + @overload def __call__( self, left: Operators, right: Optional[Any] = None, *other: Any, **kwargs: Any, + ) -> Operators: + ... + + def __call__( + self, + left: Any, + right: Optional[Any] = None, + *other: Any, + **kwargs: Any, ) -> Operators: if hasattr(left, "__sa_operate__"): - return left.operate(self, right, *other, **kwargs) + return left.operate(self, right, *other, **kwargs) # type: ignore elif self.python_impl: return self.python_impl(left, right, *other, **kwargs) # type: ignore # noqa: E501 else: diff --git a/test/typing/plain_files/sql/operators.py b/test/typing/plain_files/sql/operators.py index fea8daa6a2..8258ec65b1 100644 --- a/test/typing/plain_files/sql/operators.py +++ b/test/typing/plain_files/sql/operators.py @@ -7,10 +7,12 @@ from sqlalchemy import BigInteger from sqlalchemy import column from sqlalchemy import ColumnElement from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import String from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.sql import operators class Base(DeclarativeBase): @@ -136,3 +138,8 @@ op_b: "ColumnElement[int]" = col.op("&", return_type=Integer)(1) op_c: "ColumnElement[str]" = col.op("&", return_type=String)("1") op_d: "ColumnElement[int]" = col.op("&", return_type=BigInteger)("1") op_e: "ColumnElement[bool]" = col.bool_op("&")("1") + + +# op functions +t1 = operators.eq(A.id, 1) +select().where(t1)