]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type annotations for sqlalchemy.ext.horizontal_shard
authorGleb Kisenkov <g.kisenkov@gmail.com>
Thu, 8 Dec 2022 22:48:55 +0000 (17:48 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 12 Dec 2022 00:10:19 +0000 (19:10 -0500)
The horizontal sharding extension is now pep-484 typed. Thanks to Gleb
Kisenkov for their efforts on this.

Closes: #8948
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8948
Pull-request-sha: e40e768492685aa9ce57c4762c571f935e3fd3c7

Change-Id: I2374e174c9433846c453c20a37ec5e5584fd3b31

doc/build/changelog/unreleased_20/horiz_typing.rst [new file with mode: 0644]
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/session.py
test/ext/test_horizontal_shard.py

diff --git a/doc/build/changelog/unreleased_20/horiz_typing.rst b/doc/build/changelog/unreleased_20/horiz_typing.rst
new file mode 100644 (file)
index 0000000..55b794d
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 6810
+
+    The horizontal sharding extension is now pep-484 typed. Thanks to Gleb
+    Kisenkov for their efforts on this.
index 8f6e2ffcd977b171ea637f978e2422244d810726..69767ad6cbf74b2195c86322edd52b3f08cd1d20 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 """Horizontal sharding support.
 
@@ -14,27 +13,93 @@ distribute queries and persistence operations across multiple databases.
 For a usage example, see the :ref:`examples_sharding` example included in
 the source distribution.
 
+.. legacy:: The horizontal sharding API is not fully updated for the
+   SQLAlchemy 2.0 API, and still relies in part on the
+   legacy :class:`.Query` architecture, in particular as part of the
+   signature for the :paramref:`.ShardedSession.id_chooser` parameter.
+   This may change in a future release.
+
 """
+from __future__ import annotations
+
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
 
 from .. import event
 from .. import exc
 from .. import inspect
 from .. import util
+from ..orm import PassiveFlag
+from ..orm.mapper import Mapper
 from ..orm.query import Query
 from ..orm.session import Session
+from ..util.typing import Protocol
+
+if TYPE_CHECKING:
+    from ..engine.base import Connection
+    from ..engine.base import Engine
+    from ..engine.base import OptionEngine
+    from ..engine.result import IteratorResult
+    from ..engine.result import Result
+    from ..orm import LoaderCallableStatus
+    from ..orm._typing import _O
+    from ..orm.bulk_persistence import BulkUDCompileState
+    from ..orm.context import QueryContext
+    from ..orm.session import _EntityBindKey
+    from ..orm.session import _SessionBind
+    from ..orm.session import ORMExecuteState
+    from ..orm.state import InstanceState
+    from ..sql import Executable
+    from ..sql._typing import _TP
+    from ..sql.elements import ClauseElement
 
 __all__ = ["ShardedSession", "ShardedQuery"]
 
+_T = TypeVar("_T", bound=Any)
+
+SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]")
+
+_ShardKey = str
+
+
+class ShardChooser(Protocol):
+    def __call__(
+        self,
+        mapper: Optional[Mapper[_T]],
+        instance: Any,
+        clause: Optional[ClauseElement],
+    ) -> Any:
+        ...
+
 
-class ShardedQuery(Query):
-    def __init__(self, *args, **kwargs):
+class ShardedQuery(Query[_T]):
+    """Query class used with :class:`.ShardedSession`.
+
+    .. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy
+       :class:`.Query` class.   The :class:`.ShardedSession` now supports
+       2.0 style execution via the :meth:`.ShardedSession.execute` method
+       as well.
+
+    """
+
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
         super().__init__(*args, **kwargs)
+        assert isinstance(self.session, ShardedSession)
+
         self.id_chooser = self.session.id_chooser
         self.query_chooser = self.session.query_chooser
         self.execute_chooser = self.session.execute_chooser
         self._shard_id = None
 
