]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add C interval text loader
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 8 Jun 2021 01:00:43 +0000 (02:00 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 8 Jun 2021 01:29:13 +0000 (02:29 +0100)
psycopg3_c/psycopg3_c/types/date.pyx

index dda7944607c1cc0355fc646e11416f26c831d300..eba096b92d08d0a5de94a9a135743a2a7932ac5a 100644 (file)
@@ -45,6 +45,7 @@ DEF ORDER_PGMD = 4
 
 DEF INTERVALSTYLE_OTHERS = 0
 DEF INTERVALSTYLE_SQL_STANDARD = 1
+DEF INTERVALSTYLE_POSTGRES = 2
 
 DEF PG_DATE_EPOCH_DAYS = 730120  # date(2000, 1, 1).toordinal()
 DEF PY_DATE_MIN_DAYS = 1  # date.min.toordinal()
@@ -858,6 +859,102 @@ cdef class TimestamptzBinaryLoader(_BaseTimestamptzLoader):
                 ) from None
 
 
+@cython.final
+cdef class IntervalLoader(CLoader):
+
+    format = PQ_TEXT
+    cdef int _style
+
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+
+        cdef const char *ds = _get_intervalstyle(self._pgconn)
+        if ds[0] == b'p' and ds[8] == 0:  # postgres
+            self._style = INTERVALSTYLE_POSTGRES
+        else:  # iso_8601, sql_standard, postgres_verbose
+            self._style = INTERVALSTYLE_OTHERS
+
+    cdef object cload(self, const char *data, size_t length):
+        if self._style == INTERVALSTYLE_OTHERS:
+            return self._cload_notimpl(data, length)
+
+        cdef int days = 0, secs = 0, us = 0
+        cdef char sign
+        cdef int val
+        cdef const char *ptr = data
+        cdef const char *sep
+
+        # If there are spaces, there is a [+|-]n [days|months|years]
+        while True:
+            if ptr[0] == b'-' or ptr[0] == b'+':
+                sign = ptr[0]
+                ptr += 1
+            else:
+                sign = 0
+
+            sep = strchr(ptr, b' ')
+            if sep == NULL:
+                break
+
+            val = 0
+            ptr = _parse_date_values(ptr, &val, 1)
+            if ptr == NULL:
+                s = bytes(data).decode("utf8", "replace")
+                raise e.DataError(f"can't parse interval {s!r}")
+
+            if sign == b'-':
+                val = -val
+
+            if ptr[1] == b'y':
+                days = 365 * val
+            elif ptr[1] == b'm':
+                days = 30 * val
+            elif ptr[1] == b'd':
+                days = val
+            else:
+                s = bytes(data).decode("utf8", "replace")
+                raise e.DataError(f"can't parse interval {s!r}")
+
+            # Skip the date part word.
+            ptr = strchr(ptr + 1, b' ')
+            if ptr != NULL:
+                ptr += 1
+            else:
+                break
+
+        # Parse the time part. An eventual sign was already consumed in the loop
+        DEF NVALS = 3
+        cdef int vals[NVALS]
+        memset(vals, 0, sizeof(vals))
+        if ptr != NULL:
+            ptr = _parse_date_values(ptr, vals, NVALS)
+            if ptr == NULL:
+                s = bytes(data).decode("utf8", "replace")
+                raise e.DataError(f"can't parse interval {s!r}")
+
+            secs = vals[2] + 60 * (vals[1] + 60 * vals[0])
+
+            if ptr[0] == b'.':
+                ptr = _parse_micros(ptr + 1, &us)
+
+        if sign == b'-':
+            secs = -secs
+            us = -us
+
+        try:
+            return cdt.timedelta_new(days, secs, us)
+        except OverflowError as ex:
+            s = bytes(data).decode("utf8", "replace")
+            raise e.DataError(f"can't parse interval {s!r}: {ex}") from None
+
+    cdef object _cload_notimpl(self, const char *data, size_t length):
+        s = bytes(data).decode("utf8", "replace")
+        style = _get_intervalstyle(self.connection).decode("ascii")
+        raise NotImplementedError(
+            f"can't parse interval with IntervalStyle {style!r}: {s!r}"
+        )
+
+
 @cython.final
 cdef class IntervalBinaryLoader(CLoader):