]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fixed some shenanigan around connection params
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 02:20:15 +0000 (14:20 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 02:20:15 +0000 (14:20 +1200)
The libpq function return None for an invalid connection: take that into
account in the function signature and usage.

Dropped asymmetry between sync and async connection for the encoding
setter: use for both a set_client_encoding() function. This way it's
easy to keep the getter on the base class.

psycopg3/connection.py
psycopg3/pq/pq_ctypes.py
tests/test_async_connection.py
tests/test_connection.py
tests/test_cursor.py
tests/types/test_text.py

index 0972ff63055bbcd051493da41bc10c0979d70e8a..9976fe014b2bd31a46e38026ea65277cfd75a2b8 100644 (file)
@@ -53,15 +53,21 @@ class BaseConnection:
     @property
     def codec(self) -> codecs.CodecInfo:
         # TODO: utf8 fastpath?
-        pgenc = self.pgconn.parameter_status(b"client_encoding")
+        pgenc = self.pgconn.parameter_status(b"client_encoding") or b""
         if self._pgenc != pgenc:
-            try:
-                pyenc = pq.py_codecs[pgenc.decode("ascii")]
-            except KeyError:
-                raise e.NotSupportedError(
-                    f"encoding {pgenc.decode('ascii')} not available in Python"
-                )
-            self._codec = codecs.lookup(pyenc)
+            if pgenc:
+                try:
+                    pyenc = pq.py_codecs[pgenc.decode("ascii")]
+                except KeyError:
+                    raise e.NotSupportedError(
+                        f"encoding {pgenc.decode('ascii')} not available in Python"
+                    )
+                self._codec = codecs.lookup(pyenc)
+            else:
+                # fallback for a connection closed whose codec was never asked
+                if not hasattr(self, "_codec"):
+                    self._codec = codecs.lookup("utf8")
+
             self._pgenc = pgenc
         return self._codec
 
@@ -73,7 +79,11 @@ class BaseConnection:
 
     @property
     def encoding(self) -> str:
-        return self.pgconn.parameter_status(b"client_encoding").decode("ascii")
+        rv = self.pgconn.parameter_status(b"client_encoding")
+        if rv is not None:
+            return rv.decode("ascii")
+        else:
+            return "UTF8"
 
     @classmethod
     def _connect_gen(cls, conninfo: str) -> ConnectGen:
@@ -217,12 +227,7 @@ class Connection(BaseConnection):
     ) -> RV:
         return wait(gen, timeout=timeout)
 
