]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Cache the result of the query mangling
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 16 May 2020 14:39:39 +0000 (02:39 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 17 May 2020 09:29:34 +0000 (21:29 +1200)
Code refactored so that only the PostgresQuery object is exposed
and mangling doesn't use the variables as input so it can be cached.

psycopg3/utils/queries.py
tests/test_query.py

index 8179fcf3b98f62d66d1b207c5e7819b3fa01b806..3b0550cf35093883a53a46f045d0a414cc4c4d52 100644 (file)
@@ -5,7 +5,7 @@ Utility module to manipulate queries
 # Copyright (C) 2020 The Psycopg Team
 
 import re
-from codecs import CodecInfo
+from functools import lru_cache
 from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional
 from typing import Sequence, Tuple, Union, TYPE_CHECKING
 
@@ -17,11 +17,19 @@ if TYPE_CHECKING:
     from ..proto import Transformer
 
 
+class QueryPart(NamedTuple):
+    pre: bytes
+    item: Union[int, str]
+    format: Format
+
+
 class PostgresQuery:
     """
     Helper to convert a Python query and parameters into Postgres format.
     """
 
+    _parts: List[QueryPart]
+
     def __init__(self, transformer: "Transformer"):
         self._tx = transformer
         self.query: bytes = b""
@@ -39,13 +47,13 @@ class PostgresQuery:
         attributes (`query`, `params`, `types`, `formats`).
         """
         codec = self._tx.codec
-        if isinstance(query, str):
-            query = codec.encode(query)[0]
         if vars is not None:
-            self.query, self.formats, self._order = query2pg(
-                query, vars, codec
+            self.query, self.formats, self._order, self._parts = _query2pg(
+                query, codec.name
             )
         else:
+            if isinstance(query, str):
+                query = codec.encode(query)[0]
             self.query = query
             self.formats = self._order = None
 
@@ -57,21 +65,21 @@ class PostgresQuery:
 
         This method updates `params` and `types`.
         """
-        if vars:
-            if self._order is not None:
-                assert isinstance(vars, Mapping)
-                vars = reorder_params(vars, self._order)
-            assert isinstance(vars, Sequence)
+        if vars is not None:
+            params = _validate_and_reorder_params(
+                self._parts, vars, self._order
+            )
             self.params, self.types = self._tx.dump_sequence(
-                vars, self.formats or ()
+                params, self.formats or ()
             )
         else:
             self.params = self.types = None
 
 
-def query2pg(
-    query: bytes, vars: Params, codec: CodecInfo
-) -> Tuple[bytes, List[Format], Optional[List[str]]]:
+@lru_cache()
+def _query2pg(
+    query: Query, encoding: str
+) -> Tuple[bytes, List[Format], Optional[List[str]], List[QueryPart]]:
     """
     Convert Python query and params into something Postgres understands.
 
@@ -79,9 +87,11 @@ def query2pg(
       format (``$1``, ``$2``)
     - placeholders can be %s or %b (text or binary)
     - return ``query`` (bytes), ``formats`` (list of formats) ``order``
-      (sequence of names used in the query, in the position they appear, in
-
+      (sequence of names used in the query, in the position they appear)
+      ``parts`` (splits of queries and placeholders).
     """
+    if isinstance(query, str):
+        query = query.encode(encoding)
     if not isinstance(query, bytes):
         # encoding from str already happened
         raise TypeError(
@@ -89,33 +99,19 @@ def query2pg(
             f" got {type(query).__name__} instead"
         )
 
-    parts = split_query(query, codec.name)
+    parts = _split_query(query, encoding)
     order: Optional[List[str]] = None
     chunks: List[bytes] = []
     formats = []
 
-    if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
-        if len(vars) != len(parts) - 1:
-            raise e.ProgrammingError(
-                f"the query has {len(parts) - 1} placeholders but"
-                f" {len(vars)} parameters were passed"
-            )
-        if vars and not isinstance(parts[0].item, int):
-            raise TypeError(
-                "named placeholders require a mapping of parameters"
-            )
-
+    if isinstance(parts[0].item, int):
         for part in parts[:-1]:
             assert isinstance(part.item, int)
             chunks.append(part.pre)
             chunks.append(b"$%d" % (part.item + 1))
             formats.append(part.format)
 
-    elif isinstance(vars, Mapping):
-        if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
-            raise TypeError(
-                "positional placeholders (%s) require a sequence of parameters"
-            )
+    elif isinstance(parts[0].item, str):
         seen: Dict[str, Tuple[bytes, Format]] = {}
         order = []
         for part in parts[:-1]:
@@ -135,17 +131,49 @@ def query2pg(
                     )
                 chunks.append(seen[part.item][0])
 
+    # last part
+    chunks.append(parts[-1].pre)
+
+    return b"".join(chunks), formats, order, parts
+
+
+def _validate_and_reorder_params(
+    parts: List[QueryPart], vars: Params, order: Optional[List[str]]
+) -> Sequence[Any]:
+    """
+    Verify the compatibility between a query and a set of params.
+    """
+    if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
+        if len(vars) != len(parts) - 1:
+            raise e.ProgrammingError(
+                f"the query has {len(parts) - 1} placeholders but"
+                f" {len(vars)} parameters were passed"
+            )
+        if vars and not isinstance(parts[0].item, int):
+            raise TypeError(
+                "named placeholders require a mapping of parameters"
+            )
+        return vars
+
+    elif isinstance(vars, Mapping):
+        if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
+            raise TypeError(
+                "positional placeholders (%s) require a sequence of parameters"
+            )
+        try:
+            return [vars[item] for item in order or ()]
+        except KeyError:
+            raise e.ProgrammingError(
+                f"query parameter missing:"
+                f" {', '.join(sorted(i for i in order or () if i not in vars))}"
+            )
+
     else:
         raise TypeError(
             f"query parameters should be a sequence or a mapping,"
             f" got {type(vars).__name__}"
         )
 
-    # last part
-    chunks.append(parts[-1].pre)
-
-    return b"".join(chunks), formats, order
-
 
 _re_placeholder = re.compile(
     rb"""(?x)
@@ -162,13 +190,7 @@ _re_placeholder = re.compile(
 )
 
 
-class QueryPart(NamedTuple):
-    pre: bytes
-    item: Union[int, str]
-    format: Format
-
-
-def split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]:
+def _split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]:
     parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
     cur = 0
 
@@ -239,18 +261,3 @@ def split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]:
         i += 1
 
     return rv
-
-
-def reorder_params(
-    params: Mapping[str, Any], order: Sequence[str]
-) -> List[str]:
-    """
-    Convert a mapping of parameters into an array in a specified order
-    """
-    try:
-        return [params[item] for item in order]
-    except KeyError:
-        raise e.ProgrammingError(
-            f"query parameter missing:"
-            f" {', '.join(sorted(i for i in order if i not in params))}"
-        )
index 7fa3e15915d89c82620c91ecf6d51670276aa8f7..ad7e6941b1f684eb4dc7c6256a0907717e2a2d41 100644 (file)
@@ -1,8 +1,8 @@
-import codecs
 import pytest
 
 import psycopg3
-from psycopg3.utils.queries import split_query, query2pg, reorder_params
+from psycopg3.adapt import Transformer
+from psycopg3.utils.queries import PostgresQuery, _split_query
 
 
 @pytest.mark.parametrize(
@@ -28,7 +28,7 @@ from psycopg3.utils.queries import split_query, query2pg, reorder_params
     ],
 )
 def test_split_query(input, want):
-    assert split_query(input) == want
+    assert _split_query(input) == want
 
 
 @pytest.mark.parametrize(
@@ -46,28 +46,30 @@ def test_split_query(input, want):
 )
 def test_split_query_bad(input):
     with pytest.raises(psycopg3.ProgrammingError):
-        split_query(input)
+        _split_query(input)
 
 
 @pytest.mark.parametrize(
-    "query, params, want, wformats",
+    "query, params, want, wformats, wparams",
     [
-        (b"", [], b"", []),
-        (b"%%", [], b"%", []),
-        (b"select %s", (1,), b"select $1", [False]),
-        (b"%s %% %s", (1, 2), b"$1 % $2", [False, False]),
-        (b"%b %% %s", (1, 2), b"$1 % $2", [True, False]),
+        (b"", None, b"", None, None),
+        (b"", [], b"", [], []),
+        (b"%%", [], b"%", [], []),
+        (b"select %s", (1,), b"select $1", [False], [b"1"]),
+        (b"%s %% %s", (1, 2), b"$1 % $2", [False, False], [b"1", b"2"]),
+        (b"%b %% %s", ("a", 2), b"$1 % $2", [True, False], [b"a", b"2"]),
     ],
 )
-def test_query2pg_seq(query, params, want, wformats):
-    out, formats, order = query2pg(query, params, codecs.lookup("utf-8"))
-    assert order is None
-    assert out == want
-    assert formats == wformats
+def test_pg_query_seq(query, params, want, wformats, wparams):
+    pq = PostgresQuery(Transformer())
+    pq.convert(query, params)
+    assert pq.query == want
+    assert pq.formats == wformats
+    assert pq.params == wparams
 
 
 @pytest.mark.parametrize(
-    "query, params, want, wformats, worder",
+    "query, params, want, wformats, wparams",
     [
         (b"", {}, b"", [], []),
         (b"hello %%", {"a": 1}, b"hello %", [], []),
@@ -76,22 +78,23 @@ def test_query2pg_seq(query, params, want, wformats):
             {"hello": 1, "world": 2},
             b"select $1",
             [False],
-            ["hello"],
+            [b"1"],
         ),
         (
             b"select %(hi)s %(there)b %(hi)s",
-            {"hi": 1, "there": 2},
+            {"hi": 1, "there": "a"},
             b"select $1 $2 $1",
             [False, True],
-            ["hi", "there"],
+            [b"1", b"a"],
         ),
     ],
 )
-def test_query2pg_map(query, params, want, wformats, worder):
-    out, formats, order = query2pg(query, params, codecs.lookup("utf-8"))
-    assert out == want
-    assert formats == wformats
-    assert order == worder
+def test_pg_query_map(query, params, want, wformats, wparams):
+    pq = PostgresQuery(Transformer())
+    pq.convert(query, params)
+    assert pq.query == want
+    assert pq.formats == wformats
+    assert pq.params == wparams
 
 
 @pytest.mark.parametrize(
@@ -103,13 +106,12 @@ def test_query2pg_map(query, params, want, wformats, worder):
         (b"select %s", 1),
         (b"select %s", b"a"),
         (b"select %s", set()),
-        ("select", []),
-        ("select", []),
     ],
 )
-def test_query2pg_badtype(query, params):
+def test_pq_query_badtype(query, params):
+    pq = PostgresQuery(Transformer())
     with pytest.raises(TypeError):
-        query2pg(query, params, codecs.lookup("utf-8"))
+        pq.convert(query, params)
 
 
 @pytest.mark.parametrize(
@@ -126,19 +128,7 @@ def test_query2pg_badtype(query, params):
         (b"select %(hi)s %(hi)b", {"hi": 1}),
     ],
 )
-def test_query2pg_badprog(query, params):
+def test_pq_query_badprog(query, params):
+    pq = PostgresQuery(Transformer())
     with pytest.raises(psycopg3.ProgrammingError):
-        query2pg(query, params, codecs.lookup("utf-8"))
-
-
-@pytest.mark.parametrize(
-    "params, order, want",
-    [
-        ({"foo": 1, "bar": 2}, [], []),
-        ({"foo": 1, "bar": 2}, ["foo"], [1]),
-        ({"foo": 1, "bar": 2}, ["bar", "foo"], [2, 1]),
-    ],
-)
-def test_reorder_params(params, order, want):
-    rv = reorder_params(params, order)
-    assert rv == want
+        pq.convert(query, params)