]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Public API covered
authorGleb Kisenkov <g.kisenkov@gmail.com>
Wed, 7 Dec 2022 07:48:30 +0000 (08:48 +0100)
committerGleb Kisenkov <g.kisenkov@gmail.com>
Wed, 7 Dec 2022 07:48:30 +0000 (08:48 +0100)
The whole module needs to be refined in the future. Especially addressing the internal types.

lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/session.py

index 64d1b7b6e2f121b1b224424e38219de0b2ddb800..5075d3f39cbf9e738baf9fb8395243f0bf04670e 100644 (file)
@@ -37,18 +37,20 @@ from ..engine.result import IteratorResult
 from ..engine.result import Result
 from ..orm import LoaderCallableStatus
 from ..orm import PassiveFlag
+from ..orm._typing import _O
 from ..orm.bulk_persistence import BulkUDCompileState
 from ..orm.context import QueryContext
 from ..orm.mapper import Mapper
 from ..orm.query import Query
+from ..orm.session import _EntityBindKey
 from ..orm.session import _SessionBind
 from ..orm.session import _SessionBindKey
 from ..orm.session import ORMExecuteState
 from ..orm.session import Session
 from ..orm.state import InstanceState
 from ..sql._typing import _TP
-from ..sql.base import _MetaOptions
 from ..sql.elements import ClauseElement
+from ..util.typing import Protocol
 
 
 __all__ = ["ShardedSession", "ShardedQuery"]
@@ -58,12 +60,22 @@ _T = TypeVar("_T", bound=Any)
 SelfShardedQuery = TypeVar("SelfShardedQuery", bound="ShardedQuery[Any]")
 
 
+class ShardChooser(Protocol):
+    def __call__(
+        self,
+        mapper: Optional[Mapper[_T]],
+        instance: Any,
+        clause: Optional[ClauseElement],
+    ) -> Any:
+        ...
+
+
 class ShardedQuery(Query[_T]):
     def __init__(self, *args: Any, **kwargs: Any) -> None:
         super().__init__(*args, **kwargs)
-        self.id_chooser = self.session.id_chooser
-        self.query_chooser = self.session.query_chooser
-        self.execute_chooser = self.session.execute_chooser
+        self.id_chooser = self.session.id_chooser  # type: ignore [attr-defined] # noqa: E501
+        self.query_chooser = self.session.query_chooser  # type: ignore [attr-defined] # noqa: E501
+        self.execute_chooser = self.session.execute_chooser  # type: ignore [attr-defined] # noqa: E501
         self._shard_id = None
 
     def set_shard(self: SelfShardedQuery, shard_id: str) -> SelfShardedQuery:
