]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added two-steps dumpers implementation in C too
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 20 Jan 2021 17:25:05 +0000 (18:25 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 Jan 2021 03:11:34 +0000 (04:11 +0100)
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/numeric.py
psycopg3/psycopg3/wrappers/__init__.py [new file with mode: 0644]
psycopg3/psycopg3/wrappers/numeric.py [new file with mode: 0644]
psycopg3_c/psycopg3_c/_psycopg3/adapt.pyx
psycopg3_c/psycopg3_c/_psycopg3/transform.pyx
psycopg3_c/psycopg3_c/types/numeric.pyx
tests/test_adapt.py

index a8e956fa20b1ffa7c81c86359accb3a9fdc24155..3663d8ec157924ee34d61554aba5f131f73e39ac 100644 (file)
@@ -150,7 +150,7 @@ class Transformer(AdaptContext):
         if key1 is key:
             return dumper
 
-        # If it doesn't ask the dumper to create its own upgraded version
+        # If it does, ask the dumper to create its own upgraded version
         try:
             return cache[key1]
         except KeyError:
index 7c9dc2ee9dd1778d45d16fe5ae591c7fb2efc15a..69b9ef51e6e909689dbfc1833f132d5da9500eca 100644 (file)
@@ -12,7 +12,7 @@ from . import array, composite
 from . import range
 
 # Wrapper objects
-from .numeric import Int2, Int4, Int8, IntNumeric, Oid
+from ..wrappers.numeric import Int2, Int4, Int8, IntNumeric, Oid
 from .json import Json, Jsonb
 from .range import Range, Int4Range, Int8Range, DecimalRange
 from .range import DateRange, DateTimeRange, DateTimeTZRange
index 803173d670dce1fc8eaf8dc813d605731dca5817..13c02252d5b3e8f0d543a83a25958fbecee0b8f4 100644 (file)
@@ -12,6 +12,7 @@ from ..pq import Format
 from ..oids import builtins
 from ..adapt import Buffer, Dumper, Loader
 from ..adapt import Format as Pg3Format
+from ..wrappers.numeric import Int2, Int4, Int8, IntNumeric
 
 _PackInt = Callable[[int], bytes]
 _PackFloat = Callable[[float], bytes]
@@ -34,31 +35,6 @@ _unpack_float8 = cast(_UnpackFloat, struct.Struct("!d").unpack)
 # Wrappers to force numbers to be cast as specific PostgreSQL types
 
 
-class Int2(int):
-    def __new__(cls, arg: int) -> "Int2":
-        return super().__new__(cls, arg)  # type: ignore
-
-
-class Int4(int):
-    def __new__(cls, arg: int) -> "Int4":
-        return super().__new__(cls, arg)  # type: ignore
-
-
-class Int8(int):
-    def __new__(cls, arg: int) -> "Int8":
-        return super().__new__(cls, arg)  # type: ignore
-
-
-class IntNumeric(int):
-    def __new__(cls, arg: int) -> "IntNumeric":
-        return super().__new__(cls, arg)  # type: ignore
-
-
-class Oid(int):
-    def __new__(cls, arg: int) -> "Oid":
-        return super().__new__(cls, arg)  # type: ignore
-
-
 class NumberDumper(Dumper):
 
     format = Format.TEXT
@@ -142,20 +118,12 @@ class IntDumper(Dumper):
 
     def dump(self, obj: Any) -> bytes:
         raise TypeError(
-            "dispatcher to find the int subclass: not supposed to be called"
+            f"{type(self).__name__} is a dispatcher to other dumpers:"
+            " dump() is not supposed to be called"
         )
 
-    def get_key(cls, obj: int, format: Pg3Format) -> type:
-        if -(2 ** 31) <= obj < 2 ** 31:
-            if -(2 ** 15) <= obj < 2 ** 15:
-                return Int2
-            else:
-                return Int4
-        else:
-            if -(2 ** 63) <= obj < 2 ** 63:
-                return Int8
-            else:
-                return IntNumeric
+    def get_key(self, obj: int, format: Pg3Format) -> type:
+        return self.upgrade(obj, format).cls
 
     _int2_dumper = Int2Dumper(Int2)
     _int4_dumper = Int4Dumper(Int4)
diff --git a/psycopg3/psycopg3/wrappers/__init__.py b/psycopg3/psycopg3/wrappers/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/psycopg3/psycopg3/wrappers/numeric.py b/psycopg3/psycopg3/wrappers/numeric.py
new file mode 100644 (file)
index 0000000..a60106b
--- /dev/null
@@ -0,0 +1,30 @@
+"""
+Wrappers to force numbers to be cast as specific PostgreSQL types
+"""
+
+# Copyright (C) 2020-2021 The Psycopg Team
+
+
+class Int2(int):
+    def __new__(cls, arg: int) -> "Int2":
+        return super().__new__(cls, arg)  # type: ignore
+
+
+class Int4(int):
+    def __new__(cls, arg: int) -> "Int4":
+        return super().__new__(cls, arg)  # type: ignore
+
+
+class Int8(int):
+    def __new__(cls, arg: int) -> "Int8":
+        return super().__new__(cls, arg)  # type: ignore
+
+
+class IntNumeric(int):
+    def __new__(cls, arg: int) -> "IntNumeric":
+        return super().__new__(cls, arg)  # type: ignore
+
+
+class Oid(int):
+    def __new__(cls, arg: int) -> "Oid":
+        return super().__new__(cls, arg)  # type: ignore
index 158f48fafadc827ed4975c917f165563046095a0..1a1c188c6c2e3f4d98b66cb55aceaa5980b5cd13 100644 (file)
@@ -31,7 +31,7 @@ logger = logging.getLogger("psycopg3.adapt")
 
 @cython.freelist(8)
 cdef class CDumper:
-    cdef object cls
+    cdef readonly object cls
     cdef public libpq.Oid oid
     cdef pq.PGconn _pgconn
 
@@ -98,6 +98,12 @@ cdef class CDumper:
 
         return rv
 
+    cdef object get_key(self, object obj, object format):
+        return self.cls
+
+    cdef object upgrade(self, object obj, object format):
+        return self
+
     @classmethod
     def register(
         this_cls,
index 68cf5edb8ab25c0c62a1883b8f321d2077375450..f92c94d50649c620d7200d01c2242c0beb944096 100644 (file)
@@ -158,14 +158,13 @@ cdef class Transformer:
         # Fast path: return a Dumper class already instantiated from the same type
         cdef PyObject *cache
         cdef PyObject *ptr
+        cdef PyObject *ptr1
+        cdef RowDumper row_dumper
 
-        cls = type(<object>obj)
-        if cls is not list:
-            key = cls
-        else:
-            subobj = _find_list_element(obj, set())
-            key = (cls, type(subobj))
+        # Normally, the type of the object dictates how to dump it
+        key = type(<object>obj)
 
+        # Establish where would the dumper be cached
         bfmt = PyUnicode_AsUTF8String(<object>fmt)
         cdef char cfmt = PyBytes_AS_STRING(bfmt)[0]
         if cfmt == b's':
@@ -184,50 +183,40 @@ cdef class Transformer:
             raise ValueError(
                 f"format should be a psycopg3.adapt.Format, not {<object>fmt}")
 
+        # Reuse an existing Dumper class for objects of the same type
         ptr = PyDict_GetItem(<object>cache, key)
-        if ptr != NULL:
+        if ptr == NULL:
+            dcls = PyObject_CallFunctionObjArgs(
+                self.adapters.get_dumper, <PyObject *>key, fmt, NULL)
+            dumper = PyObject_CallFunctionObjArgs(
+                dcls, <PyObject *>key, <PyObject *>self, NULL)
+
+            row_dumper = _as_row_dumper(dumper)
+            PyDict_SetItem(<object>cache, key, row_dumper)
+            ptr = <PyObject *>row_dumper
+
+        # Check if the dumper requires an upgrade to handle this specific value
+        if (<RowDumper>ptr).cdumper is not None:
+            key1 = (<RowDumper>ptr).cdumper.get_key(<object>obj, <object>fmt)
+        else:
+            key1 = PyObject_CallFunctionObjArgs(
+                (<RowDumper>ptr).pydumper.get_key, obj, fmt, NULL)
+        if key1 is key:
             return ptr
 
-        # When dumping a string with %s we may refer to any type actually,
-        # but the user surely passed a text format
-        if cls is str and cfmt == b's':
-            fmt = <PyObject *>PG_TEXT
-
-        cdef PyObject *sub_dumper = NULL
-        if cls is list:
-            # It's not possible to declare an empty unknown array, so force text
-            if subobj is None:
-                fmt = <PyObject *>PG_TEXT
-
-            # If we are dumping a list it's the sub-object which should dictate
-            # what format to use.
-            else:
-                sub_dumper = self.get_row_dumper(<PyObject *>subobj, fmt)
-                tmp = Pg3Format.from_pq((<RowDumper>sub_dumper).format)
-                fmt = <PyObject *>tmp
-
-        dcls = PyObject_CallFunctionObjArgs(
-            self.adapters.get_dumper, <PyObject *>cls, fmt, NULL)
-        if dcls is None:
-            raise e.ProgrammingError(
-                f"cannot adapt type {cls.__name__}"
-                f" to format {Pg3Format(<object>fmt).name}")
-
-        dumper = PyObject_CallFunctionObjArgs(
-            dcls, <PyObject *>cls, <PyObject *>self, NULL)
-        if sub_dumper != NULL:
-            dumper.set_sub_dumper((<RowDumper>sub_dumper).pydumper)
-
-        cdef RowDumper row_dumper = RowDumper()
+        # If it does, ask the dumper to create its own upgraded version
+        ptr1 = PyDict_GetItem(<object>cache, key1)
+        if ptr1 != NULL:
+            return ptr1
 
-        row_dumper.pydumper = dumper
-        row_dumper.dumpfunc = dumper.dump
-        row_dumper.oid = dumper.oid
-        row_dumper.format = dumper.format
-        if isinstance(dumper, CDumper):
-            row_dumper.cdumper = <CDumper>dumper
+        if (<RowDumper>ptr).cdumper is not None:
+            dumper = (<RowDumper>ptr).cdumper.upgrade(<object>obj, <object>fmt)
+        else:
+            dumper = PyObject_CallFunctionObjArgs(
+                (<RowDumper>ptr).pydumper.upgrade, obj, fmt, NULL)
 
-        PyDict_SetItem(<object>cache, key, row_dumper)
+        row_dumper = _as_row_dumper(dumper)
+        PyDict_SetItem(<object>cache, key1, row_dumper)
         return <PyObject *>row_dumper
 
     cpdef dump_sequence(self, object params, object formats):
@@ -467,25 +456,15 @@ cdef class Transformer:
         return <PyObject *>row_loader
 
 
-cdef object _find_list_element(PyObject *L, object seen):
-    """
-    Find the first non-null element of an eventually nested list
-    """
-    cdef object list_id = <long><PyObject *>L
-    if PySet_Contains(seen, list_id):
-        raise e.DataError("cannot dump a recursive list")
-
-    PySet_Add(seen, list_id)
-
-    cdef int i
-    cdef PyObject *it
-    for i in range(PyList_GET_SIZE(<object>L)):
-        it = PyList_GET_ITEM(<object>L, i)
-        if PyList_CheckExact(<object>it):
-            subit = _find_list_element(it, seen)
-            if subit is not None:
-                return subit
-        elif <object>it is not None:
-            return <object>it
-
-    return None
+cdef object _as_row_dumper(object dumper):
+    cdef RowDumper row_dumper = RowDumper()
+
+    row_dumper.pydumper = dumper
+    row_dumper.dumpfunc = dumper.dump
+    row_dumper.oid = dumper.oid
+    row_dumper.format = dumper.format
+
+    if isinstance(dumper, CDumper):
+        row_dumper.cdumper = <CDumper>dumper
+
+    return row_dumper
index c01f80c4660da79f2af81409ae84728343079447..11786afc3006cb92dd68b614cbc5eb2f652ef41b 100644 (file)
@@ -9,12 +9,15 @@ cimport cython
 from libc.stdint cimport *
 from libc.string cimport memcpy, strlen
 from cpython.mem cimport PyMem_Free
-from cpython.long cimport PyLong_FromString, PyLong_FromLong, PyLong_AsLongLong
-from cpython.long cimport PyLong_FromLongLong, PyLong_FromUnsignedLong
+from cpython.long cimport (
+    PyLong_FromString, PyLong_FromLong, PyLong_FromLongLong,
+    PyLong_FromUnsignedLong, PyLong_AsLongLong)
+from cpython.bytes cimport PyBytes_AsStringAndSize
 from cpython.float cimport PyFloat_FromDouble, PyFloat_AsDouble
 
-from psycopg3_c._psycopg3.endian cimport (
-    be16toh, be32toh, be64toh, htobe32, htobe64)
+from psycopg3_c._psycopg3 cimport endian
+
+from psycopg3.wrappers.numeric import Int2, Int4, Int8, IntNumeric
 
 cdef extern from "Python.h":
     # work around https://github.com/cython/cython/issues/3909
@@ -25,6 +28,7 @@ cdef extern from "Python.h":
     ) except NULL
     int PyOS_snprintf(char *str, size_t size, const char *format, ...)
     int Py_DTSF_ADD_DOT_0
+    long long PyLong_AsLongLongAndOverflow(object pylong, int *overflow) except? -1
 
 
 # defined in numutils.c
@@ -36,19 +40,29 @@ int pg_lltoa(int64_t value, char *a);
 
 DEF MAXINT8LEN = 20
 
-# @cython.final  # TODO? causes compile warnings
-cdef class IntDumper(CDumper):
 
-    format = PQ_TEXT
+cdef class _NumberDumper(CDumper):
 
-    def __cinit__(self):
-        self.oid = oids.INT8_OID
+    format = PQ_TEXT
 
     cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
-        cdef char *buf = CDumper.ensure_size(rv, offset, MAXINT8LEN + 1)
-        cdef long long val = PyLong_AsLongLong(obj)
-        cdef int written = pg_lltoa(val, buf)
-        return written
+        cdef long long val
+        cdef int overflow
+        cdef char *buf
+        cdef char *src
+        cdef Py_ssize_t length
+
+        val = PyLong_AsLongLongAndOverflow(obj, &overflow)
+        if not overflow:
+            buf = CDumper.ensure_size(rv, offset, MAXINT8LEN + 1)
+            length = pg_lltoa(val, buf)
+        else:
+            b = bytes(str(obj), "utf-8")
+            PyBytes_AsStringAndSize(b, &src, &length)
+            buf = CDumper.ensure_size(rv, offset, length)
+            memcpy(buf, src, length)
+
+        return length
 
     def quote(self, obj) -> bytearray:
         cdef Py_ssize_t length
@@ -65,6 +79,52 @@ cdef class IntDumper(CDumper):
         return rv
 
 
+@cython.final
+cdef class Int2Dumper(_NumberDumper):
+
+    def __cinit__(self):
+        self.oid = oids.INT2_OID
+
+
+@cython.final
+cdef class Int4Dumper(_NumberDumper):
+
+    def __cinit__(self):
+        self.oid = oids.INT4_OID
+
+
+@cython.final
+cdef class Int8Dumper(_NumberDumper):
+
+    def __cinit__(self):
+        self.oid = oids.INT8_OID
+
+
+@cython.final
+cdef class IntNumericDumper(_NumberDumper):
+
+    def __cinit__(self):
+        self.oid = oids.NUMERIC_OID
+
+
+@cython.final
+cdef class Int2BinaryDumper(CDumper):
+
+    format = PQ_BINARY
+
+    def __cinit__(self):
+        self.oid = oids.INT2_OID
+
+    cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+        cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int16_t))
+        cdef int16_t val = PyLong_AsLongLong(obj)
+        # swap bytes if needed
+        cdef uint16_t *ptvar = <uint16_t *>(&val)
+        cdef int16_t beval = endian.htobe16(ptvar[0])
+        memcpy(buf, <void *>&beval, sizeof(int16_t))
+        return sizeof(int16_t)
+
+
 @cython.final
 cdef class Int4BinaryDumper(CDumper):
 
