From e40e768492685aa9ce57c4762c571f935e3fd3c7 Mon Sep 17 00:00:00 2001 From: Gleb Kisenkov Date: Thu, 8 Dec 2022 15:54:55 +0100 Subject: [PATCH] Fixed pep484 errors --- lib/sqlalchemy/ext/horizontal_shard.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 229eb6330e..d23e0f6424 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -52,6 +52,7 @@ if TYPE_CHECKING: from ..orm.session import _SessionBindKey from ..orm.session import ORMExecuteState from ..orm.state import InstanceState + from ..sql import Executable from ..sql._typing import _TP from ..sql.elements import ClauseElement @@ -109,6 +110,8 @@ class ShardedSession(Session): ] = None, shards: Optional[Dict[str, Any]] = None, query_cls: Type[Query[_T]] = ShardedQuery, + *, + query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None, **kwargs: Any, ) -> None: """Construct a ShardedSession. @@ -136,7 +139,6 @@ class ShardedSession(Session): to :class:`~sqlalchemy.engine.Engine` objects. """ - query_chooser = kwargs.pop("query_chooser", None) super().__init__(query_cls=query_cls, **kwargs) event.listen( @@ -144,8 +146,10 @@ class ShardedSession(Session): ) self.shard_chooser = shard_chooser self.id_chooser = id_chooser + self.execute_chooser = execute_chooser if query_chooser: + _query_chooser = query_chooser util.warn_deprecated( "The ``query_choser`` parameter is deprecated; " "please use ``execute_chooser``.", @@ -157,11 +161,9 @@ class ShardedSession(Session): "at the same time." ) - def execute_chooser(orm_context: ORMExecuteState) -> Any: - return query_chooser(orm_context.statement) + def execute_chooser(orm_context: ORMExecuteState) -> Iterable[Any]: + return _query_chooser(orm_context.statement) - self.execute_chooser = execute_chooser - else: self.execute_chooser = execute_chooser self.query_chooser = query_chooser self.__binds = {} -- 2.47.2