]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: make is_param_sequence a type guard
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 13 Aug 2023 09:43:43 +0000 (10:43 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 15 Aug 2023 15:29:03 +0000 (16:29 +0100)
Reduce the need of typing hints here and there, although, if it returns
false, it doesn't guarantee "the other half of the union".

psycopg/psycopg/_queries.py
psycopg/psycopg/raw_cursor.py

index d9bbaa8418ac0d35f3b658f5815ecb29fc7e94b9..caf19a7cc3b9a8b87ca209649fe4b256cdcfc9fb 100644 (file)
@@ -14,6 +14,7 @@ from . import errors as e
 from .sql import Composable
 from .abc import Buffer, Query, Params
 from ._enums import PyFormat
+from ._compat import TypeGuard
 from ._encodings import conn_encoding
 
 if TYPE_CHECKING:
@@ -100,7 +101,7 @@ class PostgresQuery:
             self.formats = None
 
     @staticmethod
-    def is_params_sequence(vars: Params) -> bool:
+    def is_params_sequence(vars: Params) -> TypeGuard[Sequence[Any]]:
         # Try concrete types, then abstract types
         t = type(vars)
         if t is list or t is tuple:
@@ -125,9 +126,8 @@ class PostgresQuery:
         """
         Verify the compatibility between a query and a set of params.
         """
-        sequence = PostgresQuery.is_params_sequence(vars)
 
-        if sequence:
+        if PostgresQuery.is_params_sequence(vars):
             if len(vars) != len(parts) - 1:
                 raise e.ProgrammingError(
                     f"the query has {len(parts) - 1} placeholders but"
@@ -135,7 +135,7 @@ class PostgresQuery:
                 )
             if vars and not isinstance(parts[0].item, int):
                 raise TypeError("named placeholders require a mapping of parameters")
-            return vars  # type: ignore[return-value]
+            return vars
 
         else:
             if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
@@ -143,9 +143,11 @@ class PostgresQuery:
                     "positional placeholders (%s) require a sequence of parameters"
                 )
             try:
-                return [
-                    vars[item] for item in order or ()  # type: ignore[call-overload]
-                ]
+                if order:
+                    return [vars[item] for item in order]  # type: ignore[call-overload]
+                else:
+                    return ()
+
             except KeyError:
                 raise e.ProgrammingError(
                     "query parameter missing:"
index 1a584c47e7927b7c4fb60abd970d5485ddb4e96e..9bc3164d4b19859baa9fc5c301b64f95ca421168 100644 (file)
@@ -38,10 +38,9 @@ class RawPostgresQuery(PostgresQuery):
         """
         Verify the compatibility; params must be a sequence for raw query.
         """
-        sequence = PostgresQuery.is_params_sequence(vars)
-        if not sequence:
+        if not PostgresQuery.is_params_sequence(vars):
             raise TypeError("raw query require a sequence of parameters")
-        return vars  # type: ignore[return-value]
+        return vars
 
 
 class RawCursorMixin(BaseCursor[ConnectionType, Row]):