@@ -87,9 +99,7 @@ class ShardedQuery(Query[_T]):
 class ShardedSession(Session):
     def __init__(
         self,
-        shard_chooser: Callable[
-            [Mapper[_T], Any, Optional[ClauseElement]], Any
-        ],
+        shard_chooser: ShardChooser,
         id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]],
         execute_chooser: Optional[
             Callable[[ORMExecuteState], Iterable[Any]]
@@ -176,7 +186,7 @@ class ShardedSession(Session):
 
         if identity_token is not None:
             return super()._identity_lookup(
-                mapper,
+                mapper,  # type: ignore [arg-type]
                 primary_key_identity,
                 identity_token=identity_token,
                 **kw,
@@ -200,7 +210,7 @@ class ShardedSession(Session):
 
     def _choose_shard_and_assign(
         self,
-        mapper: Mapper[_T],
+        mapper: Optional[Mapper[_T]],
         instance: Any,
         **kw: Any,
     ) -> Any:
@@ -218,7 +228,7 @@ class ShardedSession(Session):
             state.identity_token = shard_id
         return shard_id
 
-    def connection_callable(
+    def connection_callable(  # type: ignore [override]
         self,
         mapper: Optional[Mapper[_T]] = None,
         instance: Optional[Any] = None,
@@ -236,13 +246,13 @@ class ShardedSession(Session):
         if self.in_transaction():
             return self.get_transaction().connection(mapper, shard_id=shard_id)  # type: ignore [union-attr] # noqa: E501
         else:
-            return self.get_bind(
+            return self.get_bind(  # type: ignore [union-attr]
                 mapper, shard_id=shard_id, instance=instance
             ).connect(**kwargs)
 
-    def get_bind(
+    def get_bind(  # type: ignore [override]
         self,
-        mapper: Optional[Mapper[_T]] = None,
+        mapper: Optional[_EntityBindKey[_O]] = None,
         shard_id: Optional[_SessionBindKey] = None,
         instance: Optional[Any] = None,
         clause: Optional[ClauseElement] = None,
@@ -250,12 +260,12 @@ class ShardedSession(Session):
     ) -> _SessionBind:
         if shard_id is None:
             shard_id = self._choose_shard_and_assign(
-                mapper, instance, clause=clause
+                mapper, instance, clause=clause  # type: ignore [arg-type]
             )
         return self.__binds[shard_id]
 
     def bind_shard(
-        self, shard_id: str, bind: Union[Engine, OptionEngine]
+        self, shard_id: _SessionBindKey, bind: Union[Engine, OptionEngine]
     ) -> None:
         self.__binds[shard_id] = bind
 
@@ -269,7 +279,7 @@ def execute_and_instances(
 
     elif orm_context.is_update or orm_context.is_delete:
         load_options = None
-        update_options = active_options = orm_context.update_delete_options
+        update_options = active_options = orm_context.update_delete_options  # type: ignore [assignment] # noqa: E501
     else:
         load_options = update_options = active_options = None
 
@@ -278,9 +288,13 @@ def execute_and_instances(
     def iter_for_shard(
         shard_id: str,
         load_options: Union[
-            None, QueryContext.default_load_options, _MetaOptions
+            QueryContext.default_load_options,
+            Type[QueryContext.default_load_options],
+        ],
+        update_options: Union[
+            BulkUDCompileState.default_update_options,
+            Type[QueryContext.default_load_options],
         ],
-        update_options: Optional[BulkUDCompileState.default_update_options],
     ) -> Union[Result[_T], IteratorResult[_TP]]:
         execution_options = dict(orm_context.local_execution_options)
 
@@ -288,10 +302,10 @@ def execute_and_instances(
         bind_arguments["shard_id"] = shard_id
 
         if orm_context.is_select:
-            load_options += {"_refresh_identity_token": shard_id}  # type: ignore [operator] # noqa: E501
+            load_options += {"_refresh_identity_token": shard_id}
             execution_options["_sa_orm_load_options"] = load_options
         elif orm_context.is_update or orm_context.is_delete:
-            update_options += {"_refresh_identity_token": shard_id}  # type: ignore [operator] # noqa: E501
+            update_options += {"_refresh_identity_token": shard_id}
             execution_options["_sa_orm_update_options"] = update_options
 
         return orm_context.invoke_statement(
@@ -308,10 +322,10 @@ def execute_and_instances(
         shard_id = None
 
     if shard_id is not None:
-        return iter_for_shard(shard_id, load_options, update_options)
+        return iter_for_shard(shard_id, load_options, update_options)  # type: ignore [arg-type] # noqa: E501
     else:
         partial = []
-        for shard_id in session.execute_chooser(orm_context):
-            result_ = iter_for_shard(shard_id, load_options, update_options)
+        for shard_id in session.execute_chooser(orm_context):  # type: ignore [attr-defined] # noqa: E501
+            result_ = iter_for_shard(shard_id, load_options, update_options)  # type: ignore [arg-type] # noqa: E501
             partial.append(result_)
         return partial[0].merge(*partial[1:])
index 15cd587926a7121a8f0ae209e80b856127733ce9..672e8a29d20fadc26689cc6521083bec556bf051 100644 (file)
@@ -152,7 +152,7 @@ _PKIdentityArgument = Union[Any, Tuple[Any, ...]]
 _BindArguments = Dict[str, Any]
 
 _EntityBindKey = Union[Type[_O], "Mapper[_O]"]
-_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"]
+_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table", str]
 _SessionBind = Union["Engine", "Connection"]