#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
"""Horizontal sharding support.
the source distribution.
"""
+from __future__ import annotations
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+from test.ext.test_horizontal_shard import LazyLoadIdentityKeyTest
+from test.ext.test_horizontal_shard import WeatherLocation
from .. import event
from .. import exc
from .. import inspect
from .. import util
+from ..engine.base import Connection
+from ..engine.base import Engine
+from ..engine.base import OptionEngine
+from ..engine.cursor import CursorResult
+from ..engine.result import ChunkedIteratorResult
+from ..engine.result import IteratorResult
+from ..engine.result import MergedResult
+from ..orm.bulk_persistence import BulkUDCompileState
+from ..orm.context import QueryContext
+from ..orm.mapper import Mapper
from ..orm.query import Query
+from ..orm.session import ORMExecuteState
from ..orm.session import Session
+from ..orm.state import InstanceState
+from ..sql.base import _MetaOptions
+from ..sql.selectable import Select
__all__ = ["ShardedSession", "ShardedQuery"]
class ShardedQuery(Query):
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.id_chooser = self.session.id_chooser
self.query_chooser = self.session.query_chooser
self.execute_chooser = self.session.execute_chooser
self._shard_id = None
- def set_shard(self, shard_id):
+ def set_shard(self, shard_id: str) -> ShardedQuery:
"""Return a new query, limited to a single shard ID.
All subsequent operations with the returned query will
class ShardedSession(Session):
def __init__(
self,
- shard_chooser,
- id_chooser,
- execute_chooser=None,
- shards=None,
- query_cls=ShardedQuery,
- **kwargs,
- ):
+ shard_chooser: Callable,
+ id_chooser: Callable,
+ execute_chooser: Callable = None,
+ shards: Dict[str, Any] = None,
+ query_cls: type = ShardedQuery,
+ **kwargs: Any,
+ ) -> None:
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped
def _identity_lookup(
self,
- mapper,
- primary_key_identity,
- identity_token=None,
- lazy_loaded_from=None,
- **kw,
- ):
+ mapper: Mapper,
+ primary_key_identity: List[int],
+ identity_token: Optional[Any] = None,
+ lazy_loaded_from: Optional[InstanceState] = None,
+ **kw: Any,
+ ) -> Union[
+ None, WeatherLocation, LazyLoadIdentityKeyTest.setup_classes.Book
+ ]:
"""override the default :meth:`.Session._identity_lookup` method so
that we search for a given non-token primary key identity across all
possible identity tokens (e.g. shard ids).
return None
- def _choose_shard_and_assign(self, mapper, instance, **kw):
+ def _choose_shard_and_assign(
+ self, mapper: Mapper, instance: WeatherLocation, **kw: Any
+ ) -> str:
if instance is not None:
state = inspect(instance)
if state.key:
return shard_id
def connection_callable(
- self, mapper=None, instance=None, shard_id=None, **kwargs
- ):
+ self,
+ mapper: Mapper = None,
+ instance: WeatherLocation = None,
+ shard_id: Optional[Any] = None,
+ **kwargs: Any,
+ ) -> Connection:
"""Provide a :class:`_engine.Connection` to use in the unit of work
flush process.
).connect(**kwargs)
def get_bind(
- self, mapper=None, shard_id=None, instance=None, clause=None, **kw
- ):
+ self,
+ mapper: Mapper = None,
+ shard_id: str = None,
+ instance: Optional[Any] = None,
+ clause: Optional[Select] = None,
+ **kw: Any,
+ ) -> Union[Engine, OptionEngine]:
if shard_id is None:
shard_id = self._choose_shard_and_assign(
mapper, instance, clause=clause
)
return self.__binds[shard_id]
- def bind_shard(self, shard_id, bind):
+ def bind_shard(
+ self, shard_id: str, bind: Union[Engine, OptionEngine]
+ ) -> None:
self.__binds[shard_id] = bind
-def execute_and_instances(orm_context):
+def execute_and_instances(
+ orm_context: ORMExecuteState,
+) -> Union[CursorResult, ChunkedIteratorResult, MergedResult]:
if orm_context.is_select:
load_options = active_options = orm_context.load_options
update_options = None
session = orm_context.session
- def iter_for_shard(shard_id, load_options, update_options):
+ def iter_for_shard(
+ shard_id: str,
+ load_options: Union[
+ None, QueryContext.default_load_options, _MetaOptions
+ ],
+ update_options: Optional[BulkUDCompileState.default_update_options],
+ ) -> Union[CursorResult, ChunkedIteratorResult, IteratorResult]:
execution_options = dict(orm_context.local_execution_options)
bind_arguments = dict(orm_context.bind_arguments)