From e0eea374c2df82f879d69b99ba2230c743bbae27 Mon Sep 17 00:00:00 2001 From: Gleb Kisenkov Date: Thu, 8 Dec 2022 17:48:55 -0500 Subject: [PATCH] Type annotations for sqlalchemy.ext.horizontal_shard The horizontal sharding extension is now pep-484 typed. Thanks to Gleb Kisenkov for their efforts on this. Closes: #8948 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8948 Pull-request-sha: e40e768492685aa9ce57c4762c571f935e3fd3c7 Change-Id: I2374e174c9433846c453c20a37ec5e5584fd3b31 --- .../changelog/unreleased_20/horiz_typing.rst | 6 + lib/sqlalchemy/ext/horizontal_shard.py | 250 ++++++++++++++---- lib/sqlalchemy/orm/session.py | 3 +- test/ext/test_horizontal_shard.py | 4 +- 4 files changed, 210 insertions(+), 53 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/horiz_typing.rst diff --git a/doc/build/changelog/unreleased_20/horiz_typing.rst b/doc/build/changelog/unreleased_20/horiz_typing.rst new file mode 100644 index 0000000000..55b794d0b0 --- /dev/null +++ b/doc/build/changelog/unreleased_20/horiz_typing.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, typing + :tickets: 6810 + + The horizontal sharding extension is now pep-484 typed. Thanks to Gleb + Kisenkov for their efforts on this. diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 8f6e2ffcd9..69767ad6cb 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """Horizontal sharding support. @@ -14,27 +13,93 @@ distribute queries and persistence operations across multiple databases. For a usage example, see the :ref:`examples_sharding` example included in the source distribution. +.. legacy:: The horizontal sharding API is not fully updated for the + SQLAlchemy 2.0 API, and still relies in part on the + legacy :class:`.Query` architecture, in particular as part of the + signature for the :paramref:`.ShardedSession.id_chooser` parameter. + This may change in a future release. + """ +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union from .. import event from .. import exc from .. import inspect from .. import util +from ..orm import PassiveFlag +from ..orm.mapper import Mapper from ..orm.query import Query from ..orm.session import Session +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ..engine.base import Connection + from ..engine.base import Engine + from ..engine.base import OptionEngine + from ..engine.result import IteratorResult + from ..engine.result import Result + from ..orm import LoaderCallableStatus + from ..orm._typing import _O + from ..orm.bulk_persistence import BulkUDCompileState + from ..orm.context import QueryContext + from ..orm.session import _EntityBindKey + from ..orm.session import _SessionBind + from ..orm.session import ORMExecuteState + from ..orm.state import InstanceState + from ..sql import Executable + from ..sql._typing import _TP + from ..sql.elements import ClauseElement __all__ = ["ShardedSession", "ShardedQuery"] +_T = TypeVar("_T", bound=Any) + +SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]") + +_ShardKey = str + + +class ShardChooser(Protocol): + def __call__( + self, + mapper: Optional[Mapper[_T]], + instance: Any, + clause: Optional[ClauseElement], + ) -> Any: + ... + -class ShardedQuery(Query): - def __init__(self, *args, **kwargs): +class ShardedQuery(Query[_T]): + """Query class used with :class:`.ShardedSession`. + + .. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy + :class:`.Query` class. The :class:`.ShardedSession` now supports + 2.0 style execution via the :meth:`.ShardedSession.execute` method + as well. + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) + assert isinstance(self.session, ShardedSession) + self.id_chooser = self.session.id_chooser self.query_chooser = self.session.query_chooser self.execute_chooser = self.session.execute_chooser self._shard_id = None - def set_shard(self, shard_id): + def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery: """Return a new query, limited to a single shard ID. All subsequent operations with the returned query will @@ -53,15 +118,23 @@ class ShardedQuery(Query): class ShardedSession(Session): + shard_chooser: ShardChooser + id_chooser: Callable[[Query[Any], Iterable[Any]], Iterable[Any]] + execute_chooser: Callable[[ORMExecuteState], Iterable[Any]] + def __init__( self, - shard_chooser, - id_chooser, - execute_chooser=None, - shards=None, - query_cls=ShardedQuery, - **kwargs, - ): + shard_chooser: ShardChooser, + id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]], + execute_chooser: Optional[ + Callable[[ORMExecuteState], Iterable[Any]] + ] = None, + shards: Optional[Dict[str, Any]] = None, + query_cls: Type[Query[_T]] = ShardedQuery, + *, + query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None, + **kwargs: Any, + ) -> None: """Construct a ShardedSession. :param shard_chooser: A callable which, passed a Mapper, a mapped @@ -71,9 +144,14 @@ class ShardedSession(Session): should set whatever state on the instance to mark it in the future as participating in that shard. - :param id_chooser: A callable, passed a query and a tuple of identity - values, which should return a list of shard ids where the ID might - reside. The databases will be queried in the order of this listing. + :param id_chooser: A callable, passed a :class:`.ShardedQuery` and a + tuple of identity values, which should return a list of shard ids + where the ID might reside. The databases will be queried in the order + of this listing. + + .. legacy:: This parameter still uses the legacy + :class:`.ShardedQuery` class as an argument passed to the + callable. :param execute_chooser: For a given :class:`.ORMExecuteState`, returns the list of shard_ids @@ -87,7 +165,6 @@ class ShardedSession(Session): to :class:`~sqlalchemy.engine.Engine` objects. """ - query_chooser = kwargs.pop("query_chooser", None) super().__init__(query_cls=query_cls, **kwargs) event.listen( @@ -97,6 +174,7 @@ class ShardedSession(Session): self.id_chooser = id_chooser if query_chooser: + _query_chooser = query_chooser util.warn_deprecated( "The ``query_choser`` parameter is deprecated; " "please use ``execute_chooser``.", @@ -108,26 +186,34 @@ class ShardedSession(Session): "at the same time." ) - def execute_chooser(orm_context): - return query_chooser(orm_context.statement) + def _default_execute_chooser( + orm_context: ORMExecuteState, + ) -> Iterable[Any]: + return _query_chooser(orm_context.statement) - self.execute_chooser = execute_chooser - else: - self.execute_chooser = execute_chooser + if execute_chooser is None: + execute_chooser = _default_execute_chooser + + if execute_chooser is None: + raise exc.ArgumentError( + "execute_chooser or query_chooser is required" + ) + self.execute_chooser = execute_chooser self.query_chooser = query_chooser - self.__binds = {} + self.__shards: Dict[_ShardKey, _SessionBind] = {} if shards is not None: for k in shards: self.bind_shard(k, shards[k]) def _identity_lookup( self, - mapper, - primary_key_identity, - identity_token=None, - lazy_loaded_from=None, - **kw, - ): + mapper: Mapper[_O], + primary_key_identity: Union[Any, Tuple[Any, ...]], + identity_token: Optional[Any] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + lazy_loaded_from: Optional[InstanceState[Any]] = None, + **kw: Any, + ) -> Union[Optional[_O], LoaderCallableStatus]: """override the default :meth:`.Session._identity_lookup` method so that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). @@ -138,30 +224,37 @@ class ShardedSession(Session): """ if identity_token is not None: - return super()._identity_lookup( + obj = super()._identity_lookup( mapper, primary_key_identity, identity_token=identity_token, **kw, ) + + return obj else: q = self.query(mapper) if lazy_loaded_from: q = q._set_lazyload_from(lazy_loaded_from) for shard_id in self.id_chooser(q, primary_key_identity): - obj = super()._identity_lookup( + obj2 = super()._identity_lookup( mapper, primary_key_identity, identity_token=shard_id, lazy_loaded_from=lazy_loaded_from, **kw, ) - if obj is not None: - return obj + if obj2 is not None: + return obj2 return None - def _choose_shard_and_assign(self, mapper, instance, **kw): + def _choose_shard_and_assign( + self, + mapper: Optional[_EntityBindKey[_O]], + instance: Any, + **kw: Any, + ) -> Any: if instance is not None: state = inspect(instance) if state.key: @@ -171,14 +264,19 @@ class ShardedSession(Session): elif state.identity_token: return state.identity_token + assert isinstance(mapper, Mapper) shard_id = self.shard_chooser(mapper, instance, **kw) if instance is not None: state.identity_token = shard_id return shard_id - def connection_callable( - self, mapper=None, instance=None, shard_id=None, **kwargs - ): + def connection_callable( # type: ignore [override] + self, + mapper: Optional[Mapper[_T]] = None, + instance: Optional[Any] = None, + shard_id: Optional[Any] = None, + **kw: Any, + ) -> Connection: """Provide a :class:`_engine.Connection` to use in the unit of work flush process. @@ -188,26 +286,63 @@ class ShardedSession(Session): shard_id = self._choose_shard_and_assign(mapper, instance) if self.in_transaction(): - return self.get_transaction().connection(mapper, shard_id=shard_id) + trans = self.get_transaction() + assert trans is not None + return trans.connection(mapper, shard_id=shard_id) else: - return self.get_bind( - mapper, shard_id=shard_id, instance=instance - ).connect(**kwargs) + bind = self.get_bind( + mapper=mapper, shard_id=shard_id, instance=instance + ) + + if isinstance(bind, Engine): + return bind.connect(**kw) + else: + assert isinstance(bind, Connection) + return bind def get_bind( - self, mapper=None, shard_id=None, instance=None, clause=None, **kw - ): + self, + mapper: Optional[_EntityBindKey[_O]] = None, + *, + shard_id: Optional[_ShardKey] = None, + instance: Optional[Any] = None, + clause: Optional[ClauseElement] = None, + **kw: Any, + ) -> _SessionBind: if shard_id is None: shard_id = self._choose_shard_and_assign( - mapper, instance, clause=clause + mapper, instance=instance, clause=clause ) - return self.__binds[shard_id] + assert shard_id is not None + return self.__shards[shard_id] + + def bind_shard( + self, shard_id: _ShardKey, bind: Union[Engine, OptionEngine] + ) -> None: + self.__shards[shard_id] = bind + + +def execute_and_instances( + orm_context: ORMExecuteState, +) -> Union[Result[_T], IteratorResult[_TP]]: + update_options: Union[ + None, + BulkUDCompileState.default_update_options, + Type[BulkUDCompileState.default_update_options], + ] + active_options: Union[ + None, + QueryContext.default_load_options, + Type[QueryContext.default_load_options], + BulkUDCompileState.default_update_options, + Type[BulkUDCompileState.default_update_options], + ] + load_options: Union[ + None, + QueryContext.default_load_options, + Type[QueryContext.default_load_options], + ] - def bind_shard(self, shard_id, bind): - self.__binds[shard_id] = bind - - -def execute_and_instances(orm_context): if orm_context.is_select: load_options = active_options = orm_context.load_options update_options = None @@ -219,17 +354,32 @@ def execute_and_instances(orm_context): load_options = update_options = active_options = None session = orm_context.session - - def iter_for_shard(shard_id, load_options, update_options): + assert isinstance(session, ShardedSession) + + def iter_for_shard( + shard_id: str, + load_options: Union[ + None, + QueryContext.default_load_options, + Type[QueryContext.default_load_options], + ], + update_options: Union[ + None, + BulkUDCompileState.default_update_options, + Type[BulkUDCompileState.default_update_options], + ], + ) -> Union[Result[_T], IteratorResult[_TP]]: execution_options = dict(orm_context.local_execution_options) bind_arguments = dict(orm_context.bind_arguments) bind_arguments["shard_id"] = shard_id if orm_context.is_select: + assert load_options is not None load_options += {"_refresh_identity_token": shard_id} execution_options["_sa_orm_load_options"] = load_options elif orm_context.is_update or orm_context.is_delete: + assert update_options is not None update_options += {"_refresh_identity_token": shard_id} execution_options["_sa_orm_update_options"] = update_options diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 15cd587926..1f7bd6d73a 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -152,7 +152,7 @@ _PKIdentityArgument = Union[Any, Tuple[Any, ...]] _BindArguments = Dict[str, Any] _EntityBindKey = Union[Type[_O], "Mapper[_O]"] -_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"] +_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table", str] _SessionBind = Union["Engine", "Connection"] @@ -2374,6 +2374,7 @@ class Session(_SessionClassMethods, EventTarget): def get_bind( self, mapper: Optional[_EntityBindKey[_O]] = None, + *, clause: Optional[ClauseElement] = None, bind: Optional[_SessionBind] = None, _sa_skip_events: Optional[bool] = None, diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 8913478598..ab4a24f71c 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -498,7 +498,7 @@ class ShardTest: eq_({t.temperature for t in temps}, {86.0, 75.0, 91.0}) self.assert_sql_count( - sess._ShardedSession__binds["north_america"], go, 0 + sess._ShardedSession__shards["north_america"], go, 0 ) eq_( @@ -533,7 +533,7 @@ class ShardTest: assert inspect(t).deleted is (t.temperature >= 80) self.assert_sql_count( - sess._ShardedSession__binds["north_america"], go, 0 + sess._ShardedSession__shards["north_america"], go, 0 ) eq_( -- 2.47.2