from ..engine import Dialect
from ..engine import Engine
from ..engine.interfaces import _CoreMultiExecuteParams
- from ..engine.interfaces import _ExecuteOptions
from ..engine.interfaces import CacheStats
from ..engine.interfaces import CompiledCacheType
+ from ..engine.interfaces import CoreExecuteOptionsParameter
from ..engine.interfaces import SchemaTranslateMapType
from ..engine.result import Result
self,
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
- execution_options: _ExecuteOptions,
+ execution_options: CoreExecuteOptionsParameter,
) -> Result[Any]:
if self.supports_execution:
if TYPE_CHECKING:
self,
connection: Connection,
distilled_params: _CoreMultiExecuteParams,
- execution_options: _ExecuteOptions,
+ execution_options: CoreExecuteOptionsParameter,
) -> Any:
"""an additional hook for subclasses to provide a different
implementation for connection.scalar() vs. connection.execute().
from typing import Any
from typing import Callable
from typing import cast
-from typing import Iterable
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
from .. import inspection
from .. import util
from ..util.typing import Literal
-from ..util.typing import Protocol
from ..util.typing import Self
if TYPE_CHECKING:
_closure_per_cache_key: _LambdaCacheType = util.LRUCache(1000)
-class _LambdaType(Protocol):
- __code__: CodeType
- __closure__: Iterable[Tuple[Any, Any]]
+_LambdaType = Callable[[], Any]
- def __call__(self, *arg: Any, **kw: Any) -> ClauseElement:
- ...
+_AnyLambdaType = Callable[..., Any]
+
+_StmtLambdaType = Callable[[], Any]
+
+_E = TypeVar("_E", bound=Executable)
+_StmtLambdaElementType = Callable[[_E], Any]
class LambdaOptions(Options):
def lambda_stmt(
- lmb: _LambdaType,
+ lmb: _StmtLambdaType,
enable_tracking: bool = True,
track_closure_variables: bool = True,
track_on: Optional[object] = None,
closure_cache_key: Union[Tuple[Any, ...], Literal[CacheConst.NO_CACHE]]
role: Type[SQLRole]
_rec: Union[AnalyzedFunction, NonAnalyzedFunction]
- fn: _LambdaType
+ fn: _AnyLambdaType
tracker_key: Tuple[CodeType, ...]
def __repr__(self):
bindparams.extend(self._resolved_bindparams)
return cache_key
- def _invoke_user_fn(self, fn: _LambdaType, *arg: Any) -> ClauseElement:
- return fn()
+ def _invoke_user_fn(self, fn: _AnyLambdaType, *arg: Any) -> ClauseElement:
+ return fn() # type: ignore[no-any-return]
class DeferredLambdaElement(LambdaElement):
self._transforms += (deferred_copy_internals,)
-class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
+class StatementLambdaElement(
+ roles.AllowsLambdaRole, LambdaElement, Executable
+):
"""Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
The :class:`_sql.StatementLambdaElement` is constructed using the
"""
- def __add__(self, other):
+ if TYPE_CHECKING:
+
+ def __init__(
+ self,
+ fn: _StmtLambdaType,
+ role: Type[SQLRole],
+ opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
+ apply_propagate_attrs: Optional[ClauseElement] = None,
+ ):
+ ...
+
+ def __add__(
+ self, other: _StmtLambdaElementType[Any]
+ ) -> StatementLambdaElement:
return self.add_criteria(other)
def add_criteria(
self,
- other,
- enable_tracking=True,
- track_on=None,
- track_closure_variables=True,
- track_bound_values=True,
- ):
+ other: _StmtLambdaElementType[Any],
+ enable_tracking: bool = True,
+ track_on: Optional[Any] = None,
+ track_closure_variables: bool = True,
+ track_bound_values: bool = True,
+ ) -> StatementLambdaElement:
"""Add new criteria to this :class:`_sql.StatementLambdaElement`.
E.g.::
else:
raise exc.ObjectNotExecutableError(self)
+ @property
+ def _proxied(self) -> Any:
+ return self._rec_expected_expr
+
@property
def _with_options(self):
- if TYPE_CHECKING:
- assert isinstance(self._rec.expected_expr, Executable)
- return self._rec.expected_expr._with_options
+ return self._proxied._with_options
@property
def _effective_plugin_target(self):
- if TYPE_CHECKING:
- assert isinstance(self._rec.expected_expr, Executable)
- return self._rec.expected_expr._effective_plugin_target
+ return self._proxied._effective_plugin_target
@property
def _execution_options(self):
- if TYPE_CHECKING:
- assert isinstance(self._rec.expected_expr, Executable)
- return self._rec.expected_expr._execution_options
+ return self._proxied._execution_options
+
+ @property
+ def _all_selected_columns(self):
+ return self._proxied._all_selected_columns
+
+ @property
+ def is_select(self):
+ return self._proxied.is_select
+
+ @property
+ def is_update(self):
+ return self._proxied.is_update
+
+ @property
+ def is_insert(self):
+ return self._proxied.is_insert
- def spoil(self):
+ @property
+ def is_text(self):
+ return self._proxied.is_text
+
+ @property
+ def is_delete(self):
+ return self._proxied.is_delete
+
+ @property
+ def is_dml(self):
+ return self._proxied.is_dml
+
+ def spoil(self) -> NullLambdaStatement:
"""Return a new :class:`.StatementLambdaElement` that will run
all lambdas unconditionally each time.
def __init__(
self,
- fn: _LambdaType,
+ fn: _StmtLambdaElementType[Any],
parent_lambda: StatementLambdaElement,
opts: Union[Type[LambdaOptions], LambdaOptions],
):
self.opts = opts
- self.fn = fn
+ self.fn = fn # type: ignore[assignment]
self.parent_lambda = parent_lambda
self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
--- /dev/null
+from __future__ import annotations
+
+from typing import Tuple
+from typing import TYPE_CHECKING
+
+from sqlalchemy import Column
+from sqlalchemy import create_engine
+from sqlalchemy import Integer
+from sqlalchemy import lambda_stmt
+from sqlalchemy import MetaData
+from sqlalchemy import Result
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+class User(Base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ email: Mapped[str]
+
+
+user_table = Table(
+ "user_table", MetaData(), Column("id", Integer), Column("email", String)
+)
+
+
+s1 = select(user_table).where(lambda: user_table.c.id == 5)
+
+s2 = select(User).where(lambda: User.id == 5)
+
+s3 = lambda_stmt(lambda: select(user_table).where(user_table.c.id == 5))
+
+s4 = lambda_stmt(lambda: select(User).where(User.id == 5))
+
+s5 = lambda_stmt(lambda: select(user_table)) + (
+ lambda s: s.where(user_table.c.id == 5)
+)
+
+s6 = lambda_stmt(lambda: select(User)) + (lambda s: s.where(User.id == 5))
+
+
+if TYPE_CHECKING:
+
+ # EXPECTED_TYPE: StatementLambdaElement
+ reveal_type(s5)
+
+ # EXPECTED_TYPE: StatementLambdaElement
+ reveal_type(s6)
+
+
+e = create_engine("sqlite://")
+
+with e.connect() as conn:
+ result = conn.execute(s6)
+
+ if TYPE_CHECKING:
+ # EXPECTED_TYPE: CursorResult[Any]
+ reveal_type(result)
+
+ # we can type these like this
+ my_result: Result[Tuple[User]] = conn.execute(s6)
+
+ if TYPE_CHECKING:
+ # pyright and mypy disagree on the specific type here,
+ # mypy sees Result as we said, pyright seems to upgrade it to
+ # CursorResult
+ # EXPECTED_RE_TYPE: .*(?:Cursor)?Result\[Tuple\[.*User\]\]
+ reveal_type(my_result)