from libc.stdint cimport int32_t, uint16_t, uint32_t
from libc.string cimport memcpy
+from cpython.tuple cimport PyTuple_GET_SIZE
from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_FromStringAndSize
from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_Resize
from cpython.memoryview cimport PyMemoryView_FromObject
row: Sequence[Any], tx: Transformer, out: bytearray = None
) -> bytearray:
"""Convert a row of adapted data to the data to send for binary copy"""
- cdef Py_ssize_t rowlen = len(row)
+ cdef Py_ssize_t rowlen
+ if type(row) is list:
+ rowlen = PyList_GET_SIZE(row)
+ elif type(row) is tuple:
+ rowlen = PyTuple_GET_SIZE(row)
+ else:
+ rowlen = len(row)
cdef uint16_t berowlen = endian.htobe16(<int16_t>rowlen)
cdef Py_ssize_t pos # offset in 'out' where to write
tx._row_dumpers = PyList_New(rowlen)
dumpers = tx._row_dumpers
+ if PyList_GET_SIZE(dumpers) != rowlen:
+ raise e.DataError(f"expected {len(dumpers)} values in row, got {rowlen}")
for i in range(rowlen):
item = row[i]
else:
pos = PyByteArray_GET_SIZE(out)
- cdef Py_ssize_t rowlen = len(row)
+ cdef Py_ssize_t rowlen
+ if type(row) is list:
+ rowlen = PyList_GET_SIZE(row)
+ elif type(row) is tuple:
+ rowlen = PyTuple_GET_SIZE(row)
+ else:
+ rowlen = len(row)
if rowlen == 0:
PyByteArray_Resize(out, pos + 1)
cdef PyObject *fmt = <PyObject *>PG_TEXT
cdef PyObject *row_dumper
+ # try to get preloaded dumpers from set_types
+ if not tx._row_dumpers:
+ tx._row_dumpers = PyList_New(rowlen)
+
+ dumpers = tx._row_dumpers
+ if PyList_GET_SIZE(dumpers) != rowlen:
+ raise e.DataError(f"expected {len(dumpers)} values in row, got {rowlen}")
+
for i in range(rowlen):
# Include the tab before the data, so it gets included in the resizes
with_tab = i > 0
_append_text_none(out, &pos, with_tab)
continue
- row_dumper = tx.get_row_dumper(<PyObject *>item, fmt)
+ row_dumper = PyList_GET_ITEM(dumpers, i)
+ if not row_dumper:
+ row_dumper = tx.get_row_dumper(<PyObject *>item, fmt)
+ Py_INCREF(<object>row_dumper)
+ PyList_SET_ITEM(dumpers, i, <object>row_dumper)
+
if (<RowDumper>row_dumper).cdumper is not None:
# A cdumper can resize if necessary and copy in place
size = (<RowDumper>row_dumper).cdumper.cdump(
self._row_loaders = loaders
def set_dumper_types(self, types: Sequence[int], format: PqFormat) -> None:
+ # NOTE: impl detail - final _row_dumpers must be a list type
+ # (assumed by format_row_binary and format_row_text)
cdef Py_ssize_t ntypes = len(types)
dumpers = PyList_New(ntypes)
cdef int i
const int MAXINT8LEN
+# global to hold types considered as integer python types
+# _IntOrSubclassDumper and _MixedNumericDumper ctors
+# change it to int or (int, numpy.integer) once
+_int_classes = None
+
+
cdef class _IntDumper(CDumper):
format = PQ_TEXT
format = PQ_TEXT
+ def __cinit__(self, cls, context: AdaptContext | None = None):
+ global _int_classes
+
+ if _int_classes is None:
+ if "numpy" in sys.modules:
+ import numpy
+ _int_classes = (int, numpy.integer)
+ else:
+ _int_classes = int
+
cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
return dump_int_or_sub_to_text(obj, rv, offset)
return dump_decimal_to_numeric_binary(obj, rv, offset)
-_int_classes = None
-
-
cdef class _MixedNumericDumper(CDumper):
oid = oids.NUMERIC_OID
# Ensure an int or a subclass. The 'is' type check is fast.
# Passing a float must give an error, but passing an Enum should work.
- if type(obj) is not int and not isinstance(obj, int):
+ if type(obj) is not int and not isinstance(obj, _int_classes):
raise e.DataError(f"integer expected, got {type(obj).__name__!r}")
val = PyLong_AsLongLongAndOverflow(obj, &overflow)
@pytest.mark.parametrize("format", pq.Format)
def test_set_types(conn, format):
+ sample = ({"foo": "bar"}, 123)
cur = conn.cursor()
- ensure_table(cur, "id serial primary key, data jsonb")
- with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy:
- copy.set_types(["jsonb"])
- copy.write_row([{"foo": "bar"}])
+ ensure_table(cur, "id serial primary key, data jsonb, data2 bigint")
+ with cur.copy(
+ f"copy copy_in (data, data2) from stdin (format {format.name})"
+ ) as copy:
+ copy.set_types(["jsonb", "bigint"])
+ copy.write_row(sample)
+ cur.execute("select data, data2 from copy_in")
+ data = cur.fetchone()
+ assert data == sample
+
+
+@pytest.mark.parametrize("format", pq.Format)
+@pytest.mark.parametrize("use_set_types", [True, False])
+def test_segfault_rowlen_mismatch(conn, format, use_set_types):
+ samples = [[123, 456], [123, 456, 789]]
+ cur = conn.cursor()
+ ensure_table(cur, "id serial primary key, data integer, data2 integer")
+ with pytest.raises(Exception):
+ with cur.copy(
+ f"copy copy_in (data, data2) from stdin (format {format.name})"
+ ) as copy:
+ if use_set_types:
+ copy.set_types(["integer", "integer"])
+ for row in samples:
+ copy.write_row(row)
def test_set_custom_type(conn, hstore):
@pytest.mark.parametrize("format", pq.Format)
async def test_set_types(aconn, format):
+ sample = ({"foo": "bar"}, 123)
cur = aconn.cursor()
- await ensure_table_async(cur, "id serial primary key, data jsonb")
+ await ensure_table_async(cur, "id serial primary key, data jsonb, data2 bigint")
async with cur.copy(
- f"copy copy_in (data) from stdin (format {format.name})"
+ f"copy copy_in (data, data2) from stdin (format {format.name})"
) as copy:
- copy.set_types(["jsonb"])
- await copy.write_row([{"foo": "bar"}])
+ copy.set_types(["jsonb", "bigint"])
+ await copy.write_row(sample)
+ await cur.execute("select data, data2 from copy_in")
+ data = await cur.fetchone()
+ assert data == sample
+
+
+@pytest.mark.parametrize("format", pq.Format)
+@pytest.mark.parametrize("use_set_types", [True, False])
+async def test_segfault_rowlen_mismatch(aconn, format, use_set_types):
+ samples = [[123, 456], [123, 456, 789]]
+ cur = aconn.cursor()
+ await ensure_table_async(cur, "id serial primary key, data integer, data2 integer")
+ with pytest.raises(Exception):
+ async with cur.copy(
+ f"copy copy_in (data, data2) from stdin (format {format.name})"
+ ) as copy:
+ if use_set_types:
+ copy.set_types(["integer", "integer"])
+ for row in samples:
+ await copy.write_row(row)
async def test_set_custom_type(aconn, hstore):