]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Transformer.dump_sequence returns a list of formats too
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 13 Jan 2021 14:28:06 +0000 (15:28 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 13 Jan 2021 14:28:06 +0000 (15:28 +0100)
psycopg3/psycopg3/_queries.py
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/proto.py
psycopg3_c/psycopg3_c/_psycopg3.pyi
psycopg3_c/psycopg3_c/_psycopg3/transform.pyx

index 76ccb1ac10c01bed39c1212bbfda1df6edfffc7b..97155f69d5366a764010f4db96e2470897c02740 100644 (file)
@@ -31,7 +31,7 @@ class PostgresQuery:
 
     __slots__ = """
         params types formats
-        _tx _unknown_oid _parts query _encoding _order
+        _tx _want_formats _parts query _encoding _order
         """.split()
 
     def __init__(self, transformer: "Transformer"):
@@ -40,7 +40,10 @@ class PostgresQuery:
         self.params: Optional[List[Optional[bytes]]] = None
         # these are tuples so they can be used as keys e.g. in prepared stmts
         self.types: Tuple[int, ...] = ()
-        self.formats: Optional[List[Format]] = None
+
+        # The format requested by the user and the ones to really pass Postgres
+        self._want_formats: Optional[List[Format]] = None
+        self.formats: Optional[Sequence[Format]] = None
 
         self._parts: List[QueryPart]
         self.query = b""
@@ -62,20 +65,23 @@ class PostgresQuery:
             query = query.as_bytes(self._tx)
 
         if vars is not None:
-            self.query, self.formats, self._order, self._parts = _query2pg(
-                query, self._encoding
-            )
+            (
+                self.query,
+                self._want_formats,
+                self._order,
+                self._parts,
+            ) = _query2pg(query, self._encoding)
         else:
             if isinstance(query, str):
                 query = query.encode(self._encoding)
             self.query = query
-            self.formats = self._order = None
+            self._want_formats = self._order = None
 
         self.dump(vars)
 
     def dump(self, vars: Optional[Params]) -> None:
         """
-        Process a new set of variables on the same query as before.
+        Process a new set of variables on the query processed by `convert()`.
 
         This method updates `params` and `types`.
         """
@@ -83,13 +89,14 @@ class PostgresQuery:
             params = _validate_and_reorder_params(
                 self._parts, vars, self._order
             )
-            assert self.formats is not None
-            self.params, self.types = self._tx.dump_sequence(
-                params, self.formats
+            assert self._want_formats is not None
+            self.params, self.types, self.formats = self._tx.dump_sequence(
+                params, self._want_formats
             )
         else:
             self.params = None
             self.types = ()
+            self.formats = None
 
 
 @lru_cache()
index 4a0b11f5d77d6c29ec917b89a1a600b4f657f856..27c53c0b4d8b7059c2a6e822eccc1f7be03febff 100644 (file)
@@ -104,7 +104,7 @@ class Transformer(AdaptContext):
 
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[Format]
-    ) -> Tuple[List[Any], Tuple[int, ...]]:
+    ) -> Tuple[List[Any], Tuple[int, ...], Sequence[Format]]:
         ps: List[Optional[bytes]] = [None] * len(params)
         ts = [INVALID_OID] * len(params)
 
@@ -121,7 +121,7 @@ class Transformer(AdaptContext):
                 ps[i] = dumper.dump(param)
                 ts[i] = dumper.oid
 
-        return ps, tuple(ts)
+        return ps, tuple(ts), formats
 
     def get_dumper(self, obj: Any, format: Format) -> "Dumper":
         # Fast path: return a Dumper class already instantiated from the same type
index fbaa6eea4a60dfa5de19acef24fdb96f593cb872..2cc5ab66565eb67c5532752d9e5af2d59bc15b8f 100644 (file)
@@ -93,7 +93,7 @@ class Transformer(Protocol):
 
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[Format]
-    ) -> Tuple[List[Any], Tuple[int, ...]]:
+    ) -> Tuple[List[Any], Tuple[int, ...], Sequence[Format]]:
         ...
 
     def get_dumper(self, obj: Any, format: Format) -> "Dumper":
index dad01fabb0f6dd68637fa7c052f3c69e5ab094dd..50a49a877da217b4097f831dceb20c5155ce3bc8 100644 (file)
@@ -30,7 +30,7 @@ class Transformer(proto.AdaptContext):
     ) -> None: ...
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[Format]
-    ) -> Tuple[List[Any], Tuple[int, ...]]: ...
+    ) -> Tuple[List[Any], Tuple[int, ...], Sequence[Format]]: ...
     def get_dumper(self, obj: Any, format: Format) -> Dumper: ...
     def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]: ...
     def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: ...
index 9c8b18fd043a3306a4cd0a99064e51dfe4addd0e..03323064820af6011aeac36fb3937ea6aa17fa3b 100644 (file)
@@ -258,7 +258,7 @@ cdef class Transformer:
             Py_INCREF(oid)
             PyTuple_SET_ITEM(ts, i, oid)
 
-        return ps, ts
+        return ps, ts, formats
 
     cdef RowDumper _get_row_dumper(self, object param, object fmt):
         cdef RowDumper row_dumper = RowDumper()