]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Dropped Connection.pyenc and related support
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Nov 2020 22:53:27 +0000 (22:53 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 12 Nov 2020 00:14:42 +0000 (00:14 +0000)
Connection.client_encoding is now guaranteed to be a Python codec's
name.

18 files changed:
docs/connection.rst
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/encodings.py [moved from psycopg3/psycopg3/pq/encodings.py with 64% similarity]
psycopg3/psycopg3/generators.py
psycopg3/psycopg3/pq/__init__.py
psycopg3/psycopg3/pq/misc.py
psycopg3/psycopg3/sql.py
psycopg3/psycopg3/types/text.py
psycopg3_c/psycopg3_c/transform.pyx
psycopg3_c/psycopg3_c/types/text.pyx
tests/test_connection.py
tests/test_connection_async.py
tests/test_cursor.py
tests/test_encodings.py [new file with mode: 0644]
tests/types/test_text.py

index 17a7ca9252cc854b98dc6587b4b44ad8a354f25d..1f0e869c31764952ce6f7858c61ab3a706a466f2 100644 (file)
@@ -46,7 +46,14 @@ Take a look to :ref:`transactions` for the details.
     .. automethod:: commit
     .. automethod:: rollback
     .. automethod:: close
+
+    .. rubric:: Checking the connection state
+
     .. autoproperty:: closed
+    .. autoproperty:: client_encoding
+
+        The property is writable for sync connections, read-only for async
+        ones: you can call `~AsyncConnection.set_client_encoding()` on those.
 
     .. rubric:: Methods you will need if you do something cool
 
@@ -76,6 +83,7 @@ Take a look to :ref:`transactions` for the details.
     .. automethod:: commit
     .. automethod:: rollback
     .. automethod:: notifies
+    .. automethod:: set_client_encoding
 
 
 .. autoclass:: psycopg3.Notify
index 383b3c1597651a434387dbf1e00317cc9295727f..8599ca73fc826b5410fc3b251d3f0da8d484ce81 100644 (file)
@@ -75,7 +75,7 @@ class Transformer:
 
         elif isinstance(context, BaseCursor):
             self._connection = context.connection
-            self._encoding = context.connection.pyenc
+            self._encoding = context.connection.client_encoding
             self._dumpers = {}
             self._dumpers_maps.extend(
                 (self._dumpers, context.dumpers, context.connection.dumpers)
@@ -87,7 +87,7 @@ class Transformer:
 
         elif isinstance(context, BaseConnection):
             self._connection = context
-            self._encoding = context.pyenc
+            self._encoding = context.client_encoding
             self._dumpers = {}
             self._dumpers_maps.extend((self._dumpers, context.dumpers))
             self._loaders = {}
index abd8021274ec917dcb293fe084ce16157905305e..4d34116b3b16d948c6e972995c8f4aef9b777133 100644 (file)
@@ -14,12 +14,13 @@ from weakref import ref, ReferenceType
 from functools import partial
 
 from . import pq
-from . import errors as e
-from . import cursor
 from . import proto
+from . import cursor
+from . import errors as e
+from . import encodings
 from .pq import TransactionStatus, ExecStatus
-from .conninfo import make_conninfo
 from .waiting import wait, wait_async
+from .conninfo import make_conninfo
 from .generators import notifies
 
 logger = logging.getLogger(__name__)
@@ -85,10 +86,6 @@ class BaseConnection:
         self.loaders: proto.LoadersMap = {}
         self._notice_handlers: List[NoticeHandler] = []
         self._notify_handlers: List[NotifyHandler] = []
-        # postgres name of the client encoding (in bytes)
-        self._pgenc = b""
-        # python name of the client encoding
-        self._pyenc = "utf-8"
 
         wself = ref(self)
 
@@ -130,31 +127,17 @@ class BaseConnection:
             raise NotImplementedError
         return self.cursor_factory(self, format=format)
 
-    @property
-    def pyenc(self) -> str:
-        pgenc = self.pgconn.parameter_status(b"client_encoding") or b""
-        if self._pgenc != pgenc:
-            if pgenc:
-                try:
-                    self._pyenc = pq.py_codecs[pgenc]
-                except KeyError:
-                    raise e.NotSupportedError(
-                        f"encoding {pgenc.decode('ascii')} not available in Python"
-                    )
-
-            self._pgenc = pgenc
-        return self._pyenc
-
     @property
     def client_encoding(self) -> str:
-        rv = self.pgconn.parameter_status(b"client_encoding")
-        return rv.decode("utf-8") if rv else "UTF8"
+        """The Python codec name of the connection's client encoding."""
+        pgenc = self.pgconn.parameter_status(b"client_encoding") or b"UTF8"
+        return encodings.pg2py(pgenc)
 
     @client_encoding.setter
-    def client_encoding(self, value: str) -> None:
-        self._set_client_encoding(value)
+    def client_encoding(self, name: str) -> None:
+        self._set_client_encoding(name)
 
-    def _set_client_encoding(self, value: str) -> None:
+    def _set_client_encoding(self, name: str) -> None:
         raise NotImplementedError
 
     def cancel(self) -> None:
@@ -175,7 +158,7 @@ class BaseConnection:
         if not (self and self._notice_handler):
             return
 
-        diag = e.Diagnostic(res, self._pyenc)
+        diag = e.Diagnostic(res, self.client_encoding)
         for cb in self._notice_handlers:
             try:
                 cb(diag)
@@ -204,11 +187,8 @@ class BaseConnection:
         if not (self and self._notify_handlers):
             return
 
-        n = Notify(
-            pgn.relname.decode(self._pyenc),
-            pgn.extra.decode(self._pyenc),
-            pgn.be_pid,
-        )
+        enc = self.client_encoding
+        n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
         for cb in self._notify_handlers:
             cb(n)
 
@@ -282,7 +262,7 @@ class Connection(BaseConnection):
         if pgres.status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
                 "error on begin:"
-                f" {pq.error_message(pgres, encoding=self._pyenc)}"
+                f" {pq.error_message(pgres, encoding=self.client_encoding)}"
             )
 
     def commit(self) -> None:
@@ -306,7 +286,7 @@ class Connection(BaseConnection):
         if results[-1].status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
                 f"error on {command.decode('utf8')}:"
-                f" {pq.error_message(results[-1], encoding=self._pyenc)}"
+                f" {pq.error_message(results[-1], encoding=self.client_encoding)}"
             )
 
     @classmethod
