From: Daniele Varrazzo Date: Sun, 27 Dec 2020 21:26:20 +0000 (+0100) Subject: Don't use enums for libpq arguments X-Git-Tag: 3.0.dev0~228^2~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0c90f0a38828a9fbf53bb0e75654042cc9a5cce3;p=thirdparty%2Fpsycopg.git Don't use enums for libpq arguments We might want to have speed there. Enums can be returned from high-level objects such as the Connection.info object to be implemented. --- diff --git a/psycopg3/psycopg3/_transform.py b/psycopg3/psycopg3/_transform.py index 8e718a4a3..a73881496 100644 --- a/psycopg3/psycopg3/_transform.py +++ b/psycopg3/psycopg3/_transform.py @@ -82,7 +82,7 @@ class Transformer(AdaptContext): for i in range(nf): oid = result.ftype(i) fmt = result.fformat(i) - rc.append(self.get_loader(oid, fmt).load) + rc.append(self.get_loader(oid, fmt).load) # type: ignore def set_row_types( self, types: Sequence[int], formats: Sequence[Format] diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index f3e4b35da..765add5ba 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -132,7 +132,7 @@ class BaseConnection(AdaptContext): if status == TransactionStatus.UNKNOWN: return - status = TransactionStatus(status) # in case we got an int + status = TransactionStatus(status) warnings.warn( f"connection {self} was deleted while still open." f" Please use 'with' or '.close()' to close the connection", diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index 76cd25f8c..be5a7efba 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -32,7 +32,7 @@ class BaseCopy(Generic[ConnectionType]): ), "The Transformer doesn't have a PGresult set" self._pgresult: "PGresult" = self.transformer.pgresult - self.format = self._pgresult.binary_tuples + self.format = Format(self._pgresult.binary_tuples) self._encoding = self.connection.client_encoding self._first_row = True self._finished = False diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index bfb1a0086..bbc67af19 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -95,7 +95,7 @@ class BaseCursor(Generic[ConnectionType]): def status(self) -> Optional[pq.ExecStatus]: # TODO: do we want this? res = self.pgresult - return res.status if res else None + return pq.ExecStatus(res.status) if res else None @property def query(self) -> Optional[bytes]: @@ -289,16 +289,16 @@ class BaseCursor(Generic[ConnectionType]): pgq.convert(query, params) return pgq - _status_ok = { + _status_ok = ( ExecStatus.TUPLES_OK, ExecStatus.COMMAND_OK, ExecStatus.EMPTY_QUERY, - } - _status_copy = { + ) + _status_copy = ( ExecStatus.COPY_IN, ExecStatus.COPY_OUT, ExecStatus.COPY_BOTH, - } + ) def _execute_results(self, results: Sequence["PGresult"]) -> None: """ @@ -309,32 +309,36 @@ class BaseCursor(Generic[ConnectionType]): if not results: raise e.InternalError("got no result from the query") - statuses = {res.status for res in results} - badstats = statuses - self._status_ok - if not badstats: - self._results = list(results) - self.pgresult = results[0] - nrows = self.pgresult.command_tuples - if nrows is not None: - if self._rowcount < 0: - self._rowcount = nrows - else: - self._rowcount += nrows + for res in results: + if res.status not in self._status_ok: + return self._raise_from_results(results) - return + self._results = list(results) + self.pgresult = results[0] + nrows = self.pgresult.command_tuples + if nrows is not None: + if self._rowcount < 0: + self._rowcount = nrows + else: + self._rowcount += nrows + + return + def _raise_from_results(self, results: Sequence["PGresult"]) -> None: + statuses = {res.status for res in results} + badstats = statuses.difference(self._status_ok) if results[-1].status == ExecStatus.FATAL_ERROR: raise e.error_from_result( results[-1], encoding=self._conn.client_encoding ) - elif badstats & self._status_copy: + elif statuses.intersection(self._status_copy): raise e.ProgrammingError( "COPY cannot be used with execute(); use copy() insead" ) else: raise e.InternalError( f"got unexpected status from query:" - f" {', '.join(sorted(s.name for s in sorted(badstats)))}" + f" {', '.join(sorted(ExecStatus(s).name for s in badstats))}" ) def _send_prepare(self, name: bytes, query: PostgresQuery) -> None: diff --git a/psycopg3/psycopg3/generators.py b/psycopg3/psycopg3/generators.py index cbf5d6dd6..e442c2df1 100644 --- a/psycopg3/psycopg3/generators.py +++ b/psycopg3/psycopg3/generators.py @@ -35,7 +35,6 @@ def connect(conninfo: str) -> PQGenConn[PGconn]: """ conn = pq.PGconn.connect_start(conninfo.encode("utf8")) - logger.debug("connection started, status %s", conn.status.name) while 1: if conn.status == ConnStatus.BAD: raise e.OperationalError( @@ -43,7 +42,6 @@ def connect(conninfo: str) -> PQGenConn[PGconn]: ) status = conn.connect_poll() - logger.debug("connection polled, status %s", conn.status.name) if status == PollingStatus.OK: break elif status == PollingStatus.READING: diff --git a/psycopg3/psycopg3/pq/pq_ctypes.py b/psycopg3/psycopg3/pq/pq_ctypes.py index 27adb3739..9c6814826 100644 --- a/psycopg3/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/psycopg3/pq/pq_ctypes.py @@ -21,8 +21,7 @@ from typing import cast as t_cast, TYPE_CHECKING from . import _pq_ctypes as impl from .misc import PGnotify, ConninfoOption, PQerror, PGresAttDesc from .misc import error_message, connection_summary -from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format -from ._enums import Ping, PollingStatus, TransactionStatus +from ._enums import Format if TYPE_CHECKING: from . import proto @@ -113,9 +112,8 @@ class PGconn: raise MemoryError("couldn't allocate PGconn") return cls(pgconn_ptr) - def connect_poll(self) -> PollingStatus: - rv = self._call_int(impl.PQconnectPoll) - return PollingStatus(rv) + def connect_poll(self) -> int: + return self._call_int(impl.PQconnectPoll) def finish(self) -> None: self.pgconn_ptr, p = None, self.pgconn_ptr @@ -141,17 +139,15 @@ class PGconn: if not impl.PQresetStart(self.pgconn_ptr): raise PQerror("couldn't reset connection") - def reset_poll(self) -> PollingStatus: - rv = self._call_int(impl.PQresetPoll) - return PollingStatus(rv) + def reset_poll(self) -> int: + return self._call_int(impl.PQresetPoll) @classmethod - def ping(self, conninfo: bytes) -> Ping: + def ping(self, conninfo: bytes) -> int: if not isinstance(conninfo, bytes): raise TypeError(f"bytes expected, got {type(conninfo)} instead") - rv = impl.PQping(conninfo) - return Ping(rv) + return impl.PQping(conninfo) @property def db(self) -> bytes: @@ -186,14 +182,12 @@ class PGconn: return self._call_bytes(impl.PQoptions) @property - def status(self) -> ConnStatus: - rv = impl.PQstatus(self.pgconn_ptr) - return ConnStatus(rv) + def status(self) -> int: + return impl.PQstatus(self.pgconn_ptr) @property - def transaction_status(self) -> TransactionStatus: - rv = impl.PQtransactionStatus(self.pgconn_ptr) - return TransactionStatus(rv) + def transaction_status(self) -> int: + return impl.PQtransactionStatus(self.pgconn_ptr) def parameter_status(self, name: bytes) -> Optional[bytes]: self._ensure_pgconn() @@ -252,8 +246,8 @@ class PGconn: command: bytes, param_values: Optional[Sequence[Optional[bytes]]], param_types: Optional[Sequence[int]] = None, - param_formats: Optional[Sequence[Format]] = None, - result_format: Format = Format.TEXT, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, ) -> "PGresult": args = self._query_params_args( command, param_values, param_types, param_formats, result_format @@ -269,8 +263,8 @@ class PGconn: command: bytes, param_values: Optional[Sequence[Optional[bytes]]], param_types: Optional[Sequence[int]] = None, - param_formats: Optional[Sequence[Format]] = None, - result_format: Format = Format.TEXT, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, ) -> None: args = self._query_params_args( command, param_values, param_types, param_formats, result_format @@ -307,8 +301,8 @@ class PGconn: self, name: bytes, param_values: Optional[Sequence[Optional[bytes]]], - param_formats: Optional[Sequence[Format]] = None, - result_format: Format = Format.TEXT, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, ) -> None: # repurpose this function with a cheeky replacement of query with name, # drop the param_types from the result @@ -328,8 +322,8 @@ class PGconn: command: bytes, param_values: Optional[Sequence[Optional[bytes]]], param_types: Optional[Sequence[int]] = None, - param_formats: Optional[Sequence[Format]] = None, - result_format: Format = Format.TEXT, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, ) -> Any: if not isinstance(command, bytes): raise TypeError(f"bytes expected, got {type(command)} instead") @@ -549,7 +543,7 @@ class PGconn: else: return nbytes, b"" - def make_empty_result(self, exec_status: ExecStatus) -> "PGresult": + def make_empty_result(self, exec_status: int) -> "PGresult": rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status) if not rv: raise MemoryError("couldn't allocate empty PGresult") @@ -608,15 +602,14 @@ class PGresult: impl.PQclear(p) @property - def status(self) -> ExecStatus: - rv = impl.PQresultStatus(self.pgresult_ptr) - return ExecStatus(rv) + def status(self) -> int: + return impl.PQresultStatus(self.pgresult_ptr) @property def error_message(self) -> bytes: return impl.PQresultErrorMessage(self.pgresult_ptr) - def error_field(self, fieldcode: DiagnosticField) -> Optional[bytes]: + def error_field(self, fieldcode: int) -> Optional[bytes]: return impl.PQresultErrorField(self.pgresult_ptr, fieldcode) @property @@ -636,8 +629,8 @@ class PGresult: def ftablecol(self, column_number: int) -> int: return impl.PQftablecol(self.pgresult_ptr, column_number) - def fformat(self, column_number: int) -> Format: - return Format(impl.PQfformat(self.pgresult_ptr, column_number)) + def fformat(self, column_number: int) -> int: + return impl.PQfformat(self.pgresult_ptr, column_number) def ftype(self, column_number: int) -> int: return impl.PQftype(self.pgresult_ptr, column_number) @@ -649,8 +642,8 @@ class PGresult: return impl.PQfsize(self.pgresult_ptr, column_number) @property - def binary_tuples(self) -> Format: - return Format(impl.PQbinaryTuples(self.pgresult_ptr)) + def binary_tuples(self) -> int: + return impl.PQbinaryTuples(self.pgresult_ptr) def get_value( self, row_number: int, column_number: int diff --git a/psycopg3/psycopg3/pq/proto.py b/psycopg3/psycopg3/pq/proto.py index 4405af201..f0afeb476 100644 --- a/psycopg3/psycopg3/pq/proto.py +++ b/psycopg3/psycopg3/pq/proto.py @@ -8,8 +8,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING from typing_extensions import Protocol -from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format -from ._enums import Ping, PollingStatus, TransactionStatus +from ._enums import Format if TYPE_CHECKING: from .misc import PGnotify, ConninfoOption, PGresAttDesc @@ -31,7 +30,7 @@ class PGconn(Protocol): def connect_start(cls, conninfo: bytes) -> "PGconn": ... - def connect_poll(self) -> PollingStatus: + def connect_poll(self) -> int: ... def finish(self) -> None: @@ -47,11 +46,11 @@ class PGconn(Protocol): def reset_start(self) -> None: ... - def reset_poll(self) -> PollingStatus: + def reset_poll(self) -> int: ... @classmethod - def ping(self, conninfo: bytes) -> Ping: + def ping(self, conninfo: bytes) -> int: ... @property @@ -87,11 +86,11 @@ class PGconn(Protocol): ... @property - def status(self) -> ConnStatus: + def status(self) -> int: ... @property - def transaction_status(self) -> TransactionStatus: + def transaction_status(self) -> int: ... def parameter_status(self, name: bytes) -> Optional[bytes]: @@ -140,8 +139,8 @@ class PGconn(Protocol): command: bytes, param_values: Optional[Sequence[Optional[bytes]]], param_types: Optional[Sequence[int]] = None, - param_formats: Optional[Sequence[Format]] = None, - result_format: Format = Format.TEXT, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, ) -> "PGresult": ... @@ -150,8 +149,8 @@ class PGconn(Protocol): command: bytes, param_values: Optional[Sequence[Optional[bytes]]], param_types: Optional[Sequence[int]] = None, - param_formats: Optional[Sequence[Format]] = None, - result_format: Format = Format.TEXT, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, ) -> None: ... @@ -167,8 +166,8 @@ class PGconn(Protocol): self, name: bytes, param_values: Optional[Sequence[Optional[bytes]]], - param_formats: Optional[Sequence[Format]] = None, - result_format: Format = Format.TEXT, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, ) -> None: ... @@ -230,7 +229,7 @@ class PGconn(Protocol): def get_copy_data(self, async_: int) -> Tuple[int, bytes]: ... - def make_empty_result(self, exec_status: ExecStatus) -> "PGresult": + def make_empty_result(self, exec_status: int) -> "PGresult": ... @@ -239,14 +238,14 @@ class PGresult(Protocol): ... @property - def status(self) -> ExecStatus: + def status(self) -> int: ... @property def error_message(self) -> bytes: ... - def error_field(self, fieldcode: DiagnosticField) -> Optional[bytes]: + def error_field(self, fieldcode: int) -> Optional[bytes]: ... @property @@ -266,7 +265,7 @@ class PGresult(Protocol): def ftablecol(self, column_number: int) -> int: ... - def fformat(self, column_number: int) -> Format: + def fformat(self, column_number: int) -> int: ... def ftype(self, column_number: int) -> int: @@ -279,7 +278,7 @@ class PGresult(Protocol): ... @property - def binary_tuples(self) -> Format: + def binary_tuples(self) -> int: ... def get_value( diff --git a/psycopg3_c/psycopg3_c/pq/pgconn.pyx b/psycopg3_c/psycopg3_c/pq/pgconn.pyx index f07ac79e0..e337071da 100644 --- a/psycopg3_c/psycopg3_c/pq/pgconn.pyx +++ b/psycopg3_c/psycopg3_c/pq/pgconn.pyx @@ -55,9 +55,8 @@ cdef class PGconn: return PGconn._from_ptr(pgconn) - def connect_poll(self) -> PollingStatus: - cdef int rv = _call_int(self, libpq.PQconnectPoll) - return PollingStatus(rv) + def connect_poll(self) -> int: + return _call_int(self, libpq.PQconnectPoll) def finish(self) -> None: if self.pgconn_ptr is not NULL: @@ -89,14 +88,12 @@ cdef class PGconn: if not libpq.PQresetStart(self.pgconn_ptr): raise PQerror("couldn't reset connection") - def reset_poll(self) -> PollingStatus: - cdef int rv = _call_int(self, libpq.PQresetPoll) - return PollingStatus(rv) + def reset_poll(self) -> int: + return _call_int(self, libpq.PQresetPoll) @classmethod - def ping(self, const char *conninfo) -> Ping: - cdef int rv = libpq.PQping(conninfo) - return Ping(rv) + def ping(self, const char *conninfo) -> int: + return libpq.PQping(conninfo) @property def db(self) -> bytes: @@ -136,9 +133,8 @@ cdef class PGconn: return ConnStatus(rv) @property - def transaction_status(self) -> TransactionStatus: - cdef int rv = libpq.PQtransactionStatus(self.pgconn_ptr) - return TransactionStatus(rv) + def transaction_status(self) -> int: + return libpq.PQtransactionStatus(self.pgconn_ptr) def parameter_status(self, name: bytes) -> Optional[bytes]: _ensure_pgconn(self) @@ -205,8 +201,8 @@ cdef class PGconn: command: bytes, param_values: Optional[Sequence[Optional[bytes]]], param_types: Optional[Sequence[int]] = None, - param_formats: Optional[Sequence[Format]] = None, - result_format: Format = Format.TEXT, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, ) -> PGresult: _ensure_pgconn(self) @@ -235,8 +231,8 @@ cdef class PGconn: command: bytes, param_values: Optional[Sequence[Optional[bytes]]], param_types: Optional[Sequence[int]] = None, - param_formats: Optional[Sequence[Format]] = None, - result_format: Format = Format.TEXT, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, ) -> None: _ensure_pgconn(self) @@ -294,8 +290,8 @@ cdef class PGconn: self, name: bytes, param_values: Optional[Sequence[Optional[bytes]]], - param_formats: Optional[Sequence[Format]] = None, - result_format: Format = Format.TEXT, + param_formats: Optional[Sequence[int]] = None, + result_format: int = Format.TEXT, ) -> None: _ensure_pgconn(self) @@ -473,7 +469,7 @@ cdef class PGconn: else: return nbytes, b"" - def make_empty_result(self, exec_status: ExecStatus) -> PGresult: + def make_empty_result(self, exec_status: int) -> PGresult: cdef libpq.PGresult *rv = libpq.PQmakeEmptyPGresult( self.pgconn_ptr, exec_status) if not rv: @@ -525,7 +521,7 @@ cdef void notice_receiver(void *arg, const libpq.PGresult *res_ptr) with gil: cdef (int, libpq.Oid *, char * const*, int *, int *) _query_params_args( list param_values: Optional[Sequence[Optional[bytes]]], param_types: Optional[Sequence[int]], - list param_formats: Optional[Sequence[Format]], + list param_formats: Optional[Sequence[int]], ) except *: cdef int i diff --git a/psycopg3_c/psycopg3_c/pq/pgresult.pyx b/psycopg3_c/psycopg3_c/pq/pgresult.pyx index 7c8caf8b3..07e792e95 100644 --- a/psycopg3_c/psycopg3_c/pq/pgresult.pyx +++ b/psycopg3_c/psycopg3_c/pq/pgresult.pyx @@ -35,15 +35,14 @@ cdef class PGresult: return None @property - def status(self) -> ExecStatus: - cdef int rv = libpq.PQresultStatus(self.pgresult_ptr) - return ExecStatus(rv) + def status(self) -> int: + return libpq.PQresultStatus(self.pgresult_ptr) @property def error_message(self) -> bytes: return libpq.PQresultErrorMessage(self.pgresult_ptr) - def error_field(self, fieldcode: DiagnosticField) -> Optional[bytes]: + def error_field(self, int fieldcode) -> Optional[bytes]: cdef char * rv = libpq.PQresultErrorField(self.pgresult_ptr, fieldcode) if rv is not NULL: return rv