]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: add conn_encoding function
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 Apr 2022 13:49:55 +0000 (15:49 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 Apr 2022 14:05:46 +0000 (16:05 +0200)
Simplifies the chore of obtaining an encoding from the connection,
defaulting to utf-8 if not available.

psycopg/psycopg/_encodings.py
psycopg/psycopg/_queries.py
psycopg/psycopg/sql.py
psycopg/psycopg/types/string.py

index 0a2d327919a9a6ba8be9dc458643440a3f76f80b..a44f3bda035b70fdde5fc4d804a8a6f0ae93e23b 100644 (file)
@@ -5,12 +5,13 @@ Mappings between PostgreSQL and Python encodings.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
-from typing import Dict, TYPE_CHECKING
+from typing import Any, Dict, Optional, TYPE_CHECKING
 
 from .errors import NotSupportedError
 
 if TYPE_CHECKING:
     from .pq.abc import PGconn
+    from .connection import BaseConnection
 
 _py_codecs = {
     "BIG5": "big5",
@@ -70,9 +71,21 @@ py_codecs.update(
 pg_codecs = {v: k.encode() for k, v in _py_codecs.items()}
 
 
+def conn_encoding(conn: "Optional[BaseConnection[Any]]") -> str:
+    """
+    Return the Python encoding name of a psycopg connection.
+
+    Default to utf8 if the connection has no encoding info.
+    """
+    if not conn:
+        return "utf-8"
+    pgenc = conn.pgconn.parameter_status(b"client_encoding") or b"UTF8"
+    return pg2pyenc(pgenc)
+
+
 def pgconn_encoding(pgconn: "PGconn") -> str:
     """
-    Return the Python encoding name of a connection.
+    Return the Python encoding name of a libpq connection.
 
     Default to utf8 if the connection has no encoding info.
     """
index f98912103b9cb3179d10b89d5e4556523fe6d446..250454d909c106aab337a2fb30e65bb45dc2980e 100644 (file)
@@ -14,7 +14,7 @@ from . import errors as e
 from .sql import Composable
 from .abc import Buffer, Query, Params
 from ._enums import PyFormat
-from ._encodings import pgconn_encoding
+from ._encodings import conn_encoding
 
 if TYPE_CHECKING:
     from .abc import Transformer
@@ -47,15 +47,11 @@ class PostgresQuery:
         self._want_formats: Optional[List[PyFormat]] = None
         self.formats: Optional[Sequence[pq.Format]] = None
 
+        self._encoding = conn_encoding(transformer.connection)
         self._parts: List[QueryPart]
         self.query = b""
-        self._encoding = "utf-8"
         self._order: Optional[List[str]] = None
 
-        conn = transformer.connection
-        if conn:
-            self._encoding = pgconn_encoding(conn.pgconn)
-
     def convert(self, query: Query, vars: Optional[Params]) -> None:
         """
         Set up the query and parameters to convert.
index 76bd9539c22580fb3ee7e790dff4d870c563263f..ade3e4de4bbb4b5101c79d578affa898848edd59 100644 (file)
@@ -12,7 +12,7 @@ from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
 from .pq import Escaping
 from .abc import AdaptContext
 from .adapt import Transformer, PyFormat
-from ._encodings import pgconn_encoding
+from ._encodings import conn_encoding
 
 
 def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
@@ -76,7 +76,7 @@ class Composable(ABC):
 
         """
         conn = context.connection if context else None
-        enc = pgconn_encoding(conn.pgconn) if conn else "utf-8"
+        enc = conn_encoding(conn)
         b = self.as_bytes(context)
         if isinstance(b, bytes):
             return b.decode(enc)
@@ -204,9 +204,7 @@ class SQL(Composable):
     def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
         enc = "utf-8"
         if context:
-            conn = context.connection
-            if conn:
-                enc = pgconn_encoding(conn.pgconn)
+            enc = conn_encoding(context.connection)
         return self._obj.encode(enc)
 
     def format(self, *args: Any, **kwargs: Any) -> Composed:
@@ -365,7 +363,7 @@ class Identifier(Composable):
         if not conn:
             raise ValueError("a connection is necessary for Identifier")
         esc = Escaping(conn.pgconn)
-        enc = pgconn_encoding(conn.pgconn)
+        enc = conn_encoding(conn)
         escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
         return b".".join(escs)
 
@@ -450,7 +448,7 @@ class Placeholder(Composable):
 
     def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
         conn = context.connection if context else None
-        enc = pgconn_encoding(conn.pgconn) if conn else "utf-8"
+        enc = conn_encoding(conn)
         return self.as_string(context).encode(enc)
 
 
index 42cd319d14f97824558505c327e4f0a30d237df0..593055df591c248a8967316d97bb834d0f81f683 100644 (file)
@@ -11,24 +11,17 @@ from ..pq import Format, Escaping
 from ..abc import AdaptContext
 from ..adapt import Buffer, Dumper, Loader
 from ..errors import DataError
-from .._encodings import pgconn_encoding
+from .._encodings import conn_encoding
 
 if TYPE_CHECKING:
     from ..pq.abc import Escaping as EscapingProto
 
 
 class _BaseStrDumper(Dumper):
-
-    _encoding = "utf-8"
-
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         super().__init__(cls, context)
-
-        conn = self.connection
-        if conn:
-            enc = pgconn_encoding(conn.pgconn)
-            if enc != "ascii":
-                self._encoding = enc
+        enc = conn_encoding(self.connection)
+        self._encoding = enc if enc != "ascii" else "utf-8"
 
 
 class StrBinaryDumper(_BaseStrDumper):
@@ -77,15 +70,10 @@ class StrDumperUnknown(_StrDumper):
 
 
 class TextLoader(Loader):
-
-    _encoding = "utf-8"
-
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
-        conn = self.connection
-        if conn:
-            enc = pgconn_encoding(conn.pgconn)
-            self._encoding = enc if enc != "ascii" else ""
+        enc = conn_encoding(self.connection)
+        self._encoding = enc if enc != "ascii" else ""
 
     def load(self, data: Buffer) -> Union[bytes, str]:
         if self._encoding: