From: Daniele Varrazzo Date: Thu, 2 Apr 2020 12:26:57 +0000 (+1300) Subject: Added connection.encoding X-Git-Tag: 3.0.dev0~626 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=ebcefde3345b6e829463c8e52c5670e777f247aa;p=thirdparty%2Fpsycopg.git Added connection.encoding --- diff --git a/psycopg3/connection.py b/psycopg3/connection.py index 0de2b9d50..bfe1237ca 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -209,6 +209,22 @@ class Connection(BaseConnection): ) -> RV: return wait(gen, timeout=timeout) + @property + def encoding(self) -> str: + return self.pgconn.parameter_status(b"client_encoding").decode("ascii") + + @encoding.setter + def encoding(self, value: str) -> None: + with self.lock: + self.pgconn.send_query_params( + b"select set_config('client_encoding', $1, false)", + [value.encode("ascii")], + ) + gen = self._exec_gen(self.pgconn) + (result,) = self.wait(gen) + if result.status != pq.ExecStatus.TUPLES_OK: + raise e.error_from_result(result) + class AsyncConnection(BaseConnection): """ @@ -261,3 +277,25 @@ class AsyncConnection(BaseConnection): @classmethod async def wait(cls, gen: Generator[Tuple[int, Wait], Ready, RV]) -> RV: return await wait_async(gen) + + @property + def encoding(self) -> str: + return self.pgconn.parameter_status(b"client_encoding").decode("ascii") + + @encoding.setter + def encoding(self, value: str) -> None: + raise e.NotSupportedError( + "you can't set 'encoding' on an async connection." + " Use 'await conn.set_encoding()' instead" + ) + + async def set_encoding(self, value: str) -> None: + async with self.lock: + self.pgconn.send_query_params( + b"select set_config('client_encoding', $1, false)", + [value.encode("ascii")], + ) + gen = self._exec_gen(self.pgconn) + (result,) = await self.wait(gen) + if result.status != pq.ExecStatus.TUPLES_OK: + raise e.error_from_result(result) diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index 4b7a96bf8..1e0e3b7e7 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -7,7 +7,7 @@ psycopg3 cursor objects from typing import Any, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING from . import errors as e -from .pq import error_message, DiagnosticField, ExecStatus, PGresult, Format +from .pq import ExecStatus, PGresult, Format from .utils.queries import query2pg, reorder_params from .utils.typing import Query, Params @@ -93,10 +93,7 @@ class BaseCursor: return if results[-1].status == ExecStatus.FATAL_ERROR: - ecls = e.class_for_state( - results[-1].error_field(DiagnosticField.SQLSTATE) - ) - raise ecls(error_message(results[-1])) + raise e.error_from_result(results[-1]) elif badstats & { ExecStatus.COPY_IN, diff --git a/psycopg3/errors.py b/psycopg3/errors.py index c8ce84438..dd0207a54 100644 --- a/psycopg3/errors.py +++ b/psycopg3/errors.py @@ -18,7 +18,7 @@ DBAPI-defined Exceptions are defined in the following hierarchy:: # Copyright (C) 2020 The Psycopg Team -from typing import Any, Optional, Sequence, TYPE_CHECKING +from typing import Any, Optional, Sequence, Type, TYPE_CHECKING if TYPE_CHECKING: from psycopg3.pq import PGresult # noqa @@ -107,6 +107,13 @@ class NotSupportedError(DatabaseError): """ -def class_for_state(sqlstate: bytes) -> type: +def class_for_state(sqlstate: bytes) -> Type[Error]: # TODO: stub return DatabaseError + + +def error_from_result(result: "PGresult") -> Error: + from psycopg3 import pq + + cls = class_for_state(result.error_field(pq.DiagnosticField.SQLSTATE)) + return cls(pq.error_message(result)) diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index ddf39ed36..df1116566 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -36,3 +36,33 @@ def test_rollback(loop, pq, aconn): assert aconn.pgconn.transaction_status == pq.TransactionStatus.IDLE res = aconn.pgconn.exec_(b"select id from foo where id = 1") assert res.get_value(0, 0) is None + + +def test_get_encoding(aconn, loop): + cur = aconn.cursor() + loop.run_until_complete(cur.execute("show client_encoding")) + (enc,) = cur.fetchone() + assert enc == aconn.encoding + + +def test_set_encoding_noprop(aconn): + newenc = "LATIN1" if aconn.encoding != "LATIN1" else "UTF8" + assert aconn.encoding != newenc + with pytest.raises(psycopg3.NotSupportedError): + aconn.encoding = newenc + + +def test_set_encoding(aconn, loop): + newenc = "LATIN1" if aconn.encoding != "LATIN1" else "UTF8" + assert aconn.encoding != newenc + loop.run_until_complete(aconn.set_encoding(newenc)) + assert aconn.encoding == newenc + cur = aconn.cursor() + loop.run_until_complete(cur.execute("show client_encoding")) + (enc,) = cur.fetchone() + assert enc == newenc + + +def test_set_encoding_bad(aconn, loop): + with pytest.raises(psycopg3.DatabaseError): + loop.run_until_complete(aconn.set_encoding("WAT")) diff --git a/tests/test_connection.py b/tests/test_connection.py index c009c16fc..c1c4ca6c9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -36,3 +36,22 @@ def test_rollback(pq, conn): assert conn.pgconn.transaction_status == pq.TransactionStatus.IDLE res = conn.pgconn.exec_(b"select id from foo where id = 1") assert res.get_value(0, 0) is None + + +def test_get_encoding(conn): + (enc,) = conn.cursor().execute("show client_encoding").fetchone() + assert enc == conn.encoding + + +def test_set_encoding(conn): + newenc = "LATIN1" if conn.encoding != "LATIN1" else "UTF8" + assert conn.encoding != newenc + conn.encoding = newenc + assert conn.encoding == newenc + (enc,) = conn.cursor().execute("show client_encoding").fetchone() + assert enc == newenc + + +def test_set_encoding_bad(conn): + with pytest.raises(psycopg3.DatabaseError): + conn.encoding = "WAT"