From 1cdf94c10af923e89f0a82e2f756bf75abe7ac48 Mon Sep 17 00:00:00 2001 From: Gleb Kisenkov Date: Sat, 3 Dec 2022 10:56:08 +0100 Subject: [PATCH] Collected runtime types --- lib/sqlalchemy/ext/horizontal_shard.py | 94 +++++++++++++++++++------- 1 file changed, 70 insertions(+), 24 deletions(-) diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 8f6e2ffcd9..486540181b 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """Horizontal sharding support. @@ -15,26 +14,50 @@ For a usage example, see the :ref:`examples_sharding` example included in the source distribution. """ +from __future__ import annotations +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from test.ext.test_horizontal_shard import LazyLoadIdentityKeyTest +from test.ext.test_horizontal_shard import WeatherLocation from .. import event from .. import exc from .. import inspect from .. import util +from ..engine.base import Connection +from ..engine.base import Engine +from ..engine.base import OptionEngine +from ..engine.cursor import CursorResult +from ..engine.result import ChunkedIteratorResult +from ..engine.result import IteratorResult +from ..engine.result import MergedResult +from ..orm.bulk_persistence import BulkUDCompileState +from ..orm.context import QueryContext +from ..orm.mapper import Mapper from ..orm.query import Query +from ..orm.session import ORMExecuteState from ..orm.session import Session +from ..orm.state import InstanceState +from ..sql.base import _MetaOptions +from ..sql.selectable import Select __all__ = ["ShardedSession", "ShardedQuery"] class ShardedQuery(Query): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.id_chooser = self.session.id_chooser self.query_chooser = self.session.query_chooser self.execute_chooser = self.session.execute_chooser self._shard_id = None - def set_shard(self, shard_id): + def set_shard(self, shard_id: str) -> ShardedQuery: """Return a new query, limited to a single shard ID. All subsequent operations with the returned query will @@ -55,13 +78,13 @@ class ShardedQuery(Query): class ShardedSession(Session): def __init__( self, - shard_chooser, - id_chooser, - execute_chooser=None, - shards=None, - query_cls=ShardedQuery, - **kwargs, - ): + shard_chooser: Callable, + id_chooser: Callable, + execute_chooser: Callable = None, + shards: Dict[str, Any] = None, + query_cls: type = ShardedQuery, + **kwargs: Any, + ) -> None: """Construct a ShardedSession. :param shard_chooser: A callable which, passed a Mapper, a mapped @@ -122,12 +145,14 @@ class ShardedSession(Session): def _identity_lookup( self, - mapper, - primary_key_identity, - identity_token=None, - lazy_loaded_from=None, - **kw, - ): + mapper: Mapper, + primary_key_identity: List[int], + identity_token: Optional[Any] = None, + lazy_loaded_from: Optional[InstanceState] = None, + **kw: Any, + ) -> Union[ + None, WeatherLocation, LazyLoadIdentityKeyTest.setup_classes.Book + ]: """override the default :meth:`.Session._identity_lookup` method so that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). @@ -161,7 +186,9 @@ class ShardedSession(Session): return None - def _choose_shard_and_assign(self, mapper, instance, **kw): + def _choose_shard_and_assign( + self, mapper: Mapper, instance: WeatherLocation, **kw: Any + ) -> str: if instance is not None: state = inspect(instance) if state.key: @@ -177,8 +204,12 @@ class ShardedSession(Session): return shard_id def connection_callable( - self, mapper=None, instance=None, shard_id=None, **kwargs - ): + self, + mapper: Mapper = None, + instance: WeatherLocation = None, + shard_id: Optional[Any] = None, + **kwargs: Any, + ) -> Connection: """Provide a :class:`_engine.Connection` to use in the unit of work flush process. @@ -195,19 +226,28 @@ class ShardedSession(Session): ).connect(**kwargs) def get_bind( - self, mapper=None, shard_id=None, instance=None, clause=None, **kw - ): + self, + mapper: Mapper = None, + shard_id: str = None, + instance: Optional[Any] = None, + clause: Optional[Select] = None, + **kw: Any, + ) -> Union[Engine, OptionEngine]: if shard_id is None: shard_id = self._choose_shard_and_assign( mapper, instance, clause=clause ) return self.__binds[shard_id] - def bind_shard(self, shard_id, bind): + def bind_shard( + self, shard_id: str, bind: Union[Engine, OptionEngine] + ) -> None: self.__binds[shard_id] = bind -def execute_and_instances(orm_context): +def execute_and_instances( + orm_context: ORMExecuteState, +) -> Union[CursorResult, ChunkedIteratorResult, MergedResult]: if orm_context.is_select: load_options = active_options = orm_context.load_options update_options = None @@ -220,7 +260,13 @@ def execute_and_instances(orm_context): session = orm_context.session - def iter_for_shard(shard_id, load_options, update_options): + def iter_for_shard( + shard_id: str, + load_options: Union[ + None, QueryContext.default_load_options, _MetaOptions + ], + update_options: Optional[BulkUDCompileState.default_update_options], + ) -> Union[CursorResult, ChunkedIteratorResult, IteratorResult]: execution_options = dict(orm_context.local_execution_options) bind_arguments = dict(orm_context.bind_arguments) -- 2.47.2