]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(crdb): allow non-normalized encoding name from the database
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 17 May 2022 01:27:34 +0000 (03:27 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:33 +0000 (12:58 +0100)
CockroachDB returns what stored: "utf_8", "utf-8" etc.

All ConnectionInfo tests pass on crdb.

psycopg/psycopg/_encodings.py
tests/crdb/test_conninfo.py
tests/test_conninfo.py

index f672fd521e3dd8214731aa8bb0e0a15b0f65ec01..d293949e34d13835608b8214d02f7d1e6f4dfe5d 100644 (file)
@@ -7,6 +7,7 @@ Mappings between PostgreSQL and Python encodings.
 import re
 import string
 import codecs
+from functools import lru_cache
 from typing import Any, Dict, Optional, TYPE_CHECKING
 
 from .errors import NotSupportedError
@@ -107,13 +108,15 @@ def conninfo_encoding(conninfo: str) -> str:
     params = conninfo_to_dict(conninfo)
     pgenc = params.get("client_encoding")
     if pgenc:
-        pgenc = pgenc.replace("-", "").replace("_", "").upper().encode()
-        if pgenc in py_codecs:
-            return py_codecs[pgenc]
+        try:
+            return pg2pyenc(pgenc.encode())
+        except NotSupportedError:
+            pass
 
     return "utf-8"
 
 
+@lru_cache()
 def py2pgenc(name: str) -> bytes:
     """Convert a Python encoding name to PostgreSQL encoding name.
 
@@ -122,6 +125,7 @@ def py2pgenc(name: str) -> bytes:
     return pg_codecs[codecs.lookup(name).name]
 
 
+@lru_cache()
 def pg2pyenc(name: bytes) -> str:
     """Convert a Python encoding name to PostgreSQL encoding name.
 
@@ -129,7 +133,7 @@ def pg2pyenc(name: bytes) -> str:
     Python.
     """
     try:
-        return py_codecs[name]
+        return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()]
     except KeyError:
         sname = name.decode("utf8", "replace")
         raise NotSupportedError(f"codec not available in Python: {sname!r}")
index 2b9f58b09e86693628104750719f1d27876ba7b4..6fc36d0497acbc4c9712cec60b838f56eb752ff2 100644 (file)
@@ -9,3 +9,7 @@ def test_vendor(conn):
 
 def test_crdb_version(conn):
     assert conn.info.crdb_version > 200000
+
+
+def test_backend_pid(conn):
+    assert conn.info.backend_pid == 0
index 5f0c0338b8c34d6c278aa2db9a22c4df9a40fa64..8531cd573157920237f3890745a63885603b7d5c 100644 (file)
@@ -13,6 +13,10 @@ from psycopg._encodings import pg2pyenc
 snowman = "\u2603"
 
 
+def skip_crdb(*args, reason=None):
+    return pytest.param(*args, marks=pytest.mark.crdb("skip", reason=reason))
+
+
 class MyString(str):
     pass
 
@@ -226,6 +230,7 @@ class TestConnectionInfo:
         with pytest.raises(psycopg.OperationalError):
             conn.info.error_message
 
+    @pytest.mark.crdb("skip", reason="always 0 on crdb")
     def test_backend_pid(self, conn):
         assert conn.info.backend_pid
         assert conn.info.backend_pid == conn.pgconn.backend_pid
@@ -242,6 +247,7 @@ class TestConnectionInfo:
         offset = tz.utcoffset(dt.datetime(2000, 7, 1))
         assert offset and offset.total_seconds() == 7200
 
+    @pytest.mark.crdb("skip", reason="crdb doesn't allow invalid timezones")
     def test_timezone_warn(self, conn, caplog):
         conn.execute("set timezone to 'FOOBAR0'")
         assert len(caplog.records) == 0
@@ -263,6 +269,7 @@ class TestConnectionInfo:
         enc = conn.execute("show client_encoding").fetchone()[0]
         assert conn.info.encoding == pg2pyenc(enc.encode())
 
+    @pytest.mark.crdb("skip", reason="encoding not normalized")
     @pytest.mark.parametrize(
         "enc, out, codec",
         [
@@ -285,17 +292,22 @@ class TestConnectionInfo:
             ("utf8", "UTF8", "utf-8"),
             ("utf-8", "UTF8", "utf-8"),
             ("utf_8", "UTF8", "utf-8"),
-            ("eucjp", "EUC_JP", "euc_jp"),
-            ("euc-jp", "EUC_JP", "euc_jp"),
+            skip_crdb("eucjp", "EUC_JP", "euc_jp", reason="encoding"),
+            skip_crdb("euc-jp", "EUC_JP", "euc_jp", reason="encoding"),
         ],
     )
     def test_encoding_env_var(self, dsn, monkeypatch, enc, out, codec):
         monkeypatch.setenv("PGCLIENTENCODING", enc)
-        conn = psycopg.connect(dsn)
-        assert conn.info.parameter_status("client_encoding") == out
-        assert conn.info.encoding == codec
-        conn.close()
+        with psycopg.connect(dsn) as conn:
+            clienc = conn.info.parameter_status("client_encoding")
+            assert clienc
+            if conn.info.vendor == "PostgreSQL":
+                assert clienc == out
+            else:
+                assert clienc.replace("-", "").replace("_", "").upper() == out
+            assert conn.info.encoding == codec
 
+    @pytest.mark.crdb("skip", reason="encoding")
     def test_set_encoding_unsupported(self, conn):
         cur = conn.cursor()
         cur.execute("set client_encoding to EUC_TW")