@@ -75,10 +135,10 @@ cdef class Int4BinaryDumper(CDumper):
 
     cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
         cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int32_t))
-        cdef long long val = PyLong_AsLongLong(obj)
+        cdef int32_t val = PyLong_AsLongLong(obj)
         # swap bytes if needed
         cdef uint32_t *ptvar = <uint32_t *>(&val)
-        cdef int32_t beval = htobe32(ptvar[0])
+        cdef int32_t beval = endian.htobe32(ptvar[0])
         memcpy(buf, <void *>&beval, sizeof(int32_t))
         return sizeof(int32_t)
 
@@ -93,14 +153,88 @@ cdef class Int8BinaryDumper(CDumper):
 
     cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
         cdef char *buf = CDumper.ensure_size(rv, offset, sizeof(int64_t))
-        cdef long long val = PyLong_AsLongLong(obj)
+        cdef int64_t val = PyLong_AsLongLong(obj)
         # swap bytes if needed
         cdef uint64_t *ptvar = <uint64_t *>(&val)
-        cdef int64_t beval = htobe64(ptvar[0])
+        cdef int64_t beval = endian.htobe64(ptvar[0])
         memcpy(buf, <void *>&beval, sizeof(int64_t))
         return sizeof(int64_t)
 
 
+@cython.final
+cdef class IntNumericBinaryDumper(CDumper):
+
+    format = PQ_BINARY
+
+    def __cinit__(self):
+        self.oid = oids.NUMERIC_OID
+
+    cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+        raise NotImplementedError("binary decimal dump not implemented yet")
+
+
+cdef class IntDumper(_NumberDumper):
+
+    cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
+        raise TypeError(
+            f"{type(self).__name__} is a dispatcher to other dumpers:"
+            " dump() is not supposed to be called"
+        )
+
+    cpdef get_key(self, obj, format):
+        cdef long long val
+        cdef int overflow
+
+        val = PyLong_AsLongLongAndOverflow(obj, &overflow)
+        if overflow:
+            return IntNumeric
+
+        if INT32_MIN <= obj <= INT32_MAX:
+            if INT16_MIN <= obj <= INT16_MAX:
+                return Int2
+            else:
+                return Int4
+        else:
+            if INT64_MIN <= obj <= INT64_MAX:
+                return Int8
+            else:
+                return IntNumeric
+
+    _int2_dumper = Int2Dumper
+    _int4_dumper = Int4Dumper
+    _int8_dumper = Int8Dumper
+    _int_numeric_dumper = IntNumericDumper
+
+    cpdef upgrade(self, obj, format):
+        cdef long long val
+        cdef int overflow
+
+        val = PyLong_AsLongLongAndOverflow(obj, &overflow)
+        if overflow:
+            return self._int_numeric_dumper(IntNumeric)
+
+        if INT32_MIN <= obj <= INT32_MAX:
+            if INT16_MIN <= obj <= INT16_MAX:
+                return self._int2_dumper(Int2)
+            else:
+                return self._int4_dumper(Int4)
+        else:
+            if INT64_MIN <= obj <= INT64_MAX:
+                return self._int8_dumper(Int8)
+            else:
+                return self._int_numeric_dumper(IntNumeric)
+
+
+cdef class IntBinaryDumper(IntDumper):
+
+    format = PQ_BINARY
+
+    _int2_dumper = Int2BinaryDumper
+    _int4_dumper = Int4BinaryDumper
+    _int8_dumper = Int8BinaryDumper
+    _int_numeric_dumper = IntNumericBinaryDumper
+
+
 @cython.final
 cdef class IntLoader(CLoader):
 
