]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Connection.encoding made writable again
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 25 Jul 2020 11:17:51 +0000 (12:17 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 26 Oct 2020 16:18:18 +0000 (17:18 +0100)
I'm so flipflopping on this...

psycopg3/psycopg3/connection.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_copy.py
tests/test_cursor.py
tests/test_errors.py
tests/types/test_text.py

index bac9a8a0abbb3fef86ced7bed635e8fb02c0547b..12eeb0e790a13a9aafa79ba160add3d0e2c2de44 100644 (file)
@@ -154,6 +154,13 @@ class BaseConnection:
         else:
             return "UTF8"
 
+    @encoding.setter
+    def encoding(self, value: str) -> None:
+        self._set_client_encoding(value)
+
+    def _set_client_encoding(self, value: str) -> None:
+        raise NotImplementedError
+
     def cancel(self) -> None:
         c = self.pgconn.get_cancel()
         c.cancel()
@@ -283,7 +290,7 @@ class Connection(BaseConnection):
     ) -> proto.RV:
         return wait(gen, timeout=timeout)
 
-    def set_client_encoding(self, value: str) -> None:
+    def _set_client_encoding(self, value: str) -> None:
         with self.lock:
             self.pgconn.send_query_params(
                 b"select set_config('client_encoding', $1, false)",
@@ -393,6 +400,12 @@ class AsyncConnection(BaseConnection):
     async def wait(cls, gen: proto.PQGen[proto.RV]) -> proto.RV:
         return await wait_async(gen)
 
+    def _set_client_encoding(self, value: str) -> None:
+        raise AttributeError(
+            "'encoding' is read-only on async connections:"
+            " please use await .set_client_encoding() instead."
+        )
+
     async def set_client_encoding(self, value: str) -> None:
         async with self.lock:
             self.pgconn.send_query_params(
index 547030621084514d5913f9ec619658f432099c58..abbbbbc5711cfe4842c29810ab437242d83bbb59 100644 (file)
@@ -165,7 +165,7 @@ def test_get_encoding(conn):
 def test_set_encoding(conn):
     newenc = "LATIN1" if conn.encoding != "LATIN1" else "UTF8"
     assert conn.encoding != newenc
-    conn.set_client_encoding(newenc)
+    conn.encoding = newenc
     assert conn.encoding == newenc
     (enc,) = conn.cursor().execute("show client_encoding").fetchone()
     assert enc == newenc
@@ -182,7 +182,7 @@ def test_set_encoding(conn):
     ],
 )
 def test_normalize_encoding(conn, enc, out, codec):
-    conn.set_client_encoding(enc)
+    conn.encoding = enc
     assert conn.encoding == out
     assert conn.codec.name == codec
 
@@ -205,14 +205,14 @@ def test_encoding_env_var(dsn, monkeypatch, enc, out, codec):
 
 
 def test_set_encoding_unsupported(conn):
-    conn.set_client_encoding("EUC_TW")
+    conn.encoding = "EUC_TW"
     with pytest.raises(psycopg3.NotSupportedError):
         conn.cursor().execute("select 1")
 
 
 def test_set_encoding_bad(conn):
     with pytest.raises(psycopg3.DatabaseError):