-    def set_shard(self, shard_id):
+    def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery:
         """Return a new query, limited to a single shard ID.
 
         All subsequent operations with the returned query will
@@ -53,15 +118,23 @@ class ShardedQuery(Query):
 
 
 class ShardedSession(Session):
+    shard_chooser: ShardChooser
+    id_chooser: Callable[[Query[Any], Iterable[Any]], Iterable[Any]]
+    execute_chooser: Callable[[ORMExecuteState], Iterable[Any]]
+
     def __init__(
         self,
-        shard_chooser,
-        id_chooser,
-        execute_chooser=None,
-        shards=None,
-        query_cls=ShardedQuery,
-        **kwargs,
-    ):
+        shard_chooser: ShardChooser,
+        id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]],
+        execute_chooser: Optional[
+            Callable[[ORMExecuteState], Iterable[Any]]
+        ] = None,
+        shards: Optional[Dict[str, Any]] = None,
+        query_cls: Type[Query[_T]] = ShardedQuery,
+        *,
+        query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None,
+        **kwargs: Any,
+    ) -> None:
         """Construct a ShardedSession.
 
         :param shard_chooser: A callable which, passed a Mapper, a mapped
@@ -71,9 +144,14 @@ class ShardedSession(Session):
           should set whatever state on the instance to mark it in the future as
           participating in that shard.
 
-        :param id_chooser: A callable, passed a query and a tuple of identity
-          values, which should return a list of shard ids where the ID might
-          reside.  The databases will be queried in the order of this listing.
+        :param id_chooser: A callable, passed a :class:`.ShardedQuery` and a
+          tuple of identity values, which should return a list of shard ids
+          where the ID might reside. The databases will be queried in the order
+          of this listing.
+
+          .. legacy:: This parameter still uses the legacy
+             :class:`.ShardedQuery` class as an argument passed to the
+             callable.
 
         :param execute_chooser: For a given :class:`.ORMExecuteState`,
           returns the list of shard_ids
@@ -87,7 +165,6 @@ class ShardedSession(Session):
           to :class:`~sqlalchemy.engine.Engine` objects.
 
         """
-        query_chooser = kwargs.pop("query_chooser", None)
         super().__init__(query_cls=query_cls, **kwargs)
 
         event.listen(
@@ -97,6 +174,7 @@ class ShardedSession(Session):
         self.id_chooser = id_chooser
 
         if query_chooser:
+            _query_chooser = query_chooser
             util.warn_deprecated(
                 "The ``query_choser`` parameter is deprecated; "
                 "please use ``execute_chooser``.",
@@ -108,26 +186,34 @@ class ShardedSession(Session):
                     "at the same time."
                 )
 
-            def execute_chooser(orm_context):
-                return query_chooser(orm_context.statement)
+            def _default_execute_chooser(
+                orm_context: ORMExecuteState,
+            ) -> Iterable[Any]:
+                return _query_chooser(orm_context.statement)
 
-            self.execute_chooser = execute_chooser
-        else:
-            self.execute_chooser = execute_chooser
+            if execute_chooser is None:
+                execute_chooser = _default_execute_chooser
+
+        if execute_chooser is None:
+            raise exc.ArgumentError(
+                "execute_chooser or query_chooser is required"
+            )
+        self.execute_chooser = execute_chooser
         self.query_chooser = query_chooser
-        self.__binds = {}
+        self.__shards: Dict[_ShardKey, _SessionBind] = {}
         if shards is not None:
             for k in shards:
                 self.bind_shard(k, shards[k])
 
     def _identity_lookup(
         self,
-        mapper,
-        primary_key_identity,
-        identity_token=None,
-        lazy_loaded_from=None,
-        **kw,
-    ):
+        mapper: Mapper[_O],
+        primary_key_identity: Union[Any, Tuple[Any, ...]],
+        identity_token: Optional[Any] = None,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+        lazy_loaded_from: Optional[InstanceState[Any]] = None,
+        **kw: Any,
+    ) -> Union[Optional[_O], LoaderCallableStatus]:
         """override the default :meth:`.Session._identity_lookup` method so
         that we search for a given non-token primary key identity across all
         possible identity tokens (e.g. shard ids).
