]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Callproc can take both args and kwargs
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 4 Nov 2020 01:17:27 +0000 (02:17 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 4 Nov 2020 01:17:27 +0000 (02:17 +0100)
psycopg3/psycopg3/cursor.py

index d804403c9e4c16909f1d4d49c870d405d70c6b1f..77b1ecd867f0b19d2295747f8af298d1b27fb1de 100644 (file)
@@ -329,31 +329,50 @@ class BaseCursor:
             )
 
     def _callproc_sql(
-        self, name: Union[str, sql.Identifier], args: Optional[Params] = None
+        self,
+        name: Union[str, sql.Identifier],
+        args: Optional[Params] = None,
+        kwargs: Optional[Mapping[str, Any]] = None,
     ) -> sql.Composable:
+        if args and not isinstance(args, (Sequence, Mapping)):
+            raise TypeError(
+                f"callproc args should be a sequence or a mapping,"
+                f" got {type(args).__name__}"
+            )
+        if isinstance(args, Mapping) and kwargs:
+            raise TypeError(
+                "callproc supports only one args sequence and one kwargs mapping"
+            )
+
+        if not kwargs and isinstance(args, Mapping):
+            kwargs = args
+            args = None
+
+        if kwargs and not isinstance(kwargs, Mapping):
+            raise TypeError(
+                f"callproc kwargs should be a mapping,"
+                f" got {type(kwargs).__name__}"
+            )
+
         qparts: List[sql.Composable] = [
             sql.SQL("select * from "),
             name if isinstance(name, sql.Identifier) else sql.Identifier(name),
             sql.SQL("("),
         ]
 
-        if isinstance(args, Sequence):
+        if args:
             for i, item in enumerate(args):
                 if i:
                     qparts.append(sql.SQL(","))
                 qparts.append(sql.Literal(item))
-        elif isinstance(args, Mapping):
-            for i, (k, v) in enumerate(args.items()):
+
+        if kwargs:
+            for i, (k, v) in enumerate(kwargs.items()):
                 if i:
                     qparts.append(sql.SQL(","))
                 qparts.extend(
                     [sql.Identifier(k), sql.SQL(":="), sql.Literal(v)]
                 )
-        elif args:
-            raise TypeError(
-                f"callproc parameters should be a sequence or a mapping,"
-                f" got {type(args).__name__}"
-            )
 
         qparts.append(sql.SQL(")"))
         return sql.Composed(qparts)
@@ -444,7 +463,10 @@ class Cursor(BaseCursor):
         return self
 
     def callproc(
-        self, name: Union[str, sql.Identifier], args: Optional[Params] = None
+        self,
+        name: Union[str, sql.Identifier],
+        args: Optional[Params] = None,
+        kwargs: Optional[Mapping[str, Any]] = None,
     ) -> Optional[Params]:
         self.execute(self._callproc_sql(name, args))
         return args
@@ -569,9 +591,12 @@ class AsyncCursor(BaseCursor):
         return self
 
     async def callproc(
-        self, name: Union[str, sql.Identifier], args: Optional[Params] = None
+        self,
+        name: Union[str, sql.Identifier],
+        args: Optional[Params] = None,
+        kwargs: Optional[Mapping[str, Any]] = None,
     ) -> Optional[Params]:
-        await self.execute(self._callproc_sql(name, args))
+        await self.execute(self._callproc_sql(name, args, kwargs))
         return args
 
     async def fetchone(self) -> Optional[Sequence[Any]]: