From 0c1a9eb6fd83ebcf1cf3f99cf7bfa00cac8a4bf8 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 25 May 2025 19:50:19 +0100 Subject: [PATCH] refactor(query): add private method to convert to bytes --- psycopg/psycopg/_queries.py | 46 +++++++++++++++---------------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py index 750d0585b..b25d1635c 100644 --- a/psycopg/psycopg/_queries.py +++ b/psycopg/psycopg/_queries.py @@ -64,12 +64,7 @@ class PostgresQuery: The results of this function can be obtained accessing the object attributes (`query`, `params`, `types`, `formats`). """ - if isinstance(query, str): - bquery = query.encode(self._encoding) - elif isinstance(query, Composable): - bquery = query.as_bytes(self._tx) - else: - bquery = query + query = self._ensure_bytes(query) if vars is not None: # Avoid caching queries extremely long or with a huge number of @@ -78,7 +73,7 @@ class PostgresQuery: # numbers of tuples. # see https://github.com/psycopg/psycopg/discussions/628 if ( - len(bquery) <= MAX_CACHED_STATEMENT_LENGTH + len(query) <= MAX_CACHED_STATEMENT_LENGTH and len(vars) <= MAX_CACHED_STATEMENT_PARAMS ): f: _Query2Pg = _query2pg @@ -86,10 +81,10 @@ class PostgresQuery: f = _query2pg_nocache (self.query, self._want_formats, self._order, self._parts) = f( - bquery, self._encoding + query, self._encoding ) else: - self.query = bquery + self.query = query self._want_formats = self._order = None self.dump(vars) @@ -126,7 +121,7 @@ class PostgresQuery: else: raise TypeError( "query parameters should be a sequence or a mapping," - f" got {type(vars).__name__}" + f" got {type(vars).__qualname__}" ) return sequence @@ -165,6 +160,14 @@ class PostgresQuery: f" {', '.join(sorted(i for i in order or () if i not in vars))}" ) + def _ensure_bytes(self, query: Query) -> bytes: + if isinstance(query, str): + return query.encode(self._tx.encoding) + elif isinstance(query, Composable): + return query.as_bytes(self._tx) + else: + return query + # The type of the _query2pg() and _query2pg_nocache() methods _Query2Pg: TypeAlias = Callable[ @@ -244,25 +247,20 @@ class PostgresClientQuery(PostgresQuery): The results of this function can be obtained accessing the object attributes (`query`, `params`, `types`, `formats`). """ - if isinstance(query, str): - bquery = query.encode(self._encoding) - elif isinstance(query, Composable): - bquery = query.as_bytes(self._tx) - else: - bquery = query + query = self._ensure_bytes(query) if vars is not None: if ( - len(bquery) <= MAX_CACHED_STATEMENT_LENGTH + len(query) <= MAX_CACHED_STATEMENT_LENGTH and len(vars) <= MAX_CACHED_STATEMENT_PARAMS ): f: _Query2PgClient = _query2pg_client else: f = _query2pg_client_nocache - (self.template, self._order, self._parts) = f(bquery, self._encoding) + (self.template, self._order, self._parts) = f(query, self._encoding) else: - self.query = bquery + self.query = query self._order = None self.dump(vars) @@ -425,14 +423,8 @@ _ph_to_fmt = { class PostgresRawQuery(PostgresQuery): def convert(self, query: Query, vars: Params | None) -> None: - if isinstance(query, str): - bquery = query.encode(self._encoding) - elif isinstance(query, Composable): - bquery = query.as_bytes(self._tx) - else: - bquery = query - - self.query = bquery + query = self._ensure_bytes(query) + self.query = query self._want_formats = self._order = None self.dump(vars) -- 2.47.3