]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Remove prepared statement maintenance command out of the normal loop
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 17 Nov 2021 14:26:43 +0000 (15:26 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 29 Nov 2021 10:53:18 +0000 (11:53 +0100)
Accumulate the maintenance commands in the state of the prepared
processor manager and decide when to process them.

psycopg/psycopg/_preparing.py
psycopg/psycopg/connection.py
psycopg/psycopg/cursor.py
psycopg/psycopg/transaction.py

index b5097c166f52c39ab6b3d8c5b7f035de5e056494..783a25e4fdb9bdce7abe130ba152124ca9fb0794 100644 (file)
@@ -5,10 +5,11 @@ Support for prepared statements
 # Copyright (C) 2020-2021 The Psycopg Team
 
 from enum import IntEnum, auto
-from typing import Optional, Sequence, Tuple, TYPE_CHECKING
+from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
 from collections import OrderedDict
 
 from .pq import ExecStatus
+from ._compat import Deque
 from ._queries import PostgresQuery
 
 if TYPE_CHECKING:
@@ -41,6 +42,8 @@ class PrepareManager:
         # Counter to generate prepared statements names
         self._prepared_idx = 0
 
+        self._maint_commands = Deque[bytes]()
+
     @staticmethod
     def key(query: PostgresQuery) -> Key:
         return (query.query, query.types)
@@ -73,7 +76,7 @@ class PrepareManager:
 
     def _should_discard(
         self, prep: Prepare, results: Sequence["PGresult"]
-    ) -> Optional[bytes]:
+    ) -> bool:
         """Check if we need to discard our entire state: it should happen on
         rollback or on dropping objects, because the same object may get
         recreated and postgres would fail internal lookups.
@@ -87,7 +90,7 @@ class PrepareManager:
                     cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"
                 ):
                     return self.clear()
-        return None
+        return False
 
     @staticmethod
     def _check_results(results: Sequence["PGresult"]) -> bool:
@@ -103,7 +106,7 @@ class PrepareManager:
 
         return True
 
-    def _rotate(self) -> Optional[bytes]:
+    def _rotate(self) -> None:
         """Evict an old value from the cache.
 
         If it was prepared, deallocate it. Do it only once: if the cache was
@@ -114,9 +117,7 @@ class PrepareManager:
 
         if len(self._names) > self.prepared_max:
             name = self._names.popitem(last=False)[1]
-            return b"DEALLOCATE " + name
-        else:
-            return None
+            self._maint_commands.append(b"DEALLOCATE " + name)
 
     def maybe_add_to_cache(
         self, query: PostgresQuery, prep: Prepare, name: bytes
@@ -159,27 +160,39 @@ class PrepareManager:
         prep: Prepare,
         name: bytes,
         results: Sequence["PGresult"],
-    ) -> Optional[bytes]:
+    ) -> None:
         """Validate cached entry with 'key' by checking query 'results'.
 
         Possibly return a command to perform maintainance on database side.
 
         Note: this method is only called in pipeline mode.
         """
-        cmd = self._should_discard(prep, results)
-        if cmd:
-            return cmd
+        if self._should_discard(prep, results):
+            return
 
         if not self._check_results(results):
             self._names.pop(key, None)
             self._counts.pop(key, None)
-            return None
+        else:
+            self._rotate()
 
-        return self._rotate()
+    def clear(self) -> bool:
+        """Clear the cache of the maintenance commands.
 
-    def clear(self) -> Optional[bytes]:
+        Clear the internal state and prepare a command to clear the state of
+        the server.
+        """
         if self._names:
             self._names.clear()
-            return b"DEALLOCATE ALL"
+            self._maint_commands.clear()
+            self._maint_commands.append(b"DEALLOCATE ALL")
+            return True
         else:
-            return None
+            return False
+
+    def get_maintenance_commands(self) -> Iterator[bytes]:
+        """
+        Iterate over the commands needed to align the server state to our state
+        """
+        while self._maint_commands:
+            yield self._maint_commands.popleft()
index 39bfb17856bbe91a33f750337e1c73054a41eb38..652252d2ad7aff63ba1c01b85807f994e55073fb 100644 (file)
@@ -509,8 +509,8 @@ class BaseConnection(Generic[Row]):
             return
 
         yield from self._exec_command(b"ROLLBACK")
-        cmd = self._prepared.clear()
-        if cmd:
+        self._prepared.clear()
+        for cmd in self._prepared.get_maintenance_commands():
             yield from self._exec_command(cmd)
 
     def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid:
index f4cafe387d9e50ea88cfdbbce6b39bd3749923c1..93f5d8fa968b233d04c5b2de5a6ab2a5ce449b98 100644 (file)
@@ -207,6 +207,9 @@ class BaseCursor(Generic[ConnectionType, Row]):
         self._execute_results(results)
         self._last_query = query
 
+        for cmd in self._conn._prepared.get_maintenance_commands():
+            yield from self._conn._exec_command(cmd)
+
     def _executemany_gen(
         self, query: Query, params_seq: Iterable[Params]
     ) -> PQGen[None]:
@@ -226,6 +229,9 @@ class BaseCursor(Generic[ConnectionType, Row]):
 
         self._last_query = query
 
+        for cmd in self._conn._prepared.get_maintenance_commands():
+            yield from self._conn._exec_command(cmd)
+
     def _maybe_prepare_gen(
         self,
         pgq: PostgresQuery,
@@ -257,12 +263,11 @@ class BaseCursor(Generic[ConnectionType, Row]):
         results = yield from execute(self._pgconn)
 
         # Update the prepare state of the query.
-        # If an operation requires to flush our prepared statements cache, do it.
+        # If an operation requires to flush our prepared statements cache,
+        # it will be added to the maintenance commands to execute later.
         key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
         if key is not None:
-            cmd = self._conn._prepared.validate(key, prep, name, results)
-            if cmd:
-                yield from self._conn._exec_command(cmd)
+            self._conn._prepared.validate(key, prep, name, results)
 
         return results
 
index ebdd8cbd38890378a780f74f14c04770531cdbd9..4d01e836f68a3ad9ffda55aa1729a4718f168494 100644 (file)
@@ -175,9 +175,9 @@ class BaseTransaction(Generic[ConnectionType]):
             commands.append(b"ROLLBACK")
 
         # Also clear the prepared statements cache.
-        cmd = self._conn._prepared.clear()
-        if cmd:
-            commands.append(cmd)
+        if self._conn._prepared.clear():
+            for cmd in self._conn._prepared.get_maintenance_commands():
+                commands.append(cmd)
 
         yield from self._conn._exec_command(b"; ".join(commands))