from typing import Any
from typing import Callable
from typing import Dict
-from typing import List
+from typing import Iterable
from typing import Optional
+from typing import Tuple
+from typing import Type
+from typing import TypeVar
from typing import Union
-from test.ext.test_horizontal_shard import LazyLoadIdentityKeyTest
-from test.ext.test_horizontal_shard import WeatherLocation
from .. import event
from .. import exc
from .. import inspect
from ..engine.base import Connection
from ..engine.base import Engine
from ..engine.base import OptionEngine
-from ..engine.cursor import CursorResult
-from ..engine.result import ChunkedIteratorResult
from ..engine.result import IteratorResult
-from ..engine.result import MergedResult
+from ..engine.result import Result
+from ..orm import LoaderCallableStatus
+from ..orm import PassiveFlag
from ..orm.bulk_persistence import BulkUDCompileState
from ..orm.context import QueryContext
from ..orm.mapper import Mapper
from ..orm.query import Query
+from ..orm.session import _SessionBind
+from ..orm.session import _SessionBindKey
from ..orm.session import ORMExecuteState
from ..orm.session import Session
from ..orm.state import InstanceState
+from ..sql._typing import _TP
from ..sql.base import _MetaOptions
-from ..sql.selectable import Select
+from ..sql.elements import ClauseElement
+
__all__ = ["ShardedSession", "ShardedQuery"]
+_T = TypeVar("_T", bound=Any)
+
+SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]")
-class ShardedQuery(Query):
+
+class ShardedQuery(Query[_T]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.id_chooser = self.session.id_chooser
self.execute_chooser = self.session.execute_chooser
self._shard_id = None
- def set_shard(self, shard_id: str) -> ShardedQuery:
+ def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery:
"""Return a new query, limited to a single shard ID.
All subsequent operations with the returned query will
class ShardedSession(Session):
def __init__(
self,
- shard_chooser: Callable,
- id_chooser: Callable,
- execute_chooser: Callable = None,
- shards: Dict[str, Any] = None,
- query_cls: type = ShardedQuery,
+ shard_chooser: Callable[
+ [Mapper[_T], Any, Optional[ClauseElement]], Any
+ ],
+ id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]],
+ execute_chooser: Optional[
+ Callable[[ORMExecuteState], Iterable[Any]]
+ ] = None,
+ shards: Optional[Dict[str, Any]] = None,
+ query_cls: Type[Query[_T]] = ShardedQuery,
**kwargs: Any,
) -> None:
"""Construct a ShardedSession.
"at the same time."
)
- def execute_chooser(orm_context):
+ def execute_chooser(orm_context: ORMExecuteState) -> Any:
return query_chooser(orm_context.statement)
self.execute_chooser = execute_chooser
def _identity_lookup(
self,
- mapper: Mapper,
- primary_key_identity: List[int],
+ mapper: Mapper[_T],
+ primary_key_identity: Union[Any, Tuple[Any, ...]],
identity_token: Optional[Any] = None,
- lazy_loaded_from: Optional[InstanceState] = None,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ lazy_loaded_from: Optional[InstanceState[Any]] = None,
**kw: Any,
- ) -> Union[
- None, WeatherLocation, LazyLoadIdentityKeyTest.setup_classes.Book
- ]:
+ ) -> Union[Optional[_T], LoaderCallableStatus]:
"""override the default :meth:`.Session._identity_lookup` method so
that we search for a given non-token primary key identity across all
possible identity tokens (e.g. shard ids).
return None
def _choose_shard_and_assign(
- self, mapper: Mapper, instance: WeatherLocation, **kw: Any
- ) -> str:
+ self,
+ mapper: Mapper[_T],
+ instance: Any,
+ **kw: Any,
+ ) -> Any:
if instance is not None:
state = inspect(instance)
if state.key:
def connection_callable(
self,
- mapper: Mapper = None,
- instance: WeatherLocation = None,
+ mapper: Optional[Mapper[_T]] = None,
+ instance: Optional[Any] = None,
shard_id: Optional[Any] = None,
**kwargs: Any,
) -> Connection:
shard_id = self._choose_shard_and_assign(mapper, instance)
if self.in_transaction():
- return self.get_transaction().connection(mapper, shard_id=shard_id)
+ return self.get_transaction().connection(mapper, shard_id=shard_id) # type: ignore [union-attr] # noqa: E501
else:
return self.get_bind(
mapper, shard_id=shard_id, instance=instance
def get_bind(
self,
- mapper: Mapper = None,
- shard_id: str = None,
+ mapper: Optional[Mapper[_T]] = None,
+ shard_id: Optional[_SessionBindKey] = None,
instance: Optional[Any] = None,
- clause: Optional[Select] = None,
+ clause: Optional[ClauseElement] = None,
**kw: Any,
- ) -> Union[Engine, OptionEngine]:
+ ) -> _SessionBind:
if shard_id is None:
shard_id = self._choose_shard_and_assign(
mapper, instance, clause=clause
def execute_and_instances(
orm_context: ORMExecuteState,
-) -> Union[CursorResult, ChunkedIteratorResult, MergedResult]:
+) -> Union[Result[_T], IteratorResult[_TP]]:
if orm_context.is_select:
load_options = active_options = orm_context.load_options
update_options = None
None, QueryContext.default_load_options, _MetaOptions
],
update_options: Optional[BulkUDCompileState.default_update_options],
- ) -> Union[CursorResult, ChunkedIteratorResult, IteratorResult]:
+ ) -> Union[Result[_T], IteratorResult[_TP]]:
execution_options = dict(orm_context.local_execution_options)
bind_arguments = dict(orm_context.bind_arguments)
bind_arguments["shard_id"] = shard_id
if orm_context.is_select:
- load_options += {"_refresh_identity_token": shard_id}
+ load_options += {"_refresh_identity_token": shard_id} # type: ignore [operator] # noqa: E501
execution_options["_sa_orm_load_options"] = load_options
elif orm_context.is_update or orm_context.is_delete:
- update_options += {"_refresh_identity_token": shard_id}
+ update_options += {"_refresh_identity_token": shard_id} # type: ignore [operator] # noqa: E501
execution_options["_sa_orm_update_options"] = update_options
return orm_context.invoke_statement(