]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement basic typing for lambda elements
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Jan 2023 17:09:29 +0000 (12:09 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Jan 2023 22:04:59 +0000 (17:04 -0500)
These weren't working at all, so fixed things up and
added a test suite.   Keeping things very basic with Any
returns etc. as having more specific return types
starts making it too cumbersome to write end-user code.

Corrected the type passed for "lambda statements" so that a plain lambda is
accepted by mypy, pyright, others without any errors about argument types.
Additionally implemented typing for more of the public API for lambda
statements and ensured :class:`.StatementLambdaElement` is part of the
:class:`.Executable` hierarchy so it's typed as accepted by
:meth:`_engine.Connection.execute`.

Fixes: #9120
Change-Id: Ia7fa34e5b6e43fba02c8f94ccc256f3a68a1f445

doc/build/changelog/unreleased_20/9120.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/lambdas.py
test/ext/mypy/plain_files/lambda_stmt.py [new file with mode: 0644]

diff --git a/doc/build/changelog/unreleased_20/9120.rst b/doc/build/changelog/unreleased_20/9120.rst
new file mode 100644 (file)
index 0000000..9e2a54d
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 9120
+
+    Corrected the type passed for "lambda statements" so that a plain lambda is
+    accepted by mypy, pyright, others without any errors about argument types.
+    Additionally implemented typing for more of the public API for lambda
+    statements and ensured :class:`.StatementLambdaElement` is part of the
+    :class:`.Executable` hierarchy so it's typed as accepted by
+    :meth:`_engine.Connection.execute`.
index 6d19494253cb4cf0f49fdc588c2e585f22ca0f5d..043fb7a0305761c8f7be859445eceae770b9c8ca 100644 (file)
@@ -106,9 +106,9 @@ if typing.TYPE_CHECKING:
     from ..engine import Dialect
     from ..engine import Engine
     from ..engine.interfaces import _CoreMultiExecuteParams
-    from ..engine.interfaces import _ExecuteOptions
     from ..engine.interfaces import CacheStats
     from ..engine.interfaces import CompiledCacheType
+    from ..engine.interfaces import CoreExecuteOptionsParameter
     from ..engine.interfaces import SchemaTranslateMapType
     from ..engine.result import Result
 
@@ -481,7 +481,7 @@ class ClauseElement(
         self,
         connection: Connection,
         distilled_params: _CoreMultiExecuteParams,
-        execution_options: _ExecuteOptions,
+        execution_options: CoreExecuteOptionsParameter,
     ) -> Result[Any]:
         if self.supports_execution:
             if TYPE_CHECKING:
@@ -496,7 +496,7 @@ class ClauseElement(
         self,
         connection: Connection,
         distilled_params: _CoreMultiExecuteParams,
-        execution_options: _ExecuteOptions,
+        execution_options: CoreExecuteOptionsParameter,
     ) -> Any:
         """an additional hook for subclasses to provide a different
         implementation for connection.scalar() vs. connection.execute().
index b153ba999f648205ee88822fa6346bb35b262318..d737b1bcb75c4b8f4682893a7a543079cbeebb24 100644 (file)
@@ -18,13 +18,13 @@ from types import CodeType
 from typing import Any
 from typing import Callable
 from typing import cast
-from typing import Iterable
 from typing import List
 from typing import MutableMapping
 from typing import Optional
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 import weakref
 
@@ -43,7 +43,6 @@ from .. import exc
 from .. import inspection
 from .. import util
 from ..util.typing import Literal
-from ..util.typing import Protocol
 from ..util.typing import Self
 
 if TYPE_CHECKING:
@@ -60,12 +59,14 @@ _BoundParameterGetter = Callable[..., Any]
 _closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000)
 
 
-class _LambdaType(Protocol):
-    __code__: CodeType
-    __closure__: Iterable[Tuple[Any, Any]]
+_LambdaType = Callable[[], Any]
 
-    def __call__(self, *arg: Any, **kw: Any) -> ClauseElement:
-        ...
+_AnyLambdaType = Callable[..., Any]
+
+_StmtLambdaType = Callable[[], Any]
+
+_E = TypeVar("_E", bound=Executable)
+_StmtLambdaElementType = Callable[[_E], Any]
 
 
 class LambdaOptions(Options):
@@ -78,7 +79,7 @@ class LambdaOptions(Options):
 
 
 def lambda_stmt(
-    lmb: _LambdaType,
+    lmb: _StmtLambdaType,
     enable_tracking: bool = True,
     track_closure_variables: bool = True,
     track_on: Optional[object] = None,
@@ -185,7 +186,7 @@ class LambdaElement(elements.ClauseElement):
     closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
     role: Type[SQLRole]
     _rec: Union[AnalyzedFunction, NonAnalyzedFunction]
-    fn: _LambdaType
+    fn: _AnyLambdaType
     tracker_key: Tuple[CodeType, ...]
 
     def __repr__(self):
@@ -416,8 +417,8 @@ class LambdaElement(elements.ClauseElement):
             bindparams.extend(self._resolved_bindparams)
         return cache_key
 
-    def _invoke_user_fn(self, fn: _LambdaType, *arg: Any) -> ClauseElement:
-        return fn()
+    def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement:
+        return fn()  # type: ignore[no-any-return]
 
 
 class DeferredLambdaElement(LambdaElement):
@@ -494,7 +495,9 @@ class DeferredLambdaElement(LambdaElement):
             self._transforms += (deferred_copy_internals,)
 
 
-class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
+class StatementLambdaElement(
+    roles.AllowsLambdaRole, LambdaElement, Executable
+):
     """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
 
     The :class:`_sql.StatementLambdaElement` is constructed using the
@@ -520,17 +523,30 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
 
     """
 
-    def __add__(self, other):
+    if TYPE_CHECKING:
+
+        def __init__(
+            self,
+            fn: _StmtLambdaType,
+            role: Type[SQLRole],
+            opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
+            apply_propagate_attrs: Optional[ClauseElement] = None,
+        ):
+            ...
+
+    def __add__(
+        self, other: _StmtLambdaElementType[Any]
+    ) -> StatementLambdaElement:
         return self.add_criteria(other)
 
     def add_criteria(
         self,
-        other,
-        enable_tracking=True,
-        track_on=None,
-        track_closure_variables=True,
-        track_bound_values=True,
-    ):
+        other: _StmtLambdaElementType[Any],
+        enable_tracking: bool = True,
+        track_on: Optional[Any] = None,
+        track_closure_variables: bool = True,
+        track_bound_values: bool = True,
+    ) -> StatementLambdaElement:
         """Add new criteria to this :class:`_sql.StatementLambdaElement`.
 
         E.g.::
@@ -587,25 +603,51 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
         else:
             raise exc.ObjectNotExecutableError(self)
 
+    @property
+    def _proxied(self) -> Any:
+        return self._rec_expected_expr
+
     @property
     def _with_options(self):
-        if TYPE_CHECKING:
-            assert isinstance(self._rec.expected_expr, Executable)
-        return self._rec.expected_expr._with_options
+        return self._proxied._with_options
 
     @property
     def _effective_plugin_target(self):
-        if TYPE_CHECKING:
-            assert isinstance(self._rec.expected_expr, Executable)
-        return self._rec.expected_expr._effective_plugin_target
+        return self._proxied._effective_plugin_target
 
     @property
     def _execution_options(self):
-        if TYPE_CHECKING:
-            assert isinstance(self._rec.expected_expr, Executable)
-        return self._rec.expected_expr._execution_options
+        return self._proxied._execution_options
+
+    @property
+    def _all_selected_columns(self):
+        return self._proxied._all_selected_columns
+
+    @property
+    def is_select(self):
+        return self._proxied.is_select
+
+    @property
+    def is_update(self):
+        return self._proxied.is_update
+
+    @property
+    def is_insert(self):
+        return self._proxied.is_insert
 
-    def spoil(self):
+    @property
+    def is_text(self):
+        return self._proxied.is_text
+
+    @property
+    def is_delete(self):
+        return self._proxied.is_delete
+
+    @property
+    def is_dml(self):
+        return self._proxied.is_dml
+
+    def spoil(self) -> NullLambdaStatement:
         """Return a new :class:`.StatementLambdaElement` that will run
         all lambdas unconditionally each time.
 
@@ -667,12 +709,12 @@ class LinkedLambdaElement(StatementLambdaElement):
 
     def __init__(
         self,
-        fn: _LambdaType,
+        fn: _StmtLambdaElementType[Any],
         parent_lambda: StatementLambdaElement,
         opts: Union[Type[LambdaOptions], LambdaOptions],
     ):
         self.opts = opts
-        self.fn = fn
+        self.fn = fn  # type: ignore[assignment]
         self.parent_lambda = parent_lambda
 
         self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
diff --git a/test/ext/mypy/plain_files/lambda_stmt.py b/test/ext/mypy/plain_files/lambda_stmt.py
new file mode 100644 (file)
index 0000000..7e15778
--- /dev/null
@@ -0,0 +1,77 @@
+from __future__ import annotations
+
+from typing import Tuple
+from typing import TYPE_CHECKING
+
+from sqlalchemy import Column
+from sqlalchemy import create_engine
+from sqlalchemy import Integer
+from sqlalchemy import lambda_stmt
+from sqlalchemy import MetaData
+from sqlalchemy import Result
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class User(Base):
+    __tablename__ = "a"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    email: Mapped[str]
+
+
+user_table = Table(
+    "user_table", MetaData(), Column("id", Integer), Column("email", String)
+)
+
+
+s1 = select(user_table).where(lambda: user_table.c.id == 5)
+
+s2 = select(User).where(lambda: User.id == 5)
+
+s3 = lambda_stmt(lambda: select(user_table).where(user_table.c.id == 5))
+
+s4 = lambda_stmt(lambda: select(User).where(User.id == 5))
+
+s5 = lambda_stmt(lambda: select(user_table)) + (
+    lambda s: s.where(user_table.c.id == 5)
+)
+
+s6 = lambda_stmt(lambda: select(User)) + (lambda s: s.where(User.id == 5))
+
+
+if TYPE_CHECKING:
+
+    # EXPECTED_TYPE: StatementLambdaElement
+    reveal_type(s5)
+
+    # EXPECTED_TYPE: StatementLambdaElement
+    reveal_type(s6)
+
+
+e = create_engine("sqlite://")
+
+with e.connect() as conn:
+    result = conn.execute(s6)
+
+    if TYPE_CHECKING:
+        # EXPECTED_TYPE: CursorResult[Any]
+        reveal_type(result)
+
+    # we can type these like this
+    my_result: Result[Tuple[User]] = conn.execute(s6)
+
+    if TYPE_CHECKING:
+        # pyright and mypy disagree on the specific type here,
+        # mypy sees Result as we said, pyright seems to upgrade it to
+        # CursorResult
+        # EXPECTED_RE_TYPE: .*(?:Cursor)?Result\[Tuple\[.*User\]\]
+        reveal_type(my_result)