]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(copy): don't create a row maker on copy
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 15 Dec 2022 11:03:44 +0000 (11:03 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 15 Dec 2022 11:03:44 +0000 (11:03 +0000)
A COPY_OUT result has columns, but no names for the columns. This case
must be handled in cur.description (see #235) but we don't need to
handle it in copy. If we did handle it in copy, we would need a column
name fallback, which we forgot to handle, hence the problem in #460.

Close #460.

docs/news.rst
psycopg/psycopg/rows.py
tests/test_copy.py
tests/test_copy_async.py

index f26d4007bf4c6abd473ef3f80f8c2e0f6cf3bfba..318c34a88b25833f42d078a3417f17a6652cb5cc 100644 (file)
@@ -7,6 +7,12 @@
 ``psycopg`` release notes
 =========================
 
+Psycopg 3.1.6 (unreleased)
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+- Fix `cursor.copy()` with cursors using row factories (:ticket:`#460`).
+
+
 Current release
 ---------------
 
index 3bd921577003101d39d24504b01ac8e3730e20e1..4f96a1af00ad37e6cf56db135731e1880ecb4144 100644 (file)
@@ -244,7 +244,7 @@ def _get_nfields(res: "PGresult") -> Optional[int]:
     """
     nfields = res.nfields
 
-    if nfields or res.status == TUPLES_OK or res.status == SINGLE_TUPLE:
+    if res.status == TUPLES_OK or res.status == SINGLE_TUPLE:
         return nfields
     else:
         return None
index 74e190fceae42e70d1223f6737956cd8d48b0cc7..17cf2fc786b5b49cd6a3139a8de773273a271b5d 100644 (file)
@@ -71,19 +71,31 @@ def test_copy_out_read(conn, format):
 
 
 @pytest.mark.parametrize("format", Format)
-def test_copy_out_iter(conn, format):
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_copy_out_iter(conn, format, row_factory):
     if format == pq.Format.TEXT:
         want = [row + b"\n" for row in sample_text.splitlines()]
     else:
         want = sample_binary_rows
 
-    cur = conn.cursor()
+    rf = getattr(psycopg.rows, row_factory)
+    cur = conn.cursor(row_factory=rf)
     with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
         assert list(copy) == want
 
     assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
 
 
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_copy_out_no_result(conn, format, row_factory):
+    rf = getattr(psycopg.rows, row_factory)
+    cur = conn.cursor(row_factory=rf)
+    with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"):
+        with pytest.raises(e.ProgrammingError):
+            cur.fetchone()
+
+
 @pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
 def test_copy_out_param(conn, ph, params):
     cur = conn.cursor()
index 9b926a21322d053ca669f54fb7d4cc6938976599..59389dd7368cf43925567a24a3ae837748d44252 100644 (file)
@@ -53,13 +53,15 @@ async def test_copy_out_read(aconn, format):
 
 
 @pytest.mark.parametrize("format", Format)
-async def test_copy_out_iter(aconn, format):
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_copy_out_iter(aconn, format, row_factory):
     if format == pq.Format.TEXT:
         want = [row + b"\n" for row in sample_text.splitlines()]
     else:
         want = sample_binary_rows
 
-    cur = aconn.cursor()
+    rf = getattr(psycopg.rows, row_factory)
+    cur = aconn.cursor(row_factory=rf)
     async with cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
@@ -68,6 +70,16 @@ async def test_copy_out_iter(aconn, format):
     assert aconn.info.transaction_status == aconn.TransactionStatus.INTRANS
 
 
+@pytest.mark.parametrize("format", Format)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_copy_out_no_result(aconn, format, row_factory):
+    rf = getattr(psycopg.rows, row_factory)
+    cur = aconn.cursor(row_factory=rf)
+    async with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})"):
+        with pytest.raises(e.ProgrammingError):
+            await cur.fetchone()
+
+
 @pytest.mark.parametrize("ph, params", [("%s", (10,)), ("%(n)s", {"n": 10})])
 async def test_copy_out_param(aconn, ph, params):
     cur = aconn.cursor()