@@ -128,7 +262,7 @@ cdef class Int2BinaryLoader(CLoader):
     format = PQ_BINARY
 
     cdef object cload(self, const char *data, size_t length):
-        return PyLong_FromLong(<int16_t>be16toh((<uint16_t *>data)[0]))
+        return PyLong_FromLong(<int16_t>endian.be16toh((<uint16_t *>data)[0]))
 
 
 @cython.final
@@ -137,7 +271,7 @@ cdef class Int4BinaryLoader(CLoader):
     format = PQ_BINARY
 
     cdef object cload(self, const char *data, size_t length):
-        return PyLong_FromLong(<int32_t>be32toh((<uint32_t *>data)[0]))
+        return PyLong_FromLong(<int32_t>endian.be32toh((<uint32_t *>data)[0]))
 
 
 @cython.final
@@ -146,7 +280,7 @@ cdef class Int8BinaryLoader(CLoader):
     format = PQ_BINARY
 
     cdef object cload(self, const char *data, size_t length):
-        return PyLong_FromLongLong(<int64_t>be64toh((<uint64_t *>data)[0]))
+        return PyLong_FromLongLong(<int64_t>endian.be64toh((<uint64_t *>data)[0]))
 
 
 @cython.final
@@ -155,7 +289,7 @@ cdef class OidBinaryLoader(CLoader):
     format = PQ_BINARY
 
     cdef object cload(self, const char *data, size_t length):
