]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Drop TODO point now working
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Aug 2021 15:20:09 +0000 (17:20 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Aug 2021 15:54:03 +0000 (17:54 +0200)
Add tests to prove it works.

psycopg/psycopg/copy.py
tests/fix_db.py
tests/test_copy.py
tests/test_copy_async.py
tests/types/test_hstore.py

index 1d52451dfe00e43364e405d2293362a5d6941d6a..0e889dacb7a560dcd5bb0046cca11f24cc116b08 100644 (file)
@@ -91,12 +91,6 @@ class BaseCopy(Generic[ConnectionType]):
 
         The types must be specified as a sequence of oid or PostgreSQL type
         names (e.g. ``int4``, ``timestamptz[]``).
-
-        .. admonition:: TODO
-
-            Only builtin names are supprted for the moment. In order to specify
-            custom data types you must use their oid.
-
         """
         registry = self.cursor.adapters.types
         oids = [
index d826b2a5f61327644f1906fcc189d6f7a9df8d07..6d301dc5fc374def235314dca20ba58e69dfd077 100644 (file)
@@ -138,3 +138,14 @@ def check_connection_version(got, function):
     want = [m.args[0] for m in function.pytestmark if m.name == "pg"]
     if want:
         return check_server_version(got, want[0])
+
+
+@pytest.fixture
+def hstore(svcconn):
+    from psycopg import Error
+
+    try:
+        with svcconn.transaction():
+            svcconn.execute("create extension if not exists hstore")
+    except Error as e:
+        pytest.skip(str(e))
index 0f794584e9f696ea93bfe8df5e56b6779686e737..f0d15dcf8a3cf01ba564883ba30a21e932f063cd 100644 (file)
@@ -12,6 +12,8 @@ from psycopg import sql
 from psycopg import errors as e
 from psycopg.pq import Format
 from psycopg.adapt import PyFormat as PgFormat
+from psycopg.types import TypeInfo
+from psycopg.types.hstore import register_adapters as register_hstore
 from psycopg.types.numeric import Int4
 
 from .utils import gc_collect
@@ -116,6 +118,23 @@ def test_rows(conn, format):
     assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
 
 
+def test_set_custom_type(conn, hstore):
+    command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout"""
+    cur = conn.cursor()
+
+    with cur.copy(command) as copy:
+        rows = list(copy.rows())
+
+    assert rows == [('"a"=>"1", "b"=>"2"',)]
+
+    register_hstore(TypeInfo.fetch(conn, "hstore"), cur)
+    with cur.copy(command) as copy:
+        copy.set_types(["hstore"])
+        rows = list(copy.rows())
+
+    assert rows == [({"a": "1", "b": "2"},)]
+
+
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 def test_copy_out_allchars(conn, format):
     cur = conn.cursor()
index 7c755ba0e044cc48d44dd78251e7b0f99af4e1ef..2efb7a1a36440c343568f3d2b3f460834ff5ca40 100644 (file)
@@ -11,7 +11,9 @@ from psycopg import pq
 from psycopg import sql
 from psycopg import errors as e
 from psycopg.pq import Format
+from psycopg.types import TypeInfo
 from psycopg.adapt import PyFormat as PgFormat
+from psycopg.types.hstore import register_adapters as register_hstore
 
 from .utils import gc_collect
 from .test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
@@ -99,6 +101,23 @@ async def test_rows(aconn, format):
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
 
 
+async def test_set_custom_type(aconn, hstore):
+    command = """copy (select '"a"=>"1", "b"=>"2"'::hstore) to stdout"""
+    cur = aconn.cursor()
+
+    async with cur.copy(command) as copy:
+        rows = [row async for row in copy.rows()]
+
+    assert rows == [('"a"=>"1", "b"=>"2"',)]
+
+    register_hstore(await TypeInfo.fetch_async(aconn, "hstore"), cur)
+    async with cur.copy(command) as copy:
+        copy.set_types(["hstore"])
+        rows = [row async for row in copy.rows()]
+
+    assert rows == [({"a": "1", "b": "2"},)]
+
+
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 async def test_copy_out_allchars(aconn, format):
     cur = aconn.cursor()
index e52aae7e75c599b52d40b0045ad5ca16f60a2b30..f81a492912eef9913cf5179d4a79c7ebc59d2ce9 100644 (file)
@@ -97,12 +97,3 @@ def test_roundtrip_array(hstore, conn):
     register_adapters(TypeInfo.fetch(conn, "hstore"), conn)
     samp1 = conn.execute("select %s", (samp,)).fetchone()[0]
     assert samp1 == samp
-
-
-@pytest.fixture
-def hstore(svcconn):
-    try:
-        with svcconn.transaction():
-            svcconn.execute("create extension if not exists hstore")
-    except psycopg.Error as e:
-        pytest.skip(str(e))