@@ -138,30 +224,37 @@ class ShardedSession(Session):
         """
 
         if identity_token is not None:
-            return super()._identity_lookup(
+            obj = super()._identity_lookup(
                 mapper,
                 primary_key_identity,
                 identity_token=identity_token,
                 **kw,
             )
+
+            return obj
         else:
             q = self.query(mapper)
             if lazy_loaded_from:
                 q = q._set_lazyload_from(lazy_loaded_from)
             for shard_id in self.id_chooser(q, primary_key_identity):
-                obj = super()._identity_lookup(
+                obj2 = super()._identity_lookup(
                     mapper,
                     primary_key_identity,
                     identity_token=shard_id,
                     lazy_loaded_from=lazy_loaded_from,
                     **kw,
                 )
-                if obj is not None:
-                    return obj
+                if obj2 is not None:
+                    return obj2
 
             return None
 
-    def _choose_shard_and_assign(self, mapper, instance, **kw):
+    def _choose_shard_and_assign(
+        self,
+        mapper: Optional[_EntityBindKey[_O]],
+        instance: Any,
+        **kw: Any,
+    ) -> Any:
         if instance is not None:
             state = inspect(instance)
             if state.key:
@@ -171,14 +264,19 @@ class ShardedSession(Session):
             elif state.identity_token:
                 return state.identity_token
 
+        assert isinstance(mapper, Mapper)
         shard_id = self.shard_chooser(mapper, instance, **kw)
         if instance is not None:
             state.identity_token = shard_id
         return shard_id
 
-    def connection_callable(
-        self, mapper=None, instance=None, shard_id=None, **kwargs
-    ):
+    def connection_callable(  # type: ignore [override]
+        self,
+        mapper: Optional[Mapper[_T]] = None,
+        instance: Optional[Any] = None,
+        shard_id: Optional[Any] = None,
+        **kw: Any,
+    ) -> Connection:
         """Provide a :class:`_engine.Connection` to use in the unit of work
         flush process.
 
@@ -188,26 +286,63 @@ class ShardedSession(Session):
             shard_id = self._choose_shard_and_assign(mapper, instance)
 
         if self.in_transaction():
-            return self.get_transaction().connection(mapper, shard_id=shard_id)
+            trans = self.get_transaction()
+            assert trans is not None
+            return trans.connection(mapper, shard_id=shard_id)
         else:
-            return self.get_bind(
-                mapper, shard_id=shard_id, instance=instance
-            ).connect(**kwargs)
+            bind = self.get_bind(
+                mapper=mapper, shard_id=shard_id, instance=instance
+            )
+
+            if isinstance(bind, Engine):
+                return bind.connect(**kw)
+            else:
+                assert isinstance(bind, Connection)
+                return bind
 
     def get_bind(
-        self, mapper=None, shard_id=None, instance=None, clause=None, **kw
-    ):
+        self,
+        mapper: Optional[_EntityBindKey[_O]] = None,
+        *,
+        shard_id: Optional[_ShardKey] = None,
+        instance: Optional[Any] = None,
+        clause: Optional[ClauseElement] = None,
+        **kw: Any,
+    ) -> _SessionBind:
         if shard_id is None:
             shard_id = self._choose_shard_and_assign(
-                mapper, instance, clause=clause
+                mapper, instance=instance, clause=clause
             )
