From: Daniele Varrazzo Date: Sat, 21 Mar 2020 21:27:05 +0000 (+1300) Subject: Added query mangling and basic cursor execute X-Git-Tag: 3.0.dev0~683 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=3e7213cf31418e2e85247335e13b38fa699f194d;p=thirdparty%2Fpsycopg.git Added query mangling and basic cursor execute --- diff --git a/psycopg3/connection.py b/psycopg3/connection.py index bb3f86f57..60516b69d 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -4,12 +4,14 @@ psycopg3 connection objects # Copyright (C) 2020 The Psycopg Team +import codecs import logging import asyncio import threading from . import pq from . import exceptions as exc +from . import cursor from .conninfo import make_conninfo from .waiting import wait_select, wait_async, Wait, Ready @@ -26,6 +28,28 @@ class BaseConnection: def __init__(self, pgconn): self.pgconn = pgconn + self.cursor_factory = None + # name of the postgres encoding (in bytes) + self._pgenc = None + + def cursor(self, name=None): + return self.cursor_factory(self) + + @property + def codec(self): + # TODO: utf8 fastpath? + pgenc = self.pgconn.parameter_status(b"client_encoding") + if self._pgenc != pgenc: + # for unknown encodings and SQL_ASCII be strict and use ascii + pyenc = pq.py_codecs.get(pgenc.decode("ascii"), "ascii") + self._codec = codecs.lookup(pyenc) + return self._codec + + def encode(self, s): + return self.codec.encode(s)[0] + + def decode(self, b): + return self.codec.decode(b)[0] @classmethod def _connect_gen(cls, conninfo): @@ -122,6 +146,7 @@ class Connection(BaseConnection): def __init__(self, pgconn): super().__init__(pgconn) self.lock = threading.Lock() + self.cursor_factory = cursor.Cursor @classmethod def connect(cls, conninfo, connection_factory=None, **kwargs): @@ -168,6 +193,7 @@ class AsyncConnection(BaseConnection): def __init__(self, pgconn): super().__init__(pgconn) self.lock = asyncio.Lock() + self.cursor_factory = cursor.AsyncCursor @classmethod async def connect(cls, conninfo, **kwargs): diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py new file mode 100644 index 000000000..e23fc5e01 --- /dev/null +++ b/psycopg3/cursor.py @@ -0,0 +1,110 @@ +""" +psycopg3 cursor objects +""" + +# Copyright (C) 2020 The Psycopg Team + +from . import exceptions as exc +from .pq import error_message, DiagnosticField, ExecStatus +from .utils.queries import query2pg, reorder_params + + +class BaseCursor: + def __init__(self, conn, binary=False): + self.conn = conn + self.binary = binary + self._results = [] + self._result = None + self._iresult = 0 + + +class Cursor(BaseCursor): + def execute(self, query, vars=None): + with self.conn.lock: + self._results = [] + self._result = None + self._iresult = 0 + codec = self.conn.codec + + if isinstance(query, str): + query = codec.encode(query)[0] + + # process %% -> % only if there are paramters, even if empty list + if vars is not None: + query, order = query2pg(query, vars, codec) + if vars: + if order is not None: + vars = reorder_params(vars, order) + params, formats = self._adapt_sequence(vars) + self.conn.pgconn.send_query_params( + query, params, param_formats=formats + ) + else: + self.conn.pgconn.send_query(query) + + results = self.conn.wait(self.conn._exec_gen(self.conn.pgconn)) + if not results: + raise exc.InternalError("got no result from the query") + + badstats = {res.status for res in results} - { + ExecStatus.TUPLES_OK, + ExecStatus.COMMAND_OK, + ExecStatus.EMPTY_QUERY, + } + if not badstats: + self._results = results + self._result = results[0] + return self + + if results[-1].status == ExecStatus.FATAL_ERROR: + ecls = exc.class_for_state( + results[-1].error_field(DiagnosticField.SQLSTATE) + ) + raise ecls(error_message(results[-1])) + + elif badstats & { + ExecStatus.COPY_IN, + ExecStatus.COPY_OUT, + ExecStatus.COPY_BOTH, + }: + raise exc.ProgrammingError( + "COPY cannot be used with execute(); use copy() insead" + ) + else: + raise exc.InternalError( + f"got unexpected status from query:" + f" {', '.join(sorted(s.name for s in sorted(badstats)))}" + ) + + def nextset(self): + self._iresult += 1 + if self._iresult < len(self._results): + self._result = self._results[self._iresult] + return True + + def _adapt_sequence(self, vars): + # TODO: stub. Need adaptation layer. + codec = self.conn.codec + out = [ + codec.encode(str(v))[0] if v is not None else None for v in vars + ] + fmt = [0] * len(out) + return out, fmt + + +class AsyncCursor(BaseCursor): + async def execute(self, query, vars=None): + with self.conn.lock: + pass + + +class NamedCursorMixin: + pass + + +class NamedCursor(NamedCursorMixin, Cursor): + pass + + +class AsyncNamedCursor(NamedCursorMixin, AsyncCursor): + pass diff --git a/psycopg3/utils/__init__.py b/psycopg3/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/psycopg3/utils/queries.py b/psycopg3/utils/queries.py new file mode 100644 index 000000000..c91f5d4e0 --- /dev/null +++ b/psycopg3/utils/queries.py @@ -0,0 +1,145 @@ +""" +Utility module to manipulate queries +""" + +# Copyright (C) 2020 The Psycopg Team + +import re +from collections.abc import Sequence, Mapping + +from .. import exceptions as exc + + +def query2pg(query, vars, codec): + """ + Convert Python query and params into something Postgres understands. + + - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres + format (``$1``, ``$2``) + - return ``query`` (bytes), ``order`` (sequence of names used in the + query, in the position they appear, in case of named params, else None) + """ + if not isinstance(query, bytes): + # encoding from str already happened + raise TypeError( + f"the query should be str or bytes," + f" got {type(query).__name__} instead" + ) + + parts = split_query(query, codec.name) + + if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)): + if len(vars) != len(parts) - 1: + raise exc.ProgrammingError( + f"the query has {len(parts) - 1} placeholders but" + f" {len(vars)} parameters were passed" + ) + if vars and not isinstance(parts[0][1], int): + raise TypeError( + "named placeholders require a mapping of parameters" + ) + order = None + + elif isinstance(vars, Mapping): + if vars and len(parts) > 1 and not isinstance(parts[0][1], bytes): + raise TypeError( + "positional placeholders (%s) require a sequence of parameters" + ) + seen = {} + order = [] + for part in parts[:-1]: + name = codec.decode(part[1])[0] + if name not in seen: + part[1] = seen[name] = len(seen) + order.append(name) + else: + part[1] = seen[name] + + else: + raise TypeError("parameters should be a sequence or a mapping") + + # Assemble query and parameters + rv = [] + for part in parts[:-1]: + rv.append(part[0]) + rv.append(b"$%d" % (part[1] + 1)) + rv.append(parts[-1][0]) + + return b"".join(rv), order + + +def split_query(query, encoding="ascii"): + parts = [] + cur = 0 + + # pairs [(fragment, match)], with the last match None + m = None + for m in re.finditer(rb"%(?:(?:[^(])|(?:\(([^)]+)\).)|(?:.))", query): + pre = query[cur : m.span(0)[0]] + parts.append([pre, m]) + cur = m.span(0)[1] + if m is None: + parts.append([query, None]) + else: + parts.append([query[cur:], None]) + + # drop the "%%", validate + i = 0 + phtype = None + while i < len(parts): + m = parts[i][1] + if m is None: + break # last part + ph = m.group(0) + if ph == b"%%": + # unescape '%%' to '%' and merge the parts + parts[i + 1][0] = parts[i][0] + b"%" + parts[i + 1][0] + del parts[i] + continue + if ph == b"%(": + raise exc.ProgrammingError( + f"incomplete placeholder:" + f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'" + ) + elif ph == b"% ": + # explicit messasge for a typical error + raise exc.ProgrammingError( + "incomplete placeholder: '%'; if you want to use '%' as an" + " operator you can double it up, i.e. use '%%'" + ) + elif ph[-1:] != b"s": + raise exc.ProgrammingError( + f"only '%s' and '%(name)s' placeholders allowed, got" + f" {m.group(0).decode(encoding)}" + ) + + # Index or name + if m.group(1) is None: + parts[i][1] = i + else: + parts[i][1] = m.group(1) + + if phtype is None: + phtype = type(parts[i][1]) + else: + if phtype is not type(parts[i][1]): # noqa + raise exc.ProgrammingError( + "positional and named placeholders cannot be mixed" + ) + + i += 1 + + return parts + + +def reorder_params(params, order): + """ + Convert a mapping of parameters into an array in a specified order + """ + try: + return [params[item] for item in order] + except KeyError: + raise exc.ProgrammingError( + f"query parameter missing:" + f" {', '.join(sorted(i for i in order if i not in params))}" + ) diff --git a/tests/test_cursor.py b/tests/test_cursor.py new file mode 100644 index 000000000..6841c1c0f --- /dev/null +++ b/tests/test_cursor.py @@ -0,0 +1,18 @@ +def test_execute_many(conn): + cur = conn.cursor() + cur.execute("select 'foo'; select 'bar'") + assert len(cur._results) == 2 + assert cur._result.get_value(0, 0) == b"foo" + assert cur.nextset() + assert cur._result.get_value(0, 0) == b"bar" + assert cur.nextset() is None + + +def test_execute_sequence(conn): + cur = conn.cursor() + cur.execute("select %s, %s, %s", [1, "foo", None]) + assert len(cur._results) == 1 + assert cur._result.get_value(0, 0) == b"1" + assert cur._result.get_value(0, 1) == b"foo" + assert cur._result.get_value(0, 2) is None + assert cur.nextset() is None diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 000000000..a7994cc34 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,133 @@ +import codecs +import pytest + +import psycopg3 +from psycopg3.utils.queries import split_query, query2pg, reorder_params + + +@pytest.mark.parametrize( + "input, want", + [ + (b"", [[b"", None]]), + (b"foo bar", [[b"foo bar", None]]), + (b"foo %% bar", [[b"foo % bar", None]]), + (b"%s", [[b"", 0], [b"", None]]), + (b"%s foo", [[b"", 0], [b" foo", None]]), + (b"foo %s", [[b"foo ", 0], [b"", None]]), + (b"foo %%%s bar", [[b"foo %", 0], [b" bar", None]]), + (b"foo %(name)s bar", [[b"foo ", b"name"], [b" bar", None]]), + ( + b"foo %s%s bar %s baz", + [[b"foo ", 0], [b"", 1], [b" bar ", 2], [b" baz", None]], + ), + ], +) +def test_split_query(input, want): + assert split_query(input) == want + + +@pytest.mark.parametrize( + "input", + [ + b"foo %d bar", + b"foo % bar", + b"foo %%% bar", + b"foo %(foo)d bar", + b"foo %(foo)s bar %s baz", + b"foo %(foo) bar", + b"foo %(foo bar", + b"3%2", + ], +) +def test_split_query_bad(input): + with pytest.raises(psycopg3.ProgrammingError): + split_query(input) + + +@pytest.mark.parametrize( + "query, params, want", + [ + (b"", [], b""), + (b"%%", [], b"%"), + (b"select %s", (1,), b"select $1"), + (b"%s %% %s", (1, 2), b"$1 % $2"), + ], +) +def test_query2pg_seq(query, params, want): + out, order = query2pg(query, params, codecs.lookup("utf-8")) + assert order is None + assert out == want + + +@pytest.mark.parametrize( + "query, params, want, worder", + [ + (b"", {}, b"", []), + (b"hello %%", {"a": 1}, b"hello %", []), + ( + b"select %(hello)s", + {"hello": 1, "world": 2}, + b"select $1", + ["hello"], + ), + ( + b"select %(hi)s %(there)s %(hi)s", + {"hi": 1, "there": 2}, + b"select $1 $2 $1", + ["hi", "there"], + ), + ], +) +def test_query2pg_map(query, params, want, worder): + out, order = query2pg(query, params, codecs.lookup("utf-8")) + assert out == want + assert order == worder + + +@pytest.mark.parametrize( + "query, params", + [ + (b"select %s", {"a": 1}), + (b"select %(name)s", [1]), + (b"select %s", "a"), + (b"select %s", 1), + (b"select %s", b"a"), + (b"select %s", set()), + ("select", []), + ("select", []), + ], +) +def test_query2pg_badtype(query, params): + with pytest.raises(TypeError): + query2pg(query, params, codecs.lookup("utf-8")) + + +@pytest.mark.parametrize( + "query, params", + [ + (b"", [1]), + (b"%s", []), + (b"%%", [1]), + (b"$1", [1]), + (b"select %(", {"a": 1}), + (b"select %(a", {"a": 1}), + (b"select %(a)", {"a": 1}), + (b"select %s %(hi)s", 1), + ], +) +def test_query2pg_badprog(query, params): + with pytest.raises(psycopg3.ProgrammingError): + query2pg(query, params, codecs.lookup("utf-8")) + + +@pytest.mark.parametrize( + "params, order, want", + [ + ({"foo": 1, "bar": 2}, [], []), + ({"foo": 1, "bar": 2}, ["foo"], [1]), + ({"foo": 1, "bar": 2}, ["bar", "foo"], [2, 1]), + ], +) +def test_reorder_params(params, order, want): + rv = reorder_params(params, order) + assert rv == want