]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Split the internal state of prepared statements into names and counts
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 17 Nov 2021 13:15:20 +0000 (14:15 +0100)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Mon, 29 Nov 2021 08:50:50 +0000 (09:50 +0100)
psycopg/psycopg/_preparing.py
tests/test_prepared.py
tests/test_prepared_async.py

index 5ae79ae0cbdcbe937a461a576f97b768d27b0afc..b5097c166f52c39ab6b3d8c5b7f035de5e056494 100644 (file)
@@ -5,7 +5,7 @@ Support for prepared statements
 # Copyright (C) 2020-2021 The Psycopg Team
 
 from enum import IntEnum, auto
-from typing import Optional, Sequence, Tuple, TYPE_CHECKING, Union
+from typing import Optional, Sequence, Tuple, TYPE_CHECKING
 from collections import OrderedDict
 
 from .pq import ExecStatus
@@ -22,7 +22,6 @@ class Prepare(IntEnum):
 
 
 Key = Tuple[bytes, Tuple[int, ...]]
-Value = Union[int, bytes]
 
 
 class PrepareManager:
@@ -33,13 +32,11 @@ class PrepareManager:
     prepared_max: int = 100
 
     def __init__(self) -> None:
-        # Number of times each query was seen in order to prepare it.
-        # Map (query, types) -> name or number of times seen
-        #
-        # Note: with this implementation we keep the tally of up to 100
-        # queries, but most likely we will prepare way less than that. We might
-        # change that if we think it would be better.
-        self._prepared: OrderedDict[Key, Value] = OrderedDict()
+        # Map (query, types) to the number of times the query was seen.
+        self._counts: OrderedDict[Key, int] = OrderedDict()
+
+        # Map (query, types) to the name of the statement if  prepared.
+        self._names: OrderedDict[Key, bytes] = OrderedDict()
 
         # Counter to generate prepared statements names
         self._prepared_idx = 0
@@ -59,12 +56,13 @@ class PrepareManager:
             return Prepare.NO, b""
 
         key = self.key(query)
-        value: Union[bytes, int] = self._prepared.get(key, 0)
-        if isinstance(value, bytes):
+        name = self._names.get(key)
+        if name:
             # The query was already prepared in this session
-            return Prepare.YES, value
+            return Prepare.YES, name
 
-        if value >= self.prepare_threshold or prepare:
+        count = self._counts.get(key, 0)
+        if count >= self.prepare_threshold or prepare:
             # The query has been executed enough times and needs to be prepared
             name = f"_pg3_{self._prepared_idx}".encode()
             self._prepared_idx += 1
@@ -80,7 +78,7 @@ class PrepareManager:
         rollback or on dropping objects, because the same object may get
         recreated and postgres would fail internal lookups.
         """
-        if self._prepared or prep == Prepare.SHOULD:
+        if self._names or prep == Prepare.SHOULD:
             for result in results:
                 if result.status != ExecStatus.COMMAND_OK:
                     continue
@@ -111,12 +109,12 @@ class PrepareManager:
         If it was prepared, deallocate it. Do it only once: if the cache was
         resized, deallocate gradually.
         """
-        if len(self._prepared) <= self.prepared_max:
-            return None
+        if len(self._counts) > self.prepared_max:
+            self._counts.popitem(last=False)
 
-        old_val = self._prepared.popitem(last=False)[1]
-        if isinstance(old_val, bytes):
-            return b"DEALLOCATE " + old_val
+        if len(self._names) > self.prepared_max:
+            name = self._names.popitem(last=False)[1]
+            return b"DEALLOCATE " + name
         else:
             return None
 
@@ -135,17 +133,24 @@ class PrepareManager:
             return None
 
         key = self.key(query)
-        if key in self._prepared:
-            if isinstance(self._prepared[key], int):
-                if prep is Prepare.SHOULD:
-                    self._prepared[key] = name
-                else:
-                    self._prepared[key] += 1  # type: ignore[operator]
-            self._prepared.move_to_end(key)
+        if key in self._counts:
+            if prep is Prepare.SHOULD:
+                del self._counts[key]
+                self._names[key] = name
+            else:
+                self._counts[key] += 1
+                self._counts.move_to_end(key)
+            return None
+
+        elif key in self._names:
+            self._names.move_to_end(key)
             return None
+
         else:
