]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Don't use enums for libpq arguments
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 27 Dec 2020 21:26:20 +0000 (22:26 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 29 Dec 2020 17:12:37 +0000 (18:12 +0100)
We might want to have speed there. Enums can be returned from high-level
objects such as the Connection.info object to be implemented.

psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/generators.py
psycopg3/psycopg3/pq/pq_ctypes.py
psycopg3/psycopg3/pq/proto.py
psycopg3_c/psycopg3_c/pq/pgconn.pyx
psycopg3_c/psycopg3_c/pq/pgresult.pyx

index 8e718a4a390419503554a9b9ecc807ea60fe5b6d..a7388149633245a0c0bdd0403b55a93fcfe45951 100644 (file)
@@ -82,7 +82,7 @@ class Transformer(AdaptContext):
         for i in range(nf):
             oid = result.ftype(i)
             fmt = result.fformat(i)
-            rc.append(self.get_loader(oid, fmt).load)
+            rc.append(self.get_loader(oid, fmt).load)  # type: ignore
 
     def set_row_types(
         self, types: Sequence[int], formats: Sequence[Format]
index f3e4b35dabea3615617b9c83c7ef389eb04e4d33..765add5baf895f86fbde66974098bcc70e238055 100644 (file)
@@ -132,7 +132,7 @@ class BaseConnection(AdaptContext):
         if status == TransactionStatus.UNKNOWN:
             return
 
-        status = TransactionStatus(status)  # in case we got an int
+        status = TransactionStatus(status)
         warnings.warn(
             f"connection {self} was deleted while still open."
             f" Please use 'with' or '.close()' to close the connection",
index 76cd25f8cd8bc1f7639c59f2864aa2498e91a810..be5a7efba3776504ac255b132d926e38c0daf9ed 100644 (file)
@@ -32,7 +32,7 @@ class BaseCopy(Generic[ConnectionType]):
         ), "The Transformer doesn't have a PGresult set"
         self._pgresult: "PGresult" = self.transformer.pgresult
 
-        self.format = self._pgresult.binary_tuples
+        self.format = Format(self._pgresult.binary_tuples)
         self._encoding = self.connection.client_encoding
         self._first_row = True
         self._finished = False
index bfb1a0086e0f77f3eb0f5d5a05704f17f38153d0..bbc67af19e3c5589cf1b4e0727980434f6487504 100644 (file)
@@ -95,7 +95,7 @@ class BaseCursor(Generic[ConnectionType]):
     def status(self) -> Optional[pq.ExecStatus]:
         # TODO: do we want this?
         res = self.pgresult
-        return res.status if res else None
+        return pq.ExecStatus(res.status) if res else None
 
     @property
     def query(self) -> Optional[bytes]:
@@ -289,16 +289,16 @@ class BaseCursor(Generic[ConnectionType]):
         pgq.convert(query, params)
         return pgq
 
-    _status_ok = {
+    _status_ok = (
         ExecStatus.TUPLES_OK,
         ExecStatus.COMMAND_OK,
         ExecStatus.EMPTY_QUERY,
-    }
-    _status_copy = {
+    )
+    _status_copy = (
         ExecStatus.COPY_IN,
         ExecStatus.COPY_OUT,
         ExecStatus.COPY_BOTH,
-    }
+    )
 
     def _execute_results(self, results: Sequence["PGresult"]) -> None:
         """
@@ -309,32 +309,36 @@ class BaseCursor(Generic[ConnectionType]):
         if not results:
             raise e.InternalError("got no result from the query")
 
-        statuses = {res.status for res in results}
-        badstats = statuses - self._status_ok
-        if not badstats:
-            self._results = list(results)
-            self.pgresult = results[0]
-            nrows = self.pgresult.command_tuples
-            if nrows is not None:
-                if self._rowcount < 0:
-                    self._rowcount = nrows
-                else:
-                    self._rowcount += nrows
+        for res in results:
+            if res.status not in self._status_ok:
+                return self._raise_from_results(results)
 
-            return
+        self._results = list(results)
+        self.pgresult = results[0]
+        nrows = self.pgresult.command_tuples
+        if nrows is not None:
+            if self._rowcount < 0:
+                self._rowcount = nrows
+            else:
+                self._rowcount += nrows
+
+        return
 
+    def _raise_from_results(self, results: Sequence["PGresult"]) -> None:
+        statuses = {res.status for res in results}
+        badstats = statuses.difference(self._status_ok)
         if results[-1].status == ExecStatus.FATAL_ERROR:
             raise e.error_from_result(
                 results[-1], encoding=self._conn.client_encoding
             )
-        elif badstats & self._status_copy:
+        elif statuses.intersection(self._status_copy):
             raise e.ProgrammingError(
                 "COPY cannot be used with execute(); use copy() insead"
             )
         else:
             raise e.InternalError(
                 f"got unexpected status from query:"
-                f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
+                f" {', '.join(sorted(ExecStatus(s).name for s in badstats))}"
             )
 
     def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
index cbf5d6dd65e706c3491320e33699c7ab5b1d303f..e442c2df1719b1062da2f2bdb08b9f7fdd8ad2b9 100644 (file)
@@ -35,7 +35,6 @@ def connect(conninfo: str) -> PQGenConn[PGconn]:
 
     """
     conn = pq.PGconn.connect_start(conninfo.encode("utf8"))
-    logger.debug("connection started, status %s", conn.status.name)
     while 1:
         if conn.status == ConnStatus.BAD:
             raise e.OperationalError(
@@ -43,7 +42,6 @@ def connect(conninfo: str) -> PQGenConn[PGconn]:
             )
 
         status = conn.connect_poll()
-        logger.debug("connection polled, status %s", conn.status.name)
         if status == PollingStatus.OK:
             break
         elif status == PollingStatus.READING:
index 27adb37394fb712f84788cd3da85a7274e8b6140..9c68148263983ef2c3dfdc792b3842f04d3942f1 100644 (file)
@@ -21,8 +21,7 @@ from typing import cast as t_cast, TYPE_CHECKING
 from . import _pq_ctypes as impl
 from .misc import PGnotify, ConninfoOption, PQerror, PGresAttDesc
 from .misc import error_message, connection_summary
-from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format
-from ._enums import Ping, PollingStatus, TransactionStatus
+from ._enums import Format
 
 if TYPE_CHECKING:
     from . import proto
@@ -113,9 +112,8 @@ class PGconn:
             raise MemoryError("couldn't allocate PGconn")
         return cls(pgconn_ptr)
 
-    def connect_poll(self) -> PollingStatus:
-        rv = self._call_int(impl.PQconnectPoll)
-        return PollingStatus(rv)
+    def connect_poll(self) -> int:
+        return self._call_int(impl.PQconnectPoll)
 
     def finish(self) -> None:
         self.pgconn_ptr, p = None, self.pgconn_ptr
@@ -141,17 +139,15 @@ class PGconn:
         if not impl.PQresetStart(self.pgconn_ptr):
             raise PQerror("couldn't reset connection")
 
-    def reset_poll(self) -> PollingStatus:
-        rv = self._call_int(impl.PQresetPoll)
-        return PollingStatus(rv)
+    def reset_poll(self) -> int:
+        return self._call_int(impl.PQresetPoll)
 
     @classmethod
-    def ping(self, conninfo: bytes) -> Ping:
+    def ping(self, conninfo: bytes) -> int:
         if not isinstance(conninfo, bytes):
             raise TypeError(f"bytes expected, got {type(conninfo)} instead")
 
-        rv = impl.PQping(conninfo)
-        return Ping(rv)
+        return impl.PQping(conninfo)
 
     @property
     def db(self) -> bytes:
@@ -186,14 +182,12 @@ class PGconn:
         return self._call_bytes(impl.PQoptions)
 
     @property
-    def status(self) -> ConnStatus:
-        rv = impl.PQstatus(self.pgconn_ptr)
-        return ConnStatus(rv)
+    def status(self) -> int:
+        return impl.PQstatus(self.pgconn_ptr)
 
     @property
-    def transaction_status(self) -> TransactionStatus:
-        rv = impl.PQtransactionStatus(self.pgconn_ptr)
-        return TransactionStatus(rv)
+    def transaction_status(self) -> int:
+        return impl.PQtransactionStatus(self.pgconn_ptr)
 
     def parameter_status(self, name: bytes) -> Optional[bytes]:
         self._ensure_pgconn()
@@ -252,8 +246,8 @@ class PGconn:
         command: bytes,
         param_values: Optional[Sequence[Optional[bytes]]],
         param_types: Optional[Sequence[int]] = None,
-        param_formats: Optional[Sequence[Format]] = None,
-        result_format: Format = Format.TEXT,
+        param_formats: Optional[Sequence[int]] = None,
+        result_format: int = Format.TEXT,
     ) -> "PGresult":
         args = self._query_params_args(
             command, param_values, param_types, param_formats, result_format
@@ -269,8 +263,8 @@ class PGconn:
         command: bytes,
         param_values: Optional[Sequence[Optional[bytes]]],
         param_types: Optional[Sequence[int]] = None,
-        param_formats: Optional[Sequence[Format]] = None,
-        result_format: Format = Format.TEXT,
+        param_formats: Optional[Sequence[int]] = None,
+        result_format: int = Format.TEXT,
     ) -> None:
         args = self._query_params_args(
             command, param_values, param_types, param_formats, result_format
@@ -307,8 +301,8 @@ class PGconn:
         self,
         name: bytes,
         param_values: Optional[Sequence[Optional[bytes]]],
-        param_formats: Optional[Sequence[Format]] = None,
-        result_format: Format = Format.TEXT,
+        param_formats: Optional[Sequence[int]] = None,
+        result_format: int = Format.TEXT,
     ) -> None:
         # repurpose this function with a cheeky replacement of query with name,
         # drop the param_types from the result
@@ -328,8 +322,8 @@ class PGconn:
         command: bytes,
         param_values: Optional[Sequence[Optional[bytes]]],
         param_types: Optional[Sequence[int]] = None,
-        param_formats: Optional[Sequence[Format]] = None,
-        result_format: Format = Format.TEXT,
+        param_formats: Optional[Sequence[int]] = None,
+        result_format: int = Format.TEXT,
     ) -> Any:
         if not isinstance(command, bytes):
             raise TypeError(f"bytes expected, got {type(command)} instead")
