From be0831fea83247451628bc6643d5b130c63f6011 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 19 Jan 2023 12:09:29 -0500 Subject: [PATCH] implement basic typing for lambda elements These weren't working at all, so fixed things up and added a test suite. Keeping things very basic with Any returns etc. as having more specific return types starts making it too cumbersome to write end-user code. Corrected the type passed for "lambda statements" so that a plain lambda is accepted by mypy, pyright, others without any errors about argument types. Additionally implemented typing for more of the public API for lambda statements and ensured :class:`.StatementLambdaElement` is part of the :class:`.Executable` hierarchy so it's typed as accepted by :meth:`_engine.Connection.execute`. Fixes: #9120 Change-Id: Ia7fa34e5b6e43fba02c8f94ccc256f3a68a1f445 --- doc/build/changelog/unreleased_20/9120.rst | 10 ++ lib/sqlalchemy/sql/elements.py | 6 +- lib/sqlalchemy/sql/lambdas.py | 104 +++++++++++++++------ test/ext/mypy/plain_files/lambda_stmt.py | 77 +++++++++++++++ 4 files changed, 163 insertions(+), 34 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/9120.rst create mode 100644 test/ext/mypy/plain_files/lambda_stmt.py diff --git a/doc/build/changelog/unreleased_20/9120.rst b/doc/build/changelog/unreleased_20/9120.rst new file mode 100644 index 0000000000..9e2a54d2fd --- /dev/null +++ b/doc/build/changelog/unreleased_20/9120.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, typing + :tickets: 9120 + + Corrected the type passed for "lambda statements" so that a plain lambda is + accepted by mypy, pyright, others without any errors about argument types. + Additionally implemented typing for more of the public API for lambda + statements and ensured :class:`.StatementLambdaElement` is part of the + :class:`.Executable` hierarchy so it's typed as accepted by + :meth:`_engine.Connection.execute`. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 6d19494253..043fb7a030 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -106,9 +106,9 @@ if typing.TYPE_CHECKING: from ..engine import Dialect from ..engine import Engine from ..engine.interfaces import _CoreMultiExecuteParams - from ..engine.interfaces import _ExecuteOptions from ..engine.interfaces import CacheStats from ..engine.interfaces import CompiledCacheType + from ..engine.interfaces import CoreExecuteOptionsParameter from ..engine.interfaces import SchemaTranslateMapType from ..engine.result import Result @@ -481,7 +481,7 @@ class ClauseElement( self, connection: Connection, distilled_params: _CoreMultiExecuteParams, - execution_options: _ExecuteOptions, + execution_options: CoreExecuteOptionsParameter, ) -> Result[Any]: if self.supports_execution: if TYPE_CHECKING: @@ -496,7 +496,7 @@ class ClauseElement( self, connection: Connection, distilled_params: _CoreMultiExecuteParams, - execution_options: _ExecuteOptions, + execution_options: CoreExecuteOptionsParameter, ) -> Any: """an additional hook for subclasses to provide a different implementation for connection.scalar() vs. connection.execute(). diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index b153ba999f..d737b1bcb7 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -18,13 +18,13 @@ from types import CodeType from typing import Any from typing import Callable from typing import cast -from typing import Iterable from typing import List from typing import MutableMapping from typing import Optional from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import weakref @@ -43,7 +43,6 @@ from .. import exc from .. import inspection from .. import util from ..util.typing import Literal -from ..util.typing import Protocol from ..util.typing import Self if TYPE_CHECKING: @@ -60,12 +59,14 @@ _BoundParameterGetter = Callable[..., Any] _closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000) -class _LambdaType(Protocol): - __code__: CodeType - __closure__: Iterable[Tuple[Any, Any]] +_LambdaType = Callable[[], Any] - def __call__(self, *arg: Any, **kw: Any) -> ClauseElement: - ... +_AnyLambdaType = Callable[..., Any] + +_StmtLambdaType = Callable[[], Any] + +_E = TypeVar("_E", bound=Executable) +_StmtLambdaElementType = Callable[[_E], Any] class LambdaOptions(Options): @@ -78,7 +79,7 @@ class LambdaOptions(Options): def lambda_stmt( - lmb: _LambdaType, + lmb: _StmtLambdaType, enable_tracking: bool = True, track_closure_variables: bool = True, track_on: Optional[object] = None, @@ -185,7 +186,7 @@ class LambdaElement(elements.ClauseElement): closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]] role: Type[SQLRole] _rec: Union[AnalyzedFunction, NonAnalyzedFunction] - fn: _LambdaType + fn: _AnyLambdaType tracker_key: Tuple[CodeType, ...] def __repr__(self): @@ -416,8 +417,8 @@ class LambdaElement(elements.ClauseElement): bindparams.extend(self._resolved_bindparams) return cache_key - def _invoke_user_fn(self, fn: _LambdaType, *arg: Any) -> ClauseElement: - return fn() + def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement: + return fn() # type: ignore[no-any-return] class DeferredLambdaElement(LambdaElement): @@ -494,7 +495,9 @@ class DeferredLambdaElement(LambdaElement): self._transforms += (deferred_copy_internals,) -class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): +class StatementLambdaElement( + roles.AllowsLambdaRole, LambdaElement, Executable +): """Represent a composable SQL statement as a :class:`_sql.LambdaElement`. The :class:`_sql.StatementLambdaElement` is constructed using the @@ -520,17 +523,30 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): """ - def __add__(self, other): + if TYPE_CHECKING: + + def __init__( + self, + fn: _StmtLambdaType, + role: Type[SQLRole], + opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, + apply_propagate_attrs: Optional[ClauseElement] = None, + ): + ... + + def __add__( + self, other: _StmtLambdaElementType[Any] + ) -> StatementLambdaElement: return self.add_criteria(other) def add_criteria( self, - other, - enable_tracking=True, - track_on=None, - track_closure_variables=True, - track_bound_values=True, - ): + other: _StmtLambdaElementType[Any], + enable_tracking: bool = True, + track_on: Optional[Any] = None, + track_closure_variables: bool = True, + track_bound_values: bool = True, + ) -> StatementLambdaElement: """Add new criteria to this :class:`_sql.StatementLambdaElement`. E.g.:: @@ -587,25 +603,51 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): else: raise exc.ObjectNotExecutableError(self) + @property + def _proxied(self) -> Any: + return self._rec_expected_expr + @property def _with_options(self): - if TYPE_CHECKING: - assert isinstance(self._rec.expected_expr, Executable) - return self._rec.expected_expr._with_options + return self._proxied._with_options @property def _effective_plugin_target(self): - if TYPE_CHECKING: - assert isinstance(self._rec.expected_expr, Executable) - return self._rec.expected_expr._effective_plugin_target + return self._proxied._effective_plugin_target @property def _execution_options(self): - if TYPE_CHECKING: - assert isinstance(self._rec.expected_expr, Executable) - return self._rec.expected_expr._execution_options + return self._proxied._execution_options + + @property + def _all_selected_columns(self): + return self._proxied._all_selected_columns + + @property + def is_select(self): + return self._proxied.is_select + + @property + def is_update(self): + return self._proxied.is_update + + @property + def is_insert(self): + return self._proxied.is_insert - def spoil(self): + @property + def is_text(self): + return self._proxied.is_text + + @property + def is_delete(self): + return self._proxied.is_delete + + @property + def is_dml(self): + return self._proxied.is_dml + + def spoil(self) -> NullLambdaStatement: """Return a new :class:`.StatementLambdaElement` that will run all lambdas unconditionally each time. @@ -667,12 +709,12 @@ class LinkedLambdaElement(StatementLambdaElement): def __init__( self, - fn: _LambdaType, + fn: _StmtLambdaElementType[Any], parent_lambda: StatementLambdaElement, opts: Union[Type[LambdaOptions], LambdaOptions], ): self.opts = opts - self.fn = fn + self.fn = fn # type: ignore[assignment] self.parent_lambda = parent_lambda self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) diff --git a/test/ext/mypy/plain_files/lambda_stmt.py b/test/ext/mypy/plain_files/lambda_stmt.py new file mode 100644 index 0000000000..7e15778c1d --- /dev/null +++ b/test/ext/mypy/plain_files/lambda_stmt.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import Tuple +from typing import TYPE_CHECKING + +from sqlalchemy import Column +from sqlalchemy import create_engine +from sqlalchemy import Integer +from sqlalchemy import lambda_stmt +from sqlalchemy import MetaData +from sqlalchemy import Result +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + email: Mapped[str] + + +user_table = Table( + "user_table", MetaData(), Column("id", Integer), Column("email", String) +) + + +s1 = select(user_table).where(lambda: user_table.c.id == 5) + +s2 = select(User).where(lambda: User.id == 5) + +s3 = lambda_stmt(lambda: select(user_table).where(user_table.c.id == 5)) + +s4 = lambda_stmt(lambda: select(User).where(User.id == 5)) + +s5 = lambda_stmt(lambda: select(user_table)) + ( + lambda s: s.where(user_table.c.id == 5) +) + +s6 = lambda_stmt(lambda: select(User)) + (lambda s: s.where(User.id == 5)) + + +if TYPE_CHECKING: + + # EXPECTED_TYPE: StatementLambdaElement + reveal_type(s5) + + # EXPECTED_TYPE: StatementLambdaElement + reveal_type(s6) + + +e = create_engine("sqlite://") + +with e.connect() as conn: + result = conn.execute(s6) + + if TYPE_CHECKING: + # EXPECTED_TYPE: CursorResult[Any] + reveal_type(result) + + # we can type these like this + my_result: Result[Tuple[User]] = conn.execute(s6) + + if TYPE_CHECKING: + # pyright and mypy disagree on the specific type here, + # mypy sees Result as we said, pyright seems to upgrade it to + # CursorResult + # EXPECTED_RE_TYPE: .*(?:Cursor)?Result\[Tuple\[.*User\]\] + reveal_type(my_result) -- 2.47.2