From: Gleb Kisenkov Date: Mon, 5 Dec 2022 21:36:34 +0000 (+0100) Subject: Transient commit X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ab8af21e7bbf6ce1cd90528eb29b941d4d348bf8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Transient commit Some type annotation fixes --- diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 486540181b..64d1b7b6e2 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -19,12 +19,13 @@ from __future__ import annotations from typing import Any from typing import Callable from typing import Dict -from typing import List +from typing import Iterable from typing import Optional +from typing import Tuple +from typing import Type +from typing import TypeVar from typing import Union -from test.ext.test_horizontal_shard import LazyLoadIdentityKeyTest -from test.ext.test_horizontal_shard import WeatherLocation from .. import event from .. import exc from .. import inspect @@ -32,24 +33,32 @@ from .. import util from ..engine.base import Connection from ..engine.base import Engine from ..engine.base import OptionEngine -from ..engine.cursor import CursorResult -from ..engine.result import ChunkedIteratorResult from ..engine.result import IteratorResult -from ..engine.result import MergedResult +from ..engine.result import Result +from ..orm import LoaderCallableStatus +from ..orm import PassiveFlag from ..orm.bulk_persistence import BulkUDCompileState from ..orm.context import QueryContext from ..orm.mapper import Mapper from ..orm.query import Query +from ..orm.session import _SessionBind +from ..orm.session import _SessionBindKey from ..orm.session import ORMExecuteState from ..orm.session import Session from ..orm.state import InstanceState +from ..sql._typing import _TP from ..sql.base import _MetaOptions -from ..sql.selectable import Select +from ..sql.elements import ClauseElement + __all__ = ["ShardedSession", "ShardedQuery"] +_T = TypeVar("_T", bound=Any) + +SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]") -class ShardedQuery(Query): + +class ShardedQuery(Query[_T]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.id_chooser = self.session.id_chooser @@ -57,7 +66,7 @@ class ShardedQuery(Query): self.execute_chooser = self.session.execute_chooser self._shard_id = None - def set_shard(self, shard_id: str) -> ShardedQuery: + def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery: """Return a new query, limited to a single shard ID. All subsequent operations with the returned query will @@ -78,11 +87,15 @@ class ShardedQuery(Query): class ShardedSession(Session): def __init__( self, - shard_chooser: Callable, - id_chooser: Callable, - execute_chooser: Callable = None, - shards: Dict[str, Any] = None, - query_cls: type = ShardedQuery, + shard_chooser: Callable[ + [Mapper[_T], Any, Optional[ClauseElement]], Any + ], + id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]], + execute_chooser: Optional[ + Callable[[ORMExecuteState], Iterable[Any]] + ] = None, + shards: Optional[Dict[str, Any]] = None, + query_cls: Type[Query[_T]] = ShardedQuery, **kwargs: Any, ) -> None: """Construct a ShardedSession. @@ -131,7 +144,7 @@ class ShardedSession(Session): "at the same time." ) - def execute_chooser(orm_context): + def execute_chooser(orm_context: ORMExecuteState) -> Any: return query_chooser(orm_context.statement) self.execute_chooser = execute_chooser @@ -145,14 +158,13 @@ class ShardedSession(Session): def _identity_lookup( self, - mapper: Mapper, - primary_key_identity: List[int], + mapper: Mapper[_T], + primary_key_identity: Union[Any, Tuple[Any, ...]], identity_token: Optional[Any] = None, - lazy_loaded_from: Optional[InstanceState] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + lazy_loaded_from: Optional[InstanceState[Any]] = None, **kw: Any, - ) -> Union[ - None, WeatherLocation, LazyLoadIdentityKeyTest.setup_classes.Book - ]: + ) -> Union[Optional[_T], LoaderCallableStatus]: """override the default :meth:`.Session._identity_lookup` method so that we search for a given non-token primary key identity across all possible identity tokens (e.g. shard ids). @@ -187,8 +199,11 @@ class ShardedSession(Session): return None def _choose_shard_and_assign( - self, mapper: Mapper, instance: WeatherLocation, **kw: Any - ) -> str: + self, + mapper: Mapper[_T], + instance: Any, + **kw: Any, + ) -> Any: if instance is not None: state = inspect(instance) if state.key: @@ -205,8 +220,8 @@ class ShardedSession(Session): def connection_callable( self, - mapper: Mapper = None, - instance: WeatherLocation = None, + mapper: Optional[Mapper[_T]] = None, + instance: Optional[Any] = None, shard_id: Optional[Any] = None, **kwargs: Any, ) -> Connection: @@ -219,7 +234,7 @@ class ShardedSession(Session): shard_id = self._choose_shard_and_assign(mapper, instance) if self.in_transaction(): - return self.get_transaction().connection(mapper, shard_id=shard_id) + return self.get_transaction().connection(mapper, shard_id=shard_id) # type: ignore [union-attr] # noqa: E501 else: return self.get_bind( mapper, shard_id=shard_id, instance=instance @@ -227,12 +242,12 @@ class ShardedSession(Session): def get_bind( self, - mapper: Mapper = None, - shard_id: str = None, + mapper: Optional[Mapper[_T]] = None, + shard_id: Optional[_SessionBindKey] = None, instance: Optional[Any] = None, - clause: Optional[Select] = None, + clause: Optional[ClauseElement] = None, **kw: Any, - ) -> Union[Engine, OptionEngine]: + ) -> _SessionBind: if shard_id is None: shard_id = self._choose_shard_and_assign( mapper, instance, clause=clause @@ -247,7 +262,7 @@ class ShardedSession(Session): def execute_and_instances( orm_context: ORMExecuteState, -) -> Union[CursorResult, ChunkedIteratorResult, MergedResult]: +) -> Union[Result[_T], IteratorResult[_TP]]: if orm_context.is_select: load_options = active_options = orm_context.load_options update_options = None @@ -266,17 +281,17 @@ def execute_and_instances( None, QueryContext.default_load_options, _MetaOptions ], update_options: Optional[BulkUDCompileState.default_update_options], - ) -> Union[CursorResult, ChunkedIteratorResult, IteratorResult]: + ) -> Union[Result[_T], IteratorResult[_TP]]: execution_options = dict(orm_context.local_execution_options) bind_arguments = dict(orm_context.bind_arguments) bind_arguments["shard_id"] = shard_id if orm_context.is_select: - load_options += {"_refresh_identity_token": shard_id} + load_options += {"_refresh_identity_token": shard_id} # type: ignore [operator] # noqa: E501 execution_options["_sa_orm_load_options"] = load_options elif orm_context.is_update or orm_context.is_delete: - update_options += {"_refresh_identity_token": shard_id} + update_options += {"_refresh_identity_token": shard_id} # type: ignore [operator] # noqa: E501 execution_options["_sa_orm_update_options"] = update_options return orm_context.invoke_statement(