@@ -549,7 +543,7 @@ class PGconn:
         else:
             return nbytes, b""
 
-    def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
+    def make_empty_result(self, exec_status: int) -> "PGresult":
         rv = impl.PQmakeEmptyPGresult(self.pgconn_ptr, exec_status)
         if not rv:
             raise MemoryError("couldn't allocate empty PGresult")
@@ -608,15 +602,14 @@ class PGresult:
             impl.PQclear(p)
 
     @property
-    def status(self) -> ExecStatus:
-        rv = impl.PQresultStatus(self.pgresult_ptr)
-        return ExecStatus(rv)
+    def status(self) -> int:
+        return impl.PQresultStatus(self.pgresult_ptr)
 
     @property
     def error_message(self) -> bytes:
         return impl.PQresultErrorMessage(self.pgresult_ptr)
 
-    def error_field(self, fieldcode: DiagnosticField) -> Optional[bytes]:
+    def error_field(self, fieldcode: int) -> Optional[bytes]:
         return impl.PQresultErrorField(self.pgresult_ptr, fieldcode)
 
     @property
@@ -636,8 +629,8 @@ class PGresult:
     def ftablecol(self, column_number: int) -> int:
         return impl.PQftablecol(self.pgresult_ptr, column_number)
 
-    def fformat(self, column_number: int) -> Format:
-        return Format(impl.PQfformat(self.pgresult_ptr, column_number))
+    def fformat(self, column_number: int) -> int:
+        return impl.PQfformat(self.pgresult_ptr, column_number)
 
     def ftype(self, column_number: int) -> int:
         return impl.PQftype(self.pgresult_ptr, column_number)
@@ -649,8 +642,8 @@ class PGresult:
         return impl.PQfsize(self.pgresult_ptr, column_number)
 
     @property
-    def binary_tuples(self) -> Format:
-        return Format(impl.PQbinaryTuples(self.pgresult_ptr))
+    def binary_tuples(self) -> int:
+        return impl.PQbinaryTuples(self.pgresult_ptr)
 
     def get_value(
         self, row_number: int, column_number: int
index 4405af201e501e94bc2878a72b1997fed85ebc73..f0afeb4767aea102f4bfe62040940265219890b7 100644 (file)
@@ -8,8 +8,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
 from typing import TYPE_CHECKING
 from typing_extensions import Protocol
 
-from ._enums import ConnStatus, DiagnosticField, ExecStatus, Format
-from ._enums import Ping, PollingStatus, TransactionStatus
+from ._enums import Format
 
 if TYPE_CHECKING:
     from .misc import PGnotify, ConninfoOption, PGresAttDesc
@@ -31,7 +30,7 @@ class PGconn(Protocol):
     def connect_start(cls, conninfo: bytes) -> "PGconn":
         ...
 
-    def connect_poll(self) -> PollingStatus:
+    def connect_poll(self) -> int:
         ...
 
     def finish(self) -> None:
@@ -47,11 +46,11 @@ class PGconn(Protocol):
     def reset_start(self) -> None:
         ...
 
-    def reset_poll(self) -> PollingStatus:
+    def reset_poll(self) -> int:
         ...
 
     @classmethod
-    def ping(self, conninfo: bytes) -> Ping:
+    def ping(self, conninfo: bytes) -> int:
         ...
 
     @property
@@ -87,11 +86,11 @@ class PGconn(Protocol):
         ...
 
     @property
-    def status(self) -> ConnStatus:
+    def status(self) -> int:
         ...
 
     @property
-    def transaction_status(self) -> TransactionStatus:
+    def transaction_status(self) -> int:
         ...
 
     def parameter_status(self, name: bytes) -> Optional[bytes]:
@@ -140,8 +139,8 @@ class PGconn(Protocol):
         command: bytes,
         param_values: Optional[Sequence[Optional[bytes]]],
         param_types: Optional[Sequence[int]] = None,
-        param_formats: Optional[Sequence[Format]] = None,
-        result_format: Format = Format.TEXT,
+        param_formats: Optional[Sequence[int]] = None,
+        result_format: int = Format.TEXT,
     ) -> "PGresult":
         ...
 
@@ -150,8 +149,8 @@ class PGconn(Protocol):
         command: bytes,
         param_values: Optional[Sequence[Optional[bytes]]],
         param_types: Optional[Sequence[int]] = None,
-        param_formats: Optional[Sequence[Format]] = None,
-        result_format: Format = Format.TEXT,
+        param_formats: Optional[Sequence[int]] = None,
+        result_format: int = Format.TEXT,
     ) -> None:
         ...
 
@@ -167,8 +166,8 @@ class PGconn(Protocol):
         self,
         name: bytes,
         param_values: Optional[Sequence[Optional[bytes]]],
-        param_formats: Optional[Sequence[Format]] = None,
-        result_format: Format = Format.TEXT,
+        param_formats: Optional[Sequence[int]] = None,
+        result_format: int = Format.TEXT,
     ) -> None:
         ...
 
@@ -230,7 +229,7 @@ class PGconn(Protocol):
     def get_copy_data(self, async_: int) -> Tuple[int, bytes]:
         ...
 
-    def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
+    def make_empty_result(self, exec_status: int) -> "PGresult":
         ...
 
 
@@ -239,14 +238,14 @@ class PGresult(Protocol):
         ...
 
     @property
-    def status(self) -> ExecStatus:
+    def status(self) -> int:
         ...
 
     @property
     def error_message(self) -> bytes:
         ...
 
-    def error_field(self, fieldcode: DiagnosticField) -> Optional[bytes]:
+    def error_field(self, fieldcode: int) -> Optional[bytes]:
         ...
 
     @property
@@ -266,7 +265,7 @@ class PGresult(Protocol):
     def ftablecol(self, column_number: int) -> int:
         ...
 
-    def fformat(self, column_number: int) -> Format:
+    def fformat(self, column_number: int) -> int:
         ...
 
     def ftype(self, column_number: int) -> int:
@@ -279,7 +278,7 @@ class PGresult(Protocol):
         ...
 
     @property
