]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added connection.encoding
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 2 Apr 2020 12:26:57 +0000 (01:26 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 2 Apr 2020 12:45:50 +0000 (01:45 +1300)
psycopg3/connection.py
psycopg3/cursor.py
psycopg3/errors.py
tests/test_async_connection.py
tests/test_connection.py

index 0de2b9d5001d30b13a1fda0b986bb1c5d7d103a0..bfe1237ca7a01f9e5cedea3604df4e7348c2bc40 100644 (file)
@@ -209,6 +209,22 @@ class Connection(BaseConnection):
     ) -> RV:
         return wait(gen, timeout=timeout)
 
+    @property
+    def encoding(self) -> str:
+        return self.pgconn.parameter_status(b"client_encoding").decode("ascii")
+
+    @encoding.setter
+    def encoding(self, value: str) -> None:
+        with self.lock:
+            self.pgconn.send_query_params(
+                b"select set_config('client_encoding', $1, false)",
+                [value.encode("ascii")],
+            )
+            gen = self._exec_gen(self.pgconn)
+            (result,) = self.wait(gen)
+            if result.status != pq.ExecStatus.TUPLES_OK:
+                raise e.error_from_result(result)
+
 
 class AsyncConnection(BaseConnection):
     """
@@ -261,3 +277,25 @@ class AsyncConnection(BaseConnection):
     @classmethod
     async def wait(cls, gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV:
         return await wait_async(gen)
+
+    @property
+    def encoding(self) -> str:
+        return self.pgconn.parameter_status(b"client_encoding").decode("ascii")
+
+    @encoding.setter
+    def encoding(self, value: str) -> None:
+        raise e.NotSupportedError(
+            "you can't set 'encoding' on an async connection."
+            " Use 'await conn.set_encoding()' instead"
+        )
+
+    async def set_encoding(self, value: str) -> None:
+        async with self.lock:
+            self.pgconn.send_query_params(
+                b"select set_config('client_encoding', $1, false)",
+                [value.encode("ascii")],
+            )
+            gen = self._exec_gen(self.pgconn)
+            (result,) = await self.wait(gen)
+            if result.status != pq.ExecStatus.TUPLES_OK:
+                raise e.error_from_result(result)
index 4b7a96bf880d896bc397ba5b4fee84ed9face270..1e0e3b7e701bd021b8e7d50bfc95d1d19f07537c 100644 (file)
@@ -7,7 +7,7 @@ psycopg3 cursor objects
 from typing import Any, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING
 
 from . import errors as e
-from .pq import error_message, DiagnosticField, ExecStatus, PGresult, Format
+from .pq import ExecStatus, PGresult, Format
 from .utils.queries import query2pg, reorder_params
 from .utils.typing import Query, Params
 
@@ -93,10 +93,7 @@ class BaseCursor:
             return
 
         if results[-1].status == ExecStatus.FATAL_ERROR:
-            ecls = e.class_for_state(
-                results[-1].error_field(DiagnosticField.SQLSTATE)
-            )
-            raise ecls(error_message(results[-1]))
+            raise e.error_from_result(results[-1])
 
         elif badstats & {
             ExecStatus.COPY_IN,
index c8ce84438264042a5e93d4b585084c01d4f32779..dd0207a54f40f508217ff85043de0f6ac692ea2b 100644 (file)
@@ -18,7 +18,7 @@ DBAPI-defined Exceptions are defined in the following hierarchy::
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Optional, Sequence, TYPE_CHECKING
+from typing import Any, Optional, Sequence, Type, TYPE_CHECKING
 
 if TYPE_CHECKING:
     from psycopg3.pq import PGresult  # noqa
@@ -107,6 +107,13 @@ class NotSupportedError(DatabaseError):
     """
 
 
-def class_for_state(sqlstate: bytes) -> type:
+def class_for_state(sqlstate: bytes) -> Type[Error]:
     # TODO: stub
     return DatabaseError
+
+
+def error_from_result(result: "PGresult") -> Error:
+    from psycopg3 import pq
+
+    cls = class_for_state(result.error_field(pq.DiagnosticField.SQLSTATE))
+    return cls(pq.error_message(result))
index ddf39ed36b23e2da3fd935623658c37350b3c38a..df1116566070b1c1817650dc5a562f9eab230604 100644 (file)
@@ -36,3 +36,33 @@ def test_rollback(loop, pq, aconn):
     assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE
     res = aconn.pgconn.exec_(b"select id from foo where id = 1")
     assert res.get_value(0, 0) is None
+
+
+def test_get_encoding(aconn, loop):
+    cur = aconn.cursor()
+    loop.run_until_complete(cur.execute("show client_encoding"))
+    (enc,) = cur.fetchone()
+    assert enc == aconn.encoding
+
+
+def test_set_encoding_noprop(aconn):
+    newenc = "LATIN1" if aconn.encoding != "LATIN1" else "UTF8"
+    assert aconn.encoding != newenc
+    with pytest.raises(psycopg3.NotSupportedError):
+        aconn.encoding = newenc
+
+
+def test_set_encoding(aconn, loop):
+    newenc = "LATIN1" if aconn.encoding != "LATIN1" else "UTF8"
+    assert aconn.encoding != newenc
+    loop.run_until_complete(aconn.set_encoding(newenc))
+    assert aconn.encoding == newenc
+    cur = aconn.cursor()
+    loop.run_until_complete(cur.execute("show client_encoding"))
+    (enc,) = cur.fetchone()
+    assert enc == newenc
+
+
+def test_set_encoding_bad(aconn, loop):
+    with pytest.raises(psycopg3.DatabaseError):
+        loop.run_until_complete(aconn.set_encoding("WAT"))
index c009c16fc5e9e4ba7679273f6de014b143c555df..c1c4ca6c92f20c48ddbd7a5c72ed10cf05d0b025 100644 (file)
@@ -36,3 +36,22 @@ def test_rollback(pq, conn):
     assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE
     res = conn.pgconn.exec_(b"select id from foo where id = 1")
     assert res.get_value(0, 0) is None
+
+
+def test_get_encoding(conn):
+    (enc,) = conn.cursor().execute("show client_encoding").fetchone()
+    assert enc == conn.encoding
+
+
+def test_set_encoding(conn):
+    newenc = "LATIN1" if conn.encoding != "LATIN1" else "UTF8"
+    assert conn.encoding != newenc
+    conn.encoding = newenc
+    assert conn.encoding == newenc
+    (enc,) = conn.cursor().execute("show client_encoding").fetchone()
+    assert enc == newenc
+
+
+def test_set_encoding_bad(conn):
+    with pytest.raises(psycopg3.DatabaseError):
+        conn.encoding = "WAT"