-    @property
-    def encoding(self) -> str:
-        return self.pgconn.parameter_status(b"client_encoding").decode("ascii")
-
-    @encoding.setter
-    def encoding(self, value: str) -> None:
+    def set_client_encoding(self, value: str) -> None:
         with self.lock:
             self.pgconn.send_query_params(
                 b"select set_config('client_encoding', $1, false)",
@@ -286,18 +291,7 @@ class AsyncConnection(BaseConnection):
     async def wait(cls, gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV:
         return await wait_async(gen)
 
-    @property
-    def encoding(self) -> str:
-        return self.pgconn.parameter_status(b"client_encoding").decode("ascii")
-
-    @encoding.setter
-    def encoding(self, value: str) -> None:
-        raise e.NotSupportedError(
-            "you can't set 'encoding' on an async connection."
-            " Use 'await conn.set_encoding()' instead"
-        )
-
-    async def set_encoding(self, value: str) -> None:
+    async def set_client_encoding(self, value: str) -> None:
         async with self.lock:
             self.pgconn.send_query_params(
                 b"select set_config('client_encoding', $1, false)",
index e84223dcb90712e34a4952b3373c8a96398ef23f..d0f9b3bb3640a5383ee15a307ddc3ad0d9fbd043 100644 (file)
@@ -144,7 +144,7 @@ class PGconn:
         rv = impl.PQtransactionStatus(self.pgconn_ptr)
         return TransactionStatus(rv)
 
-    def parameter_status(self, name: bytes) -> bytes:
+    def parameter_status(self, name: bytes) -> Optional[bytes]:
         return impl.PQparameterStatus(self.pgconn_ptr, name)
 
     @property
index d9255552304b91d5aa1234a85b78e9343880760b..e27b981e6992f59c80a0a27191cb8d06fcd308f0 100644 (file)
@@ -45,17 +45,10 @@ def test_get_encoding(aconn, loop):
     assert enc == aconn.encoding
 
 
-def test_set_encoding_noprop(aconn):
-    newenc = "LATIN1" if aconn.encoding != "LATIN1" else "UTF8"
-    assert aconn.encoding != newenc
-    with pytest.raises(psycopg3.NotSupportedError):
-        aconn.encoding = newenc
-
-
 def test_set_encoding(aconn, loop):
     newenc = "LATIN1" if aconn.encoding != "LATIN1" else "UTF8"
     assert aconn.encoding != newenc
-    loop.run_until_complete(aconn.set_encoding(newenc))
+    loop.run_until_complete(aconn.set_client_encoding(newenc))
     assert aconn.encoding == newenc
     cur = aconn.cursor()
     loop.run_until_complete(cur.execute("show client_encoding"))
@@ -65,4 +58,4 @@ def test_set_encoding(aconn, loop):
 
 def test_set_encoding_bad(aconn, loop):
     with pytest.raises(psycopg3.DatabaseError):
-        loop.run_until_complete(aconn.set_encoding("WAT"))
+        loop.run_until_complete(aconn.set_client_encoding("WAT"))
index 6c963f0d06093601a53b20ee3b3af997c2885c36..89d00fb3d2563d1da9fe3a5c0d1acfbe4268fc16 100644 (file)
@@ -46,18 +46,18 @@ def test_get_encoding(conn):
 def test_set_encoding(conn):
     newenc = "LATIN1" if conn.encoding != "LATIN1" else "UTF8"
     assert conn.encoding != newenc
-    conn.encoding = newenc
+    conn.set_client_encoding(newenc)
     assert conn.encoding == newenc
     (enc,) = conn.cursor().execute("show client_encoding").fetchone()
     assert enc == newenc
 
 
 def test_set_encoding_unsupported(conn):
-    conn.encoding = "EUC_TW"
+    conn.set_client_encoding("EUC_TW")
     with pytest.raises(psycopg3.NotSupportedError):
         conn.cursor().execute("select 1")
 
 
 def test_set_encoding_bad(conn):
     with pytest.raises(psycopg3.DatabaseError):
-        conn.encoding = "WAT"
+        conn.set_client_encoding("WAT")
index 5aaed71b57284db06c3cdf53ecd7b05cb8889967..faa46a42ccaebf51dbf296628979626eeef9423a 100644 (file)
@@ -50,14 +50,14 @@ def test_execute_binary_result(conn):
 
 @pytest.mark.parametrize("encoding", ["utf8", "latin9"])
 def test_query_encode(conn, encoding):
-    conn.encoding = encoding
+    conn.set_client_encoding(encoding)
     cur = conn.cursor()
     (res,) = cur.execute("select '\u20ac'").fetchone()
     assert res == "\u20ac"
 
 
 def test_query_badenc(conn):
-    conn.encoding = "latin1"
+    conn.set_client_encoding("latin1")
     cur = conn.cursor()
     with pytest.raises(UnicodeEncodeError):
         cur.execute("select '\u20ac'")
index 01bf2784d9da95b198ad8c391763ad83702d559f..977fa8111669035a4654640009ae2665bc7a146f 100644 (file)
@@ -38,7 +38,7 @@ def test_dump_enc(conn, fmt_in, encoding):
     cur = conn.cursor()
     ph = "%s" if fmt_in == Format.TEXT else "%b"
 
-    conn.encoding = encoding
+    conn.set_client_encoding(encoding)
     (res,) = cur.execute(f"select {ph}::bytea", (eur,)).fetchone()
     assert res == eur.encode("utf8")
 
@@ -48,7 +48,7 @@ def test_dump_ascii(conn, fmt_in):
     cur = conn.cursor()
     ph = "%s" if fmt_in == Format.TEXT else "%b"
 
-    conn.encoding = "sql_ascii"
+    conn.set_client_encoding("sql_ascii")
     (res,) = cur.execute(f"select ascii({ph})", (eur,)).fetchone()
     assert res == ord(eur)
 
@@ -58,7 +58,7 @@ def test_dump_badenc(conn, fmt_in):
     cur = conn.cursor()
     ph = "%s" if fmt_in == Format.TEXT else "%b"
 
-    conn.encoding = "latin1"
+    conn.set_client_encoding("latin1")
     with pytest.raises(UnicodeEncodeError):
         cur.execute(f"select {ph}::bytea", (eur,))
 
@@ -69,7 +69,7 @@ def test_dump_badenc(conn, fmt_in):
 def test_load_enc(conn, typename, encoding, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
 
-    conn.encoding = encoding
+    conn.set_client_encoding(encoding)
     (res,) = cur.execute(
         f"select chr(%s::int)::{typename}", (ord(eur),)
     ).fetchone()
@@ -81,7 +81,7 @@ def test_load_enc(conn, typename, encoding, fmt_out):
 def test_load_badenc(conn, typename, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
 
-    conn.encoding = "latin1"
+    conn.set_client_encoding("latin1")
     with pytest.raises(psycopg3.DatabaseError):
         cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
 
@@ -91,7 +91,7 @@ def test_load_badenc(conn, typename, fmt_out):
 def test_load_ascii(conn, typename, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
 
-    conn.encoding = "sql_ascii"
+    conn.set_client_encoding("sql_ascii")
     (res,) = cur.execute(
         f"select chr(%s::int)::{typename}", (ord(eur),)
     ).fetchone()
@@ -103,7 +103,7 @@ def test_load_ascii(conn, typename, fmt_out):
 def test_load_ascii_encanyway(conn, typename, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
 
-    conn.encoding = "sql_ascii"
+    conn.set_client_encoding("sql_ascii")
     (res,) = cur.execute(f"select 'aa'::{typename}").fetchone()
     assert res == "aa"
 
@@ -123,7 +123,7 @@ def test_text_array(conn, typename, fmt_in, fmt_out):
 @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 def test_text_array_ascii(conn, fmt_in, fmt_out):
-    conn.encoding = "sql_ascii"
+    conn.set_client_encoding("sql_ascii")
     cur = conn.cursor(binary=fmt_out == Format.BINARY)
     a = list(map(chr, range(1, 256))) + [eur]
     exp = [s.encode("utf8") for s in a]