#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
"""Horizontal sharding support.
For a usage example, see the :ref:`examples_sharding` example included in
the source distribution.
+.. legacy:: The horizontal sharding API is not fully updated for the
+ SQLAlchemy 2.0 API, and still relies in part on the
+ legacy :class:`.Query` architecture, in particular as part of the
+ signature for the :paramref:`.ShardedSession.id_chooser` parameter.
+ This may change in a future release.
+
"""
+from __future__ import annotations
+
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from .. import event
from .. import exc
from .. import inspect
from .. import util
+from ..orm import PassiveFlag
+from ..orm.mapper import Mapper
from ..orm.query import Query
from ..orm.session import Session
+from ..util.typing import Protocol
+
+if TYPE_CHECKING:
+ from ..engine.base import Connection
+ from ..engine.base import Engine
+ from ..engine.base import OptionEngine
+ from ..engine.result import IteratorResult
+ from ..engine.result import Result
+ from ..orm import LoaderCallableStatus
+ from ..orm._typing import _O
+ from ..orm.bulk_persistence import BulkUDCompileState
+ from ..orm.context import QueryContext
+ from ..orm.session import _EntityBindKey
+ from ..orm.session import _SessionBind
+ from ..orm.session import ORMExecuteState
+ from ..orm.state import InstanceState
+ from ..sql import Executable
+ from ..sql._typing import _TP
+ from ..sql.elements import ClauseElement
__all__ = ["ShardedSession", "ShardedQuery"]
+_T = TypeVar("_T", bound=Any)
+
+SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]")
+
+_ShardKey = str
+
+
+class ShardChooser(Protocol):
+ def __call__(
+ self,
+ mapper: Optional[Mapper[_T]],
+ instance: Any,
+ clause: Optional[ClauseElement],
+ ) -> Any:
+ ...
+
-class ShardedQuery(Query):
- def __init__(self, *args, **kwargs):
+class ShardedQuery(Query[_T]):
+ """Query class used with :class:`.ShardedSession`.
+
+ .. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy
+ :class:`.Query` class. The :class:`.ShardedSession` now supports
+ 2.0 style execution via the :meth:`.ShardedSession.execute` method
+ as well.
+
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
+ assert isinstance(self.session, ShardedSession)
+
self.id_chooser = self.session.id_chooser
self.query_chooser = self.session.query_chooser
self.execute_chooser = self.session.execute_chooser
self._shard_id = None
- def set_shard(self, shard_id):
+ def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery:
"""Return a new query, limited to a single shard ID.
All subsequent operations with the returned query will
class ShardedSession(Session):
+ shard_chooser: ShardChooser
+ id_chooser: Callable[[Query[Any], Iterable[Any]], Iterable[Any]]
+ execute_chooser: Callable[[ORMExecuteState], Iterable[Any]]
+
def __init__(
self,
- shard_chooser,
- id_chooser,
- execute_chooser=None,
- shards=None,
- query_cls=ShardedQuery,
- **kwargs,
- ):
+ shard_chooser: ShardChooser,
+ id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]],
+ execute_chooser: Optional[
+ Callable[[ORMExecuteState], Iterable[Any]]
+ ] = None,
+ shards: Optional[Dict[str, Any]] = None,
+ query_cls: Type[Query[_T]] = ShardedQuery,
+ *,
+ query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None,
+ **kwargs: Any,
+ ) -> None:
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped
should set whatever state on the instance to mark it in the future as
participating in that shard.
- :param id_chooser: A callable, passed a query and a tuple of identity
- values, which should return a list of shard ids where the ID might
- reside. The databases will be queried in the order of this listing.
+ :param id_chooser: A callable, passed a :class:`.ShardedQuery` and a
+ tuple of identity values, which should return a list of shard ids
+ where the ID might reside. The databases will be queried in the order
+ of this listing.
+
+ .. legacy:: This parameter still uses the legacy
+ :class:`.ShardedQuery` class as an argument passed to the
+ callable.
:param execute_chooser: For a given :class:`.ORMExecuteState`,
returns the list of shard_ids
to :class:`~sqlalchemy.engine.Engine` objects.
"""
- query_chooser = kwargs.pop("query_chooser", None)
super().__init__(query_cls=query_cls, **kwargs)
event.listen(
self.id_chooser = id_chooser
if query_chooser:
+ _query_chooser = query_chooser
util.warn_deprecated(
"The ``query_choser`` parameter is deprecated; "
"please use ``execute_chooser``.",
"at the same time."
)
- def execute_chooser(orm_context):
- return query_chooser(orm_context.statement)
+ def _default_execute_chooser(
+ orm_context: ORMExecuteState,
+ ) -> Iterable[Any]:
+ return _query_chooser(orm_context.statement)
- self.execute_chooser = execute_chooser
- else:
- self.execute_chooser = execute_chooser
+ if execute_chooser is None:
+ execute_chooser = _default_execute_chooser
+
+ if execute_chooser is None:
+ raise exc.ArgumentError(
+ "execute_chooser or query_chooser is required"
+ )
+ self.execute_chooser = execute_chooser
self.query_chooser = query_chooser
- self.__binds = {}
+ self.__shards: Dict[_ShardKey, _SessionBind] = {}
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
def _identity_lookup(
self,
- mapper,
- primary_key_identity,
- identity_token=None,
- lazy_loaded_from=None,
- **kw,
- ):
+ mapper: Mapper[_O],
+ primary_key_identity: Union[Any, Tuple[Any, ...]],
+ identity_token: Optional[Any] = None,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ lazy_loaded_from: Optional[InstanceState[Any]] = None,
+ **kw: Any,
+ ) -> Union[Optional[_O], LoaderCallableStatus]:
"""override the default :meth:`.Session._identity_lookup` method so
that we search for a given non-token primary key identity across all
possible identity tokens (e.g. shard ids).
"""
if identity_token is not None:
- return super()._identity_lookup(
+ obj = super()._identity_lookup(
mapper,
primary_key_identity,
identity_token=identity_token,
**kw,
)
+
+ return obj
else:
q = self.query(mapper)
if lazy_loaded_from:
q = q._set_lazyload_from(lazy_loaded_from)
for shard_id in self.id_chooser(q, primary_key_identity):
- obj = super()._identity_lookup(
+ obj2 = super()._identity_lookup(
mapper,
primary_key_identity,
identity_token=shard_id,
lazy_loaded_from=lazy_loaded_from,
**kw,
)
- if obj is not None:
- return obj
+ if obj2 is not None:
+ return obj2
return None
- def _choose_shard_and_assign(self, mapper, instance, **kw):
+ def _choose_shard_and_assign(
+ self,
+ mapper: Optional[_EntityBindKey[_O]],
+ instance: Any,
+ **kw: Any,
+ ) -> Any:
if instance is not None:
state = inspect(instance)
if state.key:
elif state.identity_token:
return state.identity_token
+ assert isinstance(mapper, Mapper)
shard_id = self.shard_chooser(mapper, instance, **kw)
if instance is not None:
state.identity_token = shard_id
return shard_id
- def connection_callable(
- self, mapper=None, instance=None, shard_id=None, **kwargs
- ):
+ def connection_callable( # type: ignore [override]
+ self,
+ mapper: Optional[Mapper[_T]] = None,
+ instance: Optional[Any] = None,
+ shard_id: Optional[Any] = None,
+ **kw: Any,
+ ) -> Connection:
"""Provide a :class:`_engine.Connection` to use in the unit of work
flush process.
shard_id = self._choose_shard_and_assign(mapper, instance)
if self.in_transaction():
- return self.get_transaction().connection(mapper, shard_id=shard_id)
+ trans = self.get_transaction()
+ assert trans is not None
+ return trans.connection(mapper, shard_id=shard_id)
else:
- return self.get_bind(
- mapper, shard_id=shard_id, instance=instance
- ).connect(**kwargs)
+ bind = self.get_bind(
+ mapper=mapper, shard_id=shard_id, instance=instance
+ )
+
+ if isinstance(bind, Engine):
+ return bind.connect(**kw)
+ else:
+ assert isinstance(bind, Connection)
+ return bind
def get_bind(
- self, mapper=None, shard_id=None, instance=None, clause=None, **kw
- ):
+ self,
+ mapper: Optional[_EntityBindKey[_O]] = None,
+ *,
+ shard_id: Optional[_ShardKey] = None,
+ instance: Optional[Any] = None,
+ clause: Optional[ClauseElement] = None,
+ **kw: Any,
+ ) -> _SessionBind:
if shard_id is None:
shard_id = self._choose_shard_and_assign(
- mapper, instance, clause=clause
+ mapper, instance=instance, clause=clause
)
- return self.__binds[shard_id]
+ assert shard_id is not None
+ return self.__shards[shard_id]
+
+ def bind_shard(
+ self, shard_id: _ShardKey, bind: Union[Engine, OptionEngine]
+ ) -> None:
+ self.__shards[shard_id] = bind
+
+
+def execute_and_instances(
+ orm_context: ORMExecuteState,
+) -> Union[Result[_T], IteratorResult[_TP]]:
+ update_options: Union[
+ None,
+ BulkUDCompileState.default_update_options,
+ Type[BulkUDCompileState.default_update_options],
+ ]
+ active_options: Union[
+ None,
+ QueryContext.default_load_options,
+ Type[QueryContext.default_load_options],
+ BulkUDCompileState.default_update_options,
+ Type[BulkUDCompileState.default_update_options],
+ ]
+ load_options: Union[
+ None,
+ QueryContext.default_load_options,
+ Type[QueryContext.default_load_options],
+ ]
- def bind_shard(self, shard_id, bind):
- self.__binds[shard_id] = bind
-
-
-def execute_and_instances(orm_context):
if orm_context.is_select:
load_options = active_options = orm_context.load_options
update_options = None
load_options = update_options = active_options = None
session = orm_context.session
-
- def iter_for_shard(shard_id, load_options, update_options):
+ assert isinstance(session, ShardedSession)
+
+ def iter_for_shard(
+ shard_id: str,
+ load_options: Union[
+ None,
+ QueryContext.default_load_options,
+ Type[QueryContext.default_load_options],
+ ],
+ update_options: Union[
+ None,
+ BulkUDCompileState.default_update_options,
+ Type[BulkUDCompileState.default_update_options],
+ ],
+ ) -> Union[Result[_T], IteratorResult[_TP]]:
execution_options = dict(orm_context.local_execution_options)
bind_arguments = dict(orm_context.bind_arguments)
bind_arguments["shard_id"] = shard_id
if orm_context.is_select:
+ assert load_options is not None
load_options += {"_refresh_identity_token": shard_id}
execution_options["_sa_orm_load_options"] = load_options
elif orm_context.is_update or orm_context.is_delete:
+ assert update_options is not None
update_options += {"_refresh_identity_token": shard_id}
execution_options["_sa_orm_update_options"] = update_options