]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added query mangling and basic cursor execute
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 21 Mar 2020 21:27:05 +0000 (10:27 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Mar 2020 08:05:50 +0000 (21:05 +1300)
psycopg3/connection.py
psycopg3/cursor.py [new file with mode: 0644]
psycopg3/utils/__init__.py [new file with mode: 0644]
psycopg3/utils/queries.py [new file with mode: 0644]
tests/test_cursor.py [new file with mode: 0644]
tests/test_query.py [new file with mode: 0644]

index bb3f86f5730ce6be150b764c039f26d4e1daabae..60516b69d124b85df097db4001684f5200f7ff28 100644 (file)
@@ -4,12 +4,14 @@ psycopg3 connection objects
 
 # Copyright (C) 2020 The Psycopg Team
 
+import codecs
 import logging
 import asyncio
 import threading
 
 from . import pq
 from . import exceptions as exc
+from . import cursor
 from .conninfo import make_conninfo
 from .waiting import wait_select, wait_async, Wait, Ready
 
@@ -26,6 +28,28 @@ class BaseConnection:
 
     def __init__(self, pgconn):
         self.pgconn = pgconn
+        self.cursor_factory = None
+        # name of the postgres encoding (in bytes)
+        self._pgenc = None
+
+    def cursor(self, name=None):
+        return self.cursor_factory(self)
+
+    @property
+    def codec(self):
+        # TODO: utf8 fastpath?
+        pgenc = self.pgconn.parameter_status(b"client_encoding")
+        if self._pgenc != pgenc:
+            # for unknown encodings and SQL_ASCII be strict and use ascii
+            pyenc = pq.py_codecs.get(pgenc.decode("ascii"), "ascii")
+            self._codec = codecs.lookup(pyenc)
+        return self._codec
+
+    def encode(self, s):
+        return self.codec.encode(s)[0]
+
+    def decode(self, b):
+        return self.codec.decode(b)[0]
 
     @classmethod
     def _connect_gen(cls, conninfo):
@@ -122,6 +146,7 @@ class Connection(BaseConnection):
     def __init__(self, pgconn):
         super().__init__(pgconn)
         self.lock = threading.Lock()
+        self.cursor_factory = cursor.Cursor
 
     @classmethod
     def connect(cls, conninfo, connection_factory=None, **kwargs):
@@ -168,6 +193,7 @@ class AsyncConnection(BaseConnection):
     def __init__(self, pgconn):
         super().__init__(pgconn)
         self.lock = asyncio.Lock()
+        self.cursor_factory = cursor.AsyncCursor
 
     @classmethod
     async def connect(cls, conninfo, **kwargs):
diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py
new file mode 100644 (file)
index 0000000..e23fc5e
--- /dev/null
@@ -0,0 +1,110 @@
+"""
+psycopg3 cursor objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from . import exceptions as exc
+from .pq import error_message, DiagnosticField, ExecStatus
+from .utils.queries import query2pg, reorder_params
+
+
+class BaseCursor:
+    def __init__(self, conn, binary=False):
+        self.conn = conn
+        self.binary = binary
+        self._results = []
+        self._result = None
+        self._iresult = 0
+
+
+class Cursor(BaseCursor):
+    def execute(self, query, vars=None):
+        with self.conn.lock:
+            self._results = []
+            self._result = None
+            self._iresult = 0
+            codec = self.conn.codec
+
+            if isinstance(query, str):
+                query = codec.encode(query)[0]
+
+            # process %% -> % only if there are paramters, even if empty list
+            if vars is not None:
+                query, order = query2pg(query, vars, codec)
+            if vars:
+                if order is not None:
+                    vars = reorder_params(vars, order)
+                params, formats = self._adapt_sequence(vars)
+                self.conn.pgconn.send_query_params(
+                    query, params, param_formats=formats
+                )
+            else:
+                self.conn.pgconn.send_query(query)
+
+            results = self.conn.wait(self.conn._exec_gen(self.conn.pgconn))
+            if not results:
+                raise exc.InternalError("got no result from the query")
+
+            badstats = {res.status for res in results} - {
+                ExecStatus.TUPLES_OK,
+                ExecStatus.COMMAND_OK,
+                ExecStatus.EMPTY_QUERY,
+            }
+            if not badstats:
+                self._results = results
+                self._result = results[0]
+                return self
+
+            if results[-1].status == ExecStatus.FATAL_ERROR:
+                ecls = exc.class_for_state(
+                    results[-1].error_field(DiagnosticField.SQLSTATE)
+                )
+                raise ecls(error_message(results[-1]))
+
+            elif badstats & {
+                ExecStatus.COPY_IN,
+                ExecStatus.COPY_OUT,
+                ExecStatus.COPY_BOTH,
+            }:
+                raise exc.ProgrammingError(
+                    "COPY cannot be used with execute(); use copy() insead"
+                )
+            else:
+                raise exc.InternalError(
+                    f"got unexpected status from query:"
+                    f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
+                )
+
+    def nextset(self):
+        self._iresult += 1
+        if self._iresult < len(self._results):
+            self._result = self._results[self._iresult]
+            return True
+
+    def _adapt_sequence(self, vars):
+        # TODO: stub. Need adaptation layer.
+        codec = self.conn.codec
+        out = [
+            codec.encode(str(v))[0] if v is not None else None for v in vars
+        ]
+        fmt = [0] * len(out)
+        return out, fmt
+
+
+class AsyncCursor(BaseCursor):
+    async def execute(self, query, vars=None):
+        with self.conn.lock:
+            pass
+
+
+class NamedCursorMixin:
+    pass
+
+
+class NamedCursor(NamedCursorMixin, Cursor):
+    pass
+
+
+class AsyncNamedCursor(NamedCursorMixin, AsyncCursor):
+    pass
diff --git a/psycopg3/utils/__init__.py b/psycopg3/utils/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/psycopg3/utils/queries.py b/psycopg3/utils/queries.py
new file mode 100644 (file)
index 0000000..c91f5d4
--- /dev/null
@@ -0,0 +1,145 @@
+"""
+Utility module to manipulate queries
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+from collections.abc import Sequence, Mapping
+
+from .. import exceptions as exc
+
+
+def query2pg(query, vars, codec):
+    """
+    Convert Python query and params into something Postgres understands.
+
+    - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
+      format (``$1``, ``$2``)
+    - return ``query`` (bytes), ``order`` (sequence of names used in the
+      query, in the position they appear, in case of named params, else None)
+    """
+    if not isinstance(query, bytes):
+        # encoding from str already happened
+        raise TypeError(
+            f"the query should be str or bytes,"
+            f" got {type(query).__name__} instead"
+        )
+
+    parts = split_query(query, codec.name)
+
+    if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
+        if len(vars) != len(parts) - 1:
+            raise exc.ProgrammingError(
+                f"the query has {len(parts) - 1} placeholders but"
+                f" {len(vars)} parameters were passed"
+            )
+        if vars and not isinstance(parts[0][1], int):
+            raise TypeError(
+                "named placeholders require a mapping of parameters"
+            )
+        order = None
+
+    elif isinstance(vars, Mapping):
+        if vars and len(parts) > 1 and not isinstance(parts[0][1], bytes):
+            raise TypeError(
+                "positional placeholders (%s) require a sequence of parameters"
+            )
+        seen = {}
+        order = []
+        for part in parts[:-1]:
+            name = codec.decode(part[1])[0]
+            if name not in seen:
+                part[1] = seen[name] = len(seen)
+                order.append(name)
+            else:
+                part[1] = seen[name]
+
+    else:
+        raise TypeError("parameters should be a sequence or a mapping")
+
+    # Assemble query and parameters
+    rv = []
+    for part in parts[:-1]:
+        rv.append(part[0])
+        rv.append(b"$%d" % (part[1] + 1))
+    rv.append(parts[-1][0])
+
+    return b"".join(rv), order
+
+
+def split_query(query, encoding="ascii"):
+    parts = []
+    cur = 0
+
+    # pairs [(fragment, match)], with the last match None
+    m = None
+    for m in re.finditer(rb"%(?:(?:[^(])|(?:\(([^)]+)\).)|(?:.))", query):
+        pre = query[cur : m.span(0)[0]]
+        parts.append([pre, m])
+        cur = m.span(0)[1]
+    if m is None:
+        parts.append([query, None])
+    else:
+        parts.append([query[cur:], None])
+
+    # drop the "%%", validate
+    i = 0
+    phtype = None
+    while i < len(parts):
+        m = parts[i][1]
+        if m is None:
+            break  # last part
+        ph = m.group(0)
+        if ph == b"%%":
+            # unescape '%%' to '%' and merge the parts
+            parts[i + 1][0] = parts[i][0] + b"%" + parts[i + 1][0]
+            del parts[i]
+            continue
+        if ph == b"%(":
+            raise exc.ProgrammingError(
+                f"incomplete placeholder:"
+                f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
+            )
+        elif ph == b"% ":
+            # explicit messasge for a typical error
+            raise exc.ProgrammingError(
+                "incomplete placeholder: '%'; if you want to use '%' as an"
+                " operator you can double it up, i.e. use '%%'"
+            )
+        elif ph[-1:] != b"s":
+            raise exc.ProgrammingError(
+                f"only '%s' and '%(name)s' placeholders allowed, got"
+                f" {m.group(0).decode(encoding)}"
+            )
+
+        # Index or name
+        if m.group(1) is None:
+            parts[i][1] = i
+        else:
+            parts[i][1] = m.group(1)
+
+        if phtype is None:
+            phtype = type(parts[i][1])
+        else:
+            if phtype is not type(parts[i][1]):  # noqa
+                raise exc.ProgrammingError(
+                    "positional and named placeholders cannot be mixed"
+                )
+
+        i += 1
+
+    return parts
+
+
+def reorder_params(params, order):
+    """
+    Convert a mapping of parameters into an array in a specified order
+    """
+    try:
+        return [params[item] for item in order]
+    except KeyError:
+        raise exc.ProgrammingError(
+            f"query parameter missing:"
+            f" {', '.join(sorted(i for i in order if i not in params))}"
+        )
diff --git a/tests/test_cursor.py b/tests/test_cursor.py
new file mode 100644 (file)
index 0000000..6841c1c
--- /dev/null
@@ -0,0 +1,18 @@
+def test_execute_many(conn):
+    cur = conn.cursor()
+    cur.execute("select 'foo'; select 'bar'")
+    assert len(cur._results) == 2
+    assert cur._result.get_value(0, 0) == b"foo"
+    assert cur.nextset()
+    assert cur._result.get_value(0, 0) == b"bar"
+    assert cur.nextset() is None
+
+
+def test_execute_sequence(conn):
+    cur = conn.cursor()
+    cur.execute("select %s, %s, %s", [1, "foo", None])
+    assert len(cur._results) == 1
+    assert cur._result.get_value(0, 0) == b"1"
+    assert cur._result.get_value(0, 1) == b"foo"
+    assert cur._result.get_value(0, 2) is None
+    assert cur.nextset() is None
diff --git a/tests/test_query.py b/tests/test_query.py
new file mode 100644 (file)
index 0000000..a7994cc
--- /dev/null
@@ -0,0 +1,133 @@
+import codecs
+import pytest
+
+import psycopg3
+from psycopg3.utils.queries import split_query, query2pg, reorder_params
+
+
+@pytest.mark.parametrize(
+    "input, want",
+    [
+        (b"", [[b"", None]]),
+        (b"foo bar", [[b"foo bar", None]]),
+        (b"foo %% bar", [[b"foo % bar", None]]),
+        (b"%s", [[b"", 0], [b"", None]]),
+        (b"%s foo", [[b"", 0], [b" foo", None]]),
+        (b"foo %s", [[b"foo ", 0], [b"", None]]),
+        (b"foo %%%s bar", [[b"foo %", 0], [b" bar", None]]),
+        (b"foo %(name)s bar", [[b"foo ", b"name"], [b" bar", None]]),
+        (
+            b"foo %s%s bar %s baz",
+            [[b"foo ", 0], [b"", 1], [b" bar ", 2], [b" baz", None]],
+        ),
+    ],
+)
+def test_split_query(input, want):
+    assert split_query(input) == want
+
+
+@pytest.mark.parametrize(
+    "input",
+    [
+        b"foo %d bar",
+        b"foo % bar",
+        b"foo %%% bar",
+        b"foo %(foo)d bar",
+        b"foo %(foo)s bar %s baz",
+        b"foo %(foo) bar",
+        b"foo %(foo bar",
+        b"3%2",
+    ],
+)
+def test_split_query_bad(input):
+    with pytest.raises(psycopg3.ProgrammingError):
+        split_query(input)
+
+
+@pytest.mark.parametrize(
+    "query, params, want",
+    [
+        (b"", [], b""),
+        (b"%%", [], b"%"),
+        (b"select %s", (1,), b"select $1"),
+        (b"%s %% %s", (1, 2), b"$1 % $2"),
+    ],
+)
+def test_query2pg_seq(query, params, want):
+    out, order = query2pg(query, params, codecs.lookup("utf-8"))
+    assert order is None
+    assert out == want
+
+
+@pytest.mark.parametrize(
+    "query, params, want, worder",
+    [
+        (b"", {}, b"", []),
+        (b"hello %%", {"a": 1}, b"hello %", []),
+        (
+            b"select %(hello)s",
+            {"hello": 1, "world": 2},
+            b"select $1",
+            ["hello"],
+        ),
+        (
+            b"select %(hi)s %(there)s %(hi)s",
+            {"hi": 1, "there": 2},
+            b"select $1 $2 $1",
+            ["hi", "there"],
+        ),
+    ],
+)
+def test_query2pg_map(query, params, want, worder):
+    out, order = query2pg(query, params, codecs.lookup("utf-8"))
+    assert out == want
+    assert order == worder
+
+
+@pytest.mark.parametrize(
+    "query, params",
+    [
+        (b"select %s", {"a": 1}),
+        (b"select %(name)s", [1]),
+        (b"select %s", "a"),
+        (b"select %s", 1),
+        (b"select %s", b"a"),
+        (b"select %s", set()),
+        ("select", []),
+        ("select", []),
+    ],
+)
+def test_query2pg_badtype(query, params):
+    with pytest.raises(TypeError):
+        query2pg(query, params, codecs.lookup("utf-8"))
+
+
+@pytest.mark.parametrize(
+    "query, params",
+    [
+        (b"", [1]),
+        (b"%s", []),
+        (b"%%", [1]),
+        (b"$1", [1]),
+        (b"select %(", {"a": 1}),
+        (b"select %(a", {"a": 1}),
+        (b"select %(a)", {"a": 1}),
+        (b"select %s %(hi)s", 1),
+    ],
+)
+def test_query2pg_badprog(query, params):
+    with pytest.raises(psycopg3.ProgrammingError):
+        query2pg(query, params, codecs.lookup("utf-8"))
+
+
+@pytest.mark.parametrize(
+    "params, order, want",
+    [
+        ({"foo": 1, "bar": 2}, [], []),
+        ({"foo": 1, "bar": 2}, ["foo"], [1]),
+        ({"foo": 1, "bar": 2}, ["bar", "foo"], [2, 1]),
+    ],
+)
+def test_reorder_params(params, order, want):
+    rv = reorder_params(params, order)
+    assert rv == want