]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
perf: micro optimise attribute access to TransactionStatus
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 17 May 2022 19:14:51 +0000 (21:14 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 17 May 2022 19:28:29 +0000 (21:28 +0200)
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/copy.py
psycopg/psycopg/cursor.py
psycopg/psycopg/server_cursor.py
psycopg/psycopg/transaction.py

index 7abe06737202cb7d68b7c443c31bc463c226e51c..b7ce3134a987f53b75798eeb99eb881ddbd47bdb 100644 (file)
@@ -19,7 +19,7 @@ from . import pq
 from . import errors as e
 from . import waiting
 from . import postgres
-from .pq import ConnStatus, TransactionStatus
+from .pq import ConnStatus
 from .abc import AdaptContext, ConnectionType, Params, Query, RV
 from .abc import PQGen, PQGenConn
 from .sql import Composable, SQL
@@ -63,6 +63,10 @@ COMMAND_OK = pq.ExecStatus.COMMAND_OK
 TUPLES_OK = pq.ExecStatus.TUPLES_OK
 FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
 
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
+
 logger = logging.getLogger("psycopg")
 
 
@@ -256,10 +260,10 @@ class BaseConnection(Generic[Row]):
     def _check_intrans_gen(self, attribute: str) -> PQGen[None]:
         # Raise an exception if we are in a transaction
         status = self.pgconn.transaction_status
-        if status == TransactionStatus.IDLE and self._pipeline:
+        if status == IDLE and self._pipeline:
             yield from self._pipeline._sync_gen()
             status = self.pgconn.transaction_status
-        if status != TransactionStatus.IDLE:
+        if status != IDLE:
             if self._num_transactions:
                 raise e.ProgrammingError(
                     f"can't change {attribute!r} now: "
@@ -269,7 +273,7 @@ class BaseConnection(Generic[Row]):
                 raise e.ProgrammingError(
                     f"can't change {attribute!r} now: "
                     "connection in transaction status "
-                    f"{TransactionStatus(status).name}"
+                    f"{pq.TransactionStatus(status).name}"
                 )
 
     @property
@@ -491,7 +495,7 @@ class BaseConnection(Generic[Row]):
         if self._autocommit:
             return
 
-        if self.pgconn.transaction_status != TransactionStatus.IDLE:
+        if self.pgconn.transaction_status != IDLE:
             return
 
         yield from self._exec_command(self._get_tx_start_command())
@@ -528,7 +532,7 @@ class BaseConnection(Generic[Row]):
             raise e.ProgrammingError(
                 "commit() cannot be used during a two-phase transaction"
             )
-        if self.pgconn.transaction_status == TransactionStatus.IDLE:
+        if self.pgconn.transaction_status == IDLE:
             return
 
         yield from self._exec_command(b"COMMIT")
@@ -553,7 +557,7 @@ class BaseConnection(Generic[Row]):
         if self._pipeline and self.pgconn.pipeline_status == pq.PipelineStatus.ABORTED:
             yield from self._pipeline._sync_gen()
 
-        if self.pgconn.transaction_status == TransactionStatus.IDLE:
+        if self.pgconn.transaction_status == IDLE:
             return
 
         yield from self._exec_command(b"ROLLBACK")
@@ -580,10 +584,10 @@ class BaseConnection(Generic[Row]):
         if not isinstance(xid, Xid):
             xid = Xid.from_string(xid)
 
-        if self.pgconn.transaction_status != TransactionStatus.IDLE:
+        if self.pgconn.transaction_status != IDLE:
             raise e.ProgrammingError(
                 "can't start two-phase transaction: connection in status"
-                f" {TransactionStatus(self.pgconn.transaction_status).name}"
+                f" {pq.TransactionStatus(self.pgconn.transaction_status).name}"
             )
 
         if self._autocommit:
@@ -1011,10 +1015,7 @@ class Connection(BaseConnection[Row]):
             cur.execute(Xid._get_recover_query())
             res = cur.fetchall()
 
-        if (
-            status == TransactionStatus.IDLE
-            and self.info.transaction_status == TransactionStatus.INTRANS
-        ):
+        if status == IDLE and self.info.transaction_status == INTRANS:
             self.rollback()
 
         return res
index 5d7a636f5f4883721487a0edd527670afb1bd7a4..38a0cb306b4c2c113d4ccbfa083f2457ab183b29 100644 (file)
@@ -15,7 +15,6 @@ from contextlib import asynccontextmanager
 from . import pq
 from . import errors as e
 from . import waiting
-from .pq import TransactionStatus
 from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
 from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
@@ -36,6 +35,9 @@ if TYPE_CHECKING:
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
 logger = logging.getLogger("psycopg")
 
 
@@ -412,10 +414,7 @@ class AsyncConnection(BaseConnection[Row]):
             await cur.execute(Xid._get_recover_query())
             res = await cur.fetchall()
 
-        if (
-            status == TransactionStatus.IDLE
-            and self.info.transaction_status == TransactionStatus.INTRANS
-        ):
+        if status == IDLE and self.info.transaction_status == INTRANS:
             await self.rollback()
 
         return res
index bd654a753b5d491c5166a5b871761e91e316c196..778a850be9a1ef3d2040f7790ed0caa2de5ef289 100644 (file)
@@ -38,6 +38,8 @@ BINARY = pq.Format.BINARY
 
 COPY_IN = pq.ExecStatus.COPY_IN
 
+ACTIVE = pq.TransactionStatus.ACTIVE
+
 
 class BaseCopy(Generic[ConnectionType]):
     """
@@ -174,7 +176,7 @@ class BaseCopy(Generic[ConnectionType]):
         if not exc:
             return
 
-        if self.connection.pgconn.transaction_status != pq.TransactionStatus.ACTIVE:
+        if self.connection.pgconn.transaction_status != ACTIVE:
             # The server has already finished to send copy data. The connection
             # is already in a good state.
             return
index 2032c44d0f17db4338f43e0be5c6a66171d63efa..d211372dd8ccecaa0179840728074c5a95714b4b 100644 (file)
@@ -515,8 +515,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             )
         else:
             raise e.InternalError(
-                "unexpected result status from query:"
-                f" {pq.ExecStatus(status).name}"
+                "unexpected result status from query:" f" {pq.ExecStatus(status).name}"
             )
 
     def _set_current_result(self, i: int, format: Optional[pq.Format] = None) -> None:
index 6db3a49a2b361a47ebef6250a80c41800c96a0c8..76e1cb19f8c1a12167a256dd164271452ccf887c 100644 (file)
@@ -27,6 +27,9 @@ BINARY = pq.Format.BINARY
 
 COMMAND_OK = pq.ExecStatus.COMMAND_OK
 
+IDLE = pq.TransactionStatus.IDLE
+INTRANS = pq.TransactionStatus.INTRANS
+
 
 class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
     """Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
@@ -117,11 +120,11 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
         ts = self._conn.pgconn.transaction_status
 
         # if the connection is not in a sane state, don't even try
-        if ts not in (pq.TransactionStatus.IDLE, pq.TransactionStatus.INTRANS):
+        if ts != IDLE and ts != INTRANS:
             return
 
         # If we are IDLE, a WITHOUT HOLD cursor will surely have gone already.
-        if not self._withhold and ts == pq.TransactionStatus.IDLE:
+        if not self._withhold and ts == IDLE:
             return
 
         # if we didn't declare the cursor ourselves we still have to close it
index a7bfa2b501a7f08af191569648a44f732ff49e6c..77a576dfb568ae8430e73d8e39af1745f8e6cc11 100644 (file)
@@ -12,7 +12,7 @@ from typing import Generic, Iterator, Optional, Type, Union, TYPE_CHECKING
 from . import pq
 from . import sql
 from . import errors as e
-from .pq import TransactionStatus, ConnStatus
+from .pq import ConnStatus
 from .abc import ConnectionType, PQGen
 
 if TYPE_CHECKING:
@@ -20,6 +20,8 @@ if TYPE_CHECKING:
     from .connection import Connection
     from .connection_async import AsyncConnection
 
+IDLE = pq.TransactionStatus.IDLE
+
 logger = logging.getLogger(__name__)
 
 
@@ -196,9 +198,7 @@ class BaseTransaction(Generic[ConnectionType]):
 
         Also set the internal state of the object and verify consistency.
         """
-        self._outer_transaction = (
-            self.pgconn.transaction_status == TransactionStatus.IDLE
-        )
+        self._outer_transaction = self.pgconn.transaction_status == IDLE
         if self._outer_transaction:
             # outer transaction: if no name it's only a begin, else
             # there will be an additional savepoint