From: Daniele Varrazzo Date: Sat, 11 Apr 2020 02:20:15 +0000 (+1200) Subject: Fixed some shenanigan around connection params X-Git-Tag: 3.0.dev0~583 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=3294e3cf6a846677bdded934ca1ebb9f4fa7fe15;p=thirdparty%2Fpsycopg.git Fixed some shenanigan around connection params The libpq function return None for an invalid connection: take that into account in the function signature and usage. Dropped asymmetry between sync and async connection for the encoding setter: use for both a set_client_encoding() function. This way it's easy to keep the getter on the base class. --- diff --git a/psycopg3/connection.py b/psycopg3/connection.py index 0972ff630..9976fe014 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -53,15 +53,21 @@ class BaseConnection: @property def codec(self) -> codecs.CodecInfo: # TODO: utf8 fastpath? - pgenc = self.pgconn.parameter_status(b"client_encoding") + pgenc = self.pgconn.parameter_status(b"client_encoding") or b"" if self._pgenc != pgenc: - try: - pyenc = pq.py_codecs[pgenc.decode("ascii")] - except KeyError: - raise e.NotSupportedError( - f"encoding {pgenc.decode('ascii')} not available in Python" - ) - self._codec = codecs.lookup(pyenc) + if pgenc: + try: + pyenc = pq.py_codecs[pgenc.decode("ascii")] + except KeyError: + raise e.NotSupportedError( + f"encoding {pgenc.decode('ascii')} not available in Python" + ) + self._codec = codecs.lookup(pyenc) + else: + # fallback for a connection closed whose codec was never asked + if not hasattr(self, "_codec"): + self._codec = codecs.lookup("utf8") + self._pgenc = pgenc return self._codec @@ -73,7 +79,11 @@ class BaseConnection: @property def encoding(self) -> str: - return self.pgconn.parameter_status(b"client_encoding").decode("ascii") + rv = self.pgconn.parameter_status(b"client_encoding") + if rv is not None: + return rv.decode("ascii") + else: + return "UTF8" @classmethod def _connect_gen(cls, conninfo: str) -> ConnectGen: @@ -217,12 +227,7 @@ class Connection(BaseConnection): ) -> RV: return wait(gen, timeout=timeout) - @property - def encoding(self) -> str: - return self.pgconn.parameter_status(b"client_encoding").decode("ascii") - - @encoding.setter - def encoding(self, value: str) -> None: + def set_client_encoding(self, value: str) -> None: with self.lock: self.pgconn.send_query_params( b"select set_config('client_encoding', $1, false)", @@ -286,18 +291,7 @@ class AsyncConnection(BaseConnection): async def wait(cls, gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV: return await wait_async(gen) - @property - def encoding(self) -> str: - return self.pgconn.parameter_status(b"client_encoding").decode("ascii") - - @encoding.setter - def encoding(self, value: str) -> None: - raise e.NotSupportedError( - "you can't set 'encoding' on an async connection." - " Use 'await conn.set_encoding()' instead" - ) - - async def set_encoding(self, value: str) -> None: + async def set_client_encoding(self, value: str) -> None: async with self.lock: self.pgconn.send_query_params( b"select set_config('client_encoding', $1, false)", diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index e84223dcb..d0f9b3bb3 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -144,7 +144,7 @@ class PGconn: rv = impl.PQtransactionStatus(self.pgconn_ptr) return TransactionStatus(rv) - def parameter_status(self, name: bytes) -> bytes: + def parameter_status(self, name: bytes) -> Optional[bytes]: return impl.PQparameterStatus(self.pgconn_ptr, name) @property diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index d92555523..e27b981e6 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -45,17 +45,10 @@ def test_get_encoding(aconn, loop): assert enc == aconn.encoding -def test_set_encoding_noprop(aconn): - newenc = "LATIN1" if aconn.encoding != "LATIN1" else "UTF8" - assert aconn.encoding != newenc - with pytest.raises(psycopg3.NotSupportedError): - aconn.encoding = newenc - - def test_set_encoding(aconn, loop): newenc = "LATIN1" if aconn.encoding != "LATIN1" else "UTF8" assert aconn.encoding != newenc - loop.run_until_complete(aconn.set_encoding(newenc)) + loop.run_until_complete(aconn.set_client_encoding(newenc)) assert aconn.encoding == newenc cur = aconn.cursor() loop.run_until_complete(cur.execute("show client_encoding")) @@ -65,4 +58,4 @@ def test_set_encoding(aconn, loop): def test_set_encoding_bad(aconn, loop): with pytest.raises(psycopg3.DatabaseError): - loop.run_until_complete(aconn.set_encoding("WAT")) + loop.run_until_complete(aconn.set_client_encoding("WAT")) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6c963f0d0..89d00fb3d 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -46,18 +46,18 @@ def test_get_encoding(conn): def test_set_encoding(conn): newenc = "LATIN1" if conn.encoding != "LATIN1" else "UTF8" assert conn.encoding != newenc - conn.encoding = newenc + conn.set_client_encoding(newenc) assert conn.encoding == newenc (enc,) = conn.cursor().execute("show client_encoding").fetchone() assert enc == newenc def test_set_encoding_unsupported(conn): - conn.encoding = "EUC_TW" + conn.set_client_encoding("EUC_TW") with pytest.raises(psycopg3.NotSupportedError): conn.cursor().execute("select 1") def test_set_encoding_bad(conn): with pytest.raises(psycopg3.DatabaseError): - conn.encoding = "WAT" + conn.set_client_encoding("WAT") diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 5aaed71b5..faa46a42c 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -50,14 +50,14 @@ def test_execute_binary_result(conn): @pytest.mark.parametrize("encoding", ["utf8", "latin9"]) def test_query_encode(conn, encoding): - conn.encoding = encoding + conn.set_client_encoding(encoding) cur = conn.cursor() (res,) = cur.execute("select '\u20ac'").fetchone() assert res == "\u20ac" def test_query_badenc(conn): - conn.encoding = "latin1" + conn.set_client_encoding("latin1") cur = conn.cursor() with pytest.raises(UnicodeEncodeError): cur.execute("select '\u20ac'") diff --git a/tests/types/test_text.py b/tests/types/test_text.py index 01bf2784d..977fa8111 100644 --- a/tests/types/test_text.py +++ b/tests/types/test_text.py @@ -38,7 +38,7 @@ def test_dump_enc(conn, fmt_in, encoding): cur = conn.cursor() ph = "%s" if fmt_in == Format.TEXT else "%b" - conn.encoding = encoding + conn.set_client_encoding(encoding) (res,) = cur.execute(f"select {ph}::bytea", (eur,)).fetchone() assert res == eur.encode("utf8") @@ -48,7 +48,7 @@ def test_dump_ascii(conn, fmt_in): cur = conn.cursor() ph = "%s" if fmt_in == Format.TEXT else "%b" - conn.encoding = "sql_ascii" + conn.set_client_encoding("sql_ascii") (res,) = cur.execute(f"select ascii({ph})", (eur,)).fetchone() assert res == ord(eur) @@ -58,7 +58,7 @@ def test_dump_badenc(conn, fmt_in): cur = conn.cursor() ph = "%s" if fmt_in == Format.TEXT else "%b" - conn.encoding = "latin1" + conn.set_client_encoding("latin1") with pytest.raises(UnicodeEncodeError): cur.execute(f"select {ph}::bytea", (eur,)) @@ -69,7 +69,7 @@ def test_dump_badenc(conn, fmt_in): def test_load_enc(conn, typename, encoding, fmt_out): cur = conn.cursor(binary=fmt_out == Format.BINARY) - conn.encoding = encoding + conn.set_client_encoding(encoding) (res,) = cur.execute( f"select chr(%s::int)::{typename}", (ord(eur),) ).fetchone() @@ -81,7 +81,7 @@ def test_load_enc(conn, typename, encoding, fmt_out): def test_load_badenc(conn, typename, fmt_out): cur = conn.cursor(binary=fmt_out == Format.BINARY) - conn.encoding = "latin1" + conn.set_client_encoding("latin1") with pytest.raises(psycopg3.DatabaseError): cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),)) @@ -91,7 +91,7 @@ def test_load_badenc(conn, typename, fmt_out): def test_load_ascii(conn, typename, fmt_out): cur = conn.cursor(binary=fmt_out == Format.BINARY) - conn.encoding = "sql_ascii" + conn.set_client_encoding("sql_ascii") (res,) = cur.execute( f"select chr(%s::int)::{typename}", (ord(eur),) ).fetchone() @@ -103,7 +103,7 @@ def test_load_ascii(conn, typename, fmt_out): def test_load_ascii_encanyway(conn, typename, fmt_out): cur = conn.cursor(binary=fmt_out == Format.BINARY) - conn.encoding = "sql_ascii" + conn.set_client_encoding("sql_ascii") (res,) = cur.execute(f"select 'aa'::{typename}").fetchone() assert res == "aa" @@ -123,7 +123,7 @@ def test_text_array(conn, typename, fmt_in, fmt_out): @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) def test_text_array_ascii(conn, fmt_in, fmt_out): - conn.encoding = "sql_ascii" + conn.set_client_encoding("sql_ascii") cur = conn.cursor(binary=fmt_out == Format.BINARY) a = list(map(chr, range(1, 256))) + [eur] exp = [s.encode("utf8") for s in a]