From: Daniele Varrazzo Date: Sun, 13 Aug 2023 09:43:43 +0000 (+0100) Subject: refactor: make is_param_sequence a type guard X-Git-Tag: pool-3.2.0~66^2~3 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=b16e3bf044ea240cef3aa3d2dcbac4edb5943269;p=thirdparty%2Fpsycopg.git refactor: make is_param_sequence a type guard Reduce the need of typing hints here and there, although, if it returns false, it doesn't guarantee "the other half of the union". --- diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py index d9bbaa841..caf19a7cc 100644 --- a/psycopg/psycopg/_queries.py +++ b/psycopg/psycopg/_queries.py @@ -14,6 +14,7 @@ from . import errors as e from .sql import Composable from .abc import Buffer, Query, Params from ._enums import PyFormat +from ._compat import TypeGuard from ._encodings import conn_encoding if TYPE_CHECKING: @@ -100,7 +101,7 @@ class PostgresQuery: self.formats = None @staticmethod - def is_params_sequence(vars: Params) -> bool: + def is_params_sequence(vars: Params) -> TypeGuard[Sequence[Any]]: # Try concrete types, then abstract types t = type(vars) if t is list or t is tuple: @@ -125,9 +126,8 @@ class PostgresQuery: """ Verify the compatibility between a query and a set of params. """ - sequence = PostgresQuery.is_params_sequence(vars) - if sequence: + if PostgresQuery.is_params_sequence(vars): if len(vars) != len(parts) - 1: raise e.ProgrammingError( f"the query has {len(parts) - 1} placeholders but" @@ -135,7 +135,7 @@ class PostgresQuery: ) if vars and not isinstance(parts[0].item, int): raise TypeError("named placeholders require a mapping of parameters") - return vars # type: ignore[return-value] + return vars else: if vars and len(parts) > 1 and not isinstance(parts[0][1], str): @@ -143,9 +143,11 @@ class PostgresQuery: "positional placeholders (%s) require a sequence of parameters" ) try: - return [ - vars[item] for item in order or () # type: ignore[call-overload] - ] + if order: + return [vars[item] for item in order] # type: ignore[call-overload] + else: + return () + except KeyError: raise e.ProgrammingError( "query parameter missing:" diff --git a/psycopg/psycopg/raw_cursor.py b/psycopg/psycopg/raw_cursor.py index 1a584c47e..9bc3164d4 100644 --- a/psycopg/psycopg/raw_cursor.py +++ b/psycopg/psycopg/raw_cursor.py @@ -38,10 +38,9 @@ class RawPostgresQuery(PostgresQuery): """ Verify the compatibility; params must be a sequence for raw query. """ - sequence = PostgresQuery.is_params_sequence(vars) - if not sequence: + if not PostgresQuery.is_params_sequence(vars): raise TypeError("raw query require a sequence of parameters") - return vars # type: ignore[return-value] + return vars class RawCursorMixin(BaseCursor[ConnectionType, Row]):