]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(prepare): make maintenance operation a PQGen generator
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Apr 2024 16:36:16 +0000 (16:36 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 8 Apr 2024 20:23:27 +0000 (20:23 +0000)
This will allow us to replace the SQL commands with protocol messages.

psycopg/psycopg/_connection_base.py
psycopg/psycopg/_cursor_base.py
psycopg/psycopg/_preparing.py
psycopg/psycopg/transaction.py

index 12ae7395ff370f518fcd490c8b8d29341e0998c2..c5a6c2c1eca35b096ea4c7e838efff73a169a751 100644 (file)
@@ -552,8 +552,7 @@ class BaseConnection(Generic[Row]):
 
         yield from self._exec_command(b"ROLLBACK")
         self._prepared.clear()
-        for cmd in self._prepared.get_maintenance_commands():
-            yield from self._exec_command(cmd)
+        yield from self._prepared.maintain_gen(self)
 
         if self._pipeline:
             yield from self._pipeline._sync_gen()
index 1f53a88218ccd2022d964d5c1783084da128150e..24a2cb1a286a77c05218b74a672a79c06e786662 100644 (file)
@@ -195,9 +195,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             yield from self._conn._pipeline._communicate_gen()
 
         self._last_query = query
-
-        for cmd in self._conn._prepared.get_maintenance_commands():
-            yield from self._conn._exec_command(cmd)
+        yield from self._conn._prepared.maintain_gen(self._conn)
 
     def _executemany_gen_pipeline(
         self, query: Query, params_seq: Iterable[Params], returning: bool
@@ -232,8 +230,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         if returning:
             yield from pipeline._fetch_gen(flush=True)
 
-        for cmd in self._conn._prepared.get_maintenance_commands():
-            yield from self._conn._exec_command(cmd)
+        yield from self._conn._prepared.maintain_gen(self._conn)
 
     def _executemany_gen_no_pipeline(
         self, query: Query, params_seq: Iterable[Params], returning: bool
@@ -260,9 +257,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             yield from self._maybe_prepare_gen(pgq, prepare=True)
 
         self._last_query = query
-
-        for cmd in self._conn._prepared.get_maintenance_commands():
-            yield from self._conn._exec_command(cmd)
+        yield from self._conn._prepared.maintain_gen(self._conn)
 
     def _maybe_prepare_gen(
         self,
index 465de53a4899a2a4c7f13b814722e811d080272c..dd969f88d737d7a6a776a6503e4a0ec6c146f58e 100644 (file)
@@ -5,15 +5,18 @@ Support for prepared statements
 # Copyright (C) 2020 The Psycopg Team
 
 from enum import IntEnum, auto
-from typing import Iterator, Optional, Sequence, Tuple, TYPE_CHECKING
+from typing import Optional, Sequence, Tuple, TYPE_CHECKING
 from collections import OrderedDict
 
 from . import pq
+from .abc import PQGen
 from ._compat import Deque, TypeAlias
 from ._queries import PostgresQuery
 
 if TYPE_CHECKING:
+    from typing import Any
     from .pq.abc import PGresult
+    from ._connection_base import BaseConnection
 
 Key: TypeAlias = Tuple[bytes, Tuple[int, ...]]
 
@@ -185,9 +188,13 @@ class PrepareManager:
         else:
             return False
 
-    def get_maintenance_commands(self) -> Iterator[bytes]:
+    def maintain_gen(self, conn: "BaseConnection[Any]") -> PQGen[None]:
         """
-        Iterate over the commands needed to align the server state to our state
+        Generator to send the commands to perform periodic maintenance
+
+        Deallocate unneeded command in the server, or flush the prepared
+        statements server state entirely if necessary.
         """
         while self._maint_commands:
-            yield self._maint_commands.popleft()
+            cmd = self._maint_commands.popleft()
+            yield from conn._exec_command(cmd)
index c6405aa438ea25e6aae9965240ae9713ea59be57..56b7104459468369f0f2603b73b83299e2be3889 100644 (file)
@@ -145,6 +145,10 @@ class BaseTransaction(Generic[ConnectionType]):
         for command in self._get_rollback_commands():
             yield from self._conn._exec_command(command)
 
+        # Also clear the prepared statements cache.
+        self._conn._prepared.clear()
+        yield from self._conn._prepared.maintain_gen(self._conn)
+
         if isinstance(exc_val, Rollback):
             if not exc_val.transaction or exc_val.transaction is self:
                 return True  # Swallow the exception
@@ -191,10 +195,6 @@ class BaseTransaction(Generic[ConnectionType]):
             assert not self._conn._num_transactions
             yield b"ROLLBACK"
 
-        # Also clear the prepared statements cache.
-        if self._conn._prepared.clear():
-            yield from self._conn._prepared.get_maintenance_commands()
-
     def _push_savepoint(self) -> None:
         """
         Push the transaction on the connection transactions stack.