@@ -315,27 +295,28 @@ class Connection(BaseConnection):
     ) -> proto.RV:
         return wait(gen, timeout=timeout)
 
-    def _set_client_encoding(self, value: str) -> None:
+    def _set_client_encoding(self, name: str) -> None:
         with self.lock:
             self.pgconn.send_query_params(
                 b"select set_config('client_encoding', $1, false)",
-                [value.encode("ascii")],
+                [encodings.py2pg(name)],
             )
             gen = execute(self.pgconn)
             (result,) = self.wait(gen)
             if result.status != ExecStatus.TUPLES_OK:
-                raise e.error_from_result(result, encoding=self._pyenc)
+                raise e.error_from_result(
+                    result, encoding=self.client_encoding
+                )
 
     def notifies(self) -> Iterator[Notify]:
         """Generate a stream of `Notify`"""
         while 1:
             with self.lock:
                 ns = self.wait(notifies(self.pgconn))
+            enc = self.client_encoding
             for pgn in ns:
                 n = Notify(
-                    pgn.relname.decode(self._pyenc),
-                    pgn.extra.decode(self._pyenc),
-                    pgn.be_pid,
+                    pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
                 )
                 yield n
 
@@ -405,7 +386,7 @@ class AsyncConnection(BaseConnection):
         if pgres.status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
                 "error on begin:"
-                f" {pq.error_message(pgres, encoding=self._pyenc)}"
+                f" {pq.error_message(pgres, encoding=self.client_encoding)}"
             )
 
     async def commit(self) -> None:
@@ -427,39 +408,41 @@ class AsyncConnection(BaseConnection):
         if pgres.status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
                 f"error on {command.decode('utf8')}:"
