# Copyright (C) 2020 The Psycopg Team
from types import TracebackType
-from typing import Any, AsyncIterator, Callable, Iterator, List, Optional
-from typing import Sequence, Type, TYPE_CHECKING
+from typing import Any, AsyncIterator, Callable, Iterator, List, Mapping
+from typing import Optional, Sequence, Type, TYPE_CHECKING, Union
from operator import attrgetter
from . import errors as e
from . import pq
+from . import sql
from . import proto
from .oids import builtins
from .copy import Copy, AsyncCopy
# no-op
pass
+ def nextset(self) -> Optional[bool]:
+ self._iresult += 1
+ if self._iresult < len(self._results):
+ self.pgresult = self._results[self._iresult]
+ self._pos = 0
+ nrows = self.pgresult.command_tuples
+ self._rowcount = nrows if nrows is not None else -1
+ return True
+ else:
+ return None
+
def _start_query(self) -> None:
from .adapt import Transformer
result_format=self.format,
)
- def nextset(self) -> Optional[bool]:
- self._iresult += 1
- if self._iresult < len(self._results):
- self.pgresult = self._results[self._iresult]
- self._pos = 0
- nrows = self.pgresult.command_tuples
- self._rowcount = nrows if nrows is not None else -1
- return True
- else:
- return None
-
def _check_result(self) -> None:
res = self.pgresult
if res is None:
"the last operation didn't produce a result"
)
+ def _callproc_sql(
+ self, name: Union[str, sql.Identifier], args: Optional[Params] = None
+ ) -> sql.Composable:
+ qparts: List[sql.Composable] = [
+ sql.SQL("select * from "),
+ name if isinstance(name, sql.Identifier) else sql.Identifier(name),
+ sql.SQL("("),
+ ]
+
+ if isinstance(args, Sequence):
+ for i, item in enumerate(args):
+ if i:
+ qparts.append(sql.SQL(","))
+ qparts.append(sql.Literal(item))
+ elif isinstance(args, Mapping):
+ for i, (k, v) in enumerate(args.items()):
+ if i:
+ qparts.append(sql.SQL(","))
+ qparts.extend(
+ [sql.Identifier(k), sql.SQL(":="), sql.Literal(v)]
+ )
+ elif args:
+ raise TypeError(
+ f"callproc parameters should be a sequence or a mapping,"
+ f" got {type(args).__name__}"
+ )
+
+ qparts.append(sql.SQL(")"))
+ return sql.Composed(qparts)
+
def _check_copy_results(
self, results: Sequence[pq.proto.PGresult]
) -> None:
return self
+ def callproc(
+ self, name: Union[str, sql.Identifier], args: Optional[Params] = None
+ ) -> Optional[Params]:
+ self.execute(self._callproc_sql(name, args))
+ return args
+
def fetchone(self) -> Optional[Sequence[Any]]:
self._check_result()
rv = self._transformer.load_row(self._pos)
return self
+ async def callproc(
+ self, name: Union[str, sql.Identifier], args: Optional[Params] = None
+ ) -> Optional[Params]:
+ await self.execute(self._callproc_sql(name, args))
+ return args
+
async def fetchone(self) -> Optional[Sequence[Any]]:
self._check_result()
rv = self._transformer.load_row(self._pos)
import gc
import pytest
import weakref
+from collections import namedtuple
import psycopg3
+from psycopg3 import sql
from psycopg3.oids import builtins
cur.executemany(query, [(10, "hello"), (20, "world")])
+def test_callproc_args(conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create function testfunc(a int, b text) returns text[] language sql as
+ 'select array[$1::text, $2]'
+ """
+ )
+ assert cur.callproc("testfunc", [10, "twenty"]) == [10, "twenty"]
+ assert cur.fetchone() == (["10", "twenty"],)
+
+
+def test_callproc_badparam(conn):
+ cur = conn.cursor()
+ with pytest.raises(TypeError):
+ cur.callproc("lower", 42)
+ with pytest.raises(TypeError):
+ cur.callproc(42, ["lower"])
+
+
+def make_testfunc(conn):
+ # This parameter name tests for injection and quote escaping
+ paramname = """Robert'); drop table "students" --"""
+ procname = "randall"
+
+ # Set up the temporary function
+ stmt = (
+ sql.SQL(
+ """
+ create function {}({} numeric) returns numeric language sql as
+ 'select $1 * $1'
+ """
+ )
+ .format(sql.Identifier(procname), sql.Identifier(paramname))
+ .as_string(conn)
+ .encode(conn.codec.name)
+ )
+
+ # execute regardless of sync/async conn
+ conn.pgconn.exec_(stmt)
+
+ return namedtuple("Thang", "name, param")(procname, paramname)
+
+
+def test_callproc_dict(conn):
+
+ testfunc = make_testfunc(conn)
+ cur = conn.cursor()
+
+ cur.callproc(testfunc.name, [2])
+ assert cur.fetchone() == (4,)
+ cur.callproc(testfunc.name, {testfunc.param: 2})
+ assert cur.fetchone() == (4,)
+ cur.callproc(sql.Identifier(testfunc.name), {testfunc.param: 2})
+ assert cur.fetchone() == (4,)
+
+
+@pytest.mark.parametrize(
+ "args, exc",
+ [
+ ({"_p": 2, "foo": "bar"}, psycopg3.ProgrammingError),
+ ({"_p": "two"}, psycopg3.DataError),
+ ({"bj\xc3rn": 2}, psycopg3.ProgrammingError),
+ ({3: 2}, TypeError),
+ ({(): 2}, TypeError),
+ ],
+)
+def test_callproc_dict_bad(conn, args, exc):
+ testfunc = make_testfunc(conn)
+ if "_p" in args:
+ args[testfunc.param] = args.pop("_p")
+
+ cur = conn.cursor()
+ with pytest.raises(exc):
+ cur.callproc(testfunc.name, args)
+
+
def test_rowcount(conn):
cur = conn.cursor()
cur.execute("select 1 from generate_series(1, 42)")
import weakref
import psycopg3
+from psycopg3 import sql
+
+from .test_cursor import make_testfunc
pytestmark = pytest.mark.asyncio
await cur.executemany(query, [(10, "hello"), (20, "world")])
+async def test_callproc_args(aconn):
+ cur = aconn.cursor()
+ await cur.execute(
+ """
+ create function testfunc(a int, b text) returns text[] language sql as
+ 'select array[$1::text, $2]'
+ """
+ )
+ assert (await cur.callproc("testfunc", [10, "twenty"])) == [10, "twenty"]
+ assert (await cur.fetchone()) == (["10", "twenty"],)
+
+
+async def test_callproc_badparam(aconn):
+ cur = aconn.cursor()
+ with pytest.raises(TypeError):
+ await cur.callproc("lower", 42)
+ with pytest.raises(TypeError):
+ await cur.callproc(42, ["lower"])
+
+
+async def test_callproc_dict(aconn):
+ testfunc = make_testfunc(aconn)
+
+ cur = aconn.cursor()
+
+ await cur.callproc(testfunc.name, [2])
+ assert (await cur.fetchone()) == (4,)
+ await cur.callproc(testfunc.name, {testfunc.param: 2})
+ assert await (cur.fetchone()) == (4,)
+ await cur.callproc(sql.Identifier(testfunc.name), {testfunc.param: 2})
+ assert await (cur.fetchone()) == (4,)
+
+
+@pytest.mark.parametrize(
+ "args, exc",
+ [
+ ({"_p": 2, "foo": "bar"}, psycopg3.ProgrammingError),
+ ({"_p": "two"}, psycopg3.DataError),
+ ({"bj\xc3rn": 2}, psycopg3.ProgrammingError),
+ ({3: 2}, TypeError),
+ ({(): 2}, TypeError),
+ ],
+)
+async def test_callproc_dict_bad(aconn, args, exc):
+ testfunc = make_testfunc(aconn)
+ if "_p" in args:
+ args[testfunc.param] = args.pop("_p")
+
+ cur = aconn.cursor()
+ with pytest.raises(exc):
+ await cur.callproc(testfunc.name, args)
+
+
async def test_rowcount(aconn):
cur = aconn.cursor()