From: Daniele Varrazzo Date: Sat, 16 May 2020 14:39:39 +0000 (+1200) Subject: Cache the result of the query mangling X-Git-Tag: 3.0.dev0~517 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2e819ded17da2974e91643868171061dd40d7bac;p=thirdparty%2Fpsycopg.git Cache the result of the query mangling Code refactored so that only the PostgresQuery object is exposed and mangling doesn't use the variables as input so it can be cached. --- diff --git a/psycopg3/utils/queries.py b/psycopg3/utils/queries.py index 8179fcf3b..3b0550cf3 100644 --- a/psycopg3/utils/queries.py +++ b/psycopg3/utils/queries.py @@ -5,7 +5,7 @@ Utility module to manipulate queries # Copyright (C) 2020 The Psycopg Team import re -from codecs import CodecInfo +from functools import lru_cache from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional from typing import Sequence, Tuple, Union, TYPE_CHECKING @@ -17,11 +17,19 @@ if TYPE_CHECKING: from ..proto import Transformer +class QueryPart(NamedTuple): + pre: bytes + item: Union[int, str] + format: Format + + class PostgresQuery: """ Helper to convert a Python query and parameters into Postgres format. """ + _parts: List[QueryPart] + def __init__(self, transformer: "Transformer"): self._tx = transformer self.query: bytes = b"" @@ -39,13 +47,13 @@ class PostgresQuery: attributes (`query`, `params`, `types`, `formats`). """ codec = self._tx.codec - if isinstance(query, str): - query = codec.encode(query)[0] if vars is not None: - self.query, self.formats, self._order = query2pg( - query, vars, codec + self.query, self.formats, self._order, self._parts = _query2pg( + query, codec.name ) else: + if isinstance(query, str): + query = codec.encode(query)[0] self.query = query self.formats = self._order = None @@ -57,21 +65,21 @@ class PostgresQuery: This method updates `params` and `types`. """ - if vars: - if self._order is not None: - assert isinstance(vars, Mapping) - vars = reorder_params(vars, self._order) - assert isinstance(vars, Sequence) + if vars is not None: + params = _validate_and_reorder_params( + self._parts, vars, self._order + ) self.params, self.types = self._tx.dump_sequence( - vars, self.formats or () + params, self.formats or () ) else: self.params = self.types = None -def query2pg( - query: bytes, vars: Params, codec: CodecInfo -) -> Tuple[bytes, List[Format], Optional[List[str]]]: +@lru_cache() +def _query2pg( + query: Query, encoding: str +) -> Tuple[bytes, List[Format], Optional[List[str]], List[QueryPart]]: """ Convert Python query and params into something Postgres understands. @@ -79,9 +87,11 @@ def query2pg( format (``$1``, ``$2``) - placeholders can be %s or %b (text or binary) - return ``query`` (bytes), ``formats`` (list of formats) ``order`` - (sequence of names used in the query, in the position they appear, in - + (sequence of names used in the query, in the position they appear) + ``parts`` (splits of queries and placeholders). """ + if isinstance(query, str): + query = query.encode(encoding) if not isinstance(query, bytes): # encoding from str already happened raise TypeError( @@ -89,33 +99,19 @@ def query2pg( f" got {type(query).__name__} instead" ) - parts = split_query(query, codec.name) + parts = _split_query(query, encoding) order: Optional[List[str]] = None chunks: List[bytes] = [] formats = [] - if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)): - if len(vars) != len(parts) - 1: - raise e.ProgrammingError( - f"the query has {len(parts) - 1} placeholders but" - f" {len(vars)} parameters were passed" - ) - if vars and not isinstance(parts[0].item, int): - raise TypeError( - "named placeholders require a mapping of parameters" - ) - + if isinstance(parts[0].item, int): for part in parts[:-1]: assert isinstance(part.item, int) chunks.append(part.pre) chunks.append(b"$%d" % (part.item + 1)) formats.append(part.format) - elif isinstance(vars, Mapping): - if vars and len(parts) > 1 and not isinstance(parts[0][1], str): - raise TypeError( - "positional placeholders (%s) require a sequence of parameters" - ) + elif isinstance(parts[0].item, str): seen: Dict[str, Tuple[bytes, Format]] = {} order = [] for part in parts[:-1]: @@ -135,17 +131,49 @@ def query2pg( ) chunks.append(seen[part.item][0]) + # last part + chunks.append(parts[-1].pre) + + return b"".join(chunks), formats, order, parts + + +def _validate_and_reorder_params( + parts: List[QueryPart], vars: Params, order: Optional[List[str]] +) -> Sequence[Any]: + """ + Verify the compatibility between a query and a set of params. + """ + if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)): + if len(vars) != len(parts) - 1: + raise e.ProgrammingError( + f"the query has {len(parts) - 1} placeholders but" + f" {len(vars)} parameters were passed" + ) + if vars and not isinstance(parts[0].item, int): + raise TypeError( + "named placeholders require a mapping of parameters" + ) + return vars + + elif isinstance(vars, Mapping): + if vars and len(parts) > 1 and not isinstance(parts[0][1], str): + raise TypeError( + "positional placeholders (%s) require a sequence of parameters" + ) + try: + return [vars[item] for item in order or ()] + except KeyError: + raise e.ProgrammingError( + f"query parameter missing:" + f" {', '.join(sorted(i for i in order or () if i not in vars))}" + ) + else: raise TypeError( f"query parameters should be a sequence or a mapping," f" got {type(vars).__name__}" ) - # last part - chunks.append(parts[-1].pre) - - return b"".join(chunks), formats, order - _re_placeholder = re.compile( rb"""(?x) @@ -162,13 +190,7 @@ _re_placeholder = re.compile( ) -class QueryPart(NamedTuple): - pre: bytes - item: Union[int, str] - format: Format - - -def split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]: +def _split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]: parts: List[Tuple[bytes, Optional[Match[bytes]]]] = [] cur = 0 @@ -239,18 +261,3 @@ def split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]: i += 1 return rv - - -def reorder_params( - params: Mapping[str, Any], order: Sequence[str] -) -> List[str]: - """ - Convert a mapping of parameters into an array in a specified order - """ - try: - return [params[item] for item in order] - except KeyError: - raise e.ProgrammingError( - f"query parameter missing:" - f" {', '.join(sorted(i for i in order if i not in params))}" - ) diff --git a/tests/test_query.py b/tests/test_query.py index 7fa3e1591..ad7e6941b 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,8 +1,8 @@ -import codecs import pytest import psycopg3 -from psycopg3.utils.queries import split_query, query2pg, reorder_params +from psycopg3.adapt import Transformer +from psycopg3.utils.queries import PostgresQuery, _split_query @pytest.mark.parametrize( @@ -28,7 +28,7 @@ from psycopg3.utils.queries import split_query, query2pg, reorder_params ], ) def test_split_query(input, want): - assert split_query(input) == want + assert _split_query(input) == want @pytest.mark.parametrize( @@ -46,28 +46,30 @@ def test_split_query(input, want): ) def test_split_query_bad(input): with pytest.raises(psycopg3.ProgrammingError): - split_query(input) + _split_query(input) @pytest.mark.parametrize( - "query, params, want, wformats", + "query, params, want, wformats, wparams", [ - (b"", [], b"", []), - (b"%%", [], b"%", []), - (b"select %s", (1,), b"select $1", [False]), - (b"%s %% %s", (1, 2), b"$1 % $2", [False, False]), - (b"%b %% %s", (1, 2), b"$1 % $2", [True, False]), + (b"", None, b"", None, None), + (b"", [], b"", [], []), + (b"%%", [], b"%", [], []), + (b"select %s", (1,), b"select $1", [False], [b"1"]), + (b"%s %% %s", (1, 2), b"$1 % $2", [False, False], [b"1", b"2"]), + (b"%b %% %s", ("a", 2), b"$1 % $2", [True, False], [b"a", b"2"]), ], ) -def test_query2pg_seq(query, params, want, wformats): - out, formats, order = query2pg(query, params, codecs.lookup("utf-8")) - assert order is None - assert out == want - assert formats == wformats +def test_pg_query_seq(query, params, want, wformats, wparams): + pq = PostgresQuery(Transformer()) + pq.convert(query, params) + assert pq.query == want + assert pq.formats == wformats + assert pq.params == wparams @pytest.mark.parametrize( - "query, params, want, wformats, worder", + "query, params, want, wformats, wparams", [ (b"", {}, b"", [], []), (b"hello %%", {"a": 1}, b"hello %", [], []), @@ -76,22 +78,23 @@ def test_query2pg_seq(query, params, want, wformats): {"hello": 1, "world": 2}, b"select $1", [False], - ["hello"], + [b"1"], ), ( b"select %(hi)s %(there)b %(hi)s", - {"hi": 1, "there": 2}, + {"hi": 1, "there": "a"}, b"select $1 $2 $1", [False, True], - ["hi", "there"], + [b"1", b"a"], ), ], ) -def test_query2pg_map(query, params, want, wformats, worder): - out, formats, order = query2pg(query, params, codecs.lookup("utf-8")) - assert out == want - assert formats == wformats - assert order == worder +def test_pg_query_map(query, params, want, wformats, wparams): + pq = PostgresQuery(Transformer()) + pq.convert(query, params) + assert pq.query == want + assert pq.formats == wformats + assert pq.params == wparams @pytest.mark.parametrize( @@ -103,13 +106,12 @@ def test_query2pg_map(query, params, want, wformats, worder): (b"select %s", 1), (b"select %s", b"a"), (b"select %s", set()), - ("select", []), - ("select", []), ], ) -def test_query2pg_badtype(query, params): +def test_pq_query_badtype(query, params): + pq = PostgresQuery(Transformer()) with pytest.raises(TypeError): - query2pg(query, params, codecs.lookup("utf-8")) + pq.convert(query, params) @pytest.mark.parametrize( @@ -126,19 +128,7 @@ def test_query2pg_badtype(query, params): (b"select %(hi)s %(hi)b", {"hi": 1}), ], ) -def test_query2pg_badprog(query, params): +def test_pq_query_badprog(query, params): + pq = PostgresQuery(Transformer()) with pytest.raises(psycopg3.ProgrammingError): - query2pg(query, params, codecs.lookup("utf-8")) - - -@pytest.mark.parametrize( - "params, order, want", - [ - ({"foo": 1, "bar": 2}, [], []), - ({"foo": 1, "bar": 2}, ["foo"], [1]), - ({"foo": 1, "bar": 2}, ["bar", "foo"], [2, 1]), - ], -) -def test_reorder_params(params, order, want): - rv = reorder_params(params, order) - assert rv == want + pq.convert(query, params)