From: Daniele Varrazzo Date: Tue, 17 May 2022 01:27:34 +0000 (+0200) Subject: fix(crdb): allow non-normalized encoding name from the database X-Git-Tag: 3.1~49^2~63 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=60d0a7059b1ff1350a57f09bf7a247727db5172f;p=thirdparty%2Fpsycopg.git fix(crdb): allow non-normalized encoding name from the database CockroachDB returns what stored: "utf_8", "utf-8" etc. All ConnectionInfo tests pass on crdb. --- diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py index f672fd521..d293949e3 100644 --- a/psycopg/psycopg/_encodings.py +++ b/psycopg/psycopg/_encodings.py @@ -7,6 +7,7 @@ Mappings between PostgreSQL and Python encodings. import re import string import codecs +from functools import lru_cache from typing import Any, Dict, Optional, TYPE_CHECKING from .errors import NotSupportedError @@ -107,13 +108,15 @@ def conninfo_encoding(conninfo: str) -> str: params = conninfo_to_dict(conninfo) pgenc = params.get("client_encoding") if pgenc: - pgenc = pgenc.replace("-", "").replace("_", "").upper().encode() - if pgenc in py_codecs: - return py_codecs[pgenc] + try: + return pg2pyenc(pgenc.encode()) + except NotSupportedError: + pass return "utf-8" +@lru_cache() def py2pgenc(name: str) -> bytes: """Convert a Python encoding name to PostgreSQL encoding name. @@ -122,6 +125,7 @@ def py2pgenc(name: str) -> bytes: return pg_codecs[codecs.lookup(name).name] +@lru_cache() def pg2pyenc(name: bytes) -> str: """Convert a Python encoding name to PostgreSQL encoding name. @@ -129,7 +133,7 @@ def pg2pyenc(name: bytes) -> str: Python. """ try: - return py_codecs[name] + return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()] except KeyError: sname = name.decode("utf8", "replace") raise NotSupportedError(f"codec not available in Python: {sname!r}") diff --git a/tests/crdb/test_conninfo.py b/tests/crdb/test_conninfo.py index 2b9f58b09..6fc36d049 100644 --- a/tests/crdb/test_conninfo.py +++ b/tests/crdb/test_conninfo.py @@ -9,3 +9,7 @@ def test_vendor(conn): def test_crdb_version(conn): assert conn.info.crdb_version > 200000 + + +def test_backend_pid(conn): + assert conn.info.backend_pid == 0 diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index 5f0c0338b..8531cd573 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -13,6 +13,10 @@ from psycopg._encodings import pg2pyenc snowman = "\u2603" +def skip_crdb(*args, reason=None): + return pytest.param(*args, marks=pytest.mark.crdb("skip", reason=reason)) + + class MyString(str): pass @@ -226,6 +230,7 @@ class TestConnectionInfo: with pytest.raises(psycopg.OperationalError): conn.info.error_message + @pytest.mark.crdb("skip", reason="always 0 on crdb") def test_backend_pid(self, conn): assert conn.info.backend_pid assert conn.info.backend_pid == conn.pgconn.backend_pid @@ -242,6 +247,7 @@ class TestConnectionInfo: offset = tz.utcoffset(dt.datetime(2000, 7, 1)) assert offset and offset.total_seconds() == 7200 + @pytest.mark.crdb("skip", reason="crdb doesn't allow invalid timezones") def test_timezone_warn(self, conn, caplog): conn.execute("set timezone to 'FOOBAR0'") assert len(caplog.records) == 0 @@ -263,6 +269,7 @@ class TestConnectionInfo: enc = conn.execute("show client_encoding").fetchone()[0] assert conn.info.encoding == pg2pyenc(enc.encode()) + @pytest.mark.crdb("skip", reason="encoding not normalized") @pytest.mark.parametrize( "enc, out, codec", [ @@ -285,17 +292,22 @@ class TestConnectionInfo: ("utf8", "UTF8", "utf-8"), ("utf-8", "UTF8", "utf-8"), ("utf_8", "UTF8", "utf-8"), - ("eucjp", "EUC_JP", "euc_jp"), - ("euc-jp", "EUC_JP", "euc_jp"), + skip_crdb("eucjp", "EUC_JP", "euc_jp", reason="encoding"), + skip_crdb("euc-jp", "EUC_JP", "euc_jp", reason="encoding"), ], ) def test_encoding_env_var(self, dsn, monkeypatch, enc, out, codec): monkeypatch.setenv("PGCLIENTENCODING", enc) - conn = psycopg.connect(dsn) - assert conn.info.parameter_status("client_encoding") == out - assert conn.info.encoding == codec - conn.close() + with psycopg.connect(dsn) as conn: + clienc = conn.info.parameter_status("client_encoding") + assert clienc + if conn.info.vendor == "PostgreSQL": + assert clienc == out + else: + assert clienc.replace("-", "").replace("_", "").upper() == out + assert conn.info.encoding == codec + @pytest.mark.crdb("skip", reason="encoding") def test_set_encoding_unsupported(self, conn): cur = conn.cursor() cur.execute("set client_encoding to EUC_TW")