-                f" {pq.error_message(pgres, encoding=self._pyenc)}"
+                f" {pq.error_message(pgres, encoding=self.client_encoding)}"
             )
 
     @classmethod
     async def wait(cls, gen: proto.PQGen[proto.RV]) -> proto.RV:
         return await wait_async(gen)
 
-    def _set_client_encoding(self, value: str) -> None:
+    def _set_client_encoding(self, name: str) -> None:
         raise AttributeError(
             "'client_encoding' is read-only on async connections:"
             " please use await .set_client_encoding() instead."
         )
 
-    async def set_client_encoding(self, value: str) -> None:
+    async def set_client_encoding(self, name: str) -> None:
+        """Async version of the `client_encoding` setter."""
         async with self.lock:
             self.pgconn.send_query_params(
                 b"select set_config('client_encoding', $1, false)",
-                [value.encode("ascii")],
+                [name.encode("utf-8")],
             )
             gen = execute(self.pgconn)
             (result,) = await self.wait(gen)
             if result.status != ExecStatus.TUPLES_OK:
-                raise e.error_from_result(result, encoding=self._pyenc)
+                raise e.error_from_result(
+                    result, encoding=self.client_encoding
+                )
 
     async def notifies(self) -> AsyncIterator[Notify]:
         while 1:
             async with self.lock:
                 ns = await self.wait(notifies(self.pgconn))
+            enc = self.client_encoding
             for pgn in ns:
                 n = Notify(
-                    pgn.relname.decode(self._pyenc),
-                    pgn.extra.decode(self._pyenc),
-                    pgn.be_pid,
+                    pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
                 )
                 yield n
 
index 0c105850bd8baef4b584ae4a08f1702677261aed..f80900787e63246e44dd95936f5c73e25d749a30 100644 (file)
@@ -79,7 +79,7 @@ class BaseCopy:
                 raise TypeError(
                     "cannot copy str data in binary mode: use bytes instead"
                 )
-            self._encoding = self.connection.pyenc
+            self._encoding = self.connection.client_encoding
             return data.encode(self._encoding)
 
         else:
@@ -181,7 +181,7 @@ class Copy(BaseCopy):
 
     def finish(self, error: str = "") -> None:
         conn = self.connection
-        berr = error.encode(conn.pyenc, "replace") if error else None
+        berr = error.encode(conn.client_encoding, "replace") if error else None
         conn.wait(copy_end(conn.pgconn, berr))
         self._finished = True
 
@@ -238,7 +238,7 @@ class AsyncCopy(BaseCopy):
 
     async def finish(self, error: str = "") -> None:
         conn = self.connection
-        berr = error.encode(conn.pyenc, "replace") if error else None
+        berr = error.encode(conn.client_encoding, "replace") if error else None
         await conn.wait(copy_end(conn.pgconn, berr))
         self._finished = True
 
index fda21b1f82206d752aa37bd70f12614c1aba7b4f..995ba8124f7de37cb11fecba2eddd1645fb1262c 100644 (file)
@@ -187,7 +187,7 @@ class BaseCursor:
         res = self.pgresult
         if not res or res.status != self.ExecStatus.TUPLES_OK:
             return None
-        encoding = self.connection.pyenc
+        encoding = self.connection.client_encoding
         return [Column(res, i, encoding) for i in range(res.nfields)]
 
     @property
@@ -277,7 +277,7 @@ class BaseCursor:
 
         if results[-1].status == S.FATAL_ERROR:
             raise e.error_from_result(
-                results[-1], encoding=self.connection.pyenc
+                results[-1], encoding=self.connection.client_encoding
             )
 
         elif badstats & {S.COPY_IN, S.COPY_OUT, S.COPY_BOTH}:
@@ -387,7 +387,9 @@ class BaseCursor:
         if status in (pq.ExecStatus.COPY_IN, pq.ExecStatus.COPY_OUT):
             return
         elif status == pq.ExecStatus.FATAL_ERROR:
