From: Daniele Varrazzo Date: Thu, 21 Apr 2022 13:49:55 +0000 (+0200) Subject: refactor: add conn_encoding function X-Git-Tag: 3.1~138 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=747893ab6793758de25227764849c364d9bd564b;p=thirdparty%2Fpsycopg.git refactor: add conn_encoding function Simplifies the chore of obtaining an encoding from the connection, defaulting to utf-8 if not available. --- diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py index 0a2d32791..a44f3bda0 100644 --- a/psycopg/psycopg/_encodings.py +++ b/psycopg/psycopg/_encodings.py @@ -5,12 +5,13 @@ Mappings between PostgreSQL and Python encodings. # Copyright (C) 2020 The Psycopg Team import codecs -from typing import Dict, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING from .errors import NotSupportedError if TYPE_CHECKING: from .pq.abc import PGconn + from .connection import BaseConnection _py_codecs = { "BIG5": "big5", @@ -70,9 +71,21 @@ py_codecs.update( pg_codecs = {v: k.encode() for k, v in _py_codecs.items()} +def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str: + """ + Return the Python encoding name of a psycopg connection. + + Default to utf8 if the connection has no encoding info. + """ + if not conn: + return "utf-8" + pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8" + return pg2pyenc(pgenc) + + def pgconn_encoding(pgconn: "PGconn") -> str: """ - Return the Python encoding name of a connection. + Return the Python encoding name of a libpq connection. Default to utf8 if the connection has no encoding info. """ diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py index f98912103..250454d90 100644 --- a/psycopg/psycopg/_queries.py +++ b/psycopg/psycopg/_queries.py @@ -14,7 +14,7 @@ from . import errors as e from .sql import Composable from .abc import Buffer, Query, Params from ._enums import PyFormat -from ._encodings import pgconn_encoding +from ._encodings import conn_encoding if TYPE_CHECKING: from .abc import Transformer @@ -47,15 +47,11 @@ class PostgresQuery: self._want_formats: Optional[List[PyFormat]] = None self.formats: Optional[Sequence[pq.Format]] = None + self._encoding = conn_encoding(transformer.connection) self._parts: List[QueryPart] self.query = b"" - self._encoding = "utf-8" self._order: Optional[List[str]] = None - conn = transformer.connection - if conn: - self._encoding = pgconn_encoding(conn.pgconn) - def convert(self, query: Query, vars: Optional[Params]) -> None: """ Set up the query and parameters to convert. diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index 76bd9539c..ade3e4de4 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -12,7 +12,7 @@ from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union from .pq import Escaping from .abc import AdaptContext from .adapt import Transformer, PyFormat -from ._encodings import pgconn_encoding +from ._encodings import conn_encoding def quote(obj: Any, context: Optional[AdaptContext] = None) -> str: @@ -76,7 +76,7 @@ class Composable(ABC): """ conn = context.connection if context else None - enc = pgconn_encoding(conn.pgconn) if conn else "utf-8" + enc = conn_encoding(conn) b = self.as_bytes(context) if isinstance(b, bytes): return b.decode(enc) @@ -204,9 +204,7 @@ class SQL(Composable): def as_bytes(self, context: Optional[AdaptContext]) -> bytes: enc = "utf-8" if context: - conn = context.connection - if conn: - enc = pgconn_encoding(conn.pgconn) + enc = conn_encoding(context.connection) return self._obj.encode(enc) def format(self, *args: Any, **kwargs: Any) -> Composed: @@ -365,7 +363,7 @@ class Identifier(Composable): if not conn: raise ValueError("a connection is necessary for Identifier") esc = Escaping(conn.pgconn) - enc = pgconn_encoding(conn.pgconn) + enc = conn_encoding(conn) escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] return b".".join(escs) @@ -450,7 +448,7 @@ class Placeholder(Composable): def as_bytes(self, context: Optional[AdaptContext]) -> bytes: conn = context.connection if context else None - enc = pgconn_encoding(conn.pgconn) if conn else "utf-8" + enc = conn_encoding(conn) return self.as_string(context).encode(enc) diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py index 42cd319d1..593055df5 100644 --- a/psycopg/psycopg/types/string.py +++ b/psycopg/psycopg/types/string.py @@ -11,24 +11,17 @@ from ..pq import Format, Escaping from ..abc import AdaptContext from ..adapt import Buffer, Dumper, Loader from ..errors import DataError -from .._encodings import pgconn_encoding +from .._encodings import conn_encoding if TYPE_CHECKING: from ..pq.abc import Escaping as EscapingProto class _BaseStrDumper(Dumper): - - _encoding = "utf-8" - def __init__(self, cls: type, context: Optional[AdaptContext] = None): super().__init__(cls, context) - - conn = self.connection - if conn: - enc = pgconn_encoding(conn.pgconn) - if enc != "ascii": - self._encoding = enc + enc = conn_encoding(self.connection) + self._encoding = enc if enc != "ascii" else "utf-8" class StrBinaryDumper(_BaseStrDumper): @@ -77,15 +70,10 @@ class StrDumperUnknown(_StrDumper): class TextLoader(Loader): - - _encoding = "utf-8" - def __init__(self, oid: int, context: Optional[AdaptContext] = None): super().__init__(oid, context) - conn = self.connection - if conn: - enc = pgconn_encoding(conn.pgconn) - self._encoding = enc if enc != "ascii" else "" + enc = conn_encoding(self.connection) + self._encoding = enc if enc != "ascii" else "" def load(self, data: Buffer) -> Union[bytes, str]: if self._encoding: