]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
perf: micro optimise attribute access to ExecStatus
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 17 May 2022 19:03:22 +0000 (21:03 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 17 May 2022 19:28:29 +0000 (21:28 +0200)
psycopg/psycopg/_pipeline.py
psycopg/psycopg/_preparing.py
psycopg/psycopg/connection.py
psycopg/psycopg/copy.py
psycopg/psycopg/cursor.py
psycopg/psycopg/generators.py
psycopg/psycopg/server_cursor.py

index c524a10b09bc7d6f07bccc824853eee53b6fcf68..69c2a21d3aa09772f2083d02f7287ff4ef2090dd 100644 (file)
@@ -10,7 +10,7 @@ from typing import Any, List, Optional, Union, Tuple, Type, TYPE_CHECKING
 
 from . import pq
 from . import errors as e
-from .pq import ConnStatus, ExecStatus
+from .pq import ConnStatus
 from .abc import PipelineCommand, PQGen
 from ._compat import Deque, TypeAlias
 from ._cmodule import _psycopg
@@ -39,6 +39,9 @@ PendingResult: TypeAlias = Union[
     None, Tuple["BaseCursor[Any, Any]", Optional[Tuple[Key, Prepare, bytes]]]
 ]
 
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
+
 logger = logging.getLogger("psycopg")
 
 
@@ -150,9 +153,9 @@ class BasePipeline:
         """
         if queued is None:
             (result,) = results
-            if result.status == ExecStatus.FATAL_ERROR:
+            if result.status == FATAL_ERROR:
                 raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
-            elif result.status == ExecStatus.PIPELINE_ABORTED:
+            elif result.status == PIPELINE_ABORTED:
                 raise e.PipelineAborted("pipeline aborted")
         else:
             cursor, prepinfo = queued
index 5cd35d7aa88d9283a68cc355acb952f3f5c4288c..622aeff41ad584dc61fb70343beda1f9f566295f 100644 (file)
@@ -8,7 +8,7 @@ from enum import IntEnum, auto
 from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
 from collections import OrderedDict
 
-from .pq import ExecStatus
+from . import pq
 from ._compat import Deque, TypeAlias
 from ._queries import PostgresQuery
 
@@ -17,6 +17,9 @@ if TYPE_CHECKING:
 
 Key: TypeAlias = Tuple[bytes, Tuple[int, ...]]
 
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+
 
 class Prepare(IntEnum):
     NO = auto()
@@ -80,7 +83,7 @@ class PrepareManager:
         """
         if self._names or prep == Prepare.SHOULD:
             for result in results:
-                if result.status != ExecStatus.COMMAND_OK:
+                if result.status != COMMAND_OK:
                     continue
                 cmdstat = result.command_status
                 if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"):
@@ -95,7 +98,7 @@ class PrepareManager:
             return False
 
         status = results[0].status
-        if ExecStatus.COMMAND_OK != status != ExecStatus.TUPLES_OK:
+        if COMMAND_OK != status != TUPLES_OK:
             # We don't prepare failed queries or other weird results
             return False
 
index d7e23f8195210860f007ed2d33352954c3c68ad8..7abe06737202cb7d68b7c443c31bc463c226e51c 100644 (file)
@@ -19,7 +19,7 @@ from . import pq
 from . import errors as e
 from . import waiting
 from . import postgres
-from .pq import ConnStatus, ExecStatus, TransactionStatus
+from .pq import ConnStatus, TransactionStatus
 from .abc import AdaptContext, ConnectionType, Params, Query, RV
 from .abc import PQGen, PQGenConn
 from .sql import Composable, SQL
@@ -52,8 +52,6 @@ else:
     connect = generators.connect
     execute = generators.execute
 
-logger = logging.getLogger("psycopg")
-
 # Row Type variable for Cursor (when it needs to be distinguished from the
 # connection's one)
 CursorRow = TypeVar("CursorRow")
@@ -61,6 +59,12 @@ CursorRow = TypeVar("CursorRow")
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+
+logger = logging.getLogger("psycopg")
+
 
 class Notify(NamedTuple):
     """An asynchronous notification received from the database."""
@@ -461,12 +465,12 @@ class BaseConnection(Generic[Row]):
             self.pgconn.send_query_params(command, None, result_format=result_format)
 
         result = (yield from execute(self.pgconn))[-1]
-        if result.status not in (ExecStatus.COMMAND_OK, ExecStatus.TUPLES_OK):
-            if result.status == ExecStatus.FATAL_ERROR:
+        if result.status != COMMAND_OK and result.status != TUPLES_OK:
+            if result.status == FATAL_ERROR:
                 raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
             else:
                 raise e.InterfaceError(
-                    f"unexpected result {ExecStatus(result.status).name}"
+                    f"unexpected result {pq.ExecStatus(result.status).name}"
                     f" from command {command.decode()!r}"
                 )
         return result
index ade337ac688528898014f25c8b616c08192bd961..bd654a753b5d491c5166a5b871761e91e316c196 100644 (file)
@@ -16,7 +16,6 @@ from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
 
 from . import pq
 from . import errors as e
-from .pq import ExecStatus
 from .abc import Buffer, ConnectionType, PQGen, Transformer
 from .adapt import PyFormat
 from ._compat import create_task
@@ -31,11 +30,14 @@ if TYPE_CHECKING:
     from .connection import Connection  # noqa: F401
     from .connection_async import AsyncConnection  # noqa: F401
 
-TEXT = pq.Format.TEXT
-BINARY = pq.Format.BINARY
 PY_TEXT = PyFormat.TEXT
 PY_BINARY = PyFormat.BINARY
 
+TEXT = pq.Format.TEXT
+BINARY = pq.Format.BINARY
+
+COPY_IN = pq.ExecStatus.COPY_IN
+
 
 class BaseCopy(Generic[ConnectionType]):
     """
@@ -119,7 +121,7 @@ class BaseCopy(Generic[ConnectionType]):
         registry = self.cursor.adapters.types
         oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
 
-        if self._pgresult.status == ExecStatus.COPY_IN:
+        if self._pgresult.status == COPY_IN:
             self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
         else:
             self.formatter.transformer.set_loader_types(oids, self.formatter.format)
@@ -276,7 +278,7 @@ class Copy(BaseCopy["Connection[Any]"]):
         by exit. It is available if, despite what is documented, you end up
         using the `Copy` object outside a block.
         """
-        if self._pgresult.status == ExecStatus.COPY_IN:
+        if self._pgresult.status == COPY_IN:
             self._write_end()
             self.connection.wait(self._end_copy_in_gen(exc))
         else:
@@ -391,7 +393,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
         await self._write(data)
 
     async def finish(self, exc: Optional[BaseException]) -> None:
-        if self._pgresult.status == ExecStatus.COPY_IN:
+        if self._pgresult.status == COPY_IN:
             await self._write_end()
             await self.connection.wait(self._end_copy_in_gen(exc))
         else:
index 3f9fe2f8be0d1bce027aa11aff93bc0e62d5a1f4..2032c44d0f17db4338f43e0be5c6a66171d63efa 100644 (file)
@@ -14,7 +14,6 @@ from contextlib import contextmanager
 from . import pq
 from . import adapt
 from . import errors as e
-from .pq import ExecStatus
 from .abc import ConnectionType, Query, Params, PQGen
 from .copy import Copy
 from .rows import Row, RowMaker, RowFactory
@@ -47,6 +46,16 @@ _C = TypeVar("_C", bound="Cursor[Any]")
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
+EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+TUPLES_OK = pq.ExecStatus.TUPLES_OK
+COPY_OUT = pq.ExecStatus.COPY_OUT
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_BOTH = pq.ExecStatus.COPY_BOTH
+FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
+SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
+PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
+
 
 class BaseCursor(Generic[ConnectionType, Row]):
     __slots__ = """
@@ -123,9 +132,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         # the query said we got tuples (mostly to handle the super useful
         # query "SELECT ;"
         if res and (
-            res.nfields
-            or res.status == ExecStatus.TUPLES_OK
-            or res.status == ExecStatus.SINGLE_TUPLE
+            res.nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE
         ):
             return [Column(self, i) for i in range(res.nfields)]
         else:
@@ -305,7 +312,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
                 self._send_prepare(name, pgq)
                 if not self._conn._pipeline:
                     (result,) = yield from execute(self._pgconn)
-                    if result.status == ExecStatus.FATAL_ERROR:
+                    if result.status == FATAL_ERROR:
                         raise e.error_from_result(result, encoding=self._encoding)
             # Then execute it.
             self._send_query_prepared(name, pgq, binary=binary)
@@ -355,19 +362,19 @@ class BaseCursor(Generic[ConnectionType, Row]):
         if res is None:
             return None
 
-        elif res.status == ExecStatus.SINGLE_TUPLE:
+        status = res.status
+        if status == SINGLE_TUPLE:
             self.pgresult = res
             self._tx.set_pgresult(res, set_loaders=first)
             if first:
                 self._make_row = self._make_row_maker()
             return res
 
-        elif res.status in (ExecStatus.TUPLES_OK, ExecStatus.COMMAND_OK):
+        elif status == TUPLES_OK or status == COMMAND_OK:
             # End of single row results
-            status = res.status
             while res:
                 res = yield from fetch(self._pgconn)
-            if status != ExecStatus.TUPLES_OK:
+            if status != TUPLES_OK:
                 raise e.ProgrammingError(
                     "the operation in stream() didn't produce a result"
                 )
@@ -478,17 +485,6 @@ class BaseCursor(Generic[ConnectionType, Row]):
         pgq.convert(query, params)
         return pgq
 
-    _status_ok = (
-        ExecStatus.TUPLES_OK,
-        ExecStatus.COMMAND_OK,
-        ExecStatus.EMPTY_QUERY,
-    )
-    _status_copy = (
-        ExecStatus.COPY_IN,
-        ExecStatus.COPY_OUT,
-        ExecStatus.COPY_BOTH,
-    )
-
     def _check_results(self, results: List["PGresult"]) -> None:
         """
         Verify that the results of a query are valid.
@@ -500,24 +496,27 @@ class BaseCursor(Generic[ConnectionType, Row]):
             raise e.InternalError("got no result from the query")
 
         for res in results:
-            if res.status not in self._status_ok:
+            status = res.status
+            if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY:
                 self._raise_for_result(res)
 
     def _raise_for_result(self, result: "PGresult") -> NoReturn:
         """
         Raise an appropriate error message for an unexpected database result
         """
-        if result.status == ExecStatus.FATAL_ERROR:
+        status = result.status
+        if status == FATAL_ERROR:
             raise e.error_from_result(result, encoding=self._encoding)
-        elif result.status == ExecStatus.PIPELINE_ABORTED:
+        elif status == PIPELINE_ABORTED:
             raise e.PipelineAborted("pipeline aborted")
-        elif result.status in self._status_copy:
+        elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
             raise e.ProgrammingError(
                 "COPY cannot be used with this method; use copy() insead"
             )
         else:
             raise e.InternalError(
-                f"unexpected result status from query: {ExecStatus(result.status).name}"
+                "unexpected result status from query:"
+                f" {pq.ExecStatus(status).name}"
             )
 
     def _set_current_result(self, i: int, format: Optional[pq.Format] = None) -> None:
@@ -607,11 +606,15 @@ class BaseCursor(Generic[ConnectionType, Row]):
         res = self.pgresult
         if not res:
             raise e.ProgrammingError("no result available")
-        elif res.status == ExecStatus.FATAL_ERROR:
+
+        status = res.status
+        if status == TUPLES_OK:
+            return
+        elif status == FATAL_ERROR:
             raise e.error_from_result(res, encoding=pgconn_encoding(self._pgconn))
-        elif res.status == ExecStatus.PIPELINE_ABORTED:
+        elif status == PIPELINE_ABORTED:
             raise e.PipelineAborted("pipeline aborted")
-        elif res.status != ExecStatus.TUPLES_OK:
+        else:
             raise e.ProgrammingError("the last operation didn't produce a result")
 
     def _check_copy_result(self, result: "PGresult") -> None:
@@ -619,14 +622,14 @@ class BaseCursor(Generic[ConnectionType, Row]):
         Check that the value returned in a copy() operation is a legit COPY.
         """
         status = result.status
-        if status in (ExecStatus.COPY_IN, ExecStatus.COPY_OUT):
+        if status == COPY_IN or status == COPY_OUT:
             return
-        elif status == ExecStatus.FATAL_ERROR:
+        elif status == FATAL_ERROR:
             raise e.error_from_result(result, encoding=self._encoding)
         else:
             raise e.ProgrammingError(
                 "copy() should be used only with COPY ... TO STDOUT or COPY ..."
-                f" FROM STDIN statements, got {ExecStatus(status).name}"
+                f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
             )
 
     def _scroll(self, value: int, mode: str) -> None:
index 6c27e5cb82342ef64f6f37e2c28f3106f90a1f27..c788cadefcf7b1cbe8e82bbd8ffdd01b42953f9d 100644 (file)
@@ -20,13 +20,19 @@ from typing import List, Optional, Union
 
 from . import pq
 from . import errors as e
-from .pq import ConnStatus, PollingStatus, ExecStatus
+from .pq import ConnStatus, PollingStatus
 from .abc import PipelineCommand, PQGen, PQGenConn
 from .pq.abc import PGconn, PGresult
 from .waiting import Wait, Ready
 from ._compat import Deque
 from ._encodings import pgconn_encoding, conninfo_encoding
 
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+COPY_OUT = pq.ExecStatus.COPY_OUT
+COPY_IN = pq.ExecStatus.COPY_IN
+COPY_BOTH = pq.ExecStatus.COPY_BOTH
+PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC
+
 logger = logging.getLogger(__name__)
 
 
@@ -120,12 +126,13 @@ def fetch_many(pgconn: PGconn) -> PQGen[List[PGresult]]:
             break
 
         results.append(res)
-        if res.status in _copy_statuses:
+        status = res.status
+        if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
             # After entering copy mode the libpq will create a phony result
             # for every request so let's break the endless loop.
             break
 
-        if res.status == pq.ExecStatus.PIPELINE_SYNC:
+        if status == PIPELINE_SYNC:
             # PIPELINE_SYNC is not followed by a NULL, but we return it alone
             # similarly to other result sets.
             assert len(results) == 1, results
@@ -192,7 +199,7 @@ def pipeline_communicate(
                         break
                     results.append(res)
                     res = []
-                elif r.status == pq.ExecStatus.PIPELINE_SYNC:
+                elif r.status == PIPELINE_SYNC:
                     assert not res
                     results.append([r])
                 else:
@@ -207,13 +214,6 @@ def pipeline_communicate(
     return results
 
 
-_copy_statuses = (
-    ExecStatus.COPY_IN,
-    ExecStatus.COPY_OUT,
-    ExecStatus.COPY_BOTH,
-)
-
-
 def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
     yield Wait.R
     pgconn.consume_input()
@@ -249,7 +249,7 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
         # TODO: too brutal? Copy worked.
         raise e.ProgrammingError("you cannot mix COPY with other operations")
     result = results[0]
-    if result.status != ExecStatus.COMMAND_OK:
+    if result.status != COMMAND_OK:
         encoding = pgconn_encoding(pgconn)
         raise e.error_from_result(result, encoding=encoding)
 
@@ -281,7 +281,7 @@ def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
 
     # Retrieve the final result of copy
     (result,) = yield from fetch_many(pgconn)
-    if result.status != ExecStatus.COMMAND_OK:
+    if result.status != COMMAND_OK:
         encoding = pgconn_encoding(pgconn)
         raise e.error_from_result(result, encoding=encoding)
 
index 3d443f8c3d2c445b6948bf8d669b900eeb13cf4d..6db3a49a2b361a47ebef6250a80c41800c96a0c8 100644 (file)
@@ -25,6 +25,8 @@ DEFAULT_ITERSIZE = 100
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
+
 
 class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
     """Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
@@ -91,7 +93,7 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
         pgq = self._convert_query(query, params)
         self._execute_send(pgq, no_pqexec=True)
         results = yield from execute(self._conn.pgconn)
-        if results[-1].status != pq.ExecStatus.COMMAND_OK:
+        if results[-1].status != COMMAND_OK:
             self._raise_for_result(results[-1])
 
         # Set the format, which will be used by describe and fetch operations