-    def binary_tuples(self) -> Format:
+    def binary_tuples(self) -> int:
         ...
 
     def get_value(
index f07ac79e0cfd0fe7f428d6eb3e9f819fc768c7e2..e337071dad0714a0a9326c6d742b16fab87cfe91 100644 (file)
@@ -55,9 +55,8 @@ cdef class PGconn:
 
         return PGconn._from_ptr(pgconn)
 
-    def connect_poll(self) -> PollingStatus:
-        cdef int rv = _call_int(self, <conn_int_f>libpq.PQconnectPoll)
-        return PollingStatus(rv)
+    def connect_poll(self) -> int:
+        return _call_int(self, <conn_int_f>libpq.PQconnectPoll)
 
     def finish(self) -> None:
         if self.pgconn_ptr is not NULL:
@@ -89,14 +88,12 @@ cdef class PGconn:
         if not libpq.PQresetStart(self.pgconn_ptr):
             raise PQerror("couldn't reset connection")
 
-    def reset_poll(self) -> PollingStatus:
-        cdef int rv = _call_int(self, <conn_int_f>libpq.PQresetPoll)
-        return PollingStatus(rv)
+    def reset_poll(self) -> int:
+        return _call_int(self, <conn_int_f>libpq.PQresetPoll)
 
     @classmethod
-    def ping(self, const char *conninfo) -> Ping:
-        cdef int rv = libpq.PQping(conninfo)
-        return Ping(rv)
+    def ping(self, const char *conninfo) -> int:
+        return libpq.PQping(conninfo)
 
     @property
     def db(self) -> bytes:
@@ -136,9 +133,8 @@ cdef class PGconn:
         return ConnStatus(rv)
 
     @property
-    def transaction_status(self) -> TransactionStatus:
-        cdef int rv = libpq.PQtransactionStatus(self.pgconn_ptr)
-        return TransactionStatus(rv)
+    def transaction_status(self) -> int:
+        return libpq.PQtransactionStatus(self.pgconn_ptr)
 
     def parameter_status(self, name: bytes) -> Optional[bytes]:
         _ensure_pgconn(self)
@@ -205,8 +201,8 @@ cdef class PGconn:
         command: bytes,
         param_values: Optional[Sequence[Optional[bytes]]],
         param_types: Optional[Sequence[int]] = None,
-        param_formats: Optional[Sequence[Format]] = None,
-        result_format: Format = Format.TEXT,
+        param_formats: Optional[Sequence[int]] = None,
+        result_format: int = Format.TEXT,
     ) -> PGresult:
         _ensure_pgconn(self)
 
@@ -235,8 +231,8 @@ cdef class PGconn:
         command: bytes,
         param_values: Optional[Sequence[Optional[bytes]]],
         param_types: Optional[Sequence[int]] = None,
-        param_formats: Optional[Sequence[Format]] = None,
-        result_format: Format = Format.TEXT,
+        param_formats: Optional[Sequence[int]] = None,
+        result_format: int = Format.TEXT,
     ) -> None:
         _ensure_pgconn(self)
 
@@ -294,8 +290,8 @@ cdef class PGconn:
         self,
         name: bytes,
         param_values: Optional[Sequence[Optional[bytes]]],
-        param_formats: Optional[Sequence[Format]] = None,
-        result_format: Format = Format.TEXT,
+        param_formats: Optional[Sequence[int]] = None,
+        result_format: int = Format.TEXT,
     ) -> None:
         _ensure_pgconn(self)
 
@@ -473,7 +469,7 @@ cdef class PGconn:
         else:
             return nbytes, b""
 
-    def make_empty_result(self, exec_status: ExecStatus) -> PGresult:
+    def make_empty_result(self, exec_status: int) -> PGresult:
         cdef libpq.PGresult *rv = libpq.PQmakeEmptyPGresult(
             self.pgconn_ptr, exec_status)
         if not rv:
@@ -525,7 +521,7 @@ cdef void notice_receiver(void *arg, const libpq.PGresult *res_ptr) with gil:
 cdef (int, libpq.Oid *, char * const*, int *, int *) _query_params_args(
     list param_values: Optional[Sequence[Optional[bytes]]],
     param_types: Optional[Sequence[int]],
-    list param_formats: Optional[Sequence[Format]],
+    list param_formats: Optional[Sequence[int]],
 ) except *:
     cdef int i
 
index 7c8caf8b3287e60734acbbcd9de28587b423fe71..07e792e95317c640ffe893ce6924f650dbb15ca0 100644 (file)
@@ -35,15 +35,14 @@ cdef class PGresult:
             return None
 
     @property
-    def status(self) -> ExecStatus:
-        cdef int rv = libpq.PQresultStatus(self.pgresult_ptr)
-        return ExecStatus(rv)
+    def status(self) -> int:
+        return libpq.PQresultStatus(self.pgresult_ptr)
 
     @property
     def error_message(self) -> bytes:
         return libpq.PQresultErrorMessage(self.pgresult_ptr)
 
-    def error_field(self, fieldcode: DiagnosticField) -> Optional[bytes]:
+    def error_field(self, int fieldcode) -> Optional[bytes]:
         cdef char * rv = libpq.PQresultErrorField(self.pgresult_ptr, fieldcode)
         if rv is not NULL:
             return rv