-        conn.set_client_encoding("WAT")
+        conn.encoding = "WAT"
 
 
 @pytest.mark.parametrize(
index e62451b946a8e4100747e6c4d4229ce339d2f80c..acb45b5e7fc9059984ea634912b145e2a0deb143 100644 (file)
@@ -120,7 +120,7 @@ async def test_auto_transaction_fail(aconn):
 
 async def test_autocommit(aconn):
     assert aconn.autocommit is False
-    with pytest.raises(TypeError):
+    with pytest.raises(AttributeError):
         aconn.autocommit = True
     assert not aconn.autocommit
 
@@ -175,6 +175,9 @@ async def test_get_encoding(aconn):
 async def test_set_encoding(aconn):
     newenc = "LATIN1" if aconn.encoding != "LATIN1" else "UTF8"
     assert aconn.encoding != newenc
+    with pytest.raises(AttributeError):
+        aconn.encoding = newenc
+    assert aconn.encoding != newenc
     await aconn.set_client_encoding(newenc)
     assert aconn.encoding == newenc
     cur = aconn.cursor()
index a2265ea8ab3533dfe6378be79ef4e492d254b0a8..e7433e5a66089e886d5167a909de3774d8122425 100644 (file)
@@ -180,7 +180,7 @@ def test_copy_in_allchars(conn):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
 
-    conn.set_client_encoding("utf8")
+    conn.encoding = "utf8"
     with cur.copy("copy copy_in from stdin (format text)") as copy:
         for i in range(1, 256):
             copy.write_row((i, None, chr(i)))
index befd7f3f41a35a3488fdae975398e0221b7cb994..21cdd467e3f293eb2c85e562a042ade348f6ae62 100644 (file)
@@ -101,14 +101,14 @@ def test_execute_binary_result(conn):
 
 @pytest.mark.parametrize("encoding", ["utf8", "latin9"])
 def test_query_encode(conn, encoding):
-    conn.set_client_encoding(encoding)
+    conn.encoding = encoding
     cur = conn.cursor()
     (res,) = cur.execute("select '\u20ac'").fetchone()
     assert res == "\u20ac"
 
 
 def test_query_badenc(conn):
-    conn.set_client_encoding("latin1")
+    conn.encoding = "latin1"
     cur = conn.cursor()
     with pytest.raises(UnicodeEncodeError):
         cur.execute("select '\u20ac'")
index a1fa6c24a3a502e250678594303278606c855bef..f3dd8c2ae01d0c411ace962c38d11b0d2623386d 100644 (file)
@@ -49,7 +49,7 @@ def test_diag_encoding(conn, enc):
     msgs = []
     conn.pgconn.exec_(b"set client_min_messages to notice")
     conn.add_notice_handler(lambda diag: msgs.append(diag.message_primary))
-    conn.set_client_encoding(enc)
+    conn.encoding = enc
     cur = conn.cursor()
     cur.execute(
         "do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql"
@@ -59,7 +59,7 @@ def test_diag_encoding(conn, enc):
 
 @pytest.mark.parametrize("enc", ["utf8", "latin9"])
 def test_error_encoding(conn, enc):
-    conn.set_client_encoding(enc)
+    conn.encoding = enc
     cur = conn.cursor()
     with pytest.raises(e.DatabaseError) as excinfo:
         cur.execute(
index 28c09ab6e47416312e44d195d870f3f687bd5332..758d482444910832581f2e4d25e0dcd1a0c42198 100644 (file)
@@ -38,7 +38,7 @@ def test_dump_enc(conn, fmt_in, encoding):
     cur = conn.cursor()
     ph = "%s" if fmt_in == Format.TEXT else "%b"
 
-    conn.set_client_encoding(encoding)
+    conn.encoding = encoding
     (res,) = cur.execute(f"select {ph}::bytea", (eur,)).fetchone()
     assert res == eur.encode("utf8")
 
@@ -48,7 +48,7 @@ def test_dump_ascii(conn, fmt_in):
     cur = conn.cursor()
     ph = "%s" if fmt_in == Format.TEXT else "%b"
 
-    conn.set_client_encoding("sql_ascii")
+    conn.encoding = "sql_ascii"
     (res,) = cur.execute(f"select ascii({ph})", (eur,)).fetchone()
     assert res == ord(eur)
 
@@ -58,7 +58,7 @@ def test_dump_badenc(conn, fmt_in):
     cur = conn.cursor()
     ph = "%s" if fmt_in == Format.TEXT else "%b"
 
-    conn.set_client_encoding("latin1")
+    conn.encoding = "latin1"
     with pytest.raises(UnicodeEncodeError):
         cur.execute(f"select {ph}::bytea", (eur,))
 
@@ -69,7 +69,7 @@ def test_dump_badenc(conn, fmt_in):
 def test_load_enc(conn, typename, encoding, fmt_out):
     cur = conn.cursor(format=fmt_out)
 
-    conn.set_client_encoding(encoding)
+    conn.encoding = encoding
     (res,) = cur.execute(
         f"select chr(%s::int)::{typename}", (ord(eur),)
     ).fetchone()
@@ -81,7 +81,7 @@ def test_load_enc(conn, typename, encoding, fmt_out):
 def test_load_badenc(conn, typename, fmt_out):
     cur = conn.cursor(format=fmt_out)
 
-    conn.set_client_encoding("latin1")
+    conn.encoding = "latin1"
     with pytest.raises(psycopg3.DatabaseError):
         cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
 
@@ -91,7 +91,7 @@ def test_load_badenc(conn, typename, fmt_out):
 def test_load_ascii(conn, typename, fmt_out):
     cur = conn.cursor(format=fmt_out)
 
-    conn.set_client_encoding("sql_ascii")
+    conn.encoding = "sql_ascii"
     (res,) = cur.execute(
         f"select chr(%s::int)::{typename}", (ord(eur),)
     ).fetchone()
@@ -103,7 +103,7 @@ def test_load_ascii(conn, typename, fmt_out):
 def test_load_ascii_encanyway(conn, typename, fmt_out):
     cur = conn.cursor(format=fmt_out)
 
-    conn.set_client_encoding("sql_ascii")
+    conn.encoding = "sql_ascii"
     (res,) = cur.execute(f"select 'aa'::{typename}").fetchone()
     assert res == "aa"
 
@@ -123,7 +123,7 @@ def test_text_array(conn, typename, fmt_in, fmt_out):
 @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 def test_text_array_ascii(conn, fmt_in, fmt_out):
-    conn.set_client_encoding("sql_ascii")
+    conn.encoding = "sql_ascii"
     cur = conn.cursor(format=fmt_out)
     a = list(map(chr, range(1, 256))) + [eur]
     exp = [s.encode("utf8") for s in a]