]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Collected runtime types
authorGleb Kisenkov <g.kisenkov@gmail.com>
Sat, 3 Dec 2022 09:56:08 +0000 (10:56 +0100)
committerGleb Kisenkov <g.kisenkov@gmail.com>
Sat, 3 Dec 2022 09:56:08 +0000 (10:56 +0100)
lib/sqlalchemy/ext/horizontal_shard.py

index 8f6e2ffcd977b171ea637f978e2422244d810726..486540181b01be218a141ed4b10695dc8c533cab 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 """Horizontal sharding support.
 
@@ -15,26 +14,50 @@ For a usage example, see the :ref:`examples_sharding` example included in
 the source distribution.
 
 """
+from __future__ import annotations
 
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+from test.ext.test_horizontal_shard import LazyLoadIdentityKeyTest
+from test.ext.test_horizontal_shard import WeatherLocation
 from .. import event
 from .. import exc
 from .. import inspect
 from .. import util
+from ..engine.base import Connection
+from ..engine.base import Engine
+from ..engine.base import OptionEngine
+from ..engine.cursor import CursorResult
+from ..engine.result import ChunkedIteratorResult
+from ..engine.result import IteratorResult
+from ..engine.result import MergedResult
+from ..orm.bulk_persistence import BulkUDCompileState
+from ..orm.context import QueryContext
+from ..orm.mapper import Mapper
 from ..orm.query import Query
+from ..orm.session import ORMExecuteState
 from ..orm.session import Session
+from ..orm.state import InstanceState
+from ..sql.base import _MetaOptions
+from ..sql.selectable import Select
 
 __all__ = ["ShardedSession", "ShardedQuery"]
 
 
 class ShardedQuery(Query):
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
         super().__init__(*args, **kwargs)
         self.id_chooser = self.session.id_chooser
         self.query_chooser = self.session.query_chooser
         self.execute_chooser = self.session.execute_chooser
         self._shard_id = None
 
-    def set_shard(self, shard_id):
+    def set_shard(self, shard_id: str) -> ShardedQuery:
         """Return a new query, limited to a single shard ID.
 
         All subsequent operations with the returned query will
@@ -55,13 +78,13 @@ class ShardedQuery(Query):
 class ShardedSession(Session):
     def __init__(
         self,
-        shard_chooser,
-        id_chooser,
-        execute_chooser=None,
-        shards=None,
-        query_cls=ShardedQuery,
-        **kwargs,
-    ):
+        shard_chooser: Callable,
+        id_chooser: Callable,
+        execute_chooser: Callable = None,
+        shards: Dict[str, Any] = None,
+        query_cls: type = ShardedQuery,
+        **kwargs: Any,
+    ) -> None:
         """Construct a ShardedSession.
 
         :param shard_chooser: A callable which, passed a Mapper, a mapped
@@ -122,12 +145,14 @@ class ShardedSession(Session):
 
     def _identity_lookup(
         self,
-        mapper,
-        primary_key_identity,
-        identity_token=None,
-        lazy_loaded_from=None,
-        **kw,
-    ):
+        mapper: Mapper,
+        primary_key_identity: List[int],
+        identity_token: Optional[Any] = None,
+        lazy_loaded_from: Optional[InstanceState] = None,
+        **kw: Any,
+    ) -> Union[
+        None, WeatherLocation, LazyLoadIdentityKeyTest.setup_classes.Book
+    ]:
         """override the default :meth:`.Session._identity_lookup` method so
         that we search for a given non-token primary key identity across all
         possible identity tokens (e.g. shard ids).
@@ -161,7 +186,9 @@ class ShardedSession(Session):
 
             return None
 
-    def _choose_shard_and_assign(self, mapper, instance, **kw):
+    def _choose_shard_and_assign(
+        self, mapper: Mapper, instance: WeatherLocation, **kw: Any
+    ) -> str:
         if instance is not None:
             state = inspect(instance)
             if state.key:
@@ -177,8 +204,12 @@ class ShardedSession(Session):
         return shard_id
 
     def connection_callable(
-        self, mapper=None, instance=None, shard_id=None, **kwargs
-    ):
+        self,
+        mapper: Mapper = None,
+        instance: WeatherLocation = None,
+        shard_id: Optional[Any] = None,
+        **kwargs: Any,
+    ) -> Connection:
         """Provide a :class:`_engine.Connection` to use in the unit of work
         flush process.
 
@@ -195,19 +226,28 @@ class ShardedSession(Session):
             ).connect(**kwargs)
 
     def get_bind(
-        self, mapper=None, shard_id=None, instance=None, clause=None, **kw
-    ):
+        self,
+        mapper: Mapper = None,
+        shard_id: str = None,
+        instance: Optional[Any] = None,
+        clause: Optional[Select] = None,
+        **kw: Any,
+    ) -> Union[Engine, OptionEngine]:
         if shard_id is None:
             shard_id = self._choose_shard_and_assign(
                 mapper, instance, clause=clause
             )
         return self.__binds[shard_id]
 
-    def bind_shard(self, shard_id, bind):
+    def bind_shard(
+        self, shard_id: str, bind: Union[Engine, OptionEngine]
+    ) -> None:
         self.__binds[shard_id] = bind
 
 
-def execute_and_instances(orm_context):
+def execute_and_instances(
+    orm_context: ORMExecuteState,
+) -> Union[CursorResult, ChunkedIteratorResult, MergedResult]:
     if orm_context.is_select:
         load_options = active_options = orm_context.load_options
         update_options = None
@@ -220,7 +260,13 @@ def execute_and_instances(orm_context):
 
     session = orm_context.session
 
-    def iter_for_shard(shard_id, load_options, update_options):
+    def iter_for_shard(
+        shard_id: str,
+        load_options: Union[
+            None, QueryContext.default_load_options, _MetaOptions
+        ],
+        update_options: Optional[BulkUDCompileState.default_update_options],
+    ) -> Union[CursorResult, ChunkedIteratorResult, IteratorResult]:
         execution_options = dict(orm_context.local_execution_options)
 
         bind_arguments = dict(orm_context.bind_arguments)