# Copyright (C) 2020 The Psycopg Team
+import codecs
import logging
import asyncio
import threading
from . import pq
from . import exceptions as exc
+from . import cursor
from .conninfo import make_conninfo
from .waiting import wait_select, wait_async, Wait, Ready
def __init__(self, pgconn):
self.pgconn = pgconn
+ self.cursor_factory = None
+ # name of the postgres encoding (in bytes)
+ self._pgenc = None
+
+ def cursor(self, name=None):
+ return self.cursor_factory(self)
+
+ @property
+ def codec(self):
+ # TODO: utf8 fastpath?
+ pgenc = self.pgconn.parameter_status(b"client_encoding")
+ if self._pgenc != pgenc:
+ # for unknown encodings and SQL_ASCII be strict and use ascii
+ pyenc = pq.py_codecs.get(pgenc.decode("ascii"), "ascii")
+ self._codec = codecs.lookup(pyenc)
+ return self._codec
+
+ def encode(self, s):
+ return self.codec.encode(s)[0]
+
+ def decode(self, b):
+ return self.codec.decode(b)[0]
@classmethod
def _connect_gen(cls, conninfo):
def __init__(self, pgconn):
super().__init__(pgconn)
self.lock = threading.Lock()
+ self.cursor_factory = cursor.Cursor
@classmethod
def connect(cls, conninfo, connection_factory=None, **kwargs):
def __init__(self, pgconn):
super().__init__(pgconn)
self.lock = asyncio.Lock()
+ self.cursor_factory = cursor.AsyncCursor
@classmethod
async def connect(cls, conninfo, **kwargs):
--- /dev/null
+"""
+psycopg3 cursor objects
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from . import exceptions as exc
+from .pq import error_message, DiagnosticField, ExecStatus
+from .utils.queries import query2pg, reorder_params
+
+
+class BaseCursor:
+ def __init__(self, conn, binary=False):
+ self.conn = conn
+ self.binary = binary
+ self._results = []
+ self._result = None
+ self._iresult = 0
+
+
+class Cursor(BaseCursor):
+ def execute(self, query, vars=None):
+ with self.conn.lock:
+ self._results = []
+ self._result = None
+ self._iresult = 0
+ codec = self.conn.codec
+
+ if isinstance(query, str):
+ query = codec.encode(query)[0]
+
+ # process %% -> % only if there are paramters, even if empty list
+ if vars is not None:
+ query, order = query2pg(query, vars, codec)
+ if vars:
+ if order is not None:
+ vars = reorder_params(vars, order)
+ params, formats = self._adapt_sequence(vars)
+ self.conn.pgconn.send_query_params(
+ query, params, param_formats=formats
+ )
+ else:
+ self.conn.pgconn.send_query(query)
+
+ results = self.conn.wait(self.conn._exec_gen(self.conn.pgconn))
+ if not results:
+ raise exc.InternalError("got no result from the query")
+
+ badstats = {res.status for res in results} - {
+ ExecStatus.TUPLES_OK,
+ ExecStatus.COMMAND_OK,
+ ExecStatus.EMPTY_QUERY,
+ }
+ if not badstats:
+ self._results = results
+ self._result = results[0]
+ return self
+
+ if results[-1].status == ExecStatus.FATAL_ERROR:
+ ecls = exc.class_for_state(
+ results[-1].error_field(DiagnosticField.SQLSTATE)
+ )
+ raise ecls(error_message(results[-1]))
+
+ elif badstats & {
+ ExecStatus.COPY_IN,
+ ExecStatus.COPY_OUT,
+ ExecStatus.COPY_BOTH,
+ }:
+ raise exc.ProgrammingError(
+ "COPY cannot be used with execute(); use copy() insead"
+ )
+ else:
+ raise exc.InternalError(
+ f"got unexpected status from query:"
+ f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
+ )
+
+ def nextset(self):
+ self._iresult += 1
+ if self._iresult < len(self._results):
+ self._result = self._results[self._iresult]
+ return True
+
+ def _adapt_sequence(self, vars):
+ # TODO: stub. Need adaptation layer.
+ codec = self.conn.codec
+ out = [
+ codec.encode(str(v))[0] if v is not None else None for v in vars
+ ]
+ fmt = [0] * len(out)
+ return out, fmt
+
+
+class AsyncCursor(BaseCursor):
+ async def execute(self, query, vars=None):
+ with self.conn.lock:
+ pass
+
+
+class NamedCursorMixin:
+ pass
+
+
+class NamedCursor(NamedCursorMixin, Cursor):
+ pass
+
+
+class AsyncNamedCursor(NamedCursorMixin, AsyncCursor):
+ pass
--- /dev/null
+"""
+Utility module to manipulate queries
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+from collections.abc import Sequence, Mapping
+
+from .. import exceptions as exc
+
+
+def query2pg(query, vars, codec):
+ """
+ Convert Python query and params into something Postgres understands.
+
+ - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
+ format (``$1``, ``$2``)
+ - return ``query`` (bytes), ``order`` (sequence of names used in the
+ query, in the position they appear, in case of named params, else None)
+ """
+ if not isinstance(query, bytes):
+ # encoding from str already happened
+ raise TypeError(
+ f"the query should be str or bytes,"
+ f" got {type(query).__name__} instead"
+ )
+
+ parts = split_query(query, codec.name)
+
+ if isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
+ if len(vars) != len(parts) - 1:
+ raise exc.ProgrammingError(
+ f"the query has {len(parts) - 1} placeholders but"
+ f" {len(vars)} parameters were passed"
+ )
+ if vars and not isinstance(parts[0][1], int):
+ raise TypeError(
+ "named placeholders require a mapping of parameters"
+ )
+ order = None
+
+ elif isinstance(vars, Mapping):
+ if vars and len(parts) > 1 and not isinstance(parts[0][1], bytes):
+ raise TypeError(
+ "positional placeholders (%s) require a sequence of parameters"
+ )
+ seen = {}
+ order = []
+ for part in parts[:-1]:
+ name = codec.decode(part[1])[0]
+ if name not in seen:
+ part[1] = seen[name] = len(seen)
+ order.append(name)
+ else:
+ part[1] = seen[name]
+
+ else:
+ raise TypeError("parameters should be a sequence or a mapping")
+
+ # Assemble query and parameters
+ rv = []
+ for part in parts[:-1]:
+ rv.append(part[0])
+ rv.append(b"$%d" % (part[1] + 1))
+ rv.append(parts[-1][0])
+
+ return b"".join(rv), order
+
+
+def split_query(query, encoding="ascii"):
+ parts = []
+ cur = 0
+
+ # pairs [(fragment, match)], with the last match None
+ m = None
+ for m in re.finditer(rb"%(?:(?:[^(])|(?:\(([^)]+)\).)|(?:.))", query):
+ pre = query[cur : m.span(0)[0]]
+ parts.append([pre, m])
+ cur = m.span(0)[1]
+ if m is None:
+ parts.append([query, None])
+ else:
+ parts.append([query[cur:], None])
+
+ # drop the "%%", validate
+ i = 0
+ phtype = None
+ while i < len(parts):
+ m = parts[i][1]
+ if m is None:
+ break # last part
+ ph = m.group(0)
+ if ph == b"%%":
+ # unescape '%%' to '%' and merge the parts
+ parts[i + 1][0] = parts[i][0] + b"%" + parts[i + 1][0]
+ del parts[i]
+ continue
+ if ph == b"%(":
+ raise exc.ProgrammingError(
+ f"incomplete placeholder:"
+ f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
+ )
+ elif ph == b"% ":
+ # explicit messasge for a typical error
+ raise exc.ProgrammingError(
+ "incomplete placeholder: '%'; if you want to use '%' as an"
+ " operator you can double it up, i.e. use '%%'"
+ )
+ elif ph[-1:] != b"s":
+ raise exc.ProgrammingError(
+ f"only '%s' and '%(name)s' placeholders allowed, got"
+ f" {m.group(0).decode(encoding)}"
+ )
+
+ # Index or name
+ if m.group(1) is None:
+ parts[i][1] = i
+ else:
+ parts[i][1] = m.group(1)
+
+ if phtype is None:
+ phtype = type(parts[i][1])
+ else:
+ if phtype is not type(parts[i][1]): # noqa
+ raise exc.ProgrammingError(
+ "positional and named placeholders cannot be mixed"
+ )
+
+ i += 1
+
+ return parts
+
+
+def reorder_params(params, order):
+ """
+ Convert a mapping of parameters into an array in a specified order
+ """
+ try:
+ return [params[item] for item in order]
+ except KeyError:
+ raise exc.ProgrammingError(
+ f"query parameter missing:"
+ f" {', '.join(sorted(i for i in order if i not in params))}"
+ )
--- /dev/null
+def test_execute_many(conn):
+ cur = conn.cursor()
+ cur.execute("select 'foo'; select 'bar'")
+ assert len(cur._results) == 2
+ assert cur._result.get_value(0, 0) == b"foo"
+ assert cur.nextset()
+ assert cur._result.get_value(0, 0) == b"bar"
+ assert cur.nextset() is None
+
+
+def test_execute_sequence(conn):
+ cur = conn.cursor()
+ cur.execute("select %s, %s, %s", [1, "foo", None])
+ assert len(cur._results) == 1
+ assert cur._result.get_value(0, 0) == b"1"
+ assert cur._result.get_value(0, 1) == b"foo"
+ assert cur._result.get_value(0, 2) is None
+ assert cur.nextset() is None
--- /dev/null
+import codecs
+import pytest
+
+import psycopg3
+from psycopg3.utils.queries import split_query, query2pg, reorder_params
+
+
+@pytest.mark.parametrize(
+ "input, want",
+ [
+ (b"", [[b"", None]]),
+ (b"foo bar", [[b"foo bar", None]]),
+ (b"foo %% bar", [[b"foo % bar", None]]),
+ (b"%s", [[b"", 0], [b"", None]]),
+ (b"%s foo", [[b"", 0], [b" foo", None]]),
+ (b"foo %s", [[b"foo ", 0], [b"", None]]),
+ (b"foo %%%s bar", [[b"foo %", 0], [b" bar", None]]),
+ (b"foo %(name)s bar", [[b"foo ", b"name"], [b" bar", None]]),
+ (
+ b"foo %s%s bar %s baz",
+ [[b"foo ", 0], [b"", 1], [b" bar ", 2], [b" baz", None]],
+ ),
+ ],
+)
+def test_split_query(input, want):
+ assert split_query(input) == want
+
+
+@pytest.mark.parametrize(
+ "input",
+ [
+ b"foo %d bar",
+ b"foo % bar",
+ b"foo %%% bar",
+ b"foo %(foo)d bar",
+ b"foo %(foo)s bar %s baz",
+ b"foo %(foo) bar",
+ b"foo %(foo bar",
+ b"3%2",
+ ],
+)
+def test_split_query_bad(input):
+ with pytest.raises(psycopg3.ProgrammingError):
+ split_query(input)
+
+
+@pytest.mark.parametrize(
+ "query, params, want",
+ [
+ (b"", [], b""),
+ (b"%%", [], b"%"),
+ (b"select %s", (1,), b"select $1"),
+ (b"%s %% %s", (1, 2), b"$1 % $2"),
+ ],
+)
+def test_query2pg_seq(query, params, want):
+ out, order = query2pg(query, params, codecs.lookup("utf-8"))
+ assert order is None
+ assert out == want
+
+
+@pytest.mark.parametrize(
+ "query, params, want, worder",
+ [
+ (b"", {}, b"", []),
+ (b"hello %%", {"a": 1}, b"hello %", []),
+ (
+ b"select %(hello)s",
+ {"hello": 1, "world": 2},
+ b"select $1",
+ ["hello"],
+ ),
+ (
+ b"select %(hi)s %(there)s %(hi)s",
+ {"hi": 1, "there": 2},
+ b"select $1 $2 $1",
+ ["hi", "there"],
+ ),
+ ],
+)
+def test_query2pg_map(query, params, want, worder):
+ out, order = query2pg(query, params, codecs.lookup("utf-8"))
+ assert out == want
+ assert order == worder
+
+
+@pytest.mark.parametrize(
+ "query, params",
+ [
+ (b"select %s", {"a": 1}),
+ (b"select %(name)s", [1]),
+ (b"select %s", "a"),
+ (b"select %s", 1),
+ (b"select %s", b"a"),
+ (b"select %s", set()),
+ ("select", []),
+ ("select", []),
+ ],
+)
+def test_query2pg_badtype(query, params):
+ with pytest.raises(TypeError):
+ query2pg(query, params, codecs.lookup("utf-8"))
+
+
+@pytest.mark.parametrize(
+ "query, params",
+ [
+ (b"", [1]),
+ (b"%s", []),
+ (b"%%", [1]),
+ (b"$1", [1]),
+ (b"select %(", {"a": 1}),
+ (b"select %(a", {"a": 1}),
+ (b"select %(a)", {"a": 1}),
+ (b"select %s %(hi)s", 1),
+ ],
+)
+def test_query2pg_badprog(query, params):
+ with pytest.raises(psycopg3.ProgrammingError):
+ query2pg(query, params, codecs.lookup("utf-8"))
+
+
+@pytest.mark.parametrize(
+ "params, order, want",
+ [
+ ({"foo": 1, "bar": 2}, [], []),
+ ({"foo": 1, "bar": 2}, ["foo"], [1]),
+ ({"foo": 1, "bar": 2}, ["bar", "foo"], [2, 1]),
+ ],
+)
+def test_reorder_params(params, order, want):
+ rv = reorder_params(params, order)
+ assert rv == want