From: Daniele Varrazzo Date: Fri, 30 Oct 2020 14:23:14 +0000 (+0100) Subject: Added cursor.callproc() X-Git-Tag: 3.0.dev0~404 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b37f8365b1cae7b75028941b1696a61b8fb29cbf;p=thirdparty%2Fpsycopg.git Added cursor.callproc() --- diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index bc83bde0a..d804403c9 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -5,12 +5,13 @@ psycopg3 cursor objects # Copyright (C) 2020 The Psycopg Team from types import TracebackType -from typing import Any, AsyncIterator, Callable, Iterator, List, Optional -from typing import Sequence, Type, TYPE_CHECKING +from typing import Any, AsyncIterator, Callable, Iterator, List, Mapping +from typing import Optional, Sequence, Type, TYPE_CHECKING, Union from operator import attrgetter from . import errors as e from . import pq +from . import sql from . import proto from .oids import builtins from .copy import Copy, AsyncCopy @@ -207,6 +208,17 @@ class BaseCursor: # no-op pass + def nextset(self) -> Optional[bool]: + self._iresult += 1 + if self._iresult < len(self._results): + self.pgresult = self._results[self._iresult] + self._pos = 0 + nrows = self.pgresult.command_tuples + self._rowcount = nrows if nrows is not None else -1 + return True + else: + return None + def _start_query(self) -> None: from .adapt import Transformer @@ -307,17 +319,6 @@ class BaseCursor: result_format=self.format, ) - def nextset(self) -> Optional[bool]: - self._iresult += 1 - if self._iresult < len(self._results): - self.pgresult = self._results[self._iresult] - self._pos = 0 - nrows = self.pgresult.command_tuples - self._rowcount = nrows if nrows is not None else -1 - return True - else: - return None - def _check_result(self) -> None: res = self.pgresult if res is None: @@ -327,6 +328,36 @@ class BaseCursor: "the last operation didn't produce a result" ) + def _callproc_sql( + self, name: Union[str, sql.Identifier], args: Optional[Params] = None + ) -> sql.Composable: + qparts: List[sql.Composable] = [ + sql.SQL("select * from "), + name if isinstance(name, sql.Identifier) else sql.Identifier(name), + sql.SQL("("), + ] + + if isinstance(args, Sequence): + for i, item in enumerate(args): + if i: + qparts.append(sql.SQL(",")) + qparts.append(sql.Literal(item)) + elif isinstance(args, Mapping): + for i, (k, v) in enumerate(args.items()): + if i: + qparts.append(sql.SQL(",")) + qparts.extend( + [sql.Identifier(k), sql.SQL(":="), sql.Literal(v)] + ) + elif args: + raise TypeError( + f"callproc parameters should be a sequence or a mapping," + f" got {type(args).__name__}" + ) + + qparts.append(sql.SQL(")")) + return sql.Composed(qparts) + def _check_copy_results( self, results: Sequence[pq.proto.PGresult] ) -> None: @@ -412,6 +443,12 @@ class Cursor(BaseCursor): return self + def callproc( + self, name: Union[str, sql.Identifier], args: Optional[Params] = None + ) -> Optional[Params]: + self.execute(self._callproc_sql(name, args)) + return args + def fetchone(self) -> Optional[Sequence[Any]]: self._check_result() rv = self._transformer.load_row(self._pos) @@ -531,6 +568,12 @@ class AsyncCursor(BaseCursor): return self + async def callproc( + self, name: Union[str, sql.Identifier], args: Optional[Params] = None + ) -> Optional[Params]: + await self.execute(self._callproc_sql(name, args)) + return args + async def fetchone(self) -> Optional[Sequence[Any]]: self._check_result() rv = self._transformer.load_row(self._pos) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index fc6779e3c..1ebebdd9e 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,8 +1,10 @@ import gc import pytest import weakref +from collections import namedtuple import psycopg3 +from psycopg3 import sql from psycopg3.oids import builtins @@ -185,6 +187,83 @@ def test_executemany_badquery(conn, query): cur.executemany(query, [(10, "hello"), (20, "world")]) +def test_callproc_args(conn): + cur = conn.cursor() + cur.execute( + """ + create function testfunc(a int, b text) returns text[] language sql as + 'select array[$1::text, $2]' + """ + ) + assert cur.callproc("testfunc", [10, "twenty"]) == [10, "twenty"] + assert cur.fetchone() == (["10", "twenty"],) + + +def test_callproc_badparam(conn): + cur = conn.cursor() + with pytest.raises(TypeError): + cur.callproc("lower", 42) + with pytest.raises(TypeError): + cur.callproc(42, ["lower"]) + + +def make_testfunc(conn): + # This parameter name tests for injection and quote escaping + paramname = """Robert'); drop table "students" --""" + procname = "randall" + + # Set up the temporary function + stmt = ( + sql.SQL( + """ + create function {}({} numeric) returns numeric language sql as + 'select $1 * $1' + """ + ) + .format(sql.Identifier(procname), sql.Identifier(paramname)) + .as_string(conn) + .encode(conn.codec.name) + ) + + # execute regardless of sync/async conn + conn.pgconn.exec_(stmt) + + return namedtuple("Thang", "name, param")(procname, paramname) + + +def test_callproc_dict(conn): + + testfunc = make_testfunc(conn) + cur = conn.cursor() + + cur.callproc(testfunc.name, [2]) + assert cur.fetchone() == (4,) + cur.callproc(testfunc.name, {testfunc.param: 2}) + assert cur.fetchone() == (4,) + cur.callproc(sql.Identifier(testfunc.name), {testfunc.param: 2}) + assert cur.fetchone() == (4,) + + +@pytest.mark.parametrize( + "args, exc", + [ + ({"_p": 2, "foo": "bar"}, psycopg3.ProgrammingError), + ({"_p": "two"}, psycopg3.DataError), + ({"bj\xc3rn": 2}, psycopg3.ProgrammingError), + ({3: 2}, TypeError), + ({(): 2}, TypeError), + ], +) +def test_callproc_dict_bad(conn, args, exc): + testfunc = make_testfunc(conn) + if "_p" in args: + args[testfunc.param] = args.pop("_p") + + cur = conn.cursor() + with pytest.raises(exc): + cur.callproc(testfunc.name, args) + + def test_rowcount(conn): cur = conn.cursor() cur.execute("select 1 from generate_series(1, 42)") diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index f0c028317..8aca63855 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -3,6 +3,9 @@ import pytest import weakref import psycopg3 +from psycopg3 import sql + +from .test_cursor import make_testfunc pytestmark = pytest.mark.asyncio @@ -184,6 +187,59 @@ async def test_executemany_badquery(aconn, query): await cur.executemany(query, [(10, "hello"), (20, "world")]) +async def test_callproc_args(aconn): + cur = aconn.cursor() + await cur.execute( + """ + create function testfunc(a int, b text) returns text[] language sql as + 'select array[$1::text, $2]' + """ + ) + assert (await cur.callproc("testfunc", [10, "twenty"])) == [10, "twenty"] + assert (await cur.fetchone()) == (["10", "twenty"],) + + +async def test_callproc_badparam(aconn): + cur = aconn.cursor() + with pytest.raises(TypeError): + await cur.callproc("lower", 42) + with pytest.raises(TypeError): + await cur.callproc(42, ["lower"]) + + +async def test_callproc_dict(aconn): + testfunc = make_testfunc(aconn) + + cur = aconn.cursor() + + await cur.callproc(testfunc.name, [2]) + assert (await cur.fetchone()) == (4,) + await cur.callproc(testfunc.name, {testfunc.param: 2}) + assert await (cur.fetchone()) == (4,) + await cur.callproc(sql.Identifier(testfunc.name), {testfunc.param: 2}) + assert await (cur.fetchone()) == (4,) + + +@pytest.mark.parametrize( + "args, exc", + [ + ({"_p": 2, "foo": "bar"}, psycopg3.ProgrammingError), + ({"_p": "two"}, psycopg3.DataError), + ({"bj\xc3rn": 2}, psycopg3.ProgrammingError), + ({3: 2}, TypeError), + ({(): 2}, TypeError), + ], +) +async def test_callproc_dict_bad(aconn, args, exc): + testfunc = make_testfunc(aconn) + if "_p" in args: + args[testfunc.param] = args.pop("_p") + + cur = aconn.cursor() + with pytest.raises(exc): + await cur.callproc(testfunc.name, args) + + async def test_rowcount(aconn): cur = aconn.cursor()