]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(query): add private method to convert to bytes
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 25 May 2025 18:50:19 +0000 (19:50 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 8 Sep 2025 09:46:55 +0000 (11:46 +0200)
psycopg/psycopg/_queries.py

index 750d0585b823a513d729bfacfc7fafada491963f..b25d1635c74c85644a5fcdbe2fc12e895dd0cbc7 100644 (file)
@@ -64,12 +64,7 @@ class PostgresQuery:
         The results of this function can be obtained accessing the object
         attributes (`query`, `params`, `types`, `formats`).
         """
-        if isinstance(query, str):
-            bquery = query.encode(self._encoding)
-        elif isinstance(query, Composable):
-            bquery = query.as_bytes(self._tx)
-        else:
-            bquery = query
+        query = self._ensure_bytes(query)
 
         if vars is not None:
             # Avoid caching queries extremely long or with a huge number of
@@ -78,7 +73,7 @@ class PostgresQuery:
             # numbers of tuples.
             # see https://github.com/psycopg/psycopg/discussions/628
             if (
-                len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
+                len(query) <= MAX_CACHED_STATEMENT_LENGTH
                 and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
             ):
                 f: _Query2Pg = _query2pg
@@ -86,10 +81,10 @@ class PostgresQuery:
                 f = _query2pg_nocache
 
             (self.query, self._want_formats, self._order, self._parts) = f(
-                bquery, self._encoding
+                query, self._encoding
             )
         else:
-            self.query = bquery
+            self.query = query
             self._want_formats = self._order = None
 
         self.dump(vars)
@@ -126,7 +121,7 @@ class PostgresQuery:
         else:
             raise TypeError(
                 "query parameters should be a sequence or a mapping,"
-                f" got {type(vars).__name__}"
+                f" got {type(vars).__qualname__}"
             )
         return sequence
 
@@ -165,6 +160,14 @@ class PostgresQuery:
                     f" {', '.join(sorted(i for i in order or () if i not in vars))}"
                 )
 
+    def _ensure_bytes(self, query: Query) -> bytes:
+        if isinstance(query, str):
+            return query.encode(self._tx.encoding)
+        elif isinstance(query, Composable):
+            return query.as_bytes(self._tx)
+        else:
+            return query
+
 
 # The type of the _query2pg() and _query2pg_nocache() methods
 _Query2Pg: TypeAlias = Callable[
@@ -244,25 +247,20 @@ class PostgresClientQuery(PostgresQuery):
         The results of this function can be obtained accessing the object
         attributes (`query`, `params`, `types`, `formats`).
         """
-        if isinstance(query, str):
-            bquery = query.encode(self._encoding)
-        elif isinstance(query, Composable):
-            bquery = query.as_bytes(self._tx)
-        else:
-            bquery = query
+        query = self._ensure_bytes(query)
 
         if vars is not None:
             if (
-                len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
+                len(query) <= MAX_CACHED_STATEMENT_LENGTH
                 and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
             ):
                 f: _Query2PgClient = _query2pg_client
             else:
                 f = _query2pg_client_nocache
 
-            (self.template, self._order, self._parts) = f(bquery, self._encoding)
+            (self.template, self._order, self._parts) = f(query, self._encoding)
         else:
-            self.query = bquery
+            self.query = query
             self._order = None
 
         self.dump(vars)
@@ -425,14 +423,8 @@ _ph_to_fmt = {
 
 class PostgresRawQuery(PostgresQuery):
     def convert(self, query: Query, vars: Params | None) -> None:
-        if isinstance(query, str):
-            bquery = query.encode(self._encoding)
-        elif isinstance(query, Composable):
-            bquery = query.as_bytes(self._tx)
-        else:
-            bquery = query
-
-        self.query = bquery
+        query = self._ensure_bytes(query)
+        self.query = query
         self._want_formats = self._order = None
         self.dump(vars)