From: Daniele Varrazzo Date: Fri, 15 Jan 2021 23:35:56 +0000 (+0100) Subject: Allow passing a list of names to copy.set_types() X-Git-Tag: 3.0.dev0~151 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7d8b2fda0b2adb9839bee8750a7426c74f1dc826;p=thirdparty%2Fpsycopg.git Allow passing a list of names to copy.set_types() Also improved TypeRegistry by recognising a type name even if includes the array marker '[]' and added get_oid() function. No test for this object as it's considered internal and leveraged by several other tests, but I'm sure I will regret it. --- diff --git a/docs/copy.rst b/docs/copy.rst index 95f8551d0..c9665cb80 100644 --- a/docs/copy.rst +++ b/docs/copy.rst @@ -99,19 +99,17 @@ you have to specify them yourselves. .. code:: python - from psycopg3.oids import builtins - with cur.copy("COPY (VALUES (10::int, current_date)) TO STDOUT") as copy: - copy.set_types([builtins["int4"].oid, builtins["date"].oid]) + copy.set_types(["int4", "date"]) for row in copy.rows(): print(row) # (10, datetime.date(2046, 12, 24)) .. admonition:: TODO - Document the `!builtins` register... but more likely do something - better such as allowing to pass type names, unifying `TypeRegistry` and - `AdaptContext`, none of which I have documented, so you haven't seen - anything... 👀 + Currently only builtin names are recognised; custom types must be + specified by numeric oid. This wll change after the `TypeRegistry` and + `AdaptContext` get integrated, none of which I have documented, so you + haven't seen anything... 👀 Copying block-by-block diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index d4b93869f..763243412 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple from . import pq from . import errors as e from .pq import ExecStatus +from .oids import builtins from .adapt import Format from .proto import ConnectionType, PQGen, Transformer from .generators import copy_from, copy_to, copy_end @@ -78,15 +79,29 @@ class BaseCopy(Generic[ConnectionType]): if self._finished: raise TypeError("copy blocks can be used only once") - def set_types(self, types: Sequence[int]) -> None: + def set_types(self, types: Sequence[Union[int, str]]) -> None: """ Set the types expected out of a :sql:`COPY TO` operation. Without setting the types, the data from :sql:`COPY TO` will be returned as unparsed strings or bytes. + + The types must be specified as a sequence of oid or PostgreSQL type + names (e.g. ``int4``, ``timestamptz[]``). + + .. admonition:: TODO + + Only builtin names are supprted for the moment. In order to specify + custom data types you must use their oid. + """ + # TODO: should allow names of non-builtin types + # Must put a types map on the context. + oids = [ + t if isinstance(t, int) else builtins.get_oid(t) for t in types + ] self.formatter.transformer.set_row_types( - types, [self.formatter.format] * len(types) + oids, [self.formatter.format] * len(types) ) # High level copy protocol generators (state change of the Copy object) diff --git a/psycopg3/psycopg3/oids.py b/psycopg3/psycopg3/oids.py index 95e67a999..d9eea7da0 100644 --- a/psycopg3/psycopg3/oids.py +++ b/psycopg3/psycopg3/oids.py @@ -62,6 +62,8 @@ class TypesRegistry: def __getitem__(self, key: Union[str, int]) -> TypeInfo: if isinstance(key, str): + if key.endswith("[]"): + key = key[:-2] return self._by_name[key] elif isinstance(key, int): return self._by_oid[key] @@ -76,6 +78,13 @@ class TypesRegistry: except KeyError: return None + def get_oid(self, name: str) -> int: + t = self[name] + if name.endswith("[]"): + return t.array_oid + else: + return t.oid + builtins = TypesRegistry() diff --git a/tests/test_copy.py b/tests/test_copy.py index ea12f1e00..cfed04954 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -83,25 +83,22 @@ def test_copy_out_iter(conn, format): @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) -def test_read_rows(conn, format): +@pytest.mark.parametrize("typetype", ["names", "oids"]) +def test_read_rows(conn, format, typetype): cur = conn.cursor() with cur.copy( - f"copy ({sample_values}) to stdout (format {format.name})" + f"""copy ( + select 10::int4, 'hello'::text, '{{0.0,1.0}}'::float8[] + ) to stdout (format {format.name})""" ) as copy: - # TODO: should be passed by name - # big refactoring to be had, to have builtins not global and merged - # to adaptation context I guess... - copy.set_types( - [builtins["int4"].oid, builtins["int4"].oid, builtins["text"].oid] - ) - rows = [] - while 1: - row = copy.read_row() - if not row: - break - rows.append(row) - - assert rows == sample_records + types = ["int4", "text", "float8[]"] + if typetype == "oids": + types = [builtins.get_oid(t) for t in types] + copy.set_types(types) + row = copy.read_row() + assert copy.read_row() is None + + assert row == (10, "hello", [0.0, 1.0]) assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS