]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Transient commit
authorGleb Kisenkov <g.kisenkov@gmail.com>
Mon, 5 Dec 2022 21:36:34 +0000 (22:36 +0100)
committerGleb Kisenkov <g.kisenkov@gmail.com>
Mon, 5 Dec 2022 21:36:34 +0000 (22:36 +0100)
Some type annotation fixes

lib/sqlalchemy/ext/horizontal_shard.py

index 486540181b01be218a141ed4b10695dc8c533cab..64d1b7b6e2f121b1b224424e38219de0b2ddb800 100644 (file)
@@ -19,12 +19,13 @@ from __future__ import annotations
 from typing import Any
 from typing import Callable
 from typing import Dict
-from typing import List
+from typing import Iterable
 from typing import Optional
+from typing import Tuple
+from typing import Type
+from typing import TypeVar
 from typing import Union
 
-from test.ext.test_horizontal_shard import LazyLoadIdentityKeyTest
-from test.ext.test_horizontal_shard import WeatherLocation
 from .. import event
 from .. import exc
 from .. import inspect
@@ -32,24 +33,32 @@ from .. import util
 from ..engine.base import Connection
 from ..engine.base import Engine
 from ..engine.base import OptionEngine
-from ..engine.cursor import CursorResult
-from ..engine.result import ChunkedIteratorResult
 from ..engine.result import IteratorResult
-from ..engine.result import MergedResult
+from ..engine.result import Result
+from ..orm import LoaderCallableStatus
+from ..orm import PassiveFlag
 from ..orm.bulk_persistence import BulkUDCompileState
 from ..orm.context import QueryContext
 from ..orm.mapper import Mapper
 from ..orm.query import Query
+from ..orm.session import _SessionBind
+from ..orm.session import _SessionBindKey
 from ..orm.session import ORMExecuteState
 from ..orm.session import Session
 from ..orm.state import InstanceState
+from ..sql._typing import _TP
 from ..sql.base import _MetaOptions
-from ..sql.selectable import Select
+from ..sql.elements import ClauseElement
+
 
 __all__ = ["ShardedSession", "ShardedQuery"]
 
+_T = TypeVar("_T", bound=Any)
+
+SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]")
 
-class ShardedQuery(Query):
+
+class ShardedQuery(Query[_T]):
     def __init__(self, *args: Any, **kwargs: Any) -> None:
         super().__init__(*args, **kwargs)
         self.id_chooser = self.session.id_chooser
@@ -57,7 +66,7 @@ class ShardedQuery(Query):
         self.execute_chooser = self.session.execute_chooser
         self._shard_id = None
 
-    def set_shard(self, shard_id: str) -> ShardedQuery:
+    def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery:
         """Return a new query, limited to a single shard ID.
 
         All subsequent operations with the returned query will
@@ -78,11 +87,15 @@ class ShardedQuery(Query):
 class ShardedSession(Session):
     def __init__(
         self,
-        shard_chooser: Callable,
-        id_chooser: Callable,
-        execute_chooser: Callable = None,
-        shards: Dict[str, Any] = None,
-        query_cls: type = ShardedQuery,
+        shard_chooser: Callable[
+            [Mapper[_T], Any, Optional[ClauseElement]], Any
+        ],
+        id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]],
+        execute_chooser: Optional[
+            Callable[[ORMExecuteState], Iterable[Any]]
+        ] = None,
+        shards: Optional[Dict[str, Any]] = None,
+        query_cls: Type[Query[_T]] = ShardedQuery,
         **kwargs: Any,
     ) -> None:
         """Construct a ShardedSession.
@@ -131,7 +144,7 @@ class ShardedSession(Session):
                     "at the same time."
                 )
 
-            def execute_chooser(orm_context):
+            def execute_chooser(orm_context: ORMExecuteState) -> Any:
                 return query_chooser(orm_context.statement)
 
             self.execute_chooser = execute_chooser
