]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added cursor.callproc()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 30 Oct 2020 14:23:14 +0000 (15:23 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 30 Oct 2020 14:23:14 +0000 (15:23 +0100)
psycopg3/psycopg3/cursor.py
tests/test_cursor.py
tests/test_cursor_async.py

index bc83bde0ab7b447d79275c911177c4e25bd296ba..d804403c9e4c16909f1d4d49c870d405d70c6b1f 100644 (file)
@@ -5,12 +5,13 @@ psycopg3 cursor objects
 # Copyright (C) 2020 The Psycopg Team
 
 from types import TracebackType
-from typing import Any, AsyncIterator, Callable, Iterator, List, Optional
-from typing import Sequence, Type, TYPE_CHECKING
+from typing import Any, AsyncIterator, Callable, Iterator, List, Mapping
+from typing import Optional, Sequence, Type, TYPE_CHECKING, Union
 from operator import attrgetter
 
 from . import errors as e
 from . import pq
+from . import sql
 from . import proto
 from .oids import builtins
 from .copy import Copy, AsyncCopy
@@ -207,6 +208,17 @@ class BaseCursor:
         # no-op
         pass
 
+    def nextset(self) -> Optional[bool]:
+        self._iresult += 1
+        if self._iresult < len(self._results):
+            self.pgresult = self._results[self._iresult]
+            self._pos = 0
+            nrows = self.pgresult.command_tuples
+            self._rowcount = nrows if nrows is not None else -1
+            return True
+        else:
+            return None
+
     def _start_query(self) -> None:
         from .adapt import Transformer
 
@@ -307,17 +319,6 @@ class BaseCursor:
             result_format=self.format,
         )
 
-    def nextset(self) -> Optional[bool]:
-        self._iresult += 1
-        if self._iresult < len(self._results):
-            self.pgresult = self._results[self._iresult]
-            self._pos = 0
-            nrows = self.pgresult.command_tuples
-            self._rowcount = nrows if nrows is not None else -1
-            return True
-        else:
-            return None
-
     def _check_result(self) -> None:
         res = self.pgresult
         if res is None:
@@ -327,6 +328,36 @@ class BaseCursor:
                 "the last operation didn't produce a result"
             )
 
+    def _callproc_sql(
+        self, name: Union[str, sql.Identifier], args: Optional[Params] = None
+    ) -> sql.Composable:
+        qparts: List[sql.Composable] = [
+            sql.SQL("select * from "),
+            name if isinstance(name, sql.Identifier) else sql.Identifier(name),
+            sql.SQL("("),
+        ]
+
+        if isinstance(args, Sequence):
+            for i, item in enumerate(args):
+                if i:
+                    qparts.append(sql.SQL(","))
+                qparts.append(sql.Literal(item))
+        elif isinstance(args, Mapping):
+            for i, (k, v) in enumerate(args.items()):
+                if i:
+                    qparts.append(sql.SQL(","))
+                qparts.extend(
+                    [sql.Identifier(k), sql.SQL(":="), sql.Literal(v)]
+                )
+        elif args:
+            raise TypeError(
+                f"callproc parameters should be a sequence or a mapping,"
+                f" got {type(args).__name__}"
+            )
+
+        qparts.append(sql.SQL(")"))
+        return sql.Composed(qparts)
+
     def _check_copy_results(
         self, results: Sequence[pq.proto.PGresult]
     ) -> None:
@@ -412,6 +443,12 @@ class Cursor(BaseCursor):
 
         return self
 
+    def callproc(
+        self, name: Union[str, sql.Identifier], args: Optional[Params] = None
+    ) -> Optional[Params]:
+        self.execute(self._callproc_sql(name, args))
+        return args
+
     def fetchone(self) -> Optional[Sequence[Any]]:
         self._check_result()
         rv = self._transformer.load_row(self._pos)
@@ -531,6 +568,12 @@ class AsyncCursor(BaseCursor):
 
         return self
 
+    async def callproc(
+        self, name: Union[str, sql.Identifier], args: Optional[Params] = None
+    ) -> Optional[Params]:
+        await self.execute(self._callproc_sql(name, args))
+        return args
+
     async def fetchone(self) -> Optional[Sequence[Any]]:
         self._check_result()
         rv = self._transformer.load_row(self._pos)
index fc6779e3cbcdb0a707da59bd4596fb9f615aff9e..1ebebdd9e7d2ec334d93928d55a530f1471e6b80 100644 (file)
@@ -1,8 +1,10 @@
 import gc
 import pytest
 import weakref
+from collections import namedtuple
 
 import psycopg3
+from psycopg3 import sql
 from psycopg3.oids import builtins
 
 
@@ -185,6 +187,83 @@ def test_executemany_badquery(conn, query):
         cur.executemany(query, [(10, "hello"), (20, "world")])
 
 
