# Copyright (C) 2020 The Psycopg Team
import re
-from codecs import CodecInfo
+from functools import lru_cache
from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional
from typing import Sequence, Tuple, Union, TYPE_CHECKING
from ..proto import Transformer
+class QueryPart(NamedTuple):
+ pre: bytes
+ item: Union[int, str]
+ format: Format
+
+
class PostgresQuery:
"""
Helper to convert a Python query and parameters into Postgres format.
"""
+ _parts: List[QueryPart]
+
def __init__(self, transformer: "Transformer"):
self._tx = transformer
self.query: bytes = b""
attributes (`query`, `params`, `types`, `formats`).
"""
codec = self._tx.codec
- if isinstance(query, str):
- query = codec.encode(query)[0]
if vars is not None:
- self.query, self.formats, self._order = query2pg(
- query, vars, codec
+ self.query, self.formats, self._order, self._parts = _query2pg(
+ query, codec.name
)
else:
+ if isinstance(query, str):
+ query = codec.encode(query)[0]
self.query = query
self.formats = self._order = None
This method updates `params` and `types`.
"""
- if vars:
- if self._order is not None:
- assert isinstance(vars, Mapping)
- vars = reorder_params(vars, self._order)
- assert isinstance(vars, Sequence)
+ if vars is not None:
+ params = _validate_and_reorder_params(
+ self._parts, vars, self._order
+ )
self.params, self.types = self._tx.dump_sequence(
- vars, self.formats or ()
+ params, self.formats or ()
)
else:
self.params = self.types = None
-def query2pg(
- query: bytes, vars: Params, codec: CodecInfo
-) -> Tuple[bytes, List[Format], Optional[List[str]]]:
+@lru_cache()
+def _query2pg(
+ query: Query, encoding: str
+) -> Tuple[bytes, List[Format], Optional[List[str]], List[QueryPart]]:
"""
Convert Python query and params into something Postgres understands.
format (``$1``, ``$2``)
- placeholders can be %s or %b (text or binary)
- return ``query`` (bytes), ``formats`` (list of formats) ``order``
- (sequence of names used in the query, in the position they appear, in
-
+ (sequence of names used in the query, in the position they appear)
+ ``parts`` (splits of queries and placeholders).
"""
+ if isinstance(query, str):
+ query = query.encode(encoding)
if not isinstance(query, bytes):
# encoding from str already happened
raise TypeError(
f" got {type(query).__name__} instead"
)
- parts = split_query(query, codec.name)
+ parts = _split_query(query, encoding)
order: Optional[List[str]] = None
chunks: List[bytes] = []
formats = []
- if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
- if len(vars) != len(parts) - 1:
- raise e.ProgrammingError(
- f"the query has {len(parts) - 1} placeholders but"
- f" {len(vars)} parameters were passed"
- )
- if vars and not isinstance(parts[0].item, int):
- raise TypeError(
- "named placeholders require a mapping of parameters"
- )
-
+ if isinstance(parts[0].item, int):
for part in parts[:-1]:
assert isinstance(part.item, int)
chunks.append(part.pre)
chunks.append(b"$%d" % (part.item + 1))
formats.append(part.format)
- elif isinstance(vars, Mapping):
- if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
- raise TypeError(
- "positional placeholders (%s) require a sequence of parameters"
- )
+ elif isinstance(parts[0].item, str):
seen: Dict[str, Tuple[bytes, Format]] = {}
order = []
for part in parts[:-1]:
)
chunks.append(seen[part.item][0])
+ # last part
+ chunks.append(parts[-1].pre)
+
+ return b"".join(chunks), formats, order, parts
+
+
+def _validate_and_reorder_params(
+ parts: List[QueryPart], vars: Params, order: Optional[List[str]]
+) -> Sequence[Any]:
+ """
+ Verify the compatibility between a query and a set of params.
+ """
+ if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
+ if len(vars) != len(parts) - 1:
+ raise e.ProgrammingError(
+ f"the query has {len(parts) - 1} placeholders but"
+ f" {len(vars)} parameters were passed"
+ )
+ if vars and not isinstance(parts[0].item, int):
+ raise TypeError(
+ "named placeholders require a mapping of parameters"
+ )
+ return vars
+
+ elif isinstance(vars, Mapping):
+ if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
+ raise TypeError(
+ "positional placeholders (%s) require a sequence of parameters"
+ )
+ try:
+ return [vars[item] for item in order or ()]
+ except KeyError:
+ raise e.ProgrammingError(
+ f"query parameter missing:"
+ f" {', '.join(sorted(i for i in order or () if i not in vars))}"
+ )
+
else:
raise TypeError(
f"query parameters should be a sequence or a mapping,"
f" got {type(vars).__name__}"
)
- # last part
- chunks.append(parts[-1].pre)
-
- return b"".join(chunks), formats, order
-
_re_placeholder = re.compile(
rb"""(?x)
)
-class QueryPart(NamedTuple):
- pre: bytes
- item: Union[int, str]
- format: Format
-
-
-def split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]:
+def _split_query(query: bytes, encoding: str = "ascii") -> List[QueryPart]:
parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
cur = 0
i += 1
return rv
-
-
-def reorder_params(
- params: Mapping[str, Any], order: Sequence[str]
-) -> List[str]:
- """
- Convert a mapping of parameters into an array in a specified order
- """
- try:
- return [params[item] for item in order]
- except KeyError:
- raise e.ProgrammingError(
- f"query parameter missing:"
- f" {', '.join(sorted(i for i in order if i not in params))}"
- )
-import codecs
import pytest
import psycopg3
-from psycopg3.utils.queries import split_query, query2pg, reorder_params
+from psycopg3.adapt import Transformer
+from psycopg3.utils.queries import PostgresQuery, _split_query
@pytest.mark.parametrize(
],
)
def test_split_query(input, want):
- assert split_query(input) == want
+ assert _split_query(input) == want
@pytest.mark.parametrize(
)
def test_split_query_bad(input):
with pytest.raises(psycopg3.ProgrammingError):
- split_query(input)
+ _split_query(input)
@pytest.mark.parametrize(
- "query, params, want, wformats",
+ "query, params, want, wformats, wparams",
[
- (b"", [], b"", []),
- (b"%%", [], b"%", []),
- (b"select %s", (1,), b"select $1", [False]),
- (b"%s %% %s", (1, 2), b"$1 % $2", [False, False]),
- (b"%b %% %s", (1, 2), b"$1 % $2", [True, False]),
+ (b"", None, b"", None, None),
+ (b"", [], b"", [], []),
+ (b"%%", [], b"%", [], []),
+ (b"select %s", (1,), b"select $1", [False], [b"1"]),
+ (b"%s %% %s", (1, 2), b"$1 % $2", [False, False], [b"1", b"2"]),
+ (b"%b %% %s", ("a", 2), b"$1 % $2", [True, False], [b"a", b"2"]),
],
)
-def test_query2pg_seq(query, params, want, wformats):
- out, formats, order = query2pg(query, params, codecs.lookup("utf-8"))
- assert order is None
- assert out == want
- assert formats == wformats
+def test_pg_query_seq(query, params, want, wformats, wparams):
+ pq = PostgresQuery(Transformer())
+ pq.convert(query, params)
+ assert pq.query == want
+ assert pq.formats == wformats
+ assert pq.params == wparams
@pytest.mark.parametrize(
- "query, params, want, wformats, worder",
+ "query, params, want, wformats, wparams",
[
(b"", {}, b"", [], []),
(b"hello %%", {"a": 1}, b"hello %", [], []),
{"hello": 1, "world": 2},
b"select $1",
[False],
- ["hello"],
+ [b"1"],
),
(
b"select %(hi)s %(there)b %(hi)s",
- {"hi": 1, "there": 2},
+ {"hi": 1, "there": "a"},
b"select $1 $2 $1",
[False, True],
- ["hi", "there"],
+ [b"1", b"a"],
),
],
)
-def test_query2pg_map(query, params, want, wformats, worder):
- out, formats, order = query2pg(query, params, codecs.lookup("utf-8"))
- assert out == want
- assert formats == wformats
- assert order == worder
+def test_pg_query_map(query, params, want, wformats, wparams):
+ pq = PostgresQuery(Transformer())
+ pq.convert(query, params)
+ assert pq.query == want
+ assert pq.formats == wformats
+ assert pq.params == wparams
@pytest.mark.parametrize(
(b"select %s", 1),
(b"select %s", b"a"),
(b"select %s", set()),
- ("select", []),
- ("select", []),
],
)
-def test_query2pg_badtype(query, params):
+def test_pq_query_badtype(query, params):
+ pq = PostgresQuery(Transformer())
with pytest.raises(TypeError):
- query2pg(query, params, codecs.lookup("utf-8"))
+ pq.convert(query, params)
@pytest.mark.parametrize(
(b"select %(hi)s %(hi)b", {"hi": 1}),
],
)
-def test_query2pg_badprog(query, params):
+def test_pq_query_badprog(query, params):
+ pq = PostgresQuery(Transformer())
with pytest.raises(psycopg3.ProgrammingError):
- query2pg(query, params, codecs.lookup("utf-8"))
-
-
-@pytest.mark.parametrize(
- "params, order, want",
- [
- ({"foo": 1, "bar": 2}, [], []),
- ({"foo": 1, "bar": 2}, ["foo"], [1]),
- ({"foo": 1, "bar": 2}, ["bar", "foo"], [2, 1]),
- ],
-)
-def test_reorder_params(params, order, want):
- rv = reorder_params(params, order)
- assert rv == want
+ pq.convert(query, params)