]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed pep484 errors 8948/head
authorGleb Kisenkov <g.kisenkov@gmail.com>
Thu, 8 Dec 2022 14:54:55 +0000 (15:54 +0100)
committerGleb Kisenkov <g.kisenkov@gmail.com>
Thu, 8 Dec 2022 14:54:55 +0000 (15:54 +0100)
lib/sqlalchemy/ext/horizontal_shard.py

index 229eb6330e2677b807b66c7d791a0f107a7ddd09..d23e0f6424211d3be1f139f17419c411f94a0832 100644 (file)
@@ -52,6 +52,7 @@ if TYPE_CHECKING:
     from ..orm.session import _SessionBindKey
     from ..orm.session import ORMExecuteState
     from ..orm.state import InstanceState
+    from ..sql import Executable
     from ..sql._typing import _TP
     from ..sql.elements import ClauseElement
 
@@ -109,6 +110,8 @@ class ShardedSession(Session):
         ] = None,
         shards: Optional[Dict[str, Any]] = None,
         query_cls: Type[Query[_T]] = ShardedQuery,
+        *,
+        query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None,
         **kwargs: Any,
     ) -> None:
         """Construct a ShardedSession.
@@ -136,7 +139,6 @@ class ShardedSession(Session):
           to :class:`~sqlalchemy.engine.Engine` objects.
 
         """
-        query_chooser = kwargs.pop("query_chooser", None)
         super().__init__(query_cls=query_cls, **kwargs)
 
         event.listen(
@@ -144,8 +146,10 @@ class ShardedSession(Session):
         )
         self.shard_chooser = shard_chooser
         self.id_chooser = id_chooser
+        self.execute_chooser = execute_chooser
 
         if query_chooser:
+            _query_chooser = query_chooser
             util.warn_deprecated(
                 "The ``query_choser`` parameter is deprecated; "
                 "please use ``execute_chooser``.",
@@ -157,11 +161,9 @@ class ShardedSession(Session):
                     "at the same time."
                 )
 
-            def execute_chooser(orm_context: ORMExecuteState) -> Any:
-                return query_chooser(orm_context.statement)
+            def execute_chooser(orm_context: ORMExecuteState) -> Iterable[Any]:
+                return _query_chooser(orm_context.statement)
 
-            self.execute_chooser = execute_chooser
-        else:
             self.execute_chooser = execute_chooser
         self.query_chooser = query_chooser
         self.__binds = {}