-        return PyLong_FromUnsignedLong(be32toh((<uint32_t *>data)[0]))
+        return PyLong_FromUnsignedLong(endian.be32toh((<uint32_t *>data)[0]))
 
 
 @cython.final
@@ -202,7 +336,7 @@ cdef class FloatBinaryDumper(CDumper):
     cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
         cdef double d = PyFloat_AsDouble(obj)
         cdef uint64_t *intptr = <uint64_t *>&d
-        cdef uint64_t swp = htobe64(intptr[0])
+        cdef uint64_t swp = endian.htobe64(intptr[0])
         cdef char *tgt = CDumper.ensure_size(rv, offset, sizeof(swp))
         memcpy(tgt, <void *>&swp, sizeof(swp))
         return sizeof(swp)
@@ -225,7 +359,7 @@ cdef class Float4BinaryLoader(CLoader):
     format = PQ_BINARY
 
     cdef object cload(self, const char *data, size_t length):
-        cdef uint32_t asint = be32toh((<uint32_t *>data)[0])
+        cdef uint32_t asint = endian.be32toh((<uint32_t *>data)[0])
         # avoid warning:
         # dereferencing type-punned pointer will break strict-aliasing rules
         cdef char *swp = <char *>&asint
@@ -238,6 +372,6 @@ cdef class Float8BinaryLoader(CLoader):
     format = PQ_BINARY
 
     cdef object cload(self, const char *data, size_t length):
-        cdef uint64_t asint = be64toh((<uint64_t *>data)[0])
+        cdef uint64_t asint = endian.be64toh((<uint64_t *>data)[0])
         cdef char *swp = <char *>&asint
         return PyFloat_FromDouble((<double *>swp)[0])
index 1782e939e102af985b4742c9a0f3685d51a4c5d5..135d60ec47512f15fd7fe8e2be214c2058bb2347 100644 (file)
@@ -300,6 +300,9 @@ def test_optimised_adapters():
             continue
         c_adapters.pop(obj.__name__, None)
 
+    # TODO: This dumper is not registered yet as not implemented
+    del c_adapters["IntNumericBinaryDumper"]
+
     assert not c_adapters