+def test_callproc_args(conn):
+    cur = conn.cursor()
+    cur.execute(
+        """
+        create function testfunc(a int, b text) returns text[] language sql as
+            'select array[$1::text, $2]'
+        """
+    )
+    assert cur.callproc("testfunc", [10, "twenty"]) == [10, "twenty"]
+    assert cur.fetchone() == (["10", "twenty"],)
+
+
+def test_callproc_badparam(conn):
+    cur = conn.cursor()
+    with pytest.raises(TypeError):
+        cur.callproc("lower", 42)
+    with pytest.raises(TypeError):
+        cur.callproc(42, ["lower"])
+
+
+def make_testfunc(conn):
+    # This parameter name tests for injection and quote escaping
+    paramname = """Robert'); drop table "students" --"""
+    procname = "randall"
+
+    # Set up the temporary function
+    stmt = (
+        sql.SQL(
+            """
+        create function {}({} numeric) returns numeric language sql as
+            'select $1 * $1'
+        """
+        )
+        .format(sql.Identifier(procname), sql.Identifier(paramname))
+        .as_string(conn)
+        .encode(conn.codec.name)
+    )
+
+    # execute regardless of sync/async conn
+    conn.pgconn.exec_(stmt)
+
+    return namedtuple("Thang", "name, param")(procname, paramname)
+
+
+def test_callproc_dict(conn):
+
+    testfunc = make_testfunc(conn)
+    cur = conn.cursor()
+
+    cur.callproc(testfunc.name, [2])
+    assert cur.fetchone() == (4,)
+    cur.callproc(testfunc.name, {testfunc.param: 2})
+    assert cur.fetchone() == (4,)
+    cur.callproc(sql.Identifier(testfunc.name), {testfunc.param: 2})
+    assert cur.fetchone() == (4,)
+
+
+@pytest.mark.parametrize(
+    "args, exc",
+    [
+        ({"_p": 2, "foo": "bar"}, psycopg3.ProgrammingError),
+        ({"_p": "two"}, psycopg3.DataError),
+        ({"bj\xc3rn": 2}, psycopg3.ProgrammingError),
+        ({3: 2}, TypeError),
+        ({(): 2}, TypeError),
+    ],
+)
+def test_callproc_dict_bad(conn, args, exc):
+    testfunc = make_testfunc(conn)
+    if "_p" in args:
+        args[testfunc.param] = args.pop("_p")
+
+    cur = conn.cursor()
+    with pytest.raises(exc):
+        cur.callproc(testfunc.name, args)
+
+
 def test_rowcount(conn):
     cur = conn.cursor()
     cur.execute("select 1 from generate_series(1, 42)")
index f0c028317e84fd9e01e5146bbd150a1d4e24456f..8aca638559dbaf1fcc8d6450d0100116ff4342b0 100644 (file)
@@ -3,6 +3,9 @@ import pytest
 import weakref
 
 import psycopg3
+from psycopg3 import sql
+
+from .test_cursor import make_testfunc
 
 pytestmark = pytest.mark.asyncio
 
@@ -184,6 +187,59 @@ async def test_executemany_badquery(aconn, query):
         await cur.executemany(query, [(10, "hello"), (20, "world")])
 
 
+async def test_callproc_args(aconn):
+    cur = aconn.cursor()
+    await cur.execute(
+        """
+        create function testfunc(a int, b text) returns text[] language sql as
+            'select array[$1::text, $2]'
+        """
+    )
+    assert (await cur.callproc("testfunc", [10, "twenty"])) == [10, "twenty"]
+    assert (await cur.fetchone()) == (["10", "twenty"],)
+
+
+async def test_callproc_badparam(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(TypeError):
+        await cur.callproc("lower", 42)
+    with pytest.raises(TypeError):
+        await cur.callproc(42, ["lower"])
+
+
+async def test_callproc_dict(aconn):
+    testfunc = make_testfunc(aconn)
+
+    cur = aconn.cursor()
+
+    await cur.callproc(testfunc.name, [2])
+    assert (await cur.fetchone()) == (4,)
+    await cur.callproc(testfunc.name, {testfunc.param: 2})
+    assert await (cur.fetchone()) == (4,)
+    await cur.callproc(sql.Identifier(testfunc.name), {testfunc.param: 2})
+    assert await (cur.fetchone()) == (4,)
+
+
+@pytest.mark.parametrize(
+    "args, exc",
+    [
+        ({"_p": 2, "foo": "bar"}, psycopg3.ProgrammingError),
+        ({"_p": "two"}, psycopg3.DataError),
+        ({"bj\xc3rn": 2}, psycopg3.ProgrammingError),
+        ({3: 2}, TypeError),
+        ({(): 2}, TypeError),
+    ],
+)
+async def test_callproc_dict_bad(aconn, args, exc):
+    testfunc = make_testfunc(aconn)
+    if "_p" in args:
+        args[testfunc.param] = args.pop("_p")
+
+    cur = aconn.cursor()
+    with pytest.raises(exc):
+        await cur.callproc(testfunc.name, args)
+
+
 async def test_rowcount(aconn):
     cur = aconn.cursor()