-            value: Value = name if prep is Prepare.SHOULD else 1
-            self._prepared[key] = value
+            if prep is Prepare.SHOULD:
+                self._names[key] = name
+            else:
+                self._counts[key] = 1
             return key
 
     def validate(
@@ -164,15 +169,17 @@ class PrepareManager:
         cmd = self._should_discard(prep, results)
         if cmd:
             return cmd
+
         if not self._check_results(results):
-            del self._prepared[key]
+            self._names.pop(key, None)
+            self._counts.pop(key, None)
             return None
+
         return self._rotate()
 
     def clear(self) -> Optional[bytes]:
-        if self._prepared_idx:
-            self._prepared.clear()
-            self._prepared_idx = 0
+        if self._names:
+            self._names.clear()
             return b"DEALLOCATE ALL"
         else:
             return None
index ef7b7ee4eb8e26d270fc762ae0937d53cf4dfed3..fec835547efce664e59ab0fb56b4a33bad5f5db6 100644 (file)
@@ -77,7 +77,8 @@ def test_prepare_disable(conn):
         res.append(cur.fetchone()[0])
 
     assert res == [0] * 10
-    assert not conn._prepared._prepared
+    assert not conn._prepared._names
+    assert not conn._prepared._counts
 
 
 def test_no_prepare_multi(conn):
@@ -140,10 +141,10 @@ def test_evict_lru(conn):
         conn.execute("select 'a'")
         conn.execute(f"select {i}")
 
-    assert len(conn._prepared._prepared) == 5
-    assert conn._prepared._prepared[b"select 'a'", ()] == b"_pg3_0"
+    assert len(conn._prepared._names) == 1
+    assert conn._prepared._names[b"select 'a'", ()] == b"_pg3_0"
     for i in [9, 8, 7, 6]:
-        assert conn._prepared._prepared[f"select {i}".encode(), ()] == 1
+        assert conn._prepared._counts[f"select {i}".encode(), ()] == 1
 
     cur = conn.execute("select statement from pg_prepared_statements")
     assert cur.fetchall() == [("select 'a'",)]
@@ -156,9 +157,9 @@ def test_evict_lru_deallocate(conn):
         conn.execute("select 'a'")
         conn.execute(f"select {i}")
 
-    assert len(conn._prepared._prepared) == 5
+    assert len(conn._prepared._names) == 5
     for j in [9, 8, 7, 6, "'a'"]:
-        name = conn._prepared._prepared[f"select {j}".encode(), ()]
+        name = conn._prepared._names[f"select {j}".encode(), ()]
         assert name.startswith(b"_pg3_")
 
     cur = conn.execute(
index b3477bc5265a1a238ab5c9ec4e8a9d5848f72ce3..d6217b42c42d7595d4a8925b8cb6caee094d0faf 100644 (file)
@@ -83,7 +83,8 @@ async def test_prepare_disable(aconn):
         res.append((await cur.fetchone())[0])
 
     assert res == [0] * 10
-    assert not aconn._prepared._prepared
+    assert not aconn._prepared._names
+    assert not aconn._prepared._counts
 
 
 async def test_no_prepare_multi(aconn):
@@ -148,10 +149,10 @@ async def test_evict_lru(aconn):
         await aconn.execute("select 'a'")
         await aconn.execute(f"select {i}")
 
-    assert len(aconn._prepared._prepared) == 5
-    assert aconn._prepared._prepared[b"select 'a'", ()] == b"_pg3_0"
+    assert len(aconn._prepared._names) == 1
+    assert aconn._prepared._names[b"select 'a'", ()] == b"_pg3_0"
     for i in [9, 8, 7, 6]:
-        assert aconn._prepared._prepared[f"select {i}".encode(), ()] == 1
+        assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1
 
     cur = await aconn.execute("select statement from pg_prepared_statements")
     assert await cur.fetchall() == [("select 'a'",)]
@@ -164,11 +165,10 @@ async def test_evict_lru_deallocate(aconn):
         await aconn.execute("select 'a'")
         await aconn.execute(f"select {i}")
 
-    assert len(aconn._prepared._prepared) == 5
+    assert len(aconn._prepared._names) == 5
     for j in [9, 8, 7, 6, "'a'"]:
-        assert aconn._prepared._prepared[
-            f"select {j}".encode(), ()
-        ].startswith(b"_pg3_")
+        name = aconn._prepared._names[f"select {j}".encode(), ()]
+        assert name.startswith(b"_pg3_")
 
     cur = await aconn.execute(
         "select statement from pg_prepared_statements order by prepare_time",