From bd2ca1d3db13a46969b68522a1c6824bbfc3bf26 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=B6rg=20Breitbart?= Date: Thu, 4 Sep 2025 16:27:13 +0200 Subject: [PATCH] fix(c): respect the set_types() definitions in text format Fix #1153 --- psycopg_c/psycopg_c/_psycopg/copy.pyx | 34 ++++++++++++++++++++-- psycopg_c/psycopg_c/_psycopg/transform.pyx | 2 ++ psycopg_c/psycopg_c/types/numeric.pyx | 21 ++++++++++--- tests/test_copy.py | 30 ++++++++++++++++--- tests/test_copy_async.py | 28 +++++++++++++++--- 5 files changed, 100 insertions(+), 15 deletions(-) diff --git a/psycopg_c/psycopg_c/_psycopg/copy.pyx b/psycopg_c/psycopg_c/_psycopg/copy.pyx index b130f69d4..d28e2e8a4 100644 --- a/psycopg_c/psycopg_c/_psycopg/copy.pyx +++ b/psycopg_c/psycopg_c/_psycopg/copy.pyx @@ -7,6 +7,7 @@ C optimised functions for the copy system. from libc.stdint cimport int32_t, uint16_t, uint32_t from libc.string cimport memcpy +from cpython.tuple cimport PyTuple_GET_SIZE from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_FromStringAndSize from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_Resize from cpython.memoryview cimport PyMemoryView_FromObject @@ -24,7 +25,13 @@ def format_row_binary( row: Sequence[Any], tx: Transformer, out: bytearray = None ) -> bytearray: """Convert a row of adapted data to the data to send for binary copy""" - cdef Py_ssize_t rowlen = len(row) + cdef Py_ssize_t rowlen + if type(row) is list: + rowlen = PyList_GET_SIZE(row) + elif type(row) is tuple: + rowlen = PyTuple_GET_SIZE(row) + else: + rowlen = len(row) cdef uint16_t berowlen = endian.htobe16(rowlen) cdef Py_ssize_t pos # offset in 'out' where to write @@ -51,6 +58,8 @@ def format_row_binary( tx._row_dumpers = PyList_New(rowlen) dumpers = tx._row_dumpers + if PyList_GET_SIZE(dumpers) != rowlen: + raise e.DataError(f"expected {len(dumpers)} values in row, got {rowlen}") for i in range(rowlen): item = row[i] @@ -111,7 +120,13 @@ def format_row_text( else: pos = PyByteArray_GET_SIZE(out) - cdef Py_ssize_t rowlen = len(row) + cdef Py_ssize_t rowlen + if type(row) is list: + rowlen = PyList_GET_SIZE(row) + elif type(row) is tuple: + rowlen = PyTuple_GET_SIZE(row) + else: + rowlen = len(row) if rowlen == 0: PyByteArray_Resize(out, pos + 1) @@ -127,6 +142,14 @@ def format_row_text( cdef PyObject *fmt = PG_TEXT cdef PyObject *row_dumper + # try to get preloaded dumpers from set_types + if not tx._row_dumpers: + tx._row_dumpers = PyList_New(rowlen) + + dumpers = tx._row_dumpers + if PyList_GET_SIZE(dumpers) != rowlen: + raise e.DataError(f"expected {len(dumpers)} values in row, got {rowlen}") + for i in range(rowlen): # Include the tab before the data, so it gets included in the resizes with_tab = i > 0 @@ -136,7 +159,12 @@ def format_row_text( _append_text_none(out, &pos, with_tab) continue - row_dumper = tx.get_row_dumper(item, fmt) + row_dumper = PyList_GET_ITEM(dumpers, i) + if not row_dumper: + row_dumper = tx.get_row_dumper(item, fmt) + Py_INCREF(row_dumper) + PyList_SET_ITEM(dumpers, i, row_dumper) + if (row_dumper).cdumper is not None: # A cdumper can resize if necessary and copy in place size = (row_dumper).cdumper.cdump( diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx index f699fdc62..6f81f1f5b 100644 --- a/psycopg_c/psycopg_c/_psycopg/transform.pyx +++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx @@ -170,6 +170,8 @@ cdef class Transformer: self._row_loaders = loaders def set_dumper_types(self, types: Sequence[int], format: PqFormat) -> None: + # NOTE: impl detail - final _row_dumpers must be a list type + # (assumed by format_row_binary and format_row_text) cdef Py_ssize_t ntypes = len(types) dumpers = PyList_New(ntypes) cdef int i diff --git a/psycopg_c/psycopg_c/types/numeric.pyx b/psycopg_c/psycopg_c/types/numeric.pyx index d807302f9..6cd79f021 100644 --- a/psycopg_c/psycopg_c/types/numeric.pyx +++ b/psycopg_c/psycopg_c/types/numeric.pyx @@ -48,6 +48,12 @@ int pg_lltoa(int64_t value, char *a); const int MAXINT8LEN +# global to hold types considered as integer python types +# _IntOrSubclassDumper and _MixedNumericDumper ctors +# change it to int or (int, numpy.integer) once +_int_classes = None + + cdef class _IntDumper(CDumper): format = PQ_TEXT @@ -74,6 +80,16 @@ cdef class _IntOrSubclassDumper(_IntDumper): format = PQ_TEXT + def __cinit__(self, cls, context: AdaptContext | None = None): + global _int_classes + + if _int_classes is None: + if "numpy" in sys.modules: + import numpy + _int_classes = (int, numpy.integer) + else: + _int_classes = int + cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: return dump_int_or_sub_to_text(obj, rv, offset) @@ -521,9 +537,6 @@ cdef class DecimalBinaryDumper(CDumper): return dump_decimal_to_numeric_binary(obj, rv, offset) -_int_classes = None - - cdef class _MixedNumericDumper(CDumper): oid = oids.NUMERIC_OID @@ -738,7 +751,7 @@ cdef Py_ssize_t dump_int_or_sub_to_text( # Ensure an int or a subclass. The 'is' type check is fast. # Passing a float must give an error, but passing an Enum should work. - if type(obj) is not int and not isinstance(obj, int): + if type(obj) is not int and not isinstance(obj, _int_classes): raise e.DataError(f"integer expected, got {type(obj).__name__!r}") val = PyLong_AsLongLongAndOverflow(obj, &overflow) diff --git a/tests/test_copy.py b/tests/test_copy.py index 185eb72ff..ccd9dadd9 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -118,11 +118,33 @@ def test_rows(conn, format): @pytest.mark.parametrize("format", pq.Format) def test_set_types(conn, format): + sample = ({"foo": "bar"}, 123) cur = conn.cursor() - ensure_table(cur, "id serial primary key, data jsonb") - with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy: - copy.set_types(["jsonb"]) - copy.write_row([{"foo": "bar"}]) + ensure_table(cur, "id serial primary key, data jsonb, data2 bigint") + with cur.copy( + f"copy copy_in (data, data2) from stdin (format {format.name})" + ) as copy: + copy.set_types(["jsonb", "bigint"]) + copy.write_row(sample) + cur.execute("select data, data2 from copy_in") + data = cur.fetchone() + assert data == sample + + +@pytest.mark.parametrize("format", pq.Format) +@pytest.mark.parametrize("use_set_types", [True, False]) +def test_segfault_rowlen_mismatch(conn, format, use_set_types): + samples = [[123, 456], [123, 456, 789]] + cur = conn.cursor() + ensure_table(cur, "id serial primary key, data integer, data2 integer") + with pytest.raises(Exception): + with cur.copy( + f"copy copy_in (data, data2) from stdin (format {format.name})" + ) as copy: + if use_set_types: + copy.set_types(["integer", "integer"]) + for row in samples: + copy.write_row(row) def test_set_custom_type(conn, hstore): diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index a4cc4d744..775d4d9ce 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -122,13 +122,33 @@ async def test_rows(aconn, format): @pytest.mark.parametrize("format", pq.Format) async def test_set_types(aconn, format): + sample = ({"foo": "bar"}, 123) cur = aconn.cursor() - await ensure_table_async(cur, "id serial primary key, data jsonb") + await ensure_table_async(cur, "id serial primary key, data jsonb, data2 bigint") async with cur.copy( - f"copy copy_in (data) from stdin (format {format.name})" + f"copy copy_in (data, data2) from stdin (format {format.name})" ) as copy: - copy.set_types(["jsonb"]) - await copy.write_row([{"foo": "bar"}]) + copy.set_types(["jsonb", "bigint"]) + await copy.write_row(sample) + await cur.execute("select data, data2 from copy_in") + data = await cur.fetchone() + assert data == sample + + +@pytest.mark.parametrize("format", pq.Format) +@pytest.mark.parametrize("use_set_types", [True, False]) +async def test_segfault_rowlen_mismatch(aconn, format, use_set_types): + samples = [[123, 456], [123, 456, 789]] + cur = aconn.cursor() + await ensure_table_async(cur, "id serial primary key, data integer, data2 integer") + with pytest.raises(Exception): + async with cur.copy( + f"copy copy_in (data, data2) from stdin (format {format.name})" + ) as copy: + if use_set_types: + copy.set_types(["integer", "integer"]) + for row in samples: + await copy.write_row(row) async def test_set_custom_type(aconn, hstore): -- 2.47.3