]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(c): respect the set_types() definitions in text format
authorJörg Breitbart <jerch@rockborn.de>
Thu, 4 Sep 2025 14:27:13 +0000 (16:27 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Oct 2025 23:54:10 +0000 (01:54 +0200)
Fix #1153

psycopg_c/psycopg_c/_psycopg/copy.pyx
psycopg_c/psycopg_c/_psycopg/transform.pyx
psycopg_c/psycopg_c/types/numeric.pyx
tests/test_copy.py
tests/test_copy_async.py

index b130f69d46834d3f3fd54c9e2d348d7d333054ff..d28e2e8a48545af697906be0a2c9ad183c12fe3b 100644 (file)
@@ -7,6 +7,7 @@ C optimised functions for the copy system.
 
 from libc.stdint cimport int32_t, uint16_t, uint32_t
 from libc.string cimport memcpy
+from cpython.tuple cimport PyTuple_GET_SIZE
 from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_FromStringAndSize
 from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_Resize
 from cpython.memoryview cimport PyMemoryView_FromObject
@@ -24,7 +25,13 @@ def format_row_binary(
     row: Sequence[Any], tx: Transformer, out: bytearray = None
 ) -> bytearray:
     """Convert a row of adapted data to the data to send for binary copy"""
-    cdef Py_ssize_t rowlen = len(row)
+    cdef Py_ssize_t rowlen
+    if type(row) is list:
+        rowlen = PyList_GET_SIZE(row)
+    elif type(row) is tuple:
+        rowlen = PyTuple_GET_SIZE(row)
+    else:
+        rowlen = len(row)
     cdef uint16_t berowlen = endian.htobe16(<int16_t>rowlen)
 
     cdef Py_ssize_t pos  # offset in 'out' where to write
@@ -51,6 +58,8 @@ def format_row_binary(
         tx._row_dumpers = PyList_New(rowlen)
 
     dumpers = tx._row_dumpers
+    if PyList_GET_SIZE(dumpers) != rowlen:
+        raise e.DataError(f"expected {len(dumpers)} values in row, got {rowlen}")
 
     for i in range(rowlen):
         item = row[i]
@@ -111,7 +120,13 @@ def format_row_text(
     else:
         pos = PyByteArray_GET_SIZE(out)
 
-    cdef Py_ssize_t rowlen = len(row)
+    cdef Py_ssize_t rowlen
+    if type(row) is list:
+        rowlen = PyList_GET_SIZE(row)
+    elif type(row) is tuple:
+        rowlen = PyTuple_GET_SIZE(row)
+    else:
+        rowlen = len(row)
 
     if rowlen == 0:
         PyByteArray_Resize(out, pos + 1)
@@ -127,6 +142,14 @@ def format_row_text(
     cdef PyObject *fmt = <PyObject *>PG_TEXT
     cdef PyObject *row_dumper
 
+    # try to get preloaded dumpers from set_types
+    if not tx._row_dumpers:
+        tx._row_dumpers = PyList_New(rowlen)
+
+    dumpers = tx._row_dumpers
+    if PyList_GET_SIZE(dumpers) != rowlen:
+        raise e.DataError(f"expected {len(dumpers)} values in row, got {rowlen}")
+
     for i in range(rowlen):
         # Include the tab before the data, so it gets included in the resizes
         with_tab = i > 0
@@ -136,7 +159,12 @@ def format_row_text(
             _append_text_none(out, &pos, with_tab)
             continue
 
-        row_dumper = tx.get_row_dumper(<PyObject *>item, fmt)
+        row_dumper = PyList_GET_ITEM(dumpers, i)
+        if not row_dumper:
+            row_dumper = tx.get_row_dumper(<PyObject *>item, fmt)
+            Py_INCREF(<object>row_dumper)
+            PyList_SET_ITEM(dumpers, i, <object>row_dumper)
+
         if (<RowDumper>row_dumper).cdumper is not None:
             # A cdumper can resize if necessary and copy in place
             size = (<RowDumper>row_dumper).cdumper.cdump(
index f699fdc6218b9d47f2c6de3dd0145d8cc26cb3fd..6f81f1f5b73dba15b502f07893a3dcf2573e0fef 100644 (file)
@@ -170,6 +170,8 @@ cdef class Transformer:
         self._row_loaders = loaders
 
     def set_dumper_types(self, types: Sequence[int], format: PqFormat) -> None:
+        # NOTE: impl detail - final _row_dumpers must be a list type
+        # (assumed by format_row_binary and format_row_text)
         cdef Py_ssize_t ntypes = len(types)
         dumpers = PyList_New(ntypes)
         cdef int i
index d807302f97bcae7de7365e96ae9a65957724e5ad..6cd79f0217f3e31d909af7b803dc1f962a66d0e8 100644 (file)
@@ -48,6 +48,12 @@ int pg_lltoa(int64_t value, char *a);
     const int MAXINT8LEN
 
 
+# global to hold types considered as integer python types
+# _IntOrSubclassDumper and _MixedNumericDumper ctors
+# change it to int or (int, numpy.integer) once
+_int_classes = None
+
+
 cdef class _IntDumper(CDumper):
 
     format = PQ_TEXT
@@ -74,6 +80,16 @@ cdef class _IntOrSubclassDumper(_IntDumper):
 
     format = PQ_TEXT
 
+    def __cinit__(self, cls, context: AdaptContext | None = None):
+        global _int_classes
+
+        if _int_classes is None:
+            if "numpy" in sys.modules:
+                import numpy
+                _int_classes = (int, numpy.integer)
+            else:
+                _int_classes = int
+
     cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
         return dump_int_or_sub_to_text(obj, rv, offset)
 
@@ -521,9 +537,6 @@ cdef class DecimalBinaryDumper(CDumper):
         return dump_decimal_to_numeric_binary(obj, rv, offset)
 
 
-_int_classes = None
-
-
 cdef class _MixedNumericDumper(CDumper):
 
     oid = oids.NUMERIC_OID
@@ -738,7 +751,7 @@ cdef Py_ssize_t dump_int_or_sub_to_text(
 
     # Ensure an int or a subclass. The 'is' type check is fast.
     # Passing a float must give an error, but passing an Enum should work.
-    if type(obj) is not int and not isinstance(obj, int):
+    if type(obj) is not int and not isinstance(obj, _int_classes):
         raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
 
     val = PyLong_AsLongLongAndOverflow(obj, &overflow)
index 185eb72ff6530bc376bff9b486b8ab747403e93c..ccd9dadd9152144920a6768b103a19174eeffd63 100644 (file)
@@ -118,11 +118,33 @@ def test_rows(conn, format):
 
 @pytest.mark.parametrize("format", pq.Format)
 def test_set_types(conn, format):
+    sample = ({"foo": "bar"}, 123)
     cur = conn.cursor()
-    ensure_table(cur, "id serial primary key, data jsonb")
-    with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy:
-        copy.set_types(["jsonb"])
-        copy.write_row([{"foo": "bar"}])
+    ensure_table(cur, "id serial primary key, data jsonb, data2 bigint")
+    with cur.copy(
+        f"copy copy_in (data, data2) from stdin (format {format.name})"
+    ) as copy:
+        copy.set_types(["jsonb", "bigint"])
+        copy.write_row(sample)
+    cur.execute("select data, data2 from copy_in")
+    data = cur.fetchone()
+    assert data == sample
+
+
+@pytest.mark.parametrize("format", pq.Format)
+@pytest.mark.parametrize("use_set_types", [True, False])
+def test_segfault_rowlen_mismatch(conn, format, use_set_types):
+    samples = [[123, 456], [123, 456, 789]]
+    cur = conn.cursor()
+    ensure_table(cur, "id serial primary key, data integer, data2 integer")
+    with pytest.raises(Exception):
+        with cur.copy(
+            f"copy copy_in (data, data2) from stdin (format {format.name})"
+        ) as copy:
+            if use_set_types:
+                copy.set_types(["integer", "integer"])
+            for row in samples:
+                copy.write_row(row)
 
 
 def test_set_custom_type(conn, hstore):
index a4cc4d744e20ada51e20139b1c8903afdb87f2ce..775d4d9cebe6088bc13e8f20f00750f9644f75e6 100644 (file)
@@ -122,13 +122,33 @@ async def test_rows(aconn, format):
 
 @pytest.mark.parametrize("format", pq.Format)
 async def test_set_types(aconn, format):
+    sample = ({"foo": "bar"}, 123)
     cur = aconn.cursor()
-    await ensure_table_async(cur, "id serial primary key, data jsonb")
+    await ensure_table_async(cur, "id serial primary key, data jsonb, data2 bigint")
     async with cur.copy(
-        f"copy copy_in (data) from stdin (format {format.name})"
+        f"copy copy_in (data, data2) from stdin (format {format.name})"
     ) as copy:
-        copy.set_types(["jsonb"])
-        await copy.write_row([{"foo": "bar"}])
+        copy.set_types(["jsonb", "bigint"])
+        await copy.write_row(sample)
+    await cur.execute("select data, data2 from copy_in")
+    data = await cur.fetchone()
+    assert data == sample
+
+
+@pytest.mark.parametrize("format", pq.Format)
+@pytest.mark.parametrize("use_set_types", [True, False])
+async def test_segfault_rowlen_mismatch(aconn, format, use_set_types):
+    samples = [[123, 456], [123, 456, 789]]
+    cur = aconn.cursor()
+    await ensure_table_async(cur, "id serial primary key, data integer, data2 integer")
+    with pytest.raises(Exception):
+        async with cur.copy(
+            f"copy copy_in (data, data2) from stdin (format {format.name})"
+        ) as copy:
+            if use_set_types:
+                copy.set_types(["integer", "integer"])
+            for row in samples:
+                await copy.write_row(row)
 
 
 async def test_set_custom_type(aconn, hstore):