]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Cache generated begin statement on the connection
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 23 Jul 2021 15:11:14 +0000 (17:11 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 23 Jul 2021 15:56:14 +0000 (17:56 +0200)
psycopg/psycopg/connection.py

index 0cdd558462da137a9ae544ed01440136236ae763..c1a5d5c75c0ab95da4c0e2501bb2dea250afae56 100644 (file)
@@ -133,6 +133,7 @@ class BaseConnection(Generic[Row]):
         self._isolation_level: Optional[IsolationLevel] = None
         self._read_only: Optional[bool] = None
         self._deferrable: Optional[bool] = None
+        self._begin_statement = b""
 
     def __del__(self) -> None:
         # If fails on connection we might not have this attribute yet
@@ -206,6 +207,7 @@ class BaseConnection(Generic[Row]):
         self._isolation_level = (
             IsolationLevel(value) if value is not None else None
         )
+        self._begin_statement = b""
 
     @property
     def read_only(self) -> Optional[bool]:
@@ -223,6 +225,7 @@ class BaseConnection(Generic[Row]):
         # Subclasses must call it holding a lock
         self._check_intrans("read_only")
         self._read_only = value
+        self._begin_statement = b""
 
     @property
     def deferrable(self) -> Optional[bool]:
@@ -240,6 +243,7 @@ class BaseConnection(Generic[Row]):
         # Subclasses must call it holding a lock
         self._check_intrans("deferrable")
         self._deferrable = value
+        self._begin_statement = b""
 
     def _check_intrans(self, attribute: str) -> None:
         # Raise an exception if we are in a transaction
@@ -472,6 +476,9 @@ class BaseConnection(Generic[Row]):
         yield from self._exec_command(self._get_tx_start_command())
 
     def _get_tx_start_command(self) -> bytes:
+        if self._begin_statement:
+            return self._begin_statement
+
         parts = [b"begin"]
 
         if self.isolation_level is not None:
@@ -487,7 +494,8 @@ class BaseConnection(Generic[Row]):
                 b"deferrable" if self.deferrable else b"not deferrable"
             )
 
-        return b" ".join(parts)
+        self._begin_statement = b" ".join(parts)
+        return self._begin_statement
 
     def _commit_gen(self) -> PQGen[None]:
         """Generator implementing `Connection.commit()`."""