From: Gleb Kisenkov Date: Wed, 7 Dec 2022 07:48:30 +0000 (+0100) Subject: Public API covered X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6a036551f862b78b1da9f45cf85971539e3745bb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Public API covered The whole module needs to be refined in the future. Especially addressing the internal types. --- diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 64d1b7b6e2..5075d3f39c 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -37,18 +37,20 @@ from ..engine.result import IteratorResult from ..engine.result import Result from ..orm import LoaderCallableStatus from ..orm import PassiveFlag +from ..orm._typing import _O from ..orm.bulk_persistence import BulkUDCompileState from ..orm.context import QueryContext from ..orm.mapper import Mapper from ..orm.query import Query +from ..orm.session import _EntityBindKey from ..orm.session import _SessionBind from ..orm.session import _SessionBindKey from ..orm.session import ORMExecuteState from ..orm.session import Session from ..orm.state import InstanceState from ..sql._typing import _TP -from ..sql.base import _MetaOptions from ..sql.elements import ClauseElement +from ..util.typing import Protocol __all__ = ["ShardedSession", "ShardedQuery"] @@ -58,12 +60,22 @@ _T = TypeVar("_T", bound=Any) SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]") +class ShardChooser(Protocol): + def __call__( + self, + mapper: Optional[Mapper[_T]], + instance: Any, + clause: Optional[ClauseElement], + ) -> Any: + ... + + class ShardedQuery(Query[_T]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.id_chooser = self.session.id_chooser - self.query_chooser = self.session.query_chooser - self.execute_chooser = self.session.execute_chooser + self.id_chooser = self.session.id_chooser # type: ignore [attr-defined] # noqa: E501 + self.query_chooser = self.session.query_chooser # type: ignore [attr-defined] # noqa: E501 + self.execute_chooser = self.session.execute_chooser # type: ignore [attr-defined] # noqa: E501 self._shard_id = None def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery: @@ -87,9 +99,7 @@ class ShardedQuery(Query[_T]): class ShardedSession(Session): def __init__( self, - shard_chooser: Callable[ - [Mapper[_T], Any, Optional[ClauseElement]], Any - ], + shard_chooser: ShardChooser, id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]], execute_chooser: Optional[ Callable[[ORMExecuteState], Iterable[Any]] @@ -176,7 +186,7 @@ class ShardedSession(Session): if identity_token is not None: return super()._identity_lookup( - mapper, + mapper, # type: ignore [arg-type] primary_key_identity, identity_token=identity_token, **kw, @@ -200,7 +210,7 @@ class ShardedSession(Session): def _choose_shard_and_assign( self, - mapper: Mapper[_T], + mapper: Optional[Mapper[_T]], instance: Any, **kw: Any, ) -> Any: @@ -218,7 +228,7 @@ class ShardedSession(Session): state.identity_token = shard_id return shard_id - def connection_callable( + def connection_callable( # type: ignore [override] self, mapper: Optional[Mapper[_T]] = None, instance: Optional[Any] = None, @@ -236,13 +246,13 @@ class ShardedSession(Session): if self.in_transaction(): return self.get_transaction().connection(mapper, shard_id=shard_id) # type: ignore [union-attr] # noqa: E501 else: - return self.get_bind( + return self.get_bind( # type: ignore [union-attr] mapper, shard_id=shard_id, instance=instance ).connect(**kwargs) - def get_bind( + def get_bind( # type: ignore [override] self, - mapper: Optional[Mapper[_T]] = None, + mapper: Optional[_EntityBindKey[_O]] = None, shard_id: Optional[_SessionBindKey] = None, instance: Optional[Any] = None, clause: Optional[ClauseElement] = None, @@ -250,12 +260,12 @@ class ShardedSession(Session): ) -> _SessionBind: if shard_id is None: shard_id = self._choose_shard_and_assign( - mapper, instance, clause=clause + mapper, instance, clause=clause # type: ignore [arg-type] ) return self.__binds[shard_id] def bind_shard( - self, shard_id: str, bind: Union[Engine, OptionEngine] + self, shard_id: _SessionBindKey, bind: Union[Engine, OptionEngine] ) -> None: self.__binds[shard_id] = bind @@ -269,7 +279,7 @@ def execute_and_instances( elif orm_context.is_update or orm_context.is_delete: load_options = None - update_options = active_options = orm_context.update_delete_options + update_options = active_options = orm_context.update_delete_options # type: ignore [assignment] # noqa: E501 else: load_options = update_options = active_options = None @@ -278,9 +288,13 @@ def execute_and_instances( def iter_for_shard( shard_id: str, load_options: Union[ - None, QueryContext.default_load_options, _MetaOptions + QueryContext.default_load_options, + Type[QueryContext.default_load_options], + ], + update_options: Union[ + BulkUDCompileState.default_update_options, + Type[QueryContext.default_load_options], ], - update_options: Optional[BulkUDCompileState.default_update_options], ) -> Union[Result[_T], IteratorResult[_TP]]: execution_options = dict(orm_context.local_execution_options) @@ -288,10 +302,10 @@ def execute_and_instances( bind_arguments["shard_id"] = shard_id if orm_context.is_select: - load_options += {"_refresh_identity_token": shard_id} # type: ignore [operator] # noqa: E501 + load_options += {"_refresh_identity_token": shard_id} execution_options["_sa_orm_load_options"] = load_options elif orm_context.is_update or orm_context.is_delete: - update_options += {"_refresh_identity_token": shard_id} # type: ignore [operator] # noqa: E501 + update_options += {"_refresh_identity_token": shard_id} execution_options["_sa_orm_update_options"] = update_options return orm_context.invoke_statement( @@ -308,10 +322,10 @@ def execute_and_instances( shard_id = None if shard_id is not None: - return iter_for_shard(shard_id, load_options, update_options) + return iter_for_shard(shard_id, load_options, update_options) # type: ignore [arg-type] # noqa: E501 else: partial = [] - for shard_id in session.execute_chooser(orm_context): - result_ = iter_for_shard(shard_id, load_options, update_options) + for shard_id in session.execute_chooser(orm_context): # type: ignore [attr-defined] # noqa: E501 + result_ = iter_for_shard(shard_id, load_options, update_options) # type: ignore [arg-type] # noqa: E501 partial.append(result_) return partial[0].merge(*partial[1:]) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 15cd587926..672e8a29d2 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -152,7 +152,7 @@ _PKIdentityArgument = Union[Any, Tuple[Any, ...]] _BindArguments = Dict[str, Any] _EntityBindKey = Union[Type[_O], "Mapper[_O]"] -_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"] +_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table", str] _SessionBind = Union["Engine", "Connection"]