-        return self.__binds[shard_id]
+            assert shard_id is not None
+        return self.__shards[shard_id]
+
+    def bind_shard(
+        self, shard_id: _ShardKey, bind: Union[Engine, OptionEngine]
+    ) -> None:
+        self.__shards[shard_id] = bind
+
+
+def execute_and_instances(
+    orm_context: ORMExecuteState,
+) -> Union[Result[_T], IteratorResult[_TP]]:
+    update_options: Union[
+        None,
+        BulkUDCompileState.default_update_options,
+        Type[BulkUDCompileState.default_update_options],
+    ]
+    active_options: Union[
+        None,
+        QueryContext.default_load_options,
+        Type[QueryContext.default_load_options],
+        BulkUDCompileState.default_update_options,
+        Type[BulkUDCompileState.default_update_options],
+    ]
+    load_options: Union[
+        None,
+        QueryContext.default_load_options,
+        Type[QueryContext.default_load_options],
+    ]
 
-    def bind_shard(self, shard_id, bind):
-        self.__binds[shard_id] = bind
-
-
-def execute_and_instances(orm_context):
     if orm_context.is_select:
         load_options = active_options = orm_context.load_options
         update_options = None
@@ -219,17 +354,32 @@ def execute_and_instances(orm_context):
         load_options = update_options = active_options = None
 
     session = orm_context.session
-
-    def iter_for_shard(shard_id, load_options, update_options):
+    assert isinstance(session, ShardedSession)
+
+    def iter_for_shard(
+        shard_id: str,
+        load_options: Union[
+            None,
+            QueryContext.default_load_options,
+            Type[QueryContext.default_load_options],
+        ],
+        update_options: Union[
+            None,
+            BulkUDCompileState.default_update_options,
+            Type[BulkUDCompileState.default_update_options],
+        ],
+    ) -> Union[Result[_T], IteratorResult[_TP]]:
         execution_options = dict(orm_context.local_execution_options)
 
         bind_arguments = dict(orm_context.bind_arguments)
         bind_arguments["shard_id"] = shard_id
 
         if orm_context.is_select:
+            assert load_options is not None
             load_options += {"_refresh_identity_token": shard_id}
             execution_options["_sa_orm_load_options"] = load_options
         elif orm_context.is_update or orm_context.is_delete:
+            assert update_options is not None
             update_options += {"_refresh_identity_token": shard_id}
             execution_options["_sa_orm_update_options"] = update_options
 
index 15cd587926a7121a8f0ae209e80b856127733ce9..1f7bd6d73af8044a4d118244c7a8995f2be7b055 100644 (file)
@@ -152,7 +152,7 @@ _PKIdentityArgument = Union[Any, Tuple[Any, ...]]
 _BindArguments = Dict[str, Any]
 
 _EntityBindKey = Union[Type[_O], "Mapper[_O]"]
-_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"]
+_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table", str]
 _SessionBind = Union["Engine", "Connection"]
 
 
@@ -2374,6 +2374,7 @@ class Session(_SessionClassMethods, EventTarget):
     def get_bind(
         self,
         mapper: Optional[_EntityBindKey[_O]] = None,
+        *,
         clause: Optional[ClauseElement] = None,
         bind: Optional[_SessionBind] = None,
         _sa_skip_events: Optional[bool] = None,
index 8913478598eaeda8dad901ca44ad3d623fe57487..ab4a24f71c1e10c150354c3fde0197fe51c6b54e 100644 (file)
@@ -498,7 +498,7 @@ class ShardTest:
             eq_({t.temperature for t in temps}, {86.0, 75.0, 91.0})
 
         self.assert_sql_count(
-            sess._ShardedSession__binds["north_america"], go, 0
+            sess._ShardedSession__shards["north_america"], go, 0
         )
 
         eq_(
@@ -533,7 +533,7 @@ class ShardTest:
                 assert inspect(t).deleted is (t.temperature >= 80)
 
         self.assert_sql_count(
-            sess._ShardedSession__binds["north_america"], go, 0
+            sess._ShardedSession__shards["north_america"], go, 0
         )
 
         eq_(