from ..engine.result import Result
from ..orm import LoaderCallableStatus
from ..orm import PassiveFlag
+from ..orm._typing import _O
from ..orm.bulk_persistence import BulkUDCompileState
from ..orm.context import QueryContext
from ..orm.mapper import Mapper
from ..orm.query import Query
+from ..orm.session import _EntityBindKey
from ..orm.session import _SessionBind
from ..orm.session import _SessionBindKey
from ..orm.session import ORMExecuteState
from ..orm.session import Session
from ..orm.state import InstanceState
from ..sql._typing import _TP
-from ..sql.base import _MetaOptions
from ..sql.elements import ClauseElement
+from ..util.typing import Protocol
__all__ = ["ShardedSession", "ShardedQuery"]
SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]")
+class ShardChooser(Protocol):
+ def __call__(
+ self,
+ mapper: Optional[Mapper[_T]],
+ instance: Any,
+ clause: Optional[ClauseElement],
+ ) -> Any:
+ ...
+
+
class ShardedQuery(Query[_T]):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
- self.id_chooser = self.session.id_chooser
- self.query_chooser = self.session.query_chooser
- self.execute_chooser = self.session.execute_chooser
+ self.id_chooser = self.session.id_chooser # type: ignore [attr-defined] # noqa: E501
+ self.query_chooser = self.session.query_chooser # type: ignore [attr-defined] # noqa: E501
+ self.execute_chooser = self.session.execute_chooser # type: ignore [attr-defined] # noqa: E501
self._shard_id = None
def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery:
class ShardedSession(Session):
def __init__(
self,
- shard_chooser: Callable[
- [Mapper[_T], Any, Optional[ClauseElement]], Any
- ],
+ shard_chooser: ShardChooser,
id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]],
execute_chooser: Optional[
Callable[[ORMExecuteState], Iterable[Any]]
if identity_token is not None:
return super()._identity_lookup(
- mapper,
+ mapper, # type: ignore [arg-type]
primary_key_identity,
identity_token=identity_token,
**kw,
def _choose_shard_and_assign(
self,
- mapper: Mapper[_T],
+ mapper: Optional[Mapper[_T]],
instance: Any,
**kw: Any,
) -> Any:
state.identity_token = shard_id
return shard_id
- def connection_callable(
+ def connection_callable( # type: ignore [override]
self,
mapper: Optional[Mapper[_T]] = None,
instance: Optional[Any] = None,
if self.in_transaction():
return self.get_transaction().connection(mapper, shard_id=shard_id) # type: ignore [union-attr] # noqa: E501
else:
- return self.get_bind(
+ return self.get_bind( # type: ignore [union-attr]
mapper, shard_id=shard_id, instance=instance
).connect(**kwargs)
- def get_bind(
+ def get_bind( # type: ignore [override]
self,
- mapper: Optional[Mapper[_T]] = None,
+ mapper: Optional[_EntityBindKey[_O]] = None,
shard_id: Optional[_SessionBindKey] = None,
instance: Optional[Any] = None,
clause: Optional[ClauseElement] = None,
) -> _SessionBind:
if shard_id is None:
shard_id = self._choose_shard_and_assign(
- mapper, instance, clause=clause
+ mapper, instance, clause=clause # type: ignore [arg-type]
)
return self.__binds[shard_id]
def bind_shard(
- self, shard_id: str, bind: Union[Engine, OptionEngine]
+ self, shard_id: _SessionBindKey, bind: Union[Engine, OptionEngine]
) -> None:
self.__binds[shard_id] = bind
elif orm_context.is_update or orm_context.is_delete:
load_options = None
- update_options = active_options = orm_context.update_delete_options
+ update_options = active_options = orm_context.update_delete_options # type: ignore [assignment] # noqa: E501
else:
load_options = update_options = active_options = None
def iter_for_shard(
shard_id: str,
load_options: Union[
- None, QueryContext.default_load_options, _MetaOptions
+ QueryContext.default_load_options,
+ Type[QueryContext.default_load_options],
+ ],
+ update_options: Union[
+ BulkUDCompileState.default_update_options,
+ Type[QueryContext.default_load_options],
],
- update_options: Optional[BulkUDCompileState.default_update_options],
) -> Union[Result[_T], IteratorResult[_TP]]:
execution_options = dict(orm_context.local_execution_options)
bind_arguments["shard_id"] = shard_id
if orm_context.is_select:
- load_options += {"_refresh_identity_token": shard_id} # type: ignore [operator] # noqa: E501
+ load_options += {"_refresh_identity_token": shard_id}
execution_options["_sa_orm_load_options"] = load_options
elif orm_context.is_update or orm_context.is_delete:
- update_options += {"_refresh_identity_token": shard_id} # type: ignore [operator] # noqa: E501
+ update_options += {"_refresh_identity_token": shard_id}
execution_options["_sa_orm_update_options"] = update_options
return orm_context.invoke_statement(
shard_id = None
if shard_id is not None:
- return iter_for_shard(shard_id, load_options, update_options)
+ return iter_for_shard(shard_id, load_options, update_options) # type: ignore [arg-type] # noqa: E501
else:
partial = []
- for shard_id in session.execute_chooser(orm_context):
- result_ = iter_for_shard(shard_id, load_options, update_options)
+ for shard_id in session.execute_chooser(orm_context): # type: ignore [attr-defined] # noqa: E501
+ result_ = iter_for_shard(shard_id, load_options, update_options) # type: ignore [arg-type] # noqa: E501
partial.append(result_)
return partial[0].merge(*partial[1:])