-            raise e.error_from_result(result, encoding=self.connection.pyenc)
+            raise e.error_from_result(
+                result, encoding=self.connection.client_encoding
+            )
         else:
             raise e.ProgrammingError(
                 "copy() should be used only with COPY ... TO STDOUT or COPY ..."
@@ -442,7 +444,7 @@ class Cursor(BaseCursor):
                     (result,) = self.connection.wait(gen)
                     if result.status == self.ExecStatus.FATAL_ERROR:
                         raise e.error_from_result(
-                            result, encoding=self.connection.pyenc
+                            result, encoding=self.connection.client_encoding
                         )
                 else:
                     pgq.dump(vars)
@@ -570,7 +572,7 @@ class AsyncCursor(BaseCursor):
                     (result,) = await self.connection.wait(gen)
                     if result.status == self.ExecStatus.FATAL_ERROR:
                         raise e.error_from_result(
-                            result, encoding=self.connection.pyenc
+                            result, encoding=self.connection.client_encoding
                         )
                 else:
                     pgq.dump(vars)
similarity index 64%
rename from psycopg3/psycopg3/pq/encodings.py
rename to psycopg3/psycopg3/encodings.py
index 28eef6c139eec6c6094df0a726c79149295a1678..2a51b15efc0cdf83b438a93fa6a86f1054018716 100644 (file)
@@ -4,8 +4,11 @@ Mappings between PostgreSQL and Python encodings.
 
 # Copyright (C) 2020 The Psycopg Team
 
+import codecs
 from typing import Dict, Union
 
+from .errors import NotSupportedError
+
 _py_codecs = {
     "BIG5": "big5",
     "EUC_CN": "gb2312",
@@ -53,6 +56,28 @@ _py_codecs = {
     "WIN874": "cp874",
 }
 
-py_codecs: Dict[Union[bytes, str, None], str] = {}
+py_codecs: Dict[Union[bytes, str], str] = {}
 py_codecs.update((k, v) for k, v in _py_codecs.items())
-py_codecs.update((k.encode("ascii"), v) for k, v in _py_codecs.items())
+py_codecs.update((k.encode("utf-8"), v) for k, v in _py_codecs.items())
+
+pg_codecs = {v: k.encode("utf-8") for k, v in _py_codecs.items()}
+
+
+def py2pg(name: str) -> bytes:
+    """Convert a Python encoding name to PostgreSQL encoding name.
+
+    Raise LookupError if the Python encoding is unknown.
+    """
+    return pg_codecs[codecs.lookup(name).name]
+
+
+def pg2py(name: Union[bytes, str]) -> str:
+    """Convert a Python encoding name to PostgreSQL encoding name.
+
+    Raise NotSupportedError if the PostgreSQL encoding is not supported by
+    Python.
+    """
+    try:
+        return py_codecs[name]
+    except KeyError:
+        raise NotSupportedError("codec not available in Python: {name!r}")
index 0cd7abcf5f56caab9d22de75e3757d2495369387..15d2ad2bbf233d6dcbae2281c062bbc12f6c34c9 100644 (file)
@@ -22,6 +22,7 @@ from . import pq
 from . import errors as e
 from .proto import PQGen
 from .waiting import Wait, Ready
+from .encodings import py_codecs
 
 logger = logging.getLogger(__name__)
 
@@ -172,8 +173,8 @@ def copy_from(pgconn: pq.proto.PGconn) -> PQGen[Optional[bytes]]:
             f"1 result expected from copy end, got {len(results)}"
         )
     if results[0].status != pq.ExecStatus.COMMAND_OK:
-        encoding = pq.py_codecs.get(
-            pgconn.parameter_status(b"client_encoding"), "utf8"
+        encoding = py_codecs.get(
+            pgconn.parameter_status(b"client_encoding") or "", "utf-8"
         )
         raise e.error_from_result(results[0], encoding=encoding)
 
@@ -205,7 +206,7 @@ def copy_end(pgconn: pq.proto.PGconn, error: Optional[bytes]) -> PQGen[None]:
             f"1 result expected from copy end, got {len(results)}"
         )
     if results[0].status != pq.ExecStatus.COMMAND_OK:
-        encoding = pq.py_codecs.get(
-            pgconn.parameter_status(b"client_encoding"), "utf8"
+        encoding = py_codecs.get(
+            pgconn.parameter_status(b"client_encoding") or "", "utf-8"
         )
         raise e.error_from_result(results[0], encoding=encoding)
index f13d240d34edac7df1fa310dac73bc64a99e8f40..6472fc83f5c71891f279e4c8c32730c849fd1aa5 100644 (file)
@@ -22,7 +22,6 @@ from .enums import (
     DiagnosticField,
     Format,
 )
-from .encodings import py_codecs
 from .misc import ConninfoOption, PQerror, PGnotify, PGresAttDesc
 from .misc import error_message
 from . import proto
@@ -107,6 +106,5 @@ __all__ = (
     "PQerror",
     "error_message",
     "ConninfoOption",
-    "py_codecs",
     "version",
 )
index 782e991cc979da1495cf0100597c01ad96db7b9d..c3eed579ac86b0913f828efedb876d74c9151f86 100644 (file)
@@ -9,7 +9,6 @@ from typing import cast, NamedTuple, Optional, Union
 from ..errors import OperationalError
 from .enums import DiagnosticField, ConnStatus
 from .proto import PGconn, PGresult
-from .encodings import py_codecs
 
 
 class PQerror(OperationalError):
@@ -66,11 +65,13 @@ def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str:
                 bmsg = bmsg.splitlines()[0].split(b":", 1)[-1].strip()
 
     elif hasattr(obj, "error_message"):
+        from psycopg3.encodings import py_codecs
+
         # obj is a PGconn
         obj = cast(PGconn, obj)
         if obj.status == ConnStatus.OK:
             encoding = py_codecs.get(
-                obj.parameter_status(b"client_encoding"), "utf8"
+                obj.parameter_status(b"client_encoding") or "", "utf-8"
             )
         bmsg = obj.error_message
 
index 9d16936f4d229160e071aa0382d5055bb9b09e41..0aa0948dfdbfa52fa0f501e427f724957ece24d4 100644 (file)
@@ -342,7 +342,7 @@ class Identifier(Composable):
             raise ValueError(f"no connection in the context: {context}")
 
         esc = Escaping(conn.pgconn)
-        enc = conn.pyenc
+        enc = conn.client_encoding
         escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
         return b".".join(escs).decode(enc)
 
@@ -375,7 +375,7 @@ class Literal(Composable):
         tx = context if isinstance(context, Transformer) else Transformer(conn)
         dumper = tx.get_dumper(self._obj, Format.TEXT)
         quoted = dumper.quote(self._obj)
-        return quoted.decode(conn.pyenc if conn else "utf-8")
+        return quoted.decode(conn.client_encoding if conn else "utf-8")
 
 
 class Placeholder(Composable):
index 1cac29ffc397f7cd9c9f53046c61c4a49c34cdd3..0be59bef03554122239e22c389b91863573db084 100644 (file)
@@ -20,13 +20,11 @@ class _StringDumper(Dumper):
     def __init__(self, src: type, context: AdaptContext):
         super().__init__(src, context)
 
+        self.encoding = "utf-8"
         if self.connection:
-            if self.connection.client_encoding != "SQL_ASCII":
-                self.encoding = self.connection.pyenc
-            else:
-                self.encoding = "utf-8"
-        else:
-            self.encoding = "utf-8"
+            enc = self.connection.client_encoding
+            if enc != "ascii":
+                self.encoding = enc
 
 
 @Dumper.binary(str)
@@ -56,8 +54,9 @@ class TextLoader(Loader):
         super().__init__(oid, context)
 
         if self.connection:
-            if self.connection.client_encoding != "SQL_ASCII":
-                self.encoding = self.connection.pyenc
+            enc = self.connection.client_encoding
+            if enc != "ascii":
+                self.encoding = enc
             else:
                 self.encoding = ""
         else:
@@ -78,7 +77,9 @@ class TextLoader(Loader):
 class UnknownLoader(Loader):
     def __init__(self, oid: int, context: AdaptContext):
         super().__init__(oid, context)
-        self.encoding = self.connection.pyenc if self.connection else "utf-8"
+        self.encoding = (
+            self.connection.client_encoding if self.connection else "utf-8"
+        )
 
     def load(self, data: bytes) -> str:
         return data.decode(self.encoding)
index 7d7e79fc63e5c2f31d1a43a0eb39ed91c71fc084..0fa2f1705a1ac693b6d31295910d4ab486c7d4cb 100644 (file)
@@ -91,7 +91,7 @@ cdef class Transformer:
 
         elif isinstance(context, BaseCursor):
             self._connection = context.connection
-            self._encoding = context.connection.pyenc
+            self._encoding = context.connection.client_encoding
             self._dumpers = {}
             self._dumpers_maps.extend(
                 (self._dumpers, context.dumpers, self.connection.dumpers)
@@ -103,7 +103,7 @@ cdef class Transformer:
 
         elif isinstance(context, BaseConnection):
             self._connection = context
-            self._encoding = context.pyenc
+            self._encoding = context.client_encoding
             self._dumpers = {}
             self._dumpers_maps.extend((self._dumpers, context.dumpers))
             self._loaders = {}
index 547fe8c6428e0af1c72426c822c662fca81751da..691d4b13ffa1df68a236c662db6f9ce518ca76f8 100644 (file)
@@ -17,23 +17,21 @@ cdef class TextLoader(CLoader):
         super().__init__(oid, context)
 
         self.is_utf8 = 0
-        self.encoding = NULL
+        self.encoding = "utf-8"
 
         conn = self.connection
-        if conn is not None:
-            if conn.client_encoding == "UTF8":
+        if conn:
+            self._bytes_encoding = conn.client_encoding.encode("utf-8")
+            self.encoding = self._bytes_encoding
+            if self._bytes_encoding == b"utf-8":
                 self.is_utf8 = 1
-            elif conn.client_encoding != "SQL_ASCII":
-                self._bytes_encoding = conn.pyenc.encode("utf-8")
-                self.encoding = self._bytes_encoding
-        else:
-            self.encoding = "utf-8"
+            elif self._bytes_encoding == b"ascii":
+                self.encoding = NULL
 
     cdef object cload(self, const char *data, size_t length):
         if self.is_utf8:
             return PyUnicode_DecodeUTF8(<char *>data, length, NULL)
-
-        if self.encoding:
+        elif self.encoding:
             return PyUnicode_Decode(<char *>data, length, self.encoding, NULL)
         else:
             return data[:length]
index eabd0aad10cd513010799a8b2a1c088f022e30ed..7f4503bf4f5f1d48a4c4b7854d1ffd771464b9ab 100644 (file)
@@ -5,6 +5,7 @@ import weakref
 
 import psycopg3
 from psycopg3 import Connection, Notify
+from psycopg3 import encodings
 from psycopg3.errors import UndefinedTable
 from psycopg3.conninfo import conninfo_to_dict
 
@@ -199,16 +200,16 @@ def test_autocommit_unknown(conn):
 
 def test_get_encoding(conn):
     (enc,) = conn.cursor().execute("show client_encoding").fetchone()
-    assert enc == conn.client_encoding
+    assert conn.client_encoding == encodings.pg2py(enc)
 
 
 def test_set_encoding(conn):
-    newenc = "LATIN1" if conn.client_encoding != "LATIN1" else "UTF8"
+    newenc = "iso8859-1" if conn.client_encoding != "iso8859-1" else "utf-8"
     assert conn.client_encoding != newenc
     conn.client_encoding = newenc
     assert conn.client_encoding == newenc
     (enc,) = conn.cursor().execute("show client_encoding").fetchone()
-    assert enc == newenc
+    assert encodings.pg2py(enc) == newenc
 
 
 @pytest.mark.parametrize(
@@ -223,8 +224,10 @@ def test_set_encoding(conn):
 )
 def test_normalize_encoding(conn, enc, out, codec):
     conn.client_encoding = enc
-    assert conn.client_encoding == out
-    assert conn.pyenc == codec
+    assert (
+        conn.pgconn.parameter_status(b"client_encoding").decode("utf-8") == out
+    )
+    assert conn.client_encoding == codec
 
 
 @pytest.mark.parametrize(
@@ -240,18 +243,21 @@ def test_normalize_encoding(conn, enc, out, codec):
 def test_encoding_env_var(dsn, monkeypatch, enc, out, codec):
     monkeypatch.setenv("PGCLIENTENCODING", enc)
     conn = psycopg3.connect(dsn)
-    assert conn.client_encoding == out
-    assert conn.pyenc == codec
+    assert (
+        conn.pgconn.parameter_status(b"client_encoding").decode("utf-8") == out
+    )
+    assert conn.client_encoding == codec
 
 
 def test_set_encoding_unsupported(conn):
-    conn.client_encoding = "EUC_TW"
+    cur = conn.cursor()
+    cur.execute("set client_encoding to EUC_TW")
     with pytest.raises(psycopg3.NotSupportedError):
-        conn.cursor().execute("select 1")
+        cur.execute("select 'x'")
 
 
 def test_set_encoding_bad(conn):
-    with pytest.raises(psycopg3.DatabaseError):
+    with pytest.raises(LookupError):
         conn.client_encoding = "WAT"
 
 
index 30b50eef3c38e4507d21d5f5bb5b1768f08014bb..ca53b9788da32a4534b3f2d43c98d6eed3090bcc 100644 (file)
@@ -4,6 +4,7 @@ import logging
 import weakref
 
 import psycopg3
+from psycopg3 import encodings
 from psycopg3 import AsyncConnection
 from psycopg3.errors import UndefinedTable
 from psycopg3.conninfo import conninfo_to_dict
@@ -209,11 +210,11 @@ async def test_get_encoding(aconn):
     cur = await aconn.cursor()
     await cur.execute("show client_encoding")
     (enc,) = await cur.fetchone()
-    assert enc == aconn.client_encoding
+    assert aconn.client_encoding == encodings.pg2py(enc)
 
 
 async def test_set_encoding(aconn):
-    newenc = "LATIN1" if aconn.client_encoding != "LATIN1" else "UTF8"
+    newenc = "iso8859-1" if aconn.client_encoding != "iso8859-1" else "utf-8"
     assert aconn.client_encoding != newenc
     with pytest.raises(AttributeError):
         aconn.client_encoding = newenc
@@ -223,7 +224,7 @@ async def test_set_encoding(aconn):
     cur = await aconn.cursor()
     await cur.execute("show client_encoding")
     (enc,) = await cur.fetchone()
-    assert enc == newenc
+    assert encodings.pg2py(enc) == newenc
 
 
 @pytest.mark.parametrize(
@@ -234,12 +235,16 @@ async def test_set_encoding(aconn):
         ("utf_8", "UTF8", "utf-8"),
         ("eucjp", "EUC_JP", "euc_jp"),
         ("euc-jp", "EUC_JP", "euc_jp"),
+        ("latin9", "LATIN9", "iso8859-15"),
     ],
 )
 async def test_normalize_encoding(aconn, enc, out, codec):
     await aconn.set_client_encoding(enc)
-    assert aconn.client_encoding == out
-    assert aconn.pyenc == codec
+    assert (
+        aconn.pgconn.parameter_status(b"client_encoding").decode("utf-8")
+        == out
+    )
+    assert aconn.client_encoding == codec
 
 
 @pytest.mark.parametrize(
@@ -255,8 +260,11 @@ async def test_normalize_encoding(aconn, enc, out, codec):
 async def test_encoding_env_var(dsn, monkeypatch, enc, out, codec):
     monkeypatch.setenv("PGCLIENTENCODING", enc)
     aconn = await psycopg3.AsyncConnection.connect(dsn)
-    assert aconn.client_encoding == out
-    assert aconn.pyenc == codec
+    assert (
+        aconn.pgconn.parameter_status(b"client_encoding").decode("utf-8")
+        == out
+    )
+    assert aconn.client_encoding == codec
 
 
 async def test_set_encoding_unsupported(aconn):
index c5b8fd9f6541d9b88f12daaf5ba750a8d0a5d6d6..5dc132eddece68dabd826662aa0606ad31c0628a 100644 (file)
@@ -219,7 +219,7 @@ def make_testfunc(conn):
         )
         .format(sql.Identifier(procname), sql.Identifier(paramname))
         .as_string(conn)
-        .encode(conn.pyenc)
+        .encode(conn.client_encoding)
     )
 
     # execute regardless of sync/async conn
diff --git a/tests/test_encodings.py b/tests/test_encodings.py
new file mode 100644 (file)
index 0000000..0267778
--- /dev/null
@@ -0,0 +1,43 @@
+import codecs
+import pytest
+
+import psycopg3
+from psycopg3 import encodings
+
+
+def test_names_normalised():
+    for name in encodings._py_codecs.values():
+        assert codecs.lookup(name).name == name
+
+
+@pytest.mark.parametrize(
+    "pyenc, pgenc",
+    [
+        ("ascii", "SQL_ASCII"),
+        ("utf8", "UTF8"),
+        ("utf-8", "UTF8"),
+        ("uTf-8", "UTF8"),
+        ("latin9", "LATIN9"),
+        ("iso8859-15", "LATIN9"),
+    ],
+)
+def test_py2pg(pyenc, pgenc):
+    assert encodings.py2pg(pyenc) == pgenc.encode("utf8")
+
+
+@pytest.mark.parametrize(
+    "pyenc, pgenc",
+    [
+        ("ascii", "SQL_ASCII"),
+        ("utf-8", "UTF8"),
+        ("iso8859-15", "LATIN9"),
+    ],
+)
+def test_pg2py(pyenc, pgenc):
+    assert encodings.pg2py(pgenc.encode("utf-8")) == pyenc
+
+
+@pytest.mark.parametrize("pgenc", ["MULE_INTERNAL", "EUC_TW"])
+def test_pg2py_missing(pgenc):
+    with pytest.raises(psycopg3.NotSupportedError):
+        encodings.pg2py(pgenc.encode("utf-8"))
index 01b94804076c7ade31d5156b69731dbd3baa7bd5..268b407a83b964d955c01868c9a1e84b7d92d3ab 100644 (file)
@@ -76,7 +76,7 @@ def test_load_1char(conn, typename, fmt_out):
 
 
 @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
-@pytest.mark.parametrize("encoding", ["utf8", "latin9", "sql_ascii"])
+@pytest.mark.parametrize("encoding", ["utf8", "latin9", "ascii"])
 def test_dump_enc(conn, fmt_in, encoding):
     cur = conn.cursor()
     ph = "%s" if fmt_in == Format.TEXT else "%b"
@@ -124,7 +124,7 @@ def test_load_badenc(conn, typename, fmt_out):
 def test_load_ascii(conn, typename, fmt_out):
     cur = conn.cursor(format=fmt_out)
 
-    conn.client_encoding = "sql_ascii"
+    conn.client_encoding = "ascii"
     (res,) = cur.execute(
         f"select chr(%s::int)::{typename}", (ord(eur),)
     ).fetchone()
@@ -136,7 +136,7 @@ def test_load_ascii(conn, typename, fmt_out):
 def test_load_ascii_encanyway(conn, typename, fmt_out):
     cur = conn.cursor(format=fmt_out)
 
-    conn.client_encoding = "sql_ascii"
+    conn.client_encoding = "ascii"
     (res,) = cur.execute(f"select 'aa'::{typename}").fetchone()
     assert res == "aa"
 
@@ -156,7 +156,7 @@ def test_text_array(conn, typename, fmt_in, fmt_out):
 @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 def test_text_array_ascii(conn, fmt_in, fmt_out):
-    conn.client_encoding = "sql_ascii"
+    conn.client_encoding = "ascii"
     cur = conn.cursor(format=fmt_out)
     a = list(map(chr, range(1, 256))) + [eur]
     exp = [s.encode("utf8") for s in a]