@@ -145,14 +158,13 @@ class ShardedSession(Session):
 
     def _identity_lookup(
         self,
-        mapper: Mapper,
-        primary_key_identity: List[int],
+        mapper: Mapper[_T],
+        primary_key_identity: Union[Any, Tuple[Any, ...]],
         identity_token: Optional[Any] = None,
-        lazy_loaded_from: Optional[InstanceState] = None,
+        passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+        lazy_loaded_from: Optional[InstanceState[Any]] = None,
         **kw: Any,
-    ) -> Union[
-        None, WeatherLocation, LazyLoadIdentityKeyTest.setup_classes.Book
-    ]:
+    ) -> Union[Optional[_T], LoaderCallableStatus]:
         """override the default :meth:`.Session._identity_lookup` method so
         that we search for a given non-token primary key identity across all
         possible identity tokens (e.g. shard ids).
@@ -187,8 +199,11 @@ class ShardedSession(Session):
             return None
 
     def _choose_shard_and_assign(
-        self, mapper: Mapper, instance: WeatherLocation, **kw: Any
-    ) -> str:
+        self,
+        mapper: Mapper[_T],
+        instance: Any,
+        **kw: Any,
+    ) -> Any:
         if instance is not None:
             state = inspect(instance)
             if state.key:
@@ -205,8 +220,8 @@ class ShardedSession(Session):
 
     def connection_callable(
         self,
-        mapper: Mapper = None,
-        instance: WeatherLocation = None,
+        mapper: Optional[Mapper[_T]] = None,
+        instance: Optional[Any] = None,
         shard_id: Optional[Any] = None,
         **kwargs: Any,
     ) -> Connection:
@@ -219,7 +234,7 @@ class ShardedSession(Session):
             shard_id = self._choose_shard_and_assign(mapper, instance)
 
         if self.in_transaction():
-            return self.get_transaction().connection(mapper, shard_id=shard_id)
+            return self.get_transaction().connection(mapper, shard_id=shard_id)  # type: ignore [union-attr] # noqa: E501
         else:
             return self.get_bind(
                 mapper, shard_id=shard_id, instance=instance
@@ -227,12 +242,12 @@ class ShardedSession(Session):
 
     def get_bind(
         self,
-        mapper: Mapper = None,
-        shard_id: str = None,
+        mapper: Optional[Mapper[_T]] = None,
+        shard_id: Optional[_SessionBindKey] = None,
         instance: Optional[Any] = None,
-        clause: Optional[Select] = None,
+        clause: Optional[ClauseElement] = None,
         **kw: Any,
-    ) -> Union[Engine, OptionEngine]:
+    ) -> _SessionBind:
         if shard_id is None:
             shard_id = self._choose_shard_and_assign(
                 mapper, instance, clause=clause
@@ -247,7 +262,7 @@ class ShardedSession(Session):
 
 def execute_and_instances(
     orm_context: ORMExecuteState,
-) -> Union[CursorResult, ChunkedIteratorResult, MergedResult]:
+) -> Union[Result[_T], IteratorResult[_TP]]:
     if orm_context.is_select:
         load_options = active_options = orm_context.load_options
         update_options = None
@@ -266,17 +281,17 @@ def execute_and_instances(
             None, QueryContext.default_load_options, _MetaOptions
         ],
         update_options: Optional[BulkUDCompileState.default_update_options],
-    ) -> Union[CursorResult, ChunkedIteratorResult, IteratorResult]:
+    ) -> Union[Result[_T], IteratorResult[_TP]]:
         execution_options = dict(orm_context.local_execution_options)
 
         bind_arguments = dict(orm_context.bind_arguments)
         bind_arguments["shard_id"] = shard_id
 
         if orm_context.is_select:
-            load_options += {"_refresh_identity_token": shard_id}
+            load_options += {"_refresh_identity_token": shard_id}  # type: ignore [operator] # noqa: E501
             execution_options["_sa_orm_load_options"] = load_options
         elif orm_context.is_update or orm_context.is_delete:
-            update_options += {"_refresh_identity_token": shard_id}
+            update_options += {"_refresh_identity_token": shard_id}  # type: ignore [operator] # noqa: E501
             execution_options["_sa_orm_update_options"] = update_options
 
         return orm_context.invoke_statement(