]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Allow passing a list of names to copy.set_types()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 15 Jan 2021 23:35:56 +0000 (00:35 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 16 Jan 2021 01:37:13 +0000 (02:37 +0100)
Also improved TypeRegistry by recognising a type name even if includes
the array marker '[]' and added get_oid() function.

No test for this object as it's considered internal and leveraged by
several other tests, but I'm sure I will regret it.

docs/copy.rst
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/oids.py
tests/test_copy.py

index 95f8551d07b4880926358220a96a11c9140b212f..c9665cb80587443b3ed002ef0373c7350da2d418 100644 (file)
@@ -99,19 +99,17 @@ you have to specify them yourselves.
 
 .. code:: python
 
-    from psycopg3.oids import builtins
-
     with cur.copy("COPY (VALUES (10::int, current_date)) TO STDOUT") as copy:
-        copy.set_types([builtins["int4"].oid, builtins["date"].oid])
+        copy.set_types(["int4", "date"])
         for row in copy.rows():
             print(row)  # (10, datetime.date(2046, 12, 24))
 
 .. admonition:: TODO
 
-    Document the `!builtins` register... but more likely do something
-    better such as allowing to pass type names, unifying `TypeRegistry` and
-    `AdaptContext`, none of which I have documented, so you haven't seen
-    anything... ðŸ‘€
+    Currently only builtin names are recognised; custom types must be
+    specified by numeric oid. This wll change after the `TypeRegistry` and
+    `AdaptContext` get integrated, none of which I have documented, so you
+    haven't seen anything... ðŸ‘€
 
 
 Copying block-by-block
index d4b93869f5d02096de6386ef5cb2961af6a21096..763243412b20c934776a8bbc1af20f844d367a9e 100644 (file)
@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
 from . import pq
 from . import errors as e
 from .pq import ExecStatus
+from .oids import builtins
 from .adapt import Format
 from .proto import ConnectionType, PQGen, Transformer
 from .generators import copy_from, copy_to, copy_end
@@ -78,15 +79,29 @@ class BaseCopy(Generic[ConnectionType]):
         if self._finished:
             raise TypeError("copy blocks can be used only once")
 
-    def set_types(self, types: Sequence[int]) -> None:
+    def set_types(self, types: Sequence[Union[int, str]]) -> None:
         """
         Set the types expected out of a :sql:`COPY TO` operation.
 
         Without setting the types, the data from :sql:`COPY TO` will be
         returned as unparsed strings or bytes.
+
+        The types must be specified as a sequence of oid or PostgreSQL type
+        names (e.g. ``int4``, ``timestamptz[]``).
+
+        .. admonition:: TODO
+
+            Only builtin names are supprted for the moment. In order to specify
+            custom data types you must use their oid.
+
         """
+        # TODO: should allow names of non-builtin types
+        # Must put a types map on the context.
+        oids = [
+            t if isinstance(t, int) else builtins.get_oid(t) for t in types
+        ]
         self.formatter.transformer.set_row_types(
-            types, [self.formatter.format] * len(types)
+            oids, [self.formatter.format] * len(types)
         )
 
     # High level copy protocol generators (state change of the Copy object)
index 95e67a999b15cc6bdb26a6ecb489c3d5b71bc81a..d9eea7da043518fdece9d3168b6f9550bd1af95a 100644 (file)
@@ -62,6 +62,8 @@ class TypesRegistry:
 
     def __getitem__(self, key: Union[str, int]) -> TypeInfo:
         if isinstance(key, str):
+            if key.endswith("[]"):
+                key = key[:-2]
             return self._by_name[key]
         elif isinstance(key, int):
             return self._by_oid[key]
@@ -76,6 +78,13 @@ class TypesRegistry:
         except KeyError:
             return None
 
+    def get_oid(self, name: str) -> int:
+        t = self[name]
+        if name.endswith("[]"):
+            return t.array_oid
+        else:
+            return t.oid
+
 
 builtins = TypesRegistry()
 
index ea12f1e00faf0863c21cebf3b5e9387f0bceb17a..cfed04954a8ca42c1376a997a5c822f7b627f42c 100644 (file)
@@ -83,25 +83,22 @@ def test_copy_out_iter(conn, format):
 
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
-def test_read_rows(conn, format):
+@pytest.mark.parametrize("typetype", ["names", "oids"])
+def test_read_rows(conn, format, typetype):
     cur = conn.cursor()
     with cur.copy(
-        f"copy ({sample_values}) to stdout (format {format.name})"
+        f"""copy (
+            select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[]
+        ) to stdout (format {format.name})"""
     ) as copy:
-        # TODO: should be passed by name
-        # big refactoring to be had, to have builtins not global and merged
-        # to adaptation context I guess...
-        copy.set_types(
-            [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid]
-        )
-        rows = []
-        while 1:
-            row = copy.read_row()
-            if not row:
-                break
-            rows.append(row)
-
-    assert rows == sample_records
+        types = ["int4", "text", "float8[]"]
+        if typetype == "oids":
+            types = [builtins.get_oid(t) for t in types]
+        copy.set_types(types)
+        row = copy.read_row()
+        assert copy.read_row() is None
+
+    assert row == (10, "hello", [0.0, 1.0])
     assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS