From: Daniele Varrazzo Date: Sat, 25 Jul 2020 11:17:51 +0000 (+0100) Subject: Connection.encoding made writable again X-Git-Tag: 3.0.dev0~453 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bf89ebb8fd862183a55b868708ca03e866cf60cd;p=thirdparty%2Fpsycopg.git Connection.encoding made writable again I'm so flipflopping on this... --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index bac9a8a0a..12eeb0e79 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -154,6 +154,13 @@ class BaseConnection: else: return "UTF8" + @encoding.setter + def encoding(self, value: str) -> None: + self._set_client_encoding(value) + + def _set_client_encoding(self, value: str) -> None: + raise NotImplementedError + def cancel(self) -> None: c = self.pgconn.get_cancel() c.cancel() @@ -283,7 +290,7 @@ class Connection(BaseConnection): ) -> proto.RV: return wait(gen, timeout=timeout) - def set_client_encoding(self, value: str) -> None: + def _set_client_encoding(self, value: str) -> None: with self.lock: self.pgconn.send_query_params( b"select set_config('client_encoding', $1, false)", @@ -393,6 +400,12 @@ class AsyncConnection(BaseConnection): async def wait(cls, gen: proto.PQGen[proto.RV]) -> proto.RV: return await wait_async(gen) + def _set_client_encoding(self, value: str) -> None: + raise AttributeError( + "'encoding' is read-only on async connections:" + " please use await .set_client_encoding() instead." + ) + async def set_client_encoding(self, value: str) -> None: async with self.lock: self.pgconn.send_query_params( diff --git a/tests/test_connection.py b/tests/test_connection.py index 547030621..abbbbbc57 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -165,7 +165,7 @@ def test_get_encoding(conn): def test_set_encoding(conn): newenc = "LATIN1" if conn.encoding != "LATIN1" else "UTF8" assert conn.encoding != newenc - conn.set_client_encoding(newenc) + conn.encoding = newenc assert conn.encoding == newenc (enc,) = conn.cursor().execute("show client_encoding").fetchone() assert enc == newenc @@ -182,7 +182,7 @@ def test_set_encoding(conn): ], ) def test_normalize_encoding(conn, enc, out, codec): - conn.set_client_encoding(enc) + conn.encoding = enc assert conn.encoding == out assert conn.codec.name == codec @@ -205,14 +205,14 @@ def test_encoding_env_var(dsn, monkeypatch, enc, out, codec): def test_set_encoding_unsupported(conn): - conn.set_client_encoding("EUC_TW") + conn.encoding = "EUC_TW" with pytest.raises(psycopg3.NotSupportedError): conn.cursor().execute("select 1") def test_set_encoding_bad(conn): with pytest.raises(psycopg3.DatabaseError): - conn.set_client_encoding("WAT") + conn.encoding = "WAT" @pytest.mark.parametrize( diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index e62451b94..acb45b5e7 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -120,7 +120,7 @@ async def test_auto_transaction_fail(aconn): async def test_autocommit(aconn): assert aconn.autocommit is False - with pytest.raises(TypeError): + with pytest.raises(AttributeError): aconn.autocommit = True assert not aconn.autocommit @@ -175,6 +175,9 @@ async def test_get_encoding(aconn): async def test_set_encoding(aconn): newenc = "LATIN1" if aconn.encoding != "LATIN1" else "UTF8" assert aconn.encoding != newenc + with pytest.raises(AttributeError): + aconn.encoding = newenc + assert aconn.encoding != newenc await aconn.set_client_encoding(newenc) assert aconn.encoding == newenc cur = aconn.cursor() diff --git a/tests/test_copy.py b/tests/test_copy.py index a2265ea8a..e7433e5a6 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -180,7 +180,7 @@ def test_copy_in_allchars(conn): cur = conn.cursor() ensure_table(cur, sample_tabledef) - conn.set_client_encoding("utf8") + conn.encoding = "utf8" with cur.copy("copy copy_in from stdin (format text)") as copy: for i in range(1, 256): copy.write_row((i, None, chr(i))) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index befd7f3f4..21cdd467e 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -101,14 +101,14 @@ def test_execute_binary_result(conn): @pytest.mark.parametrize("encoding", ["utf8", "latin9"]) def test_query_encode(conn, encoding): - conn.set_client_encoding(encoding) + conn.encoding = encoding cur = conn.cursor() (res,) = cur.execute("select '\u20ac'").fetchone() assert res == "\u20ac" def test_query_badenc(conn): - conn.set_client_encoding("latin1") + conn.encoding = "latin1" cur = conn.cursor() with pytest.raises(UnicodeEncodeError): cur.execute("select '\u20ac'") diff --git a/tests/test_errors.py b/tests/test_errors.py index a1fa6c24a..f3dd8c2ae 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -49,7 +49,7 @@ def test_diag_encoding(conn, enc): msgs = [] conn.pgconn.exec_(b"set client_min_messages to notice") conn.add_notice_handler(lambda diag: msgs.append(diag.message_primary)) - conn.set_client_encoding(enc) + conn.encoding = enc cur = conn.cursor() cur.execute( "do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql" @@ -59,7 +59,7 @@ def test_diag_encoding(conn, enc): @pytest.mark.parametrize("enc", ["utf8", "latin9"]) def test_error_encoding(conn, enc): - conn.set_client_encoding(enc) + conn.encoding = enc cur = conn.cursor() with pytest.raises(e.DatabaseError) as excinfo: cur.execute( diff --git a/tests/types/test_text.py b/tests/types/test_text.py index 28c09ab6e..758d48244 100644 --- a/tests/types/test_text.py +++ b/tests/types/test_text.py @@ -38,7 +38,7 @@ def test_dump_enc(conn, fmt_in, encoding): cur = conn.cursor() ph = "%s" if fmt_in == Format.TEXT else "%b" - conn.set_client_encoding(encoding) + conn.encoding = encoding (res,) = cur.execute(f"select {ph}::bytea", (eur,)).fetchone() assert res == eur.encode("utf8") @@ -48,7 +48,7 @@ def test_dump_ascii(conn, fmt_in): cur = conn.cursor() ph = "%s" if fmt_in == Format.TEXT else "%b" - conn.set_client_encoding("sql_ascii") + conn.encoding = "sql_ascii" (res,) = cur.execute(f"select ascii({ph})", (eur,)).fetchone() assert res == ord(eur) @@ -58,7 +58,7 @@ def test_dump_badenc(conn, fmt_in): cur = conn.cursor() ph = "%s" if fmt_in == Format.TEXT else "%b" - conn.set_client_encoding("latin1") + conn.encoding = "latin1" with pytest.raises(UnicodeEncodeError): cur.execute(f"select {ph}::bytea", (eur,)) @@ -69,7 +69,7 @@ def test_dump_badenc(conn, fmt_in): def test_load_enc(conn, typename, encoding, fmt_out): cur = conn.cursor(format=fmt_out) - conn.set_client_encoding(encoding) + conn.encoding = encoding (res,) = cur.execute( f"select chr(%s::int)::{typename}", (ord(eur),) ).fetchone() @@ -81,7 +81,7 @@ def test_load_enc(conn, typename, encoding, fmt_out): def test_load_badenc(conn, typename, fmt_out): cur = conn.cursor(format=fmt_out) - conn.set_client_encoding("latin1") + conn.encoding = "latin1" with pytest.raises(psycopg3.DatabaseError): cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),)) @@ -91,7 +91,7 @@ def test_load_badenc(conn, typename, fmt_out): def test_load_ascii(conn, typename, fmt_out): cur = conn.cursor(format=fmt_out) - conn.set_client_encoding("sql_ascii") + conn.encoding = "sql_ascii" (res,) = cur.execute( f"select chr(%s::int)::{typename}", (ord(eur),) ).fetchone() @@ -103,7 +103,7 @@ def test_load_ascii(conn, typename, fmt_out): def test_load_ascii_encanyway(conn, typename, fmt_out): cur = conn.cursor(format=fmt_out) - conn.set_client_encoding("sql_ascii") + conn.encoding = "sql_ascii" (res,) = cur.execute(f"select 'aa'::{typename}").fetchone() assert res == "aa" @@ -123,7 +123,7 @@ def test_text_array(conn, typename, fmt_in, fmt_out): @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) def test_text_array_ascii(conn, fmt_in, fmt_out): - conn.set_client_encoding("sql_ascii") + conn.encoding = "sql_ascii" cur = conn.cursor(format=fmt_out) a = list(map(chr, range(1, 256))) + [eur] exp